改动点:

(1)把传统的卷积改造成深度可分离卷积;

(2)使用pytorch实现的ctc,不再使用百度开源的warpctc,主要原因是本人使用Windows来开发调试,编译warpctc貌似很麻烦;

crnn网络实现代码:

class BidirectionalLSTM(nn.Module):def __init__(self, nInput_size, nHidden,nOut):super(BidirectionalLSTM, self).__init__()self.lstm = nn.LSTM(nInput_size, nHidden, bidirectional=True)self.linear = nn.Linear(nHidden * 2, nOut)def forward(self, input):recurrent, (hidden,cell)= self.lstm(input)T, b, h = recurrent.size()t_rec = recurrent.view(T * b, h)output = self.linear(t_rec)  # [T * b, nOut]output = output.view(T, b, -1) #输出变换为[seq,batch,类别总数]return outputclass CNN(nn.Module):def __init__(self,imageHeight,nChannel):super(CNN,self).__init__()assert imageHeight % 32 == 0,'image Height has to be a multiple of 32'self.depth_conv0 = nn.Conv2d(in_channels=nChannel,out_channels=nChannel,kernel_size=3,stride=1,padding=1,groups=nChannel)self.point_conv0 = nn.Conv2d(in_channels=nChannel,out_channels=64,kernel_size=1,stride=1,padding=0,groups=1)self.relu0 = nn.ReLU(inplace=True)self.pool0 = nn.MaxPool2d(kernel_size=2,stride=2)self.depth_conv1 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1,groups=64)self.point_conv1 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=1,stride=1,padding=0,groups=1)self.relu1 = nn.ReLU(inplace=True)self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)self.depth_conv2 = nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1,padding=1,groups=128)self.point_conv2 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=1,stride=1,padding=0,groups=1)self.batchNorm2 = nn.BatchNorm2d(256)self.relu2 = nn.ReLU(inplace=True)self.depth_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)self.point_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1)self.relu3 = nn.ReLU(inplace=True)self.pool3 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,1),padding=(0,1))self.depth_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)self.point_conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)self.batchNorm4 = nn.BatchNorm2d(512)self.relu4 = nn.ReLU(inplace=True)self.depth_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=512)self.point_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)self.relu5 = nn.ReLU(inplace=True)self.pool5 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,1),padding=(0,1))#self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0)self.depth_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0, groups=512)self.point_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)self.batchNorm6 = nn.BatchNorm2d(512)self.relu6= nn.ReLU(inplace=True)def forward(self,input):depth0 = self.depth_conv0(input)point0 = self.point_conv0(depth0)relu0 = self.relu0(point0)pool0 = self.pool0(relu0)# print(pool0.size())depth1 = self.depth_conv1(pool0)point1 = self.point_conv1(depth1)relu1 = self.relu1(point1)pool1 = self.pool1(relu1)#print(pool1.size())depth2 = self.depth_conv2(pool1)point2 = self.point_conv2(depth2)batchNormal2 = self.batchNorm2(point2)relu2 = self.relu2(batchNormal2)#print(relu2.size())depth3 = self.depth_conv3(relu2)point3 = self.point_conv3(depth3)relu3 = self.relu3(point3)pool3 = self.pool3(relu3)#print(pool3.size())depth4 = self.depth_conv4(pool3)point4 = self.point_conv4(depth4)batchNormal4 = self.batchNorm4(point4)relu4 = self.relu4(batchNormal4)#print(relu4.size())depth5 = self.depth_conv5(relu4)point5 = self.point_conv5(depth5)relu5 = self.relu5(point5)pool5 = self.pool5(relu5)#print(pool5.size())depth6 = self.depth_conv6(pool5)point6 = self.point_conv6(depth6)batchNormal6 = self.batchNorm6(point6)relu6 = self.relu6(batchNormal6)#print(relu6.size())return relu6class CRNN(nn.Module):def __init__(self,imgHeight, nChannel, nClass, nHidden):super(CRNN,self).__init__()self.cnn = nn.Sequential(CNN(imgHeight, nChannel))self.lstm = nn.Sequential(BidirectionalLSTM(512, nHidden, nHidden),BidirectionalLSTM(nHidden, nHidden, nClass),)def forward(self,input):conv = self.cnn(input)# pytorch框架输出结构为BCHWbatch,channel,height,width = conv.size()assert  height==1,"the output height must be 1."# 将height==1的维度去掉-->BCWconv = conv.squeeze(dim=2)# 调整各个维度的位置(B,C,W)->(W,B,C),对应lstm的输入(seq,batch,input_size)conv = conv.permute(2,0,1)output = self.lstm(conv)return  output

训练网络代码:

import os
import torch
import cv2
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from crnn_new import crnn
import time# 调整图像大小和归一化操作
class resizeAndNormalize():def __init__(self,size,interpolation=cv2.INTER_LINEAR):# 注意对于opencv,size的格式是(w,h)self.size = sizeself.interpolation = interpolation# ToTensor属于类  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.self.toTensor = transforms.ToTensor()def __call__(self, image):# (x,y) 对于opencv来说,图像宽对应x轴,高对应y轴image = cv2.resize(image,self.size,interpolation=self.interpolation)#转为tensor的数据结构image = self.toTensor(image)#对图像进行归一化操作image = image.sub_(0.5).div_(0.5)return imageclass CRNNDataSet(Dataset):def __init__(self,imageRoot,labelRoot):self.image_root = imageRootself.image_dict = self.readfile(labelRoot)self.image_name = [fileName for fileName,_ in self.image_dict.items()]def __getitem__(self, index):image_path = os.path.join(self.image_root,self.image_name[index])keys = self.image_dict.get(self.image_name[index])label = [int(x) for x in keys]image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)# if image is None:#     return None,None(height,width) = image.shape#由于crnn网络输入图像的高为32,故需要resize原始图像的heightsize_height = 32ratio = 32/float(height)size_width = int(ratio * width)transform = resizeAndNormalize((size_width,size_height))#图像预处理image = transform(image)#标签格式转换为IntTensorlabel = torch.IntTensor(label)return image,labeldef __len__(self):return len(self.image_name)def readfile(self,fileName):res = []with open(fileName, 'r') as f:lines = f.readlines()for line in lines:res.append(line.strip())dic = {}total = 0for line in res:part = line.split(' ')#由于会存在训练过程中取图像的时候图像不存在导致异常,所以在初始化的时候就判断图像是否存在if  not os.path.exists(os.path.join(self.image_root, part[0])):print(os.path.join(self.image_root, part[0]))total += 1else:dic[part[0]] = part[1:]print(total)return dictrainData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data.txt")trainLoader = DataLoader(dataset=trainData,batch_size=30,shuffle=True,num_workers=0)valData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_t.txt")valLoader = DataLoader(dataset=valData,batch_size=1,shuffle=True,num_workers=1)def decode(preds):pred = []for i in range(len(preds)):if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i-1])):pred.append(int(preds[i]))return preddef val(model, loss_function, max_iteration,use_gpu=True):# 将模式切换为验证评估模式model.eval()k = 0totalloss = 0correct_num = 0total_num = 0val_iter = iter(valLoader)max_iter = min(max_iteration,len(valLoader))for i in range(max_iter):k = k + 1data,label = val_iter.next()labels = torch.IntTensor([])for j in range(label.size(0)):labels = torch.cat((labels,label[j]),0)if torch.cuda.is_available() and use_gpu:data = data.cuda()output = model(data)input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))loss = loss_function(output,labels,input_lengths,target_lengths) /  label.size(0)totalloss += float(loss)pred_label = output.max(2)[1]pred_label = pred_label.transpose(1,0).contiguous().view(-1)pred = decode(pred_label)total_num += len(pred)for x,y in zip(pred,labels):if int(x) == int(y):correct_num += 1accuracy = correct_num / float(total_num) * 100test_loss = totalloss / kprint('Test loss : %.3f , accuary : %.3f%%' % (test_loss, accuracy))def train():use_gpu = Truelearning_rate = 0.0005weight_decay = 1e-4max_epoch = 10modelpath = 'F:\crnn_model\pytorch-crnn.pth'char_set = open('../train/char_std_5990.txt','r',encoding='utf-8').readlines()char_set = ''.join([ch.strip('\n') for ch in char_set[1:]] +['卍'])n_class = len(char_set)model = crnn.CRNN(imgHeight=32,nChannel=1,nClass=n_class,nHidden=256)if torch.cuda.is_available() and use_gpu:model.cuda()loss_func = torch.nn.CTCLoss(blank=n_class-1)optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate,weight_decay=weight_decay)if os.path.exists(modelpath):print("load model from %s" % modelpath)model.load_state_dict(torch.load(modelpath))print("done!")lossTotal = 0.0k = 0printInterval = 100valinterval = 1000start_time = time.time()for epoch in range(max_epoch):for i,(data,label) in enumerate(trainLoader):k = k + 1#开启训练模式model.train()labels = torch.IntTensor([])for j in range(label.size(0)):labels = torch.cat((labels,label[j]),0)if torch.cuda.is_available and use_gpu:data = data.cuda()loss_func = loss_func.cuda()labels = labels.cuda()output = model(data)#log_probs = output#example 建议使用这样,貌似直接把output送进去loss fun也没发现什么问题log_probs = output.log_softmax(2).detach().requires_grad_()targets = labelsinput_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))#forward(self, log_probs, targets, input_lengths, target_lengths)loss = loss_func(log_probs,targets,input_lengths,target_lengths) / label.size(0)lossTotal += float(loss)if k % printInterval == 0:print("[%d/%d] [%d/%d] loss:%f" % (epoch, max_epoch, i + 1, len(trainLoader), lossTotal/printInterval))lossTotal = 0.0torch.save(model.state_dict(), 'F:\crnn_model\pytorch-crnn.pth')optimizer.zero_grad()loss.backward()optimizer.step()if k % valinterval == 0:val(model,loss_func)end_time = time.time()print("takes {}s".format((end_time - start_time)))if __name__ == '__main__':train()

测试代码:

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
import torch
from config import opt
from crnn import crnn
from PIL import Image
from torchvision import transformsclass resizeNormalize(object):def __init__(self, size, interpolation=Image.BILINEAR):self.size = sizeself.interpolation = interpolationself.toTensor = transforms.ToTensor()def __call__(self, img):img = img.resize(self.size, self.interpolation)img = self.toTensor(img)img.sub_(0.5).div_(0.5)return imgdef decode(preds,char_set):pred_text = ''for i in range(len(preds)):if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i-1])):pred_text += char_set[int(preds[i])-1]return pred_text# test if crnn workif __name__ == '__main__':imagepath = './12.jpg'img_h = opt.img_huse_gpu = opt.use_gpumodelpath = 'F:\crnn_model\pytorch-crnn-Copy68.pth'#modelpath = '../train/models/pytorch-crnn.pth'# modelpath = opt.modelpathchar_set = open('char_std_5990.txt', 'r', encoding='utf-8').readlines()char_set = ''.join([ch.strip('\n') for ch in char_set[1:]] + ['卍'])n_class = len(char_set)print(n_class)from crnn_new import crnnmodel = crnn.CRNN(img_h, 1, n_class, 256)if os.path.exists(modelpath):print('Load model from "%s" ...' % modelpath)model.load_state_dict(torch.load(modelpath))print('Done!')if torch.cuda.is_available and use_gpu:model.cuda()image = Image.open(imagepath).convert('L')(w,h) = image.sizesize_h = 32ratio = size_h / float(h)size_w = int(w * ratio)# keep the ratiotransform = resizeNormalize((size_w, size_h))image = transform(image)image = image.unsqueeze(0)if torch.cuda.is_available and use_gpu:image = image.cuda()model.eval()preds = model(image)preds = preds.max(2)preds = preds[1]preds = preds.squeeze()pred_text = decode(preds,char_set)print('predict == >',pred_text)

使用pytorch实现crnn相关推荐

  1. RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构

    一.基础知识: 下图是一个循环神经网络实现语言模型的示例,可以看出其是基于当前的输入与过去的输入序列,预测序列的下一个字符. 序列特点就是某一步的输出不仅依赖于这一步的输入,还依赖于其他步的输入或输出 ...

  2. Pytorch使用CRNN CTCLoss实现OCR系统

    卷积递归神经网络 此项目使用CNN + RNN + CTCLoss实现OCR系统,灵感来自CRNN网络. 一.用法 python ./train.py --help 二.演示 1.使用TestData ...

  3. CRNN——pytorch + wrap_ctc编译,实现pytorch版CRNN

    文章目录 简介 CTC网络的输入 CTC网络的计算过程 CTC网络的输出 pytorch安装 warp-CTC安装 Bug解决 References 简介 CTC可以生成一个损失函数,用于在序列数据上 ...

  4. python实现文字识别软件_文字识别(OCR)CRNN(基于pytorch、python3) 实现不定长中文字符识别...

    文字识别(OCR)CRNN(基于pytorch.python3) 实现不定长中文字符识别 发布时间:2018-09-26 19:40, 浏览次数:1265 , 标签: OCR CRNN pytorch ...

  5. 文字识别(OCR)CRNN(基于pytorch、python3) 实现不定长中文字符识别

     最近开源了一个之前做的人脸关键点检测的算法,欢迎star GitHub - Sierkinhane/TAB: Think about boundary: Fusing multi-level bou ...

  6. pytorch crnn训练

    一.代码下载 代码 这个代码主要是针对中文的检测. 二.环境配置 WIN 10 or Ubuntu 16.04 PyTorch 1.2.0 (may fix ctc loss) with cuda 1 ...

  7. python字符识别_crnn(基于pytorch、python3) 实现不定长中文字符识别

    在六七月份参加了一个比赛,做的项目是提取图片中的文字信息,首先是接触了一些文本检测算法(如CTPN,East),后研究了文本识别算法(我认为较好的是CRNN).代码实现是参考算法提出者的pytorch ...

  8. CRNN:端到端不定长文字识别算法

    点击上方"AI搞事情"关注我们 ❝ 论文:<An End-to-End Trainable Neural Network for Image-based Sequence R ...

  9. PyTorch Tutorial

    目录 图像.视觉.CNN相关实现 对抗生成网络.生成模型.GAN相关实现 机器翻译.问答系统.NLP相关实现 先进视觉推理系统 深度强化学习相关实现 通用神经网络高级应用 图像.视觉.CNN相关实现 ...

  10. 【OCR技术系列之八】端到端不定长文本识别CRNN代码实现

    CRNN是OCR领域非常经典且被广泛使用的识别算法,其理论基础可以参考我上一篇文章,本文将着重讲解CRNN代码实现过程以及识别效果. 数据处理 利用图像处理技术我们手工大批量生成文字图像,一共360万 ...

最新文章

  1. 地图构建两篇顶级论文解析
  2. [BZOJ3262]陌上花开
  3. 三组关键词,拆解2021年赤子城的中期业绩报告
  4. 域控 批量导入 用户_kerberos域用户提权分析
  5. Ubuntu 12.04 eclipse 安装 svn插件
  6. php+5.3.15下载,Rapid PHP2018
  7. 梦醒了,一切都结束了
  8. 软件构建中的设计(一)
  9. javaSE基础篇之char
  10. MSYS2 安装和配置
  11. exeScope软件修改exe或dll文件资源-20150818
  12. foxmail邮件备份到服务器上,foxmail发送邮件自动保存到邮件服务器的方法
  13. 解决Windows系统无法复制粘贴问题
  14. php解析psd图层,PSD解析工具实现(七)
  15. keil格式化代码方法
  16. (附源码)计算机毕业设计ssm高校第二课堂管理系统
  17. Python数据分析案例篇(一)泰坦尼克号数据分析
  18. 梯度下降学习率的设定策略
  19. 通俗易懂数仓建模:范式建模与维度建模
  20. 计算机科学领域专业,计算机科学与技术专业主要包括哪些领域?

热门文章

  1. android仿qq编辑图片,仿QQ图片编辑器 – ImageEditor
  2. 5G虚拟专网七大典型行业案例!(附下载)
  3. 31省“5G应用”主攻方向+责任单位一览!
  4. 匹配滤波器为何使得输出SNR最大?
  5. 概率论实验 04 - | 基于Matlab的匹配滤波器
  6. 如何获得微信小游戏源码
  7. [Java]批量生成二维码
  8. 爬取cloudmusic歌单
  9. STM32 JLINK接口定义 JTAG/SWD
  10. PPAPI插件与浏览器的通信