写在前面,最近两天在做ocr识别相关内容,趁有时间来记录一下。本文的代码是基于Pytorch框架mobilenetv3基础网络的CRNN+CTC网络实现

文字检测与识别介绍

文字识别也是图像领域一个常见问题。然而,对于自然场景图像,首先要定位图像中的文字位置,然后才能进行识别。

所以一般来说,从自然场景图片中进行文字识别,需要包括2个步骤:

  • 文字检测:解决的问题是哪里有文字,文字的范围有多少
  • 文字识别:对定位好的文字区域进行识别,主要解决的问题是每个文字是什么,将图像中的文字区域进转化为字符信息。

在自然场景图片中的文字检测算法比较好用的有PSENET、DBNET等。
本文的重点是如何对已经定位好的文字区域图片进行识别,即假设之前已经文字检测算法已经定位图中的“subway”区域(红框),接下来就是文字识别。

文字识别

1. 背景介绍

基于RNN的文字识别算法主要有两个框架,本文主要介绍CRNN模型:

  1. CNN+RNN+CTC【CRNN】
  2. CNN+Seq2Seq+Attention【ED】

注:两种算法在使用过程中的发现:ED模型对英文识别效果会稍微好一些,但是在推理阶段耗时也更大。对于中文识别来说,两者效果相差不大,但CRNN相对来说解码阶段简单了很多,所以在中文识别方面,CRNN使用的更多些

2. CRNN网络结构介绍

整个CRNN网络主要可以分为四个部分:

  1. Convlutional Layers

【输入:待识别图片(2, 32, 280),输出:图像卷积特征(2, 96, 1,70)】
这里用一个普通的CNN网络去提取图像特征,本文考虑到时间耗时和精度两方面因素,使用了Mobilenetv3,详细代码见下文

  1. Recurrent Layers & 3. Transcription Layers

【输入:图像卷积特征(2, 96, 1,70),输出:预测结果(70, 2, 90)】
这里用一个双向LSTM网络在卷积特征的基础上继续提取文字序列特征,对RNN的输出做softmax,来作为对应时序特征块的输出

  1. 计算loss or 解码

在训练阶段,将预测结果与gt做CTC loss;
在预测阶段,直接对预测结果进行解码

3. 基本设置

  • 任务背景: 数据集icdar2015,提取所有出现字符,加一个blank共90个字符
  • 图片大小:resize为32 * 280
  • 网络输出:T =70(输入LSTM的数据的时间步, CNN 部分输出序列长度) * 90(一共90个不同的字符, 有多少字符此处数字为多少)

4. 网络构造

  1. RCNN整体模型
# RCNN模型
class RecModel(nn.Module):def __init__(self, config):super(RecModel, self).__init__()self.algorithm = config['base']['algorithm']self.backbone = create_module(config['backbone']['function'])(config['base']['pretrained'],config['base']['is_gray'])self.head = create_module(config['head']['function'])(use_conv=config['base']['use_conv'],use_attention=config['base']['use_attention'],use_lstm=config['base']['use_lstm'],lstm_num=config['base']['lstm_num'],inchannel=config['base']['inchannel'],hiddenchannel=config['base']['hiddenchannel'],classes=config['base']['classes'])def forward(self, img):x = self.backbone(img)x = self.head(x)return x
  1. 以mobilenet为backbone的模型
class MobileNetV3_Small(nn.Module):def __init__(self, is_gray):super(MobileNetV3_Small, self).__init__()if(is_gray):self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)else:self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(16)self.hs1 = hswish()self.bneck = nn.Sequential(Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), (2,1)),Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 1),Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, (2,1)),Block(5, 24, 96, 40, hswish(), SeModule(40), 1),Block(5, 40, 240, 40, hswish(), SeModule(40), 1),Block(5, 40, 240, 40, hswish(), SeModule(40), 2),Block(5, 40, 120, 48, hswish(), SeModule(48), 1),Block(5, 48, 144, 48, hswish(), SeModule(48), (2,1)),Block(5, 48, 288, 96, hswish(), SeModule(96), 1),Block(5, 96, 576, 96, hswish(), SeModule(96), 1),Block(5, 96, 576, 96, hswish(), SeModule(96), 1),)self.init_params()def init_params(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):out = self.hs1(self.bn1(self.conv1(x)))out = self.bneck(out)return out
  1. 以双向LSTM及全连接为head的结构
class BLSTM(nn.Module):def __init__(self, nIn, nHidden, nOut):super(BLSTM, self).__init__()self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)self.embedding = nn.Linear(nHidden * 2, nOut)def forward(self, input):recurrent, _ = self.rnn(input)T, b, h = recurrent.size()t_rec = recurrent.view(T * b, h)output = self.embedding(t_rec)  # [T * b, nOut]output = output.view(T, b, -1)return outputclass CRNN_Head(nn.Module):def __init__(self,use_conv=False,use_attention=False,use_lstm=True,lstm_num=2,inchannel=512,hiddenchannel=128,classes=1000):super(CRNN_Head,self).__init__()self.use_lstm = use_lstmself.lstm_num = lstm_numself.use_conv = use_convif use_attention:self.attention = SeModule(inchannel)self.use_attention = use_attentionif(use_lstm):assert lstm_num>0 ,Exception('lstm_num need to more than 0 if use_lstm = True')for i in range(lstm_num):if(i==0):if(lstm_num==1):setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel, hiddenchannel,classes))else:setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel,hiddenchannel,hiddenchannel))elif(i==lstm_num-1):setattr(self, 'lstm_{}'.format(i + 1), BLSTM(hiddenchannel, hiddenchannel, classes))else:setattr(self, 'lstm_{}'.format(i + 1), BLSTM(hiddenchannel, hiddenchannel, hiddenchannel))elif(use_conv):self.out = nn.Conv2d(inchannel, classes, kernel_size=1, padding=0)else:self.out = nn.Linear(inchannel,classes)def forward(self, x):b, c, h, w = x.size()assert h == 1, "the height of conv must be 1"x = x.squeeze(2)x = x.permute(2, 0, 1)  # [w, b, c]if self.use_lstm:for i in range(self.lstm_num):x = getattr(self, 'lstm_{}'.format(i + 1))(x)else:x = self.out(x)return x
  1. CTC Loss

from warpctc_pytorch import CTCLoss as PytorchCTCLoss
import torch.nn as nn
from .basical_loss import focal_ctc_lossclass CTCLoss(nn.Module):def __init__(self,config):super(CTCLoss,self).__init__()self.criterion = PytorchCTCLoss()self.config = configdef forward(self,preds, labels, preds_size, labels_len):loss = self.criterion(preds, labels, preds_size, labels_len)if self.config['loss']['reduction']=='none':loss = focal_ctc_loss(loss)return loss/self.config['trainload']['batch_size']criterion = CTCLoss()
loss = criterion(preds, labels, preds_size, labels_len)
# preds:网络输出,(70, 2, 90),即(T(特征步长), Bs(批次数), C(字符总数))
# labels: [14, 23, 54, 54, 72, 83, 74, 26, 58, 6, 50],即同一批次内合并所以的标签的对应词库id
# preds_size:[70, 70],即网络输出的结果,步长
# labels_len:[8, 3],即同一批次内每张图片对应的标签个数
  1. 预测阶段的解码

网络输出为 (32, 90 ), 解码先取每个位置的最大概率的字符index, index转str时,如果两个相同的index连续,那么合并为一个

例: 假设输出为: [1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 3, 0, 7, 7, 7, 0, 3, 3] , 由于后边全为0,只取前18位. 0 对应的字符是 ‘-’, 对于相邻的非0字符, 看做一个字符, 因此该例子为 [1,0,0,1,0,0,2,0,3,0,7,0,3], 再将0对应的blank 去掉, 则为实际的字符index为 [1,1,2,3,7,3]


参考:一文读懂CRNN+CTC文字识别
参考:pytorch-crnn实践以及内置ctc_loss使用小结
有想要代码的小伙伴请留言~

深度学习之OCR识别相关推荐

  1. 票据识别android代码,深度学习开源ocr识别票据

    AI开发平台ModelArts ModelArts是面向开发者的一站式AI开发平台,为机器学习与深度学习提供海量数据预处理及半自动化标注.大规模分布式Training.自动化模型生成,及端-边-云模型 ...

  2. OCR技术系列之四】基于深度学习的文字识别(3755个汉字)(转)

    上一篇提到文字数据集的合成,现在我们手头上已经得到了3755个汉字(一级字库)的印刷体图像数据集,我们可以利用它们进行接下来的3755个汉字的识别系统的搭建.用深度学习做文字识别,用的网络当然是CNN ...

  3. 【深度学习】OCR文本识别

    OCR文字识别定义 OCR(optical character recognition)文字识别是指电子设备(例如扫描仪或数码相机)检查纸上打印的字符,然后用字符识别方法将形状翻译成计算机文字的过程: ...

  4. 【OCR技术系列之四】基于深度学习的文字识别(3755个汉字)

    上一篇提到文字数据集的合成,现在我们手头上已经得到了3755个汉字(一级字库)的印刷体图像数据集,我们可以利用它们进行接下来的3755个汉字的识别系统的搭建.用深度学习做文字识别,用的网络当然是CNN ...

  5. python深度文字识别_【OCR技术系列之四】基于深度学习的文字识别(3755个汉字)...

    上一篇提到文字数据集的合成,现在我们手头上已经得到了3755个汉字(一级字库)的印刷体图像数据集,我们可以利用它们进行接下来的3755个汉字的识别系统的搭建.用深度学习做文字识别,用的网络当然是CNN ...

  6. 三篇论文,纵览深度学习在表格识别中的最新应用

    本文从三篇表格识别领域的精选论文出发,深入分析了深度学习在表格识别任务中的应用. 表格识别是文档分析与识别领域的一个重要分支,其具体目标是从表格中获取和访问数据及其它有效信息.众所周知,本质上表格是信 ...

  7. 【AI in 美团】深度学习在OCR中的应用

    背景 计算机视觉是利用摄像机和电脑代替人眼,使得计算机拥有类似于人类的对目标进行检测.识别.理解.跟踪.判别决策的功能.以美团业务为例,在商家上单.团单展示.消费评价等多个环节都会涉及计算机视觉的应用 ...

  8. 【AI in 美团】 深度学习在OCR中的应用

    2019独角兽企业重金招聘Python工程师标准>>> 背景 计算机视觉是利用摄像机和电脑代替人眼,使得计算机拥有类似于人类的对目标进行检测.识别.理解.跟踪.判别决策的功能.以美团 ...

  9. 基于深度学习的OCR

    为了提升用户体验,O2O产品对OCR技术的需求已渗透到上单.支付.配送和用户评价等环节.OCR在美团业务中主要起着两方面作用.一方面是辅助录入,比如在移动支付环节通过对银行卡卡号的拍照识别,以实现自动 ...

最新文章

  1. linq调用mysql函数_如何为linq对象制作一个展平函数(Linq To Entities for mysql)?
  2. C语言中控制printf的打印颜色实例及vt100的控制符
  3. [ Linux ] 釋放記憶體指令(cache) - 轉載
  4. python3 面向对象详解_Python3面向对象
  5. 数据仓库工具箱:维度建模权威指南3
  6. 总结:8.9 模拟(枚举搜索)
  7. [置顶]       cocos2d-x 手游源码站
  8. Settings【学习笔记05】
  9. 如何解决linux下编译出现的multiple definition of错误
  10. rbw数字信号处理_数字中频概述 - 频谱分析
  11. 小白帽从病毒视角聊企业安全建设
  12. java程序的入口点_Java程序的入口点
  13. 龙ol服务器维护补偿boss,龙OL低级稀有BOSS刷新点
  14. Python 学习笔记 变量 xxx XXX
  15. JS 格林威治时间格式(GMT)与普通时间格式的互相转换
  16. Java基础项目 开发团队分配管理系统
  17. socks+proxychains网络代理
  18. 泰森多边形(Voronoi图)
  19. win10安装软件出现乱码怎么办
  20. 易语言获取指定文本模块封装源码

热门文章

  1. c语言大端存储,c语言 之大端小端存储问题
  2. 简欧设计 简约而不简单
  3. 当你在浏览器输入baidu.com并敲下回车发生了什么
  4. 有哪些很奇PA,但又比较少人知道的病毒?
  5. 现代信号处理——AR模型谱估计
  6. mysqlamanda备份
  7. 亚巴逊首页分类导航菜单触发区域控制原理窥视
  8. linux定时任务每小时_linux 后台运行,linux定时脚本任务,定时(每分钟),每小时...
  9. unity打包webgl报错及处理 IL2cpp/build/unityLinker.exe not run properly;IL2cpp.exe not run properly
  10. 有两种常见的情况充斥着SEO优化市场,让排名得不到稳定