最近一直在做手写体识别的工作,其中有个很重要的loss那就是ctc loss,之前在文档识别与分析课程中学习过,但是时间久远,早已忘得一干二净,现在重新整理记录下

本文大量引用了- CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇),只是用自己的语言理解了一下,原论文:Connectionist Temporal Classification: Labelling UnsegmSequence Data with Recurrent Neural Networ

解决的问题

套用知乎上的一句话,CTC Loss要解决的问题就是当label长度小于模型输出长度时,如何做损失函数。
一般做分类时,已有的softmax loss都是模型输出长度和label长度相同且严格对齐,而语音识别或者手写体识别中,无法预知一句话或者一张图应该输出多长的文字,这时做法有两种:seq2seq+attention机制,不限制输出长度,在最后加一个结束符号,让模型自动和gt label对齐;另一种是给定一个模型输出的最大长度,但是这些输出并没有对齐的label怎么办呢,这时就需要CTC loss了。

输出序列的扩展

所以,如果要计算?(?│?),可以累加其对应的全部输出序列o (也即映射到最终label的“路径”)的概率即可,如下图。

前向和后向计算

由于我们没有每个时刻输出对应的label,因此CTC使用最大似然进行训练(CTC 假设输出的概率是(相对于输入)条件独立的)
给定输入xxx,输出序列 ooo 的条件概率是:
p(π∣x)=∏yπtt,∀π∈L′Tp(\pi|x) = \prod y^t_{\pi_t}, \forall \pi \in L^{\prime T} p(π∣x)=∏yπt​t​,∀π∈L′T
πt\pi _tπt​ 是序列 ooo 中的一个元素,yyy为模型在所有时刻输出各个字符的概率,shape为T*C(T是时刻,提前已固定。C是字符类别数,所有字符+blank(不是空格,是空) ,yπtty^t_{\pi_t}yπt​t​ 是模型t时刻输出为πt\pi _tπt​的概率

我们模型的目标就是给定输入x,使得能映射到最终label的所有输出序列o的条件概率之和最大,该条件概率就是p(π∣x)p(\pi|x)p(π∣x),和模型的输出概率yyy直接关联

那么我们如何计算这些条件概率之和呢?首先想到的就是暴力算法,一一找到可以映射到最终label的所有输出序列,然后概率连乘最后相加,但是很耗时,有木有更快的做法?联系一下HMM模型中的前向和后向算法,它就是利用动态规划求某个序列出现的概率,和此处我们要计算某个输出序列的条件概率很相似
比如HMM模型中,我们要求红白红出现的概率,我们就可以利用动态规划的思想,因为红白红包含子问题红白的产生,红白包含子问题红的产生,参考引用的图片。
而这里我们以apple这个label都可以由哪些输出序列映射过去为例(T为8):
其中的一种 _ _ a p _ p l e

当然其他也可以如 a p p _ p p l e,但是考虑到我们最终对输出序列的处理(两个空字符之间的重复元素会去除,字符是从左到右的,且是依次的),我们的路径(状态转移)不是随便的,根据这样的规则,我们可以找到所有可以映射到apple的输出序列


很明显可以看到这和HMM很像,包含很多相同子问题,可以用动态规划做

定义在时刻t经过节点s的全部前缀子路径的概率总和为前向概率 αt(s)\alpha_t (s)αt​(s),如α3(4)\alpha_3 (4)α3​(4)为在时刻3所有经过第4个节点的全部前缀子路径的概率总和: α3(4)\alpha_3 (4)α3​(4) = p(_ap) + p(aap) + p(a_p) + p(app),该节点为p

类似的定义在时刻t经过节点s的全部后缀子路径的概率总和为前向概率 βt(s)\beta_t (s)βt​(s),如β6(8)\beta_6 (8)β6​(8)为在时刻6所有经过第8个节点的全部后缀子路径的概率总和: β3(4)\beta_3 (4)β3​(4) = p(lle) + p(l_e) + p(lee) + p(le_),该节点为l

总结

Focal CTC Loss





实现

参考论文 Focal CTC Loss for Chinese Optical Character Recognition on Unbalanced Datasets


  • 语音识别:深入理解CTC Loss原理
  • CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇)
  • 隐马尔可夫(HMM)、前/后向算法、Viterbi算法 再次总结
  • 【Learning Notes】CTC 原理及实现
  • 统计学习方法-p178

CTC Loss和Focal CTC Loss相关推荐

  1. NTU商汤提出新 loss!Focal Frequency Loss 提升图像重建和图像合成的质量 ICCV2021

    点击下方"AI算法与图像处理",一起进步!重磅干货,第一时间送达 码字不易,给打工人点个赞吧. 今天分享一篇南洋理工大学&商汤科技的最新论文: Focal Frequenc ...

  2. 【图像分类损失】PolyLoss:一个优于 Cross-entropy loss和Focal loss的分类损失

    论文题目:<PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions> 论文地址:http ...

  3. 【BraTS】Brain Tumor Segmentation 脑部肿瘤分割4--UNet的损失函数(交叉熵损失Cross-Entropy,Dice Loss和Focal Loss)

    下面,是我与chatGPT3.5的一段对话,主要是从以下几个点展开的: UNet是什么? UNet输出的channel和pixel表示什么? UNet计算损失使用的什么函数? Cross-Entrop ...

  4. quality focal loss distribute focal loss 详解(paper, 代码)

    参见generalized focal loss paper 其中包含有Quality Focal Loss 和 Distribution Focal Loss. 目录 背景 Focal Loss Q ...

  5. [概念]医学图像分割中常用的Loss function(损失函数) + 从loss处理图像分割中类别极度不均衡

    目录 一.前言 二.损失函数 2.1 根据像素正确与否设计的loss function 2.1.1  Log Loss 2.1.2 WCE Loss 2.1.3 Focal Loss 2.2 根据评测 ...

  6. ACR Loss: Adaptive Coordinate-based Regression Loss for Face Alignment

    ACR Loss: Adaptive Coordinate-based Regression Loss for Face Alignment Introduction 背景介绍 相关方法 提出的方法 ...

  7. 【AI面试】L1 loss、L2 loss和Smooth L1 Loss,L1正则化和L2正则化

    损失函数是深度学习模型优化的一个灵魂基础,所以无论是很新的transform模型,还是比较早期的AlexNet,都不可避免的要涉及到损失函数的设计和应用. 所以,各种形形色色的损失函数(Loss)也就 ...

  8. 【AI面试】hard label与soft label,Label Smoothing Loss 和 Smooth L1 Loss

    往期文章: AI/CV面试,直达目录汇总 [AI面试]NMS 与 Soft NMS 的辨析 [AI面试]L1 loss.L2 loss和Smooth L1 Loss,L1正则化和L2正则化 在一次询问 ...

  9. NLP文本情感分析:测试集loss比训练集loss大很多,训练集效果好测试集效果差的原因

    NLP情感分析:测试集loss比训练集loss大很多 一.前言 二.原因 一.前言 最近在学习神经网络自然语言处理的相关知识,发现运行的之后测试集的loss比训练集的loss大很多,而accuracy ...

  10. 常用损失函数总结(L1 loss、L2 loss、Negative Log-Likelihood loss、Cross-Entropy loss、Hinge Embedding loss、Margi)

    常用损失函数总结(L1 loss.L2 loss.Negative Log-Likelihood loss.Cross-Entropy loss.Hinge Embedding loss.Margi) ...

最新文章

  1. Quartz.NET和Log4Net三种输出[转]
  2. php 调用变量方法名,php中引用(变量和函数名前加符号)用法
  3. 三运放差分放大电路分析_三运放差分放大电路
  4. sql server(常用)
  5. 玩转CSS选择器(一) 之 使用方法介绍
  6. python取列表前几个元素_Python下几种从一个序列中取出元素的方法
  7. 输入学生的个数,姓名,成绩,然后按照学生的成绩的降序来打印学生的姓名
  8. ftk学习记(窗口全屏设置篇)
  9. Spring AOP体系学习
  10. oracle的concat的用法
  11. NS方程由精确解求源项matlab代码
  12. OPENGL和DX的不同.
  13. [单调栈 扫描线] BZOJ 4826 [Hnoi2017]影魔
  14. python用逗号隔开输出_c语言提取逗号隔开的 python输出用逗号隔开的数字
  15. 大数据写入到Oracle数据库(批量插入数据)
  16. 利用autossh反向代理实现内网穿透
  17. SecurityException: Uid 0312 does not have permission content://com.android.providers...
  18. 广告联盟的几大防作弊技术
  19. ACM-ICPC 2018沈阳赛区网络预选赛
  20. 多测师肖sir_高级讲师_第2个月第33讲解jenkins

热门文章

  1. 免费申请国外免费域名超详细教程
  2. telink wiki使用简单说明
  3. MySQL 数据库命名规范.PDF
  4. css span 右端对齐_span右对齐
  5. 甲骨文裁员,N+6 赔偿……部分员工不满
  6. 打印机出现另存为xps_打印机打印文件时弹出另存为xps/pdf该怎办?
  7. 中国大学MOOC行为金融学及答案
  8. 基于视频的相似图片处理[均值哈希算法相似度、三直方图算法相似度]
  9. 面试官问:你的缺点是什么,这么回答漂亮!(真实案例)
  10. CAD卸载/完美解决安装失败/如何彻底卸载清除干净cad各种残留注册表和文件的方法