最近在做OCR相关的东西,关于OCR真的是有悠久了历史了,最开始用tesseract然而效果总是不理想,其中字符分割真的是个博大精深的问题,那么多年那么多算法,然而应用到实际总是有诸多问题。比如说非等间距字体的分割,汉字的分割,有光照阴影的图片的字体分割等等,针对特定的问题,特定的算法能有不错的效果,但也仅限于特定问题,很难有一些通用的结果。于是看了Xlvector的博客之后,发现可以端到端来实现OCR,他是基于mxnet的,于是我想把它转到tensorflow这个框架来,顺便还能熟悉一下这个框架。本文主要介绍实现思路,更加细节的实现方法见另一篇。

正文

生成数据

利用captcha来生成验证码,具体生成验证码的代码请见这里,共生成4-6位包含数字和英文大小写的训练图片128000张和测试图片400张。命名规则就是num_label.png,生成的图片如下图

关于生成数据,再多说一点,可以像Xlvector那样一边生成一边训练,这样样本是无穷的,效果更好。但是实际应用中有限样本的情况还是更多的。

载入数据

两种载入数据方式

pipeline

最开始想通过一个tf.train.string_input_producer来读入所有的文件名,然后以pipline的方式读入,但是由于标签的是不定长的,想通过正则来生成label,一开始是想用py_func来实现,后来发现传入string会有问题,所以最后还是选择生成tf.record文件,关于不定长问题,把比较短的标签在后面补零(0是blank的便签,就是说自己的类别中不能出现0这个类),然后读出每个batch后,再把0去掉。

一次性载入

我这里给一个目录,然后遍历里面所有的文件,等到训练的时候,每一个epoch循环把文件的index给手动shuffle一下,然后就可以每次截取出一个batch来用作输入了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class DataIterator:
def __init__(self, data_dir):
self.image_names = []
self.image = []
self.labels=[]
for root, sub_folder, file_list in os.walk(data_dir):
for file_path in file_list:
image_name = os.path.join(root,file_path)
self.image_names.append(image_name)
im = cv2.imread(image_name,0).astype(np.float32)/255.
im = cv2.resize(im,(image_width,image_height))
# transpose to (160*60) and the step shall be 160
# in this way, each row is a feature vector
im = im.swapaxes(0,1)
self.image.append(np.array(im))
#image is named as ./<folder>/00000_abcd.png
code = image_name.split('/')[2].split('_')[1].split('.')[0]
code = [SPACE_INDEX if code == SPACE_TOKEN else maps[c] for c in list(code)]
self.labels.append(code)
print(image_name,' ',code)
@property
def size(self):
return len(self.labels)
def input_index_generate_batch(self,index=None):
if index:
image_batch=[self.image[i] for i in index]
label_batch=[self.labels[i] for i in index]
else:
# get the whole data as input
image_batch=self.image
label_batch=self.labels
def get_input_lens(sequences):
lengths = np.asarray([len(s) for s in sequences], dtype=np.int64)
return sequences,lengths
batch_inputs,batch_seq_len = get_input_lens(np.array(image_batch))
batch_labels = sparse_tuple_from_label(label_batch)
return batch_inputs,batch_seq_len,batch_labels

需要注意的是tensorflow lstm输入格式的问题,其label tensor应该是稀疏矩阵,所以读取图片和label之后,还要进行一些处理,具体可以看代码
关于载入图片,发现12.8w张图一次读进内存,内存也就涨了5G,如果训练数据加大,还是加一个pipeline来读比较好。

网络结构

然后是网络结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
graph = tf.Graph()
with graph.as_default():
inputs = tf.placeholder(tf.float32, [None, None, num_features])
labels = tf.sparse_placeholder(tf.int32)
seq_len = tf.placeholder(tf.int32, [None])
# Stacking rnn cells
stack = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(FLAGS.num_hidden,state_is_tuple=True) for i in range(FLAGS.num_layers)] , state_is_tuple=True)
# The second output is the last state and we will no use that
outputs, _ = tf.nn.dynamic_rnn(stack, inputs, seq_len, dtype=tf.float32)
shape = tf.shape(inputs)
batch_s, max_timesteps = shape[0], shape[1]
# Reshaping to apply the same weights over the timesteps
outputs = tf.reshape(outputs, [-1, FLAGS.num_hidden])
# Truncated normal with mean 0 and stdev=0.1
W = tf.Variable(tf.truncated_normal([FLAGS.num_hidden,
num_classes],
stddev=0.1),name='W')
b = tf.Variable(tf.constant(0., shape=[num_classes],name='b'))
# Doing the affine projection
logits = tf.matmul(outputs, W) + b
# Reshaping back to the original shape
logits = tf.reshape(logits, [batch_s, -1, num_classes])
# Time major
logits = tf.transpose(logits, (1, 0, 2))
global_step = tf.Variable(0,trainable=False)
loss = tf.nn.ctc_loss(labels=labels,inputs=logits, sequence_length=seq_len)
cost = tf.reduce_mean(loss)
#optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
# momentum=FLAGS.momentum).minimize(cost,global_step=global_step)
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.initial_learning_rate,
beta1=FLAGS.beta1,beta2=FLAGS.beta2).minimize(loss,global_step=global_step)
# Option 2: tf.contrib.ctc.ctc_beam_search_decoder
# (it's slower but you'll get better results)
#decoded, log_prob = tf.nn.ctc_greedy_decoder(logits, seq_len,merge_repeated=False)
decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len,merge_repeated=False)
# Inaccuracy: label error rate
lerr = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), labels))

这里我参考了stackoverflow的一篇帖子写的,根据tensorflow 1.0.1的版本做了微调,使用了Adam作为optimizer。
需要注意的是ctc_beam_search_decoder是非常耗时的,见下图

和greedy_decoder的区别是,greedy_decoder根据当前序列预测下一个字符,并且取概率最高的作为结果,再此基础上再进行下一次预测。而beam_search_decoder每次会保存取k个概率最高的结果,以此为基础再进行预测,并将下一个字符出现的概率与当前k个出现的概率相乘,这样就可以减缓贪心造成的丢失好解的情况,当k=1的时候,二者就一样了。

结果

—update—
稍微调一调,网络可以跑到85%以上。
把网络用在识别身份证号,试了73张网上爬的(不同分辨率下的)真实图片,错了一张,准确率在98%左右(不过毕竟身份证号比较简单)

大概14个epoch后,准确率过了50%,现在跑到了73%的正确率。

最后,代码托管在Github上。

后记

百度出了一个warpCTC可以加速CTC的计算,试用了一下CPU的版本发现好像没什么速度的提升,不知道是不是姿势不对,回头再试试GPU的版本。
对于更加细节的实现方法(输入输出的构造,以及warpCTC和内置ctc_loss的异同)放在了另一篇博客。

  • warpCTC的GPU版本试过之后发现速度差不多,但是能极大的减少CPU的占用
  • 对于不同的优化器,数据,同样的参数是不能普适的。往往之前的参数可以收敛,换个optimizer,数据,网络就不能收敛了。这个时候要微调参数。对于不同的优化器之间区别,文末有一篇神文可以参考

如果有发现问题,请前辈们一定要不吝赐教,在下方留言指出,或者在github上提出issue

tensorflow LSTM + CTC实现端到端OCR相关推荐

  1. tensorflow LSTM+CTC实现端到端的不定长数字串识别

    转载地址: https://www.jianshu.com/p/45828b18f133 上一篇文章tensorflow 实现端到端的OCR:二代身份证号识别实现了定长18位数字串的识别,并最终达到了 ...

  2. 1tensorflow 实现端到端的OCR:二代身份证号识别 + 2tensorflow LSTM+CTC实现端到端的不定长数字串识别

    1tensorflow 实现端到端的OCR:二代身份证号识别 链接地址:https://www.jianshu.com/p/803642d0d8f8 2tensorflow LSTM+CTC实现端到端 ...

  3. c++ 图片验证码识别_基于tensorflow 实现端到端的OCR:二代身份证号识别

    最近在研究OCR识别相关的东西,最终目标是能识别身份证上的所有中文汉字+数字,不过本文先设定一个小目标,先识别定长为18的身份证号,当然本文的思路也是可以复用来识别定长的验证码识别的. 本文实现思路主 ...

  4. rhel 8.2不识别unicode_基于tensorflow 实现端到端的OCR:二代身份证号识别

    最近在研究OCR识别相关的东西,最终目标是能识别身份证上的所有中文汉字+数字,不过本文先设定一个小目标,先识别定长为18的身份证号,当然本文的思路也是可以复用来识别定长的验证码识别的.本文实现思路主要 ...

  5. python 调c++生成的dll 中识别char *_基于tensorflow 实现端到端的OCR:二代身份证号识别...

    最近在研究OCR识别相关的东西,最终目标是能识别身份证上的所有中文汉字+数字,不过本文先设定一个小目标,先识别定长为18的身份证号,当然本文的思路也是可以复用来识别定长的验证码识别的. 本文实现思路主 ...

  6. 基于深度学习(端到端)的OCR文字识别

    版权声明:转载请说明来源,谢谢 https://blog.csdn.net/wsp_1138886114/article/details/83864582 </div><link r ...

  7. ECCV 2022 | 浙大快手提出CoText:基于对比学习和多信息表征的端到端视频OCR模型...

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 点击进入-> CV 微信技术交流群 转载自:CSIG文档图像分析与识别专委会 本文是对快手和浙大联合研 ...

  8. c语言cnn实现ocr字符,端到端的OCR:基于CNN的实现

    OCR是一个古老的问题.这里我们考虑一类特殊的OCR问题,就是验证码的识别.传统做验证码的识别,需要经过如下步骤: 1. 二值化 2. 字符分割 3. 字符识别 这里最难的就是分割.如果字符之间有粘连 ...

  9. 【OCR技术系列之八】端到端不定长文本识别CRNN代码实现

    CRNN是OCR领域非常经典且被广泛使用的识别算法,其理论基础可以参考我上一篇文章,本文将着重讲解CRNN代码实现过程以及识别效果. 数据处理 利用图像处理技术我们手工大批量生成文字图像,一共360万 ...

最新文章

  1. J2EE 13规范(4)-JSP
  2. 《JAVA与模式》之建造模式
  3. linux下基于jrtplib库的实时传送实现
  4. 以IP段作为监听地址
  5. leetocde1129. 颜色交替的最短路径(bfs)
  6. [密码学基础][每个信息安全博士生应该知道的52件事][Bristol Cryptography][第17篇]述和比较DES和AES的轮结构
  7. matlab中bwlabel意思,Matlab 里bwlabel 函数的具体含义
  8. BugkuCTF-WEB题矛盾
  9. NG Ng-template(模板元素)
  10. python难度大吗_python需要学多久?自学两年也很难达到企业标准
  11. HDOJ 2642 HDU 2642 Stars ACM 2642 IN HDU
  12. Javacript中(function(){})() 与 (function(){}()) 区别 {转}
  13. export default 和 export 的使用方式(六)
  14. TensorFlow入门:mnist数据集解析
  15. 鸢尾花数据集的各种玩法
  16. python怎么安装lxml库_lxml解析库的安装和使用
  17. 首款基于龙芯的域名系统服务器发布,首款基于龙芯CPU的国产域名服务器发布
  18. 【C++小游戏】推箱子代码+详解
  19. 程序员面试需要刷力扣算法题吗
  20. 敏捷方法 - 极限编程与工程实践

热门文章

  1. 电脑ip地址设置_关于电脑的远程开机(唤醒)
  2. java mapreduce编程_Hadoop实验——MapReduce编程(1)
  3. 7-3 逆序的三位数(C语言)
  4. 【c语言】蓝桥杯算法训练 斜率计算
  5. Xmanager连接Linux 9的方法
  6. Jenkins + sonarqube集成实现发布代码审计
  7. CocoaPod出现“target overrides the `OTHER_LDFLAGS`……的解决方案
  8. kill_mysql_sleep_thread
  9. Unirest 轻量级的HTTP开发库
  10. mysql主主复制、主从复制、半同步的实现