1.仓库地址

https://github.com/meijieru/crnn.pytorch
原版用lua实现的:https://github.com/bgshih/crnn
需要用到的warp_ctc_pytorch: https://github.com/SeanNaren/warp-ctc

2.环境安装

普通的环境都可以吧,我是cuda10.0,torch1.2.0 python3.6. 其他环境也应该可以。
然后库缺少什么就安装什么 pip install ***

warp-CTC需要编译

git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
cd ../pytorch_binding
python setup.py install

我就是这么没有报错就ok
测试是否安装成功就进入python
import warpctc_pytorch
没有报错就说明成功

3.数据准备,lmdb制作

需要这么放置,图片和文本放在一个文件夹,文本名和图片名字一样,文本里面内容是图片上文字。
运行https://github.com/wuzuowuyou/crnn_pytorch/blob/master/myfile/create_lmdb.py脚本
这里注意需要python2运行。我用Python3运行各种报错什么编码问题,用py2跑一点报错都没有,python2也需要装lmdb,(pip2 install lmdb)
跑成功会自动生成这两个东东
./lmdb/data.mdb
./lmdb/lock.mdb
把lmdb文件夹放在data目录下面。

4. 训练

python train.py --adadelta --trainRoot ./data/lmdb/ --valRoot ./data/lmdb/ --cuda

这里注意一下,如果有大小写,需要改下字典表
train.py line32
parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ')

5.报错解决

各种报错啊
5.1 trainRoot,valRoot需要改下大小写
5.2 TypeError: Won't implicitly convert Unicode to bytes; use .encode()
按照错误提示加上encode
txn.get('num-samples'.encode())
label_byte = txn.get(label_key.encode())
imgbuf = txn.get(img_key.encode())
5.3
text, _ = self.encode(text)
File "/home/crnn.pytorch/utils.py", line 45, in encode
for char in text
File "/home/crnn.pytorch/utils.py", line 45, in
for char in text
KeyError: 'b'
解决方案:
dataset.py line 61
label = str(txn.get(label_key)) ->
label_byte=txn.get(label_key.encode())
label = label_byte.decode()

5.4 raise ValueError('sampler option is mutually exclusive with '
ValueError: sampler option is mutually exclusive with shuffle
大意就是sampler和shuffle互斥
我加了 and 0 不用sample
if not opt.random_sample and 0:

5.5 在验证的时候还报错,
Start val
Traceback (most recent call last):
File "/data_2/project_2021/crnn/crnn.pytorch-master/train.py", line 219, in
val(crnn, test_dataset, criterion)
File "/data_2/project_2021/crnn/crnn.pytorch-master/train.py", line 168, in val
preds = preds.squeeze(2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
我不验证,加and 0:
if i % opt.valInterval == 0 and 0:
val(crnn, test_dataset, criterion)

错误解决了,然后就可以训练,打印如下:

(relu6): ReLU(inplace=True))(rnn): Sequential((0): BidirectionalLSTM((rnn): LSTM(512, 256, bidirectional=True)(embedding): Linear(in_features=512, out_features=256, bias=True))(1): BidirectionalLSTM((rnn): LSTM(256, 256, bidirectional=True)(embedding): Linear(in_features=512, out_features=63, bias=True)))
)
[0/100000000][1/9] Loss: 8.430408
[0/100000000][2/9] Loss: 20.137066
[0/100000000][3/9] Loss: 25.239346
[0/100000000][4/9] Loss: 21.249365
[0/100000000][5/9] Loss: 20.604660
[0/100000000][6/9] Loss: 14.782236

6.测试 demo.py

需要改下这里,和训练的时候一致
model = crnn.CRNN(32, 1, 37, 256)

报错
File "/data_2/project_2021/crnn/crnn.pytorch-master/demo_show.py", line 42, in
model.load_state_dict(torch.load(model_path))
File "/data_1/Yang/software_install/Anaconda1105/envs/CenterNet_1.0_3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CRNN:
Missing key(s) in state_dict: "cnn.conv0.weight", "cnn.conv0.bias", "cnn.conv1.weight", "cnn.conv1.bias", "cnn.conv2.weight", "cnn.conv2.bias", "cnn.batchnorm2.weight", "cnn.batchnorm2.bias", "cnn.batchnorm2.running_mean", "cnn.batchnorm2.running_var", "cnn.conv3.weight", "cnn.conv3.bias", "cnn.conv4.weight", "cnn.conv4.bias", "cnn.batchnorm4.weight", "cnn.batchnorm4.bias", "cnn.batchnorm4.running_mean", "cnn.batchnorm4.running_var", "cnn.conv5.weight", "cnn.conv5.bias", "cnn.conv6.weight", "cnn.conv6.bias", "cnn.batchnorm6.weight", "cnn.batchnorm6.bias", "cnn.batchnorm6.running_mean", "cnn.batchnorm6.running_var", "rnn.0.rnn.weight_ih_l0", "rnn.0.rnn.weight_hh_l0", "rnn.0.rnn.bias_ih_l0", "rnn.0.rnn.bias_hh_l0", "rnn.0.rnn.weight_ih_l0_reverse", "rnn.0.rnn.weight_hh_l0_reverse", "rnn.0.rnn.bias_ih_l0_reverse", "rnn.0.rnn.bias_hh_l0_reverse", "rnn.0.embedding.weight", "rnn.0.embedding.bias", "rnn.1.rnn.weight_ih_l0", "rnn.1.rnn.weight_hh_l0", "rnn.1.rnn.bias_ih_l0", "rnn.1.rnn.bias_hh_l0", "rnn.1.rnn.weight_ih_l0_reverse", "rnn.1.rnn.weight_hh_l0_reverse", "rnn.1.rnn.bias_ih_l0_reverse", "rnn.1.rnn.bias_hh_l0_reverse", "rnn.1.embedding.weight", "rnn.1.embedding.bias".
Unexpected key(s) in state_dict: "module.cnn.conv0.weight", "module.cnn.conv0.bias", "module.cnn.conv1.weight", "module.cnn.conv1.bias", "module.cnn.conv2.weight", "module.cnn.conv2.bias", "module.cnn.batchnorm2.weight", "module.cnn.batchnorm2.bias", "module.cnn.batchnorm2.running_mean", "module.cnn.batchnorm2.running_var", "module.cnn.batchnorm2.num_batches_tracked", "module.cnn.conv3.weight", "module.cnn.conv3.bias", "module.cnn.conv4.weight", "module.cnn.conv4.bias", "module.cnn.batchnorm4.weight", "module.cnn.batchnorm4.bias", "module.cnn.batchnorm4.running_mean", "module.cnn.batchnorm4.running_var", "module.cnn.batchnorm4.num_batches_tracked", "module.cnn.conv5.weight", "module.cnn.conv5.bias", "module.cnn.conv6.weight", "module.cnn.conv6.bias", "module.cnn.batchnorm6.weight", "module.cnn.batchnorm6.bias", "module.cnn.batchnorm6.running_mean", "module.cnn.batchnorm6.running_var", "module.cnn.batchnorm6.num_batches_tracked", "module.rnn.0.rnn.weight_ih_l0", "module.rnn.0.rnn.weight_hh_l0", "module.rnn.0.rnn.bias_ih_l0", "module.rnn.0.rnn.bias_hh_l0", "module.rnn.0.rnn.weight_ih_l0_reverse", "module.rnn.0.rnn.weight_hh_l0_reverse", "module.rnn.0.rnn.bias_ih_l0_reverse", "module.rnn.0.rnn.bias_hh_l0_reverse", "module.rnn.0.embedding.weight", "module.rnn.0.embedding.bias", "module.rnn.1.rnn.weight_ih_l0", "module.rnn.1.rnn.weight_hh_l0", "module.rnn.1.rnn.bias_ih_l0", "module.rnn.1.rnn.bias_hh_l0", "module.rnn.1.rnn.weight_ih_l0_reverse", "module.rnn.1.rnn.weight_hh_l0_reverse", "module.rnn.1.rnn.bias_ih_l0_reverse", "module.rnn.1.rnn.bias_hh_l0_reverse", "module.rnn.1.embedding.weight", "module.rnn.1.embedding.bias".

Process finished with exit code 1

原因在于我们保存的pth权重名字多了module.去掉就好。
需要改成如下:

nclass = len(alphabet) + 1model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():model = model.cuda()#
# for m in model.state_dict().keys():
#      print("==:: ", m)load_model_ = torch.load(model_path)
# for k, v in load_model_.items():
#     print(k,"  ::shape",v.shape)state_dict_rename = collections.OrderedDict()
for k, v in load_model_.items():name = k[7:] # remove `module.`state_dict_rename[name] = vprint('loading pretrained model from %s' % model_path)
model.load_state_dict(state_dict_rename)

然后就可以测试了.
改动太多了,我把改好的代码上传git,有需要的下载。其中,放了10张测试图片和label,可以完成转lmdb。https://github.com/wuzuowuyou/crnn_pytorch

crnn pytorch 训练、测试相关推荐

  1. Pytorch 训练与测试时爆显存(cuda out of memory)的终极解决方案,使用cpu(勿喷)

    Pytorch 训练与测试时爆显存(cuda out of memory)的终极解决方案,使用cpu(勿喷) 参见了很多方法,都没有用. 简单点,直接把gpu设成-1

  2. PyTorch安装测试训练建自己的数据集

    Pytorch安装测试训练建自己的数据集 前言 一.PyTorch是什么? 二.PyTorch环境搭建 1.设备要求 2.安装Pytorch 3.验证PyTorch 二.CIFAR10测试 1.关于C ...

  3. python吃显卡还是内存不足_解决Pytorch 训练与测试时爆显存(out of memory)的问题

    Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果模型实在太大,那也没办法. 使用torch.cuda.empty_cache() ...

  4. Pytorch 训练与测试时爆显存(out of memory)的一个解决方案

    Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果模型实在太大,那也没办法. 使用torch.cuda.empty_cache() ...

  5. 解决Pytorch 训练与测试时爆显存(out of memory)的问题

    Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果模型实在太大,那也没办法. 使用torch.cuda.empty_cache() ...

  6. pytorch实现resnet50(训练+测试+模型转换)

    本章使用pytorch训练resnet50,使用cifar数据集. 数据集: 代码工程: 1.train.py import torch from torch import nn, optim imp ...

  7. 如何用PyTorch训练图像分类器

    本文为 AI 研习社编译的技术博客,原标题 : How to Train an Image Classifier in PyTorch and use it to Perform Basic Infe ...

  8. PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快

    PyTorch训练中Dataset多线程加载数据,而不是在DataLoader 背景与需求 现在做深度学习的越来越多人都有用PyTorch,他容易上手,而且API相对TF友好的不要太多.今天就给大家带 ...

  9. 【PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快】

    文章目录 一.引言 二.背景与需求 三.方法的实现 四.代码与数据测试 五.测试结果 5.1.Max elapse 5.2.Multi Load Max elapse 5.3.Min elapse 5 ...

  10. 【MMDetection3D】环境搭建,使用PointPillers训练测试可视化KITTI数据集

    文章目录 前言 3D目标检测概述 KITTI数据集简介 MMDetection3D 环境搭建 数据集准备 训练 测试及可视化 绘制损失函数曲线 参考资料 前言 2D卷不动了,来卷3D,之后更多地工作会 ...

最新文章

  1. php 指定表格字体大小_PHPExcel根据单元格值设置字体/背景颜色
  2. VS2017 ASP.NET MVC 5.0 开部署问题汇总
  3. 输入命令导出oracle
  4. 成功解决OSError: dlopen() failed to load a library: cairo / cairo-2 / cairo-gobject-2 / cairo.so.2
  5. android 集成同一interface不同泛型_Dig101:Go之读懂interface的底层设计
  6. opencv mat数据剪裁感兴趣的部分处理方法
  7. 眉骨高者为大贵之相_男人此处“高大”,大富大贵,前途不可限量!!
  8. Q96:PT(1.2.2):球面2D方格纹理(Sphere 2D Checker)
  9. 使用Apache commons-pool2实现高效的FTPClient连接池的方法
  10. python日记----2017.7.20
  11. 设计模式学习与应用——单例模式
  12. 固高运动控制卡学习7 --模拟量
  13. Macbook怎么开启三指移动 ForceTouch TrackPad开启三指移动方法
  14. coreldraw怎样定数等分_cdr怎样将一个圆形平均划分为三等分?
  15. A Pareto-Efficient Algorithm for Multiple Objective Optimization in E-Commerce Recommendation阅读翻译
  16. 使用canvas在原有图片上进行画框并保存
  17. 华为操作系统,阿里巴巴飞天操作系统 ------- 操作系统生态
  18. html战旗游戏,战棋页游-策略类战棋网页游戏推荐
  19. 五年之后的前端会是什么样?
  20. [转载]内存管理与TLB

热门文章

  1. ubuntu16.04便捷使用(常用工具、常用快捷键、常用使用教程)
  2. lua 常用数据类型总结
  3. 构建AD域 、 管理AD域
  4. 高频实验设备,高频电子线路信号发生器实验箱
  5. 高等数学——伽马函数
  6. Linux档案与目录管理
  7. linux文件放在哪个目录,linux中驱动放在哪个目录下
  8. SQL注入——判断注入
  9. 贴片电阻字码阻值对照表
  10. 新浪微博发布文章html,微博怎么发文章