需求:调研CNN+LSTM+CTC的实现

解决方案; 参考github实现

示例代码:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
tf CNN+LSTM+CTC 训练识别不定长数字字符图片@author: pengyuanjie
"""
from com.shenl.ocrTensorflowCnn.genIDCard import *import numpy as np
import time
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf''' TF_CPP_MIN_LOG_LEVEL指定Tensorflow的日志级别
- 0:显示所有日志(默认等级)
- 1:显示info、warning和error日志
- 2:显示warning和error信息
- 3:显示error日志信息
'''#定义一些常量
#图片大小,32 x 256(高*宽)
OUTPUT_SHAPE = (32,256)#训练最大轮次
num_epochs = 10000num_hidden = 64
num_layers = 1#生成一个图片对象
obj = gen_id_card()num_classes = obj.len + 1 + 1  # 10位数字 + blank + ctc blank#初始化学习速率
INITIAL_LEARNING_RATE = 1e-3
DECAY_STEPS = 5000
REPORT_STEPS = 100
LEARNING_RATE_DECAY_FACTOR = 0.9  # The learning rate decay factor
MOMENTUM = 0.9DIGITS='0123456789'
BATCHES = 10
BATCH_SIZE = 64
TRAIN_SIZE = BATCHES * BATCH_SIZEdef decode_sparse_tensor(sparse_tensor):#print("sparse_tensor = ", sparse_tensor)decoded_indexes = list()current_i = 0current_seq = []for offset, i_and_index in enumerate(sparse_tensor[0]):i = i_and_index[0]if i != current_i:decoded_indexes.append(current_seq)current_i = icurrent_seq = list()current_seq.append(offset)decoded_indexes.append(current_seq)#print("decoded_indexes = ", decoded_indexes)result = []for index in decoded_indexes:#print("index = ", index)result.append(decode_a_seq(index, sparse_tensor))#print(result)return resultdef decode_a_seq(indexes, spars_tensor):decoded = []for m in indexes:str = DIGITS[spars_tensor[1][m]]decoded.append(str)# Replacing blank label to none#str_decoded = str_decoded.replace(chr(ord('9') + 1), '')# Replacing space label to space#str_decoded = str_decoded.replace(chr(ord('0') - 1), ' ')# print("ffffffff", str_decoded)return decodeddef report_accuracy(decoded_list, test_targets):original_list = decode_sparse_tensor(test_targets)detected_list = decode_sparse_tensor(decoded_list)true_numer = 0if len(original_list) != len(detected_list):print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list)," test and detect length desn't match")returnprint("T/F: original(length) <-------> detectcted(length)")for idx, number in enumerate(original_list):detect_number = detected_list[idx]hit = (number == detect_number)print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")")if hit:true_numer = true_numer + 1print("Test Accuracy:", true_numer * 1.0 / len(original_list))#转化一个序列列表为稀疏矩阵
def sparse_tuple_from(sequences, dtype=np.int32):"""Create a sparse representention of x.Args:sequences: a list of lists of type dtype where each element is a sequenceReturns:A tuple with (indices, values, shape)"""indices = []values = []for n, seq in enumerate(sequences):indices.extend(zip([n] * len(seq), range(len(seq)))) ##python3 是rangevalues.extend(seq)indices = np.asarray(indices, dtype=np.int64)values = np.asarray(values, dtype=dtype)shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)return indices, values, shapedef weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.5)return tf.Variable(initial) def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)def conv2d(x, W, stride=(1, 1), padding='SAME'):return tf.nn.conv2d(x, W, strides=[1, stride[0], stride[1], 1],padding=padding) def max_pool(x, ksize=(2, 2), stride=(2, 2)):return tf.nn.max_pool(x, ksize=[1, ksize[0], ksize[1], 1],strides=[1, stride[0], stride[1], 1], padding='SAME')def avg_pool(x, ksize=(2, 2), stride=(2, 2)):return tf.nn.avg_pool(x, ksize=[1, ksize[0], ksize[1], 1],strides=[1, stride[0], stride[1], 1], padding='SAME')# 生成一个训练batch
def get_next_batch(batch_size=128):obj = gen_id_card()#(batch_size,256,32)inputs = np.zeros([batch_size, OUTPUT_SHAPE[1],OUTPUT_SHAPE[0]])codes = []for i in range(batch_size):#生成不定长度的字串##image, text, vec = obj.gen_image(True) 不应该传该参数Trueimage, text, vec = obj.gen_image()#np.transpose 矩阵转置 (32*256,) => (32,256) => (256,32)inputs[i,:] = np.transpose(image.reshape((OUTPUT_SHAPE[0],OUTPUT_SHAPE[1])))codes.append(list(text))targets = [np.asarray(i) for i in codes]print (targets)sparse_targets = sparse_tuple_from(targets)#(batch_size,) 值都是256seq_len = np.ones(inputs.shape[0]) * OUTPUT_SHAPE[1]return inputs, sparse_targets, seq_len#定义CNN网络,处理图片,
def convolutional_layers():#输入数据,shape [batch_size, max_stepsize, num_features]inputs = tf.placeholder(tf.float32, [None, None, OUTPUT_SHAPE[0]])#第一层卷积层, 32*256*1 => 16*128*48W_conv1 = weight_variable([5, 5, 1, 48])b_conv1 = bias_variable([48])x_expanded = tf.expand_dims(inputs, 3)h_conv1 = tf.nn.relu(conv2d(x_expanded, W_conv1) + b_conv1)h_pool1 = max_pool(h_conv1, ksize=(2, 2), stride=(2, 2))#第二层, 16*128*48 => 16*64*64W_conv2 = weight_variable([5, 5, 48, 64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)h_pool2 = max_pool(h_conv2, ksize=(2, 1), stride=(2, 1))#第三层, 16*64*64 => 8*32*128W_conv3 = weight_variable([5, 5, 64, 128])b_conv3 = bias_variable([128])h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)h_pool3 = max_pool(h_conv3, ksize=(2, 2), stride=(2, 2))#全连接W_fc1 = weight_variable([16 * 8 * OUTPUT_SHAPE[1], OUTPUT_SHAPE[1]])b_fc1 = bias_variable([OUTPUT_SHAPE[1]])conv_layer_flat = tf.reshape(h_pool3, [-1, 16 * 8 * OUTPUT_SHAPE[1]])features = tf.nn.relu(tf.matmul(conv_layer_flat, W_fc1) + b_fc1)#(batchsize,256)shape = tf.shape(features)features = tf.reshape(features, [shape[0], OUTPUT_SHAPE[1], 1])  # batchsize * outputshape * 1return inputs,featuresdef get_train_model():#features = convolutional_layers()#print features.get_shape()inputs = tf.placeholder(tf.float32, [None, None, OUTPUT_SHAPE[0]])#定义ctc_loss需要的稀疏矩阵targets = tf.sparse_placeholder(tf.int32)#1维向量 序列长度 [batch_size,]seq_len = tf.placeholder(tf.int32, [None])#定义LSTM网络cell = tf.contrib.rnn.LSTMCell(num_hidden, state_is_tuple=True)stack = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)outputs, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)shape = tf.shape(inputs)batch_s, max_timesteps = shape[0], shape[1]outputs = tf.reshape(outputs, [-1, num_hidden])W = tf.Variable(tf.truncated_normal([num_hidden,num_classes],stddev=0.1), name="W")b = tf.Variable(tf.constant(0., shape=[num_classes]), name="b")logits = tf.matmul(outputs, W) + blogits = tf.reshape(logits, [batch_s, -1, num_classes])logits = tf.transpose(logits, (1, 0, 2))return logits, inputs, targets, seq_len, W, bdef train():global_step = tf.Variable(0, trainable=False)learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE,global_step,DECAY_STEPS,LEARNING_RATE_DECAY_FACTOR,staircase=True)logits, inputs, targets, seq_len, W, b = get_train_model()loss = tf.nn.ctc_loss(labels=targets,inputs=logits, sequence_length=seq_len)cost = tf.reduce_mean(loss)#optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=MOMENTUM).minimize(cost, global_step=global_step)optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss,global_step=global_step)decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)acc = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), targets))init = tf.global_variables_initializer()def do_report():test_inputs,test_targets,test_seq_len = get_next_batch(BATCH_SIZE)test_feed = {inputs: test_inputs,targets: test_targets,seq_len: test_seq_len}dd, log_probs, accuracy = session.run([decoded[0], log_prob, acc], test_feed)report_accuracy(dd, test_targets)# decoded_list = decode_sparse_tensor(dd)def do_batch():train_inputs, train_targets, train_seq_len = get_next_batch(BATCH_SIZE)feed = {inputs: train_inputs, targets: train_targets, seq_len: train_seq_len}b_loss,b_targets, b_logits, b_seq_len,b_cost, steps, _ = session.run([loss, targets, logits, seq_len, cost, global_step, optimizer], feed)#print b_loss#print b_targets, b_logits, b_seq_lenprint (b_cost, steps)if steps > 0 and steps % REPORT_STEPS == 0:do_report()#save_path = saver.save(session, "ocr.model", global_step=steps)# print(save_path)return b_cost, stepswith tf.Session() as session:session.run(init)saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)for curr_epoch in range(num_epochs):print("Epoch.......", curr_epoch)train_cost = train_ler = 0for batch in range(BATCHES):start = time.time()c, steps = do_batch()train_cost += c * BATCH_SIZEseconds = time.time() - startprint("Step:", steps, ", batch seconds:", seconds)train_cost /= TRAIN_SIZEtrain_inputs, train_targets, train_seq_len = get_next_batch(BATCH_SIZE)val_feed = {inputs: train_inputs,targets: train_targets,seq_len: train_seq_len}val_cost, val_ler, lr, steps = session.run([cost, acc, learning_rate, global_step], feed_dict=val_feed)log = "Epoch {}/{}, steps = {}, train_cost = {:.3f}, train_ler = {:.3f}, val_cost = {:.3f}, val_ler = {:.3f}, time = {:.3f}s, learning_rate = {}"print(log.format(curr_epoch + 1, num_epochs, steps, train_cost, train_ler, val_cost, val_ler, time.time() - start, lr))if __name__ == '__main__':inputs, sparse_targets,seq_len = get_next_batch(2)decode_sparse_tensor(sparse_targets);train()

CNN+LSTM+CTC相关推荐

  1. 实战:CNN+BLSTM+CTC的验证码识别从训练到部署 | 技术头条

    作者|_Coriander 转载自Jerry的算法和NLP(ID: gh_36eba310d433) 1.前言 本项目适用于Python3.6,GPU>=NVIDIA GTX1050Ti,原ma ...

  2. 实战 | CNN+BLSTM+CTC的验证码识别从训练到部署

    点击上方"Jerry的算法和NLP",选择"星标"公众号                      重磅干货,第一时间送达 项目传送门: https://git ...

  3. 【转】CNN+BLSTM+CTC的验证码识别从训练到部署

    [转]CNN+BLSTM+CTC的验证码识别从训练到部署 转载地址:https://www.jianshu.com/p/80ef04b16efc 项目地址:https://github.com/ker ...

  4. [验证码识别技术] 字符型验证码终结者-CNN+BLSTM+CTC

    验证码识别(少样本,高精度)项目地址:https://github.com/kerlomz/captcha_trainer 1. 前言 本项目适用于Python3.6,GPU>=NVIDIA G ...

  5. DL之CNN:利用CNN(keras, CTC loss, {image_ocr})算法实现OCR光学字符识别

    DL之CNN:利用CNN(keras, CTC loss)算法实现OCR光学字符识别 目录 输出结果 实现的全部代码 输出结果 更新-- 实现的全部代码 部分代码源自:GitHub https://r ...

  6. OCR算法:CNN+BLSTM+CTC架构(VS15)

    原文链接:OCR算法-CNN+BLSTM+CTC架构 由于作者使用了Boost1.57-Vc14,而1.57的VC14版本作者没有给出下载链接,因此需要自行编译,建议换掉作者的第三方库,使用其他的库, ...

  7. 一文搞定!手把手教你文字识别(识别篇:LSTM+CTC, CRNN, chineseocr方法)

    个人博客导航页(点击右侧链接即可打开个人博客):大牛带你入门技术栈 文字识别是AI的一个重要应用场景,文字识别过程一般由图像输入.预处理.文本检测.文本识别.结果输出等环节组成.   其中,文本检测. ...

  8. python3语音识别模块_语音识别(LSTM+CTC)

    序言:语音识别作为人工智能领域重要研究方向,近几年发展迅猛,其中RNN的贡献尤为突出.RNN设计的目的就是让神经网络可以处理序列化的数据.本文笔者将陪同小伙伴们一块儿踏上语音识别之梦幻旅途,相信此处风 ...

  9. 吴良超 融合 cnn+lstm

    吴良超 融合 cnn+lstm 链接 from keras.applications.vgg16 import VGG16 from keras.models import Sequential, M ...

最新文章

  1. 银行持续交付实战:一个单体系统足以撑起全球大项目
  2. java indexof int,int indexOf(String str, int fromIndex)
  3. Three.js中实现点击按钮添加删除旋转立方体
  4. python代码解读软件_5种带你轻松分析Python代码的软件库
  5. linux 4.9 内核 nptl,【linuxThread和NPTL】
  6. Knative 初体验:Serving Hello World
  7. 关于电商网站购物车功能如何与登录账号相关联的一点想法
  8. 计算机硬件基础英语ppt,计算机硬件技术基础,computer hardware technology elements,音标,读音,翻译,英文例句,英语词典...
  9. 设计一台模型计算机 实现下列指令系统,基本模型机的设计与实现1
  10. Hadoop高级培训课程大纲-开发者版
  11. VBA学习_3:对象、集合及对象的属性和方法
  12. 51Nod-1259-整数划分 V2
  13. UE4教程:虚幻4引擎(Unreal Engine 4)学习指南
  14. 图片base64解码转换
  15. 什么是大数据(转自知乎)
  16. 可决系数、相关系数、均方误差
  17. 药品计算机数据备份管理制度,GMP丨《药品记录与数据管理要求》(试行)解读
  18. 【小麦苗课堂】高可用培训(RAC+DG+OGG)--包括11g、12c、18c、19c等版本
  19. 传递Bitmap + 图片压缩处理 并保存 + 壁纸设置 总结
  20. 九段刀客 vue-router实现原理

热门文章

  1. JavaScript实现以数组形式返回斐波那契数列fibonacci算法(附完整源码)
  2. wxWidgets:wxSplashScreen 示例
  3. boost::mp11::mp_apply_q相关用法的测试程序
  4. bgi::detail::intersection_content用法的测试程序
  5. boost的chrono模块explore limits探索极限的测试程序
  6. boost::callable_traits的remove_varargs_t的测试程序
  7. VTK:小部件之HoverWidget
  8. OpenCV LATCH Matching描述符匹配算法的实例(附完整代码)
  9. Qt Creator移动平台
  10. C++horspool算法查找字符串是否包含子字符串(附完整源码)