前言:理解了很久的CTC,每次都是点到即止,所以一直没有很明确,现在重新整理。

定义

CTC (Connectionist Temporal Classification)是一种loss function

对比

传统方法

 在传统的语音识别的模型中,我们对语音模型进行训练之前,往往都要将文本与语音进行严格的对齐操作。这样就有两点不太好:
 1. 严格对齐要花费人力、时间。
 2. 严格对齐之后,模型预测出的label只是局部分类的结果,而无法给出整个序列的输出结果,往往要对预测出的label做一些后处理才可以得到我们最终想要的结果。
  虽然现在已经有了一些比较成熟的开源对齐工具供大家使用,但是随着deep learning越来越火,有人就会想,能不能让我们的网络自己去学习对齐方式呢?因此CTC(Connectionist temporal classification)就应运而生啦。
  想一想,为什么CTC就不需要去对齐语音和文本呢?因为CTC它允许我们的神经网络在任意一个时间段预测label,只有一个要求:就是输出的序列顺序只要是正确的就ok啦~这样我们就不在需要让文本和语音严格对齐了,而且CTC输出的是整个序列标签,因此也不需要我们再去做一些后处理操作。
  对一段音频使用CTC和使用文本对齐的例子如下图所示:
  
  

主要区别

训练流程和传统的神经网络类似,构建loss function,然后根据BP算法进行训练,不同之处在于传统的神经网络的训练准则是针对每帧数据,即每帧数据的训练误差最小,而CTC的训练准则是基于序列(比如语音识别的一整句话)的,比如最大化p(z|x)p(z|x)p(z|x) ,序列化的概率求解比较复杂,因为一个输出序列可以对应很多的路径,所有引入前后向算法来简化计算。






算法细节

符号定义

概率计算

误差反传


参考文献

  • CTC学习笔记(二) 训练和公式推导

    • 很详细的公示推导
    • 前向后向算法计算序列概率,并最大化
    • 使用BPTT算法得到损失函数对神经网络参数的偏导.
  • tensorflowbook

    • 具体实现
    • 语音识别实例.
  • 语音识别:深入理解CTC Loss原理

    • 符号表示等非常详细
  • Sequence Modeling With CTC
    • 最好的教程!
    • 有动图,有对比
  • CS224S / LINGUIST285 - Spoken Language Processing
    • 语言处理的课程,非常好!
    • chapter 8讲的CTC
  • 百度贾磊CTC

实现代码

#coding=utf-8
import timeimport tensorflow as tf
import scipy.io.wavfile as wav
import numpy as npfrom six.moves import xrange as rangetry:from python_speech_features import mfcc
except ImportError:print("Failed to import python_speech_features.\n Try pip install python_speech_features.")raise ImportError# 常量
SPACE_TOKEN = '<space>'
SPACE_INDEX = 0
FIRST_INDEX = ord('a') - 1  # 0 is reserved to space# mfcc默认提取出来的一帧13个特征
num_features = 13
# 26个英文字母 + 1个空白 + 1个no label = 28 label个数
num_classes = ord('z') - ord('a') + 1 + 1 + 1# 迭代次数
num_epochs = 200
# lstm隐藏单元数
num_hidden = 40
# 2层lstm网络
num_layers = 1
# batch_size设置为1
batch_size = 1
# 初始学习率
initial_learning_rate = 0.01# 样本个数
num_examples = 1
# 一个epoch有多少个batch
num_batches_per_epoch = int(num_examples/batch_size)def sparse_tuple_from(sequences, dtype=np.int32):"""得到一个list的稀疏表示,为了直接将数据赋值给tensorflow的tf.sparse_placeholder稀疏矩阵Args:sequences: 序列的列表Returns:一个三元组,和tensorflow的tf.sparse_placeholder同结构"""indices = []values = []for n, seq in enumerate(sequences):indices.extend(zip([n]*len(seq), range(len(seq))))values.extend(seq)indices = np.asarray(indices, dtype=np.int64)values = np.asarray(values, dtype=dtype)shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1]+1], dtype=np.int64)return indices, values, shapedef get_audio_feature():'''获取wav文件提取mfcc特征之后的数据'''audio_filename = "audio.wav"#读取wav文件内容,fs为采样率, audio为数据fs, audio = wav.read(audio_filename)#提取mfcc特征inputs = mfcc(audio, samplerate=fs)# 对特征数据进行归一化,减去均值除以方差feature_inputs = np.asarray(inputs[np.newaxis, :])feature_inputs = (feature_inputs - np.mean(feature_inputs))/np.std(feature_inputs)#特征数据的序列长度feature_seq_len = [feature_inputs.shape[1]]return feature_inputs, feature_seq_lendef get_audio_label():'''将label文本转换成整数序列,然后再换成稀疏三元组'''target_filename = 'label.txt'with open(target_filename, 'r') as f:#原始文本为“she had your dark suit in greasy wash water all year”line = f.readlines()[0].strip()targets = line.replace(' ', '  ')# 放入list中,空格用''代替#['she', '', 'had', '', 'your', '', 'dark', '', 'suit', '', 'in', '', 'greasy', '', 'wash', '', 'water', '', 'all', '', 'year']targets = targets.split(' ')# 每个字母作为一个label,转换成如下:#['s' 'h' 'e' '<space>' 'h' 'a' 'd' '<space>' 'y' 'o' 'u' 'r' '<space>' 'd'# 'a' 'r' 'k' '<space>' 's' 'u' 'i' 't' '<space>' 'i' 'n' '<space>' 'g' 'r'# 'e' 'a' 's' 'y' '<space>' 'w' 'a' 's' 'h' '<space>' 'w' 'a' 't' 'e' 'r'#'<space>' 'a' 'l' 'l' '<space>' 'y' 'e' 'a' 'r']targets = np.hstack([SPACE_TOKEN if x == '' else list(x) for x in targets])# 将label转换成整数序列表示:# [19  8  5  0  8  1  4  0 25 15 21 18  0  4  1 18 11  0 19 21  9 20  0  9 14# 0  7 18  5  1 19 25  0 23  1 19  8  0 23  1 20  5 18  0  1 12 12  0 25  5# 1 18]targets = np.asarray([SPACE_INDEX if x == SPACE_TOKEN else ord(x) - FIRST_INDEXfor x in targets])# 将列表转换成稀疏三元组train_targets = sparse_tuple_from([targets])return train_targetsdef inference(inputs, seq_len):'''2层双向LSTM的网络结构定义Args:inputs: 输入数据,形状是[batch_size, 序列最大长度,一帧特征的个数13]序列最大长度是指,一个样本在转成特征矩阵之后保存在一个矩阵中,在n个样本组成的batch中,因为不同的样本的序列长度不一样,在组成的3维数据中,第2维的长度要足够容纳下所有的样本的特征序列长度。seq_len: batch里每个样本的有效的序列长度'''#定义一个向前计算的LSTM单元,40个隐藏单元cell_fw = tf.contrib.rnn.LSTMCell(num_hidden, initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1),state_is_tuple=True)# 组成一个有2个cell的listcells_fw = [cell_fw] * num_layers# 定义一个向后计算的LSTM单元,40个隐藏单元cell_bw = tf.contrib.rnn.LSTMCell(num_hidden, initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1),state_is_tuple=True)# 组成一个有2个cell的listcells_bw = [cell_bw] * num_layers# 将前面定义向前计算和向后计算的2个cell的list组成双向lstm网络# sequence_length为实际有效的长度,大小为batch_size,# 相当于表示batch中每个样本的实际有用的序列长度有多长。# 输出的outputs宽度是隐藏单元的个数,即num_hidden的大小outputs, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cells_fw,cells_bw,inputs,dtype=tf.float32,sequence_length=seq_len)#获得输入数据的形状shape = tf.shape(inputs)batch_s, max_timesteps = shape[0], shape[1]# 将2层LSTM的输出转换成宽度为40的矩阵# 后面进行全连接计算outputs = tf.reshape(outputs, [-1, num_hidden])W = tf.Variable(tf.truncated_normal([num_hidden,num_classes],stddev=0.1))b = tf.Variable(tf.constant(0., shape=[num_classes]))# 进行全连接线性计算logits = tf.matmul(outputs, W) + b# 将全连接计算的结果,由宽度40变成宽度80,# 即最后的输入给CTC的数据宽度必须是26+2的宽度logits = tf.reshape(logits, [batch_s, -1, num_classes])# 转置,将第一维和第二维交换。# 变成序列的长度放第一维,batch_size放第二维。# 也是为了适应Tensorflow的CTC的输入格式logits = tf.transpose(logits, (1, 0, 2))return logitsdef main():# 输入特征数据,形状为:[batch_size, 序列长度,一帧特征数]inputs = tf.placeholder(tf.float32, [None, None, num_features])# 输入数据的label,定义成稀疏sparse_placeholder会生成稀疏的tensor:SparseTensor# 这个结构可以直接输入给ctc求losstargets = tf.sparse_placeholder(tf.int32)# 序列的长度,大小是[batch_size]大小# 表示的是batch中每个样本的有效序列长度是多少seq_len = tf.placeholder(tf.int32, [None])# 向前计算网络,定义网络结构,输入是特征数据,输出提供给ctc计算损失值。logits = inference(inputs, seq_len)# ctc计算损失# 参数targets必须是一个值为int32的稀疏tensor的结构:tf.SparseTensor# 参数logits是前面lstm网络的输出# 参数seq_len是这个batch的样本中,每个样本的序列长度。loss = tf.nn.ctc_loss(targets, logits, seq_len)# 计算损失的平均值cost = tf.reduce_mean(loss)# 采用冲量优化方法optimizer = tf.train.MomentumOptimizer(initial_learning_rate, 0.9).minimize(cost)# 还有另外一个ctc的函数:tf.contrib.ctc.ctc_beam_search_decoder# 本函数会得到更好的结果,但是效果比ctc_beam_search_decoder低# 返回的结果中,decode是ctc解码的结果,即输入的数据解码出结果序列是什么decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)# 采用计算编辑距离的方式计算,计算decode后结果的错误率。ler = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32),targets))config = tf.ConfigProto()config.gpu_options.allow_growth = Truewith tf.Session(config=config) as session:# 初始化变量tf.global_variables_initializer().run()for curr_epoch in range(num_epochs):train_cost = train_ler = 0start = time.time()for batch in range(num_batches_per_epoch):#获取训练数据,本例中只去一个样本的训练数据train_inputs, train_seq_len = get_audio_feature()# 获取这个样本的labeltrain_targets = get_audio_label()feed = {inputs: train_inputs,targets: train_targets,seq_len: train_seq_len}# 一次训练,更新参数batch_cost, _ = session.run([cost, optimizer], feed)# 计算累加的训练的损失值train_cost += batch_cost * batch_size# 计算训练集的错误率train_ler += session.run(ler, feed_dict=feed)*batch_sizetrain_cost /= num_examplestrain_ler /= num_examples# 打印每一轮迭代的损失值,错误率log = "Epoch {}/{}, train_cost = {:.3f}, train_ler = {:.3f}, time = {:.3f}"print(log.format(curr_epoch+1, num_epochs, train_cost, train_ler,time.time() - start))# 在进行了1200次训练之后,计算一次实际的测试,并且输出# 读取测试数据,这里读取的和训练数据的同一个样本test_inputs, test_seq_len = get_audio_feature()test_targets = get_audio_label()test_feed = {inputs: test_inputs,targets: test_targets,seq_len: test_seq_len}d = session.run(decoded[0], feed_dict=test_feed)# 将得到的测试语音经过ctc解码后的整数序列转换成字母str_decoded = ''.join([chr(x) for x in np.asarray(d[1]) + FIRST_INDEX])# 将no label转换成空str_decoded = str_decoded.replace(chr(ord('z') + 1), '')# 将空白转换成空格str_decoded = str_decoded.replace(chr(ord('a') - 1), ' ')# 打印最后的结果print('Decoded:\n%s' % str_decoded)if __name__ == "__main__":main()

CTC loss 理解相关推荐

  1. 语音识别:深入理解CTC Loss原理

      最近看了百度的Deep Speech,看到语音识别使用的损失函数是CTC loss.便整理了一下有关于CTC loss的一些定义和推导.由于个人水平有限,如果文章有错误,还恳请各位指出,万分感谢~ ...

  2. 【项目实践】中英文文字检测与识别项目(CTPN+CRNN+CTC Loss原理讲解)

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:opencv学堂 OCR--简介 文字识别也是图像领域一 ...

  3. 【OCR】CTC loss原理

    1 CTC loss出现的背景 在图像文本识别.语言识别的应用中,所面临的一个问题是神经网络输出与ground truth的长度不一致,这样一来,loss就会很难计算,举个例子来讲,如果网络的输出是& ...

  4. 深入浅出CTC loss

    前言   本片博客主要学习了CTC并在动态规划求CTC loss的理解上学习了这篇博客   由于在看的过程中,还是花了很长时间反复推敲作者的理解,因此在这边用更加简单的话来解释一下CTC loss 背 ...

  5. 语音识别 CTC Loss

    (以下内容搬运自 PaddleSpeech) Derivative of CTC Loss 关于CTC的介绍已经有很多不错的教程了,但是完整的描述CTCLoss的前向和反向过程的很少,而且有些公式推导 ...

  6. DL之CNN:利用CNN(keras, CTC loss, {image_ocr})算法实现OCR光学字符识别

    DL之CNN:利用CNN(keras, CTC loss)算法实现OCR光学字符识别 目录 输出结果 实现的全部代码 输出结果 更新-- 实现的全部代码 部分代码源自:GitHub https://r ...

  7. CTC Loss (一)

    论文:https://mediatum.ub.tum.de/doc/1292048/file.pdf 在文本识别模型CRNN中,一张包含单行文本的图片输入模型经过CNN.LSTM后输出大小的featu ...

  8. YOLO loss理解

    自己理解的YOLO loss  是 对于真实(label)有物体的格子,计算位置(坐标)损失,权重大一点.所有框都计算判别概率损失,无物体的格子 权重小一点.所有各自计算类别损失

  9. 『OCR_recognition』CTC loss几种解码方式

    文章目录 前言 一.贪心搜索 (greedy search) 1.1 原理解释 1.2 图示说明 1.3 代码实现 二.束搜索(Beam Search) 2.1 原理解释 2.2 图示说明 2.3 代 ...

  10. Equalization Loss理解-更新中

    Equalization Loss for Long-Tailed Object Recognition 一.前言 二.交叉熵回顾 2.1.Softmax Cross-Entropy Loss 2.2 ...

最新文章

  1. 可否使用串联LED(或者光敏LED)来制作光电检测板?
  2. 阿里巴巴向全社会开放黑科技:“泡在水里”的服务器
  3. 历城职专学前计算机专业,历城职专学前教育专业2020学年第一学期技能运动会拉开帷幕...
  4. roads 构筑极致用户体验_坚持用户思维 推动领克汽车逆势突围
  5. Vsphere auto deploy 简介
  6. 读这样的文章才能清楚什么是RIA
  7. C/C++ 错误处理
  8. C/C++[codeup 1967]数组逆置
  9. java软件的安装过程
  10. 二手房房价影响因素分析
  11. maf相关代码和命令
  12. IT战略规划之流程再造 —2013年中科院计算所培训中心系列公益讲座
  13. Unicode双向算法详解(bidi算法)(二)
  14. 计算机组成及原理ppt课件,计算机组成原理第五章课件.ppt
  15. 【转】-ECshop数据库表结构
  16. antdesignpro ProTable 搜索模式自定义搜索字段
  17. --Redis入坑--RedisPipelineException:Pipeline contained one or more invalid commands;WRONGTYPE ...
  18. 基于有源钳位三电平的有源电力滤波器(ANPC-APF)MATLAB仿真,包括自建的DSOGI锁相模块和PQ谐波检测模块
  19. 【Reference Reading】MRI引导中子捕获治疗通过上调低密度脂蛋白转运体使用双钆/硼剂靶向肿瘤细胞
  20. Tomcat服务器部署工件出错和无法访问网页异常解决

热门文章

  1. 废旧 Android 手机如何改造成 Linux 服务器
  2. java实现与图灵机器人聊天_调用图灵机器人API实现简单聊天
  3. win10、Ubuntu双系统删除Ubuntu的方法
  4. 2017博鳌亚洲青年论坛(香港)顺利召开 中国发展人工智能优势在哪?
  5. 在mac上使用nginx配置codeigniter框架一直显示404的问题的一种方法(重启)
  6. 商户开通微信支付详细流程文档
  7. mac怎么设置锁屏壁纸,锁屏壁纸和屏幕壁纸不同
  8. 小米手机不断自己重启问题解决
  9. Python爬虫之BeautifulSoup
  10. 学习PMbok对pmp考试的认知理解和itto输入输出的整理笔记