Datawhale干货

作者:阿水,北京航空航天大学,Datawhale成员

本文以世界人工智能创新大赛(AIWIN)手写体 OCR 识别竞赛为实践背景,给出了OCR实践的常见思路和流程。本项目使用PaddlePaddle 2.0动态图实现的CRNN文字识别模型,全文代码及思路如下。后台回复 211112 可获取完整代码。

代码地址:https://aistudio.baidu.com/aistudio/projectdetail/2612313

赛题背景

银行日常业务中涉及到各类凭证的识别录入,例如身份证录入、支票录入、对账单录入等。以往的录入方式主要是以人工录入为主,效率较低,人力成本较高。近几年来,OCR相关技术以其自动执行、人为干预较少等特点正逐步替代传统的人工录入方式。但OCR技术在实际应用中也存在一些问题,在各类凭证字段的识别中,手写体由于其字体差异性大、字数不固定、语义关联性较低、凭证背景干扰等原因,导致OCR识别率准确率不高,需要大量人工校正,对日常的银行录入业务造成了一定的影响。

比赛地址:http://ailab.aiwin.org.cn/competitions/65

赛题任务

本次赛题将提供手写体图像切片数据集,数据集从真实业务场景中,经过切片脱敏得到,参赛队伍通过识别技术,获得对应的识别结果。即:

  • 输入:手写体图像切片数据集

  • 输出:对应的识别结果

代码说明

本项目是PaddlePaddle 2.0动态图实现的CRNN文字识别模型,可支持长短不一的图片输入。CRNN是一种端到端的识别模式,不需要通过分割图片即可完成图片中全部的文字识别。CRNN的结构主要是CNN+RNN+CTC,它们分别的作用是:

  • 使用深度CNN,对输入图像提取特征,得到特征图;

  • 使用双向RNN(BLSTM)对特征序列进行预测,对序列中的每个特征向量进行学习,并输出预测标签(真实值)分布;

  • 使用 CTC Loss,把从循环层获取的一系列标签分布转换成最终的标签序列。

CRNN的结构如下,一张高为32的图片,宽度随意,一张图片经过多层卷积之后,高度就变成了1,经过paddle.squeeze()就去掉了高度,也就说从输入的图片BCHW经过卷积之后就成了BCW。然后把特征顺序从BCW改为WBC输入到RNN中,经过两次的RNN之后,模型的最终输入为(W, B, Class_num)。这恰好是CTCLoss函数的输入。

代码详情

使用环境:

  • PaddlePaddle 2.0.1

  • Python 3.7

!\rm -rf __MACOSX/ 测试集/ 训练集/ dataset/
!unzip 2021A_T1_Task1_数据集含训练集和测试集.zip > out.log

步骤1:生成额外的数据集

这一步可以跳过,如果想要获取更好的精度,可以自己添加。

import os
import time
from random import choice, randint, randrangefrom PIL import Image, ImageDraw, ImageFont# 验证码图片文字的字符集
characters = '拾伍佰正仟万捌贰整陆玖圆叁零角分肆柒亿壹元'def selectedCharacters(length):result = ''.join(choice(characters) for _ in range(length))return resultdef getColor():r = randint(0, 100)g = randint(0, 100)b = randint(0, 100)return (r, g, b)def main(size=(200, 100), characterNumber=6, bgcolor=(255, 255, 255)):# 创建空白图像和绘图对象imageTemp = Image.new('RGB', size, bgcolor)draw01 = ImageDraw.Draw(imageTemp)# 生成并计算随机字符串的宽度和高度text = selectedCharacters(characterNumber)print(text)font = ImageFont.truetype(font_path, 40)width, height = draw01.textsize(text, font)if width + 2 * characterNumber > size[0] or height > size[1]:print('尺寸不合法')return# 绘制随机字符串中的字符startX = 0widthEachCharater = width // characterNumberfor i in range(characterNumber):startX += widthEachCharater + 1position = (startX, (size[1] - height) // 2)draw01.text(xy=position, text=text[i], font=font, fill=getColor())# 对像素位置进行微调,实现扭曲的效果imageFinal = Image.new('RGB', size, bgcolor)pixelsFinal = imageFinal.load()pixelsTemp = imageTemp.load()for y in range(size[1]):offset = randint(-1, 0)for x in range(size[0]):newx = x + offsetif newx >= size[0]:newx = size[0] - 1elif newx < 0:newx = 0pixelsFinal[newx, y] = pixelsTemp[x, y]# 绘制随机颜色随机位置的干扰像素draw02 = ImageDraw.Draw(imageFinal)for i in range(int(size[0] * size[1] * 0.07)):draw02.point((randrange(0, size[0]), randrange(0, size[1])), fill=getColor())# 保存并显示图片imageFinal.save("dataset/images/%d_%s.jpg" % (round(time.time() * 1000), text))def create_list():images = os.listdir('dataset/images')f_train = open('dataset/train_list.txt', 'w', encoding='utf-8')f_test = open('dataset/test_list.txt', 'w', encoding='utf-8')for i, image in enumerate(images):image_path = os.path.join('dataset/images', image).replace('\\', '/')label = image.split('.')[0].split('_')[1]if i % 100 == 0:f_test.write('%s\t%s\n' % (image_path, label))else:f_train.write('%s\t%s\n' % (image_path, label))def creat_vocabulary():# 生成词汇表with open('dataset/train_list.txt', 'r', encoding='utf-8') as f:lines = f.readlines()v = set()for line in lines:_, label = line.replace('\n', '').split('\t')for c in label:v.add(c)vocabulary_path = 'dataset/vocabulary.txt'with open(vocabulary_path, 'w', encoding='utf-8') as f:f.write(' \n')for c in v:f.write(c + '\n')if __name__ == '__main__':if not os.path.exists('dataset/images'):os.makedirs('dataset/images')

步骤2:安装依赖环境

!pip install Levenshtein
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.16.0)
Requirement already satisfied: rapidfuzz<1.9,>=1.8.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Levenshtein) (1.8.2)

步骤3:读取数据集

import glob, codecs, json, os
import numpy as npdate_jpgs = glob.glob('./训练集/date/images/*.jpg')
amount_jpgs = glob.glob('./训练集/amount/images/*.jpg')lines = codecs.open('./训练集/date/gt.json', encoding='utf-8').readlines()
lines = ''.join(lines)
date_gt = json.loads(lines.replace(',\n}', '}'))lines = codecs.open('./训练集/amount/gt.json', encoding='utf-8').readlines()
lines = ''.join(lines)
amount_gt = json.loads(lines.replace(',\n}', '}'))
data_path = date_jpgs + amount_jpgs
date_gt.update(amount_gt)s = ''
for x in date_gt:s += date_gt[x]char_list = list(set(list(s)))
char_list = char_list

步骤4:构造训练集

!mkdir dataset
!mkdir dataset/images
!cp 训练集/date/images/*.jpg dataset/images
!cp 训练集/amount/images/*.jpg dataset/images
mkdir: cannot create directory ‘dataset’: File exists
mkdir: cannot create directory ‘dataset/images’: File exists
with open('dataset/vocabulary.txt', 'w') as up:for x in char_list:up.write(x + '\n')data_path = glob.glob('dataset/images/*.jpg')
np.random.shuffle(data_path)
with open('dataset/train_list.txt', 'w') as up:for x in data_path[:-100]:up.write(f'{x}\t{date_gt[os.path.basename(x)]}\n')with open('dataset/test_list.txt', 'w') as up:for x in data_path[-100:]:up.write(f'{x}\t{date_gt[os.path.basename(x)]}\n')

执行上面程序生成的图片会放在dataset/images目录下,生成的训练数据列表和测试数据列表分别放在dataset/train_list.txtdataset/test_list.txt,最后还有个数据词汇表dataset/vocabulary.txt

数据列表的格式如下,左边是图片的路径,右边是文字标签。

dataset/images/1617420021182_c1dw.jpg c1dw
dataset/images/1617420021204_uvht.jpg uvht
dataset/images/1617420021227_hb30.jpg hb30
dataset/images/1617420021266_4nkx.jpg 4nkx
dataset/images/1617420021296_80nv.jpg 80nv

以下是数据集词汇表的格式,一行一个字符,第一行是空格,不代表任何字符。

f
s
2
7
3
n
d
w

训练自定义数据,参考上面的格式即可。

步骤5:训练模型

不管你是自定义数据集还是使用上面生成的数据,只要文件路径正确,即可开始进行训练。该训练支持长度不一的图片输入,但是每一个batch的数据的数据长度还是要一样的,这种情况下,笔者就用了collate_fn()函数,该函数可以把数据最长的找出来,然后把其他的数据补0,加到相同的长度。同时该函数还要输出它其中每条数据标签的实际长度,因为损失函数需要输入标签的实际长度。

  • 在训练过程中,程序会使用VisualDL记录训练结果

import paddle
import numpy as np
import os
from datetime import datetime
from utils.model import Model
from utils.decoder import ctc_greedy_decoder, label_to_string, cer
from paddle.io import DataLoader
from utils.data import collate_fn
from utils.data import CustomDataset
from visualdl import LogWriter# 训练数据列表路径
train_data_list_path = 'dataset/train_list.txt'
# 测试数据列表路径
test_data_list_path = 'dataset/test_list.txt'
# 词汇表路径
voc_path = 'dataset/vocabulary.txt'
# 模型保存的路径
save_model = 'models/'
# 每一批数据大小
batch_size = 32
# 预训练模型路径
pretrained_model = None
# 训练轮数
num_epoch = 100
# 初始学习率大小
learning_rate = 1e-3
# 日志记录噐
writer = LogWriter(logdir='log')def train():# 获取训练数据train_dataset = CustomDataset(train_data_list_path, voc_path, img_height=32)train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)# 获取测试数据test_dataset = CustomDataset(test_data_list_path, voc_path, img_height=32, is_data_enhance=False)test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)# 获取模型model = Model(train_dataset.vocabulary, image_height=train_dataset.img_height, channel=1)paddle.summary(model, input_size=(batch_size, 1, train_dataset.img_height, 500))# 设置优化方法boundaries = [30, 100, 200]lr = [0.1 ** l * learning_rate for l in range(len(boundaries) + 1)]scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=boundaries, values=lr, verbose=False)optimizer = paddle.optimizer.Adam(parameters=model.parameters(),learning_rate=scheduler,weight_decay=paddle.regularizer.L2Decay(1e-4))# 获取损失函数ctc_loss = paddle.nn.CTCLoss()# 加载预训练模型if pretrained_model is not None:model.set_state_dict(paddle.load(os.path.join(pretrained_model, 'model.pdparams')))optimizer.set_state_dict(paddle.load(os.path.join(pretrained_model, 'optimizer.pdopt')))train_step = 0test_step = 0# 开始训练for epoch in range(num_epoch):for batch_id, (inputs, labels, input_lengths, label_lengths) in enumerate(train_loader()):out = model(inputs)# 计算损失input_lengths = paddle.full(shape=[batch_size], fill_value=out.shape[0], dtype='int64')loss = ctc_loss(out, labels, input_lengths, label_lengths)loss.backward()optimizer.step()optimizer.clear_grad()# 多卡训练只使用一个进程打印if batch_id % 100 == 0:print('[%s] Train epoch %d, batch %d, loss: %f' % (datetime.now(), epoch, batch_id, loss))writer.add_scalar('Train loss', loss, train_step)train_step += 1# 执行评估if epoch % 10 == 0:model.eval()cer = evaluate(model, test_loader, train_dataset.vocabulary)print('[%s] Test epoch %d, cer: %f' % (datetime.now(), epoch, cer))writer.add_scalar('Test cer', cer, test_step)test_step += 1model.train()# 记录学习率writer.add_scalar('Learning rate', scheduler.last_lr, epoch)scheduler.step()# 保存模型paddle.save(model.state_dict(), os.path.join(save_model, 'model.pdparams'))paddle.save(optimizer.state_dict(), os.path.join(save_model, 'optimizer.pdopt'))# 评估模型
def evaluate(model, test_loader, vocabulary):cer_result = []for batch_id, (inputs, labels, _, _) in enumerate(test_loader()):# 执行识别outs = model(inputs)outs = paddle.transpose(outs, perm=[1, 0, 2])outs = paddle.nn.functional.softmax(outs)# 解码获取识别结果labelss = []out_strings = []for out in outs:out_string = ctc_greedy_decoder(out, vocabulary)out_strings.append(out_string)for i, label in enumerate(labels):label_str = label_to_string(label, vocabulary)labelss.append(label_str)for out_string, label in zip(*(out_strings, labelss)):# 计算字错率c = cer(out_string, label) / float(len(label))cer_result.append(c)cer_result = float(np.mean(cer_result))return cer_resultif __name__ == '__main__':train()

步骤6:模型预测

训练结束之后,使用保存的模型进行预测。通过修改image_path指定需要预测的图片路径,解码方法,笔者使用了一个最简单的贪心策略。

import os
from PIL import Image
import numpy as np
import paddlefrom utils.model import Model
from utils.data import process
from utils.decoder import ctc_greedy_decoderwith open('dataset/vocabulary.txt', 'r', encoding='utf-8') as f:vocabulary = f.readlines()vocabulary = [v.replace('\n', '') for v in vocabulary]save_model = 'models/'
model = Model(vocabulary, image_height=32)
model.set_state_dict(paddle.load(os.path.join(save_model, 'model.pdparams')))
model.eval()def infer(path):data = process(path, img_height=32)data = data[np.newaxis, :]data = paddle.to_tensor(data, dtype='float32')# 执行识别out = model(data)out = paddle.transpose(out, perm=[1, 0, 2])out = paddle.nn.functional.softmax(out)[0]# 解码获取识别结果out_string = ctc_greedy_decoder(out, vocabulary)# print('预测结果:%s' % out_string)return out_stringif __name__ == '__main__':image_path = 'dataset/images/0_8bb194207a248698017a854d62c96104.jpg'display(Image.open(image_path))print(infer(image_path))

<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=123x33 at 0x7F4D525F08D0>
贰零贰零贰壹
from tqdm import tqdm, tqdm_notebookresult_dict = {}
for path in tqdm(glob.glob('./测试集/date/images/*.jpg')):text = infer(path)result_dict[os.path.basename(path)] = {'result': text,'confidence': 0.9}for path in tqdm(glob.glob('./测试集/amount/images/*.jpg')):text = infer(path)result_dict[os.path.basename(path)] = {'result': text,'confidence': 0.9}
with open('answer.json', 'w', encoding='utf-8') as up:json.dump(result_dict, up, ensure_ascii=False, indent=4)!zip answer.json.zip answer.jsonadding: answer.json (deflated 85%)

整理不易,三连

世界人工智能大赛OCR赛题方案!相关推荐

  1. 世界人工智能大赛方案汇总(nlp,cv)

    Datawhale学习 开源贡献者:阿水.致Great.姚程栋.卜首等 有同学希望通过今年的世界人工智能大赛来提升专业能力,同时增加履历,拿到这次赛事的直推offer.根据大家反馈遇到的问题,我们邀请 ...

  2. 世界人工智能大赛方案解析!

    在日常生活中新闻具备有多的信息,在世界人工智能大赛互联网舆情企业风险事件的识别和预警 比赛中参赛选手需要根据新闻识别主体和新闻类型. 比赛官网(报名可下载数据集):http://ailab.aiwin ...

  3. 2021全国职业技能大赛-网络安全赛题解析总结①(超详细)

    2021全国职业技能大赛-网络安全赛题解析总结(1) 模块A 基础设施设置与安全加固 有问题可以私聊博主 模块A 基础设施设置与安全加固 一.项目和任务描述: 假定你是某企业的网络安全工程师,对于企业 ...

  4. 2022全国职业技能大赛-网络安全赛题解析总结②(超详细)

    2022全国职业技能大赛-网络安全赛题解析总结(自己得思路) 模块A 基础设施设置与安全加固(20分) 模块B 网络安全事件响应.数字取证调查和应用安全(40分) 模块C CTF夺旗-攻击(本模块20 ...

  5. 2022全国职业技能大赛-网络安全赛题解析总结①(超详细)

    2022全国职业技能大赛-网络安全赛题解析总结(自己得思路) 模块A 基础设施设置与安全加固(20分) 模块B 网络安全事件响应.数字取证调查和应用安全(40分) 模块C CTF夺旗-攻击(20分) ...

  6. 2022全国职业技能大赛-网络安全赛题解析总结⑤(超详细)

    2022全国职业技能大赛-网络安全赛题解析总结(自己得思路) 模块A 基础设施设置与安全加固(20分) 模块B 网络安全事件响应.数字取证调查和应用安全(40分) 模块C CTF夺旗-攻击(20分) ...

  7. 2023年全国职业院校技能大赛软件测试赛题第1套

    2023年全国职业院校技能大赛 软件测试赛题第1套                                         赛项名称:            软件测试             ...

  8. 腾讯广告算法大赛2020赛题初探坑

    腾讯广告算法大赛2020赛题初探坑 写在前面 1.赛题和数据 2.评分标准 3.特征工程 3.1one-hot编码 3.2hash特征 3.3target encode 3.4embedding大法好 ...

  9. 2023年全国职业院校技能大赛 软件测试赛题第2套

    2023年全国职业院校技能大赛 软件测试赛题第2套 赛项名称: 软件测试 英文名称: Software Testing 赛项编号: GZ034 归属产业: 电子与信息大类 赛项组别: 高等职业教育 赛 ...

最新文章

  1. 【跃迁之路】【578天】程序员高效学习方法论探索系列(实验阶段335-2018.09.06)...
  2. 本地缓存需要高时效性怎么办_缓存在高并发场景下的常见问题
  3. C语言 enum和typedef enum的区别
  4. 【NLP】Prompt Learning-使用模板激发语言模型潜能
  5. Ubuntu18.04深度学习环境配置(简易方式)
  6. 剑指 Offer 01-----20
  7. 人脸识别技术原理与工程实践
  8. 随想录(关于dsp)
  9. mybatis 批量插入的两种方式
  10. Atitit 算法原理与导论 目录 1. Attilax总结的有用算法 按用途分类 1 1.1. 排序算法 字符串匹配(String Matching) 1 1.2. 加密算法 编码算法 序列
  11. 第一个Django项目----一小时写出账号密码管理系统
  12. 通达信经典指标组合图文详解
  13. 浏览器flash/html5视频播放如何倍速(Enounce MySpeed)
  14. 雷达传感器应用,微波雷达感应模块,物联网传感技术发展
  15. 论文编辑——插入公式编号并对齐、插入图表编号、正文引用各类编号
  16. 送给计算机老师平安夜贺卡,给老师的平安夜祝福语
  17. python爬取二手房信息_python爬虫爬取链家二手房信息
  18. 一起学英语第二季第五期
  19. 国内十大优质黄金期货交易平台排名榜单(最新版一览)
  20. 告别传统机房:3D 机房数据可视化实现智能化与VR技术的新碰撞

热门文章

  1. 5G支持下,人工智能除了AI换脸,还能干什么?
  2. 前端包管理器的领头大哥——npm
  3. 去信任外包虚荣地址生成
  4. 儿童保健管理系统技术方案
  5. Internal Server Error 错误 The server encountered an internal error or misconfiguration and was una...
  6. stack overflow -最好的编程技术论坛!
  7. 『Flutter开发实战』一小时掌握Dart语言
  8. pear php linux,linux下安装PEAR、Zend Debugger和Smarty
  9. 测试中使用SecureCRT的经验归纳
  10. 中国电信189邮箱手机推送功能评测