CTC Loss (一)
论文:https://mediatum.ub.tum.de/doc/1292048/file.pdf
在文本识别模型CRNN中,一张包含单行文本的图片输入模型经过CNN、LSTM后输出大小的feature map, 假设T=25表示时间序列长度,m=26代表需要识别的字符集的大小(假设只识别小写英文字母),对每一个时间步接softmax后就得到识别结果的概率分布,对每一个时间步满足,但是在与label进行loss计算时需要先将图片中的每一个字符与label对齐,这就需要对单个字进行位置和语义标注,非常麻烦。而且由于字体样式和大小的关系,每列输出并不一定能和每个字符一一对应。ctc loss是一种专门针对这种场景不需要对齐的loss计算方法。接下来介绍ctc loss的具体计算方法
- 空白blank
表示预先定义的模型待识别字符集,因为输入图片中有的位置没有文字,引入空白blank字符,下文以 - 表示blank,LSTM的输出变成
- \(\beta\)变换
定义\(\beta\)变换,LSTM输出首先经过decode,然后经过\(\beta\)变换得到最终结果。\(\beta\)变换包括删除连续重复字符以及blank。例如,当T=12时,下列四个输出经过\(\beta\)变换都变成state。
给定输入,模型输出为的概率为
表示所有经过变换后是的路径
其中,对于任意一条路径有
注意这里中的,下标表示路径中的每一个时刻。而上面的下标表示不同的路径
ctc的训练目标是通过梯度调整模型权重,使得最大
在实际训练过程中,LSTM的输出特征图T的大小少为几十多则几百,如果遍历每一条路径,复杂度是指数级的,假如识别的是汉字,字符集长度为几千,序列长度上百,那要遍历种选择,速度太慢。实际CTC借用了HMM的"前向 - 后向"(forward - backward)算法来计算,具体过程如下
首先定义路径为在路径的头尾和每两个字符间插入blank
显然
定义所有经变换后结果是且在时刻结果为的路径集和为,求导
上式中第二项与无关,因此
就是恰好与概率相关的路径,即时刻都经过
上述的在时都经过(此处下标代表路径的时刻的字符),所有类似于经过变换后结果是且在的路径集和表示为
如图,蓝色路径和红色路径分别为上述的和,和可以表示为
和可以表示为
则
令
则
推广一下,所有经过 变换结果为且的路径可以写成如下形式
进一步推广, 所有经过 变换结果为且的路径可以写成如下形式
定义前向递推概率和
其中 表示路径的前个字符经过变换变成的前个字符,代表了时刻经过的所有路径的的概率和,即前向递推概率和。
当时,路径只能从或开始,所以有如下性质:
同理,定义后向递推概率和
其中表示后个字符经过变换为后半段子路径,表示时刻经过的所有路径的的概率和,即后向递推概率和。
当时,路径只能以或结束,所以有如下性质:
计算递推loss
和相乘有
当计算loss对ctc输入即LSTM输出中的某个值的梯度时,只需考虑所有经过的路径,因此可以得到
梯度如下
接下来只需计算出和即可
前面我们给出了的初始条件,即时,路径只能从或开始。
- 当时刻字符为时,可以由当前字符或前一个非空白字符得到。
- 当即当前字符不是且和前一个字符相同时,可以由当前字符或前一个字符得到,如下图所示
- 当时刻字符不是且时,可以由当前字符、前一个字符、前一个非空白字符得到,如下图所示
由此可以得到递推公式
根据初始条件和递推公式,便可以用动态规划计算出,代码如下
import numpy as npdef alpha_vanilla(y, labels): # labels是插入blank后的T, V = y.shape # T: time step, V: probsL = len(labels) # label lengthalpha = np.zeros([T, L])# initalpha[0, 0] = y[0, labels[0]]alpha[0, 1] = y[0, labels[1]]for t in range(1, T):for i in range(L):s = labels[i]a = alpha[t - 1, i]if i - 1 >= 0:a += alpha[t - 1, i - 1]if i - 2 >= 0 and s != 0 and s != labels[i - 2]:a += alpha[t - 1, i - 2]alpha[t, i] = a * y[t, s]return alpha
同理可得后向递推公式
def beta_vanilla(y, labels):T, V = y.shapeL = len(labels)beta = np.zeros([T, L])# initbeta[-1, -1] = y[-1, labels[-1]]beta[-1, -2] = y[-1, labels[-2]]for t in range(T - 2, -1, -1):for i in range(L):s = labels[i]a = beta[t + 1, i]if i + 1 < L:a += beta[t + 1, i + 1]if i + 2 < L and s != 0 and s != labels[i + 2]:a += beta[t + 1, i + 2]beta[t, i] = a * y[t, s]return beta
计算梯度
求导中,分子第一项是因为中分别包含一个项,其它项均为与无关的常数。
另外,中可能包含多个字符,因为计算的梯度要进行累加。例如,,即求输出中处的字符的梯度,这里的可能通过变换成中的第一个也可能变换成第二个。因此,最终的梯度计算结果为
其中,
一般我们优化似然函数的对数,梯度如下
其中,可直接求得
梯度计算代码如下
def gradient(y, labels):T, V = y.shapealpha = alpha_vanilla(y, labels)beta = beta_vanilla(y, labels)p = alpha[-1, -1] + alpha[-1, -2]grad = np.zeros([T, V])for t in range(T):for s in range(V):lab = [i for i, c in enumerate(labels) if c == s]for i in lab:grad[t, s] += alpha[t, i] * beta[t, i]grad[t, s] /= y[t, s] ** 2grad /= preturn grad
参考
一文读懂CRNN+CTC文字识别 - 知乎
【Learning Notes】CTC 原理及实现_MoussaTintin的博客-CSDN博客
Sequence Modeling with CTC
CTC Loss (一)相关推荐
- 【项目实践】中英文文字检测与识别项目(CTPN+CRNN+CTC Loss原理讲解)
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:opencv学堂 OCR--简介 文字识别也是图像领域一 ...
- DL之CNN:利用CNN(keras, CTC loss, {image_ocr})算法实现OCR光学字符识别
DL之CNN:利用CNN(keras, CTC loss)算法实现OCR光学字符识别 目录 输出结果 实现的全部代码 输出结果 更新-- 实现的全部代码 部分代码源自:GitHub https://r ...
- 语音识别:深入理解CTC Loss原理
最近看了百度的Deep Speech,看到语音识别使用的损失函数是CTC loss.便整理了一下有关于CTC loss的一些定义和推导.由于个人水平有限,如果文章有错误,还恳请各位指出,万分感谢~ ...
- 【OCR】CTC loss原理
1 CTC loss出现的背景 在图像文本识别.语言识别的应用中,所面临的一个问题是神经网络输出与ground truth的长度不一致,这样一来,loss就会很难计算,举个例子来讲,如果网络的输出是& ...
- 深入浅出CTC loss
前言 本片博客主要学习了CTC并在动态规划求CTC loss的理解上学习了这篇博客 由于在看的过程中,还是花了很长时间反复推敲作者的理解,因此在这边用更加简单的话来解释一下CTC loss 背 ...
- 语音识别 CTC Loss
(以下内容搬运自 PaddleSpeech) Derivative of CTC Loss 关于CTC的介绍已经有很多不错的教程了,但是完整的描述CTCLoss的前向和反向过程的很少,而且有些公式推导 ...
- 『OCR_recognition』CTC loss几种解码方式
文章目录 前言 一.贪心搜索 (greedy search) 1.1 原理解释 1.2 图示说明 1.3 代码实现 二.束搜索(Beam Search) 2.1 原理解释 2.2 图示说明 2.3 代 ...
- 分类回归loss函数汇总分析
2019独角兽企业重金招聘Python工程师标准>>> 版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/qq_14845119/ar ...
- 10万元奖金语音识别赛进行中!CTC 模型 Baseline 助你轻松上分
随着互联网.智能硬件的普及,智能音箱和语音助手已经深入人们的日常生活,家居场景下的语音识别技术已成为企业和研究机构竞相追逐的关键技术. 目前,由北京智源人工智能研究院.爱数智慧.biendata 共同 ...
最新文章
- 常见面试之机器学习算法思想简单梳理
- Python 技巧篇-如何避免python报错导致强制关闭窗口
- 洛谷 P1205 [USACO1.2]方块转换 Transformations
- boost::units模块实现异构单元片段
- 初识Linux .bash_profile, .bash_logout, and .bashrc 文件
- Ubuntu 12.04安装下载工具 UGet 1.8.0 及 aria2用法
- “发明在商业上获得成功”对专利法22条第三款有关创造性规定的影响
- Git笔记(8) 远程仓库的使用
- functools模块
- DL加速器与GPU的不同,一个用于推理,一个用于训练。
- 好久没更新了,更新一篇,关于ZEC的吧
- 邮件服务器pop3和imap,POP3服务器和IMAP服务器
- 白领控诉:被逼下乡5年,我们的幸福何处寻找
- 该死的clear 根本不释放内存,怎么才能释放泛型LIST的内存?
- 掌上飞车-艳云脚本云控系统
- sendcloud php,Sendcloud的x_smtpapi具体如何定义?
- 安装内网穿透Frps
- ADS129X芯片中文资料(二)——模拟功能部分介绍
- Linux CentOS 7修改分辨率
- php怎么让浏览器崩溃,让IE6浏览器崩溃