目录

  • 1.数据集相关操作
    • 1.1标签最长字符个数统计
    • 1.2char和id的映射字典构建
    • 1.3数据集图像尺寸分析
  • 2.将transformer引入OCR
    • 2.1准备工作
    • 2.2数据集创建
  • 3.模型构建
  • 4.模型训练
  • 5.贪心解码
  • 6.总结

本文是跟着datawhale组队学习学的,原文在这:动手学CV-Pytorch 6.2_使用transformer实现OCR字符识别

1.数据集相关操作

#! pip install opencv-python
import os
import cv2
# 数据集根目录,请将数据下载到此位置
base_data_dir = './ICDAR_2015'
# 训练数据集和验证数据集所在路径
train_img_dir = os.path.join(base_data_dir, 'train')
valid_img_dir = os.path.join(base_data_dir, 'valid')
# 训练集和验证集标签文件路径
train_lbl_path = os.path.join(base_data_dir, 'train_gt.txt')
valid_lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')
# 中间文件存储路径,存储标签字符与其id的映射关系
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')

1.1标签最长字符个数统计

def statistics_max_len_label(lbl_path):"""统计标签文件中最长的label所包含的字符数lbl_path:txt标签文件路径"""max_len = -1with open(lbl_path, 'r',encoding = 'utf-8') as reader:for line in reader:items = line.rstrip().split(',')#img_name = item[0] #提取图像名称lbl_str = items[1].strip()[1:-1]#提取标签,去除标签中的引号lbl_len = len(lbl_str)max_len = max_len if max_len>lbl_len else lbl_lenreturn max_len
train_max_label_len = statistics_max_len_label(train_lbl_path) # 训练集最长label
valid_max_label_len = statistics_max_len_label(valid_lbl_path) # 验证集最长label
max_label_len = max(train_max_label_len, valid_max_label_len) # 全数据集最长label
print(f"数据集中包含字符最多的label长度为{max_label_len}")
数据集中包含字符最多的label长度为21
def statistics_label_cnt(lbl_path, lbl_cnt_map):"""统计标签文件中label都包含了哪些字符以及各自出现的次数lbl_path:标签所处路径lbl_cnt_map:记录标签中字符出现次数的字典"""with open(lbl_path, 'r',encoding = 'utf-8') as reader:for line in reader:items = line.rstrip().split(',')#img_name = item[0] #提取图像名称lbl_str = items[1].strip()[1:-1]#提取标签,去除标签中的引号for lbl in lbl_str:if lbl not in lbl_cnt_map.keys():lbl_cnt_map[lbl] = 1else:lbl_cnt_map[lbl] +=1lbl_cnt_map = dict() # 用于存储字符出现次数的字典
statistics_label_cnt(train_lbl_path, lbl_cnt_map) # 训练集中字符出现次数统计
print("训练集中label中出现的字符:")
print(lbl_cnt_map)
statistics_label_cnt(valid_lbl_path, lbl_cnt_map) # 训练集和验证集中字符出现次数统计
print("训练集+验证集label中出现的字符:")
print(lbl_cnt_map)
训练集中label中出现的字符:
{'[': 2, '0': 182, '6': 38, ']': 2, '2': 119, '-': 68, '3': 50, 'C': 593, 'a': 843, 'r': 655, 'p': 197, 'k': 96, 'E': 1421, 'X': 110, 'I': 861, 'T': 896, 'R': 836, 'f': 133, 'u': 293, 's': 557, 'i': 651, 'o': 659, 'n': 605, 'l': 408, 'e': 1055, 'v': 123, 'A': 1189, 'U': 319, 'O': 965, 'N': 785, 'c': 318, 't': 563, 'm': 202, 'W': 179, 'H': 391, 'Y': 229, 'P': 389, 'F': 259, 'G': 345, '?': 5, 'S': 1161, 'b': 88, 'h': 299, ' ': 50, 'g': 171, 'L': 745, 'M': 367, 'D': 383, 'd': 257, '$': 46, '5': 77, '4': 44, '.': 95, 'w': 97, 'B': 331, '1': 184, '7': 43, '8': 44, 'V': 158, 'y': 161, 'K': 163, '!': 51, '9': 66, 'z': 12, ';': 3, '#': 16, 'j': 15, "'": 51, 'J': 72, ':': 19, 'x': 27, '%': 28, '/': 24, 'q': 3, 'Q': 19, '(': 6, ')': 5, '\\': 8, '"': 8, '´': 3, 'Z': 29, '&': 9, 'É': 1, '@': 4, '=': 1, '+': 1}
训练集+验证集label中出现的字符:
{'[': 2, '0': 232, '6': 44, ']': 2, '2': 139, '-': 87, '3': 69, 'C': 893, 'a': 1200, 'r': 935, 'p': 317, 'k': 137, 'E': 2213, 'X': 181, 'I': 1241, 'T': 1315, 'R': 1262, 'f': 203, 'u': 415, 's': 793, 'i': 924, 'o': 954, 'n': 880, 'l': 555, 'e': 1534, 'v': 169, 'A': 1827, 'U': 467, 'O': 1440, 'N': 1158, 'c': 442, 't': 829, 'm': 278, 'W': 288, 'H': 593, 'Y': 341, 'P': 582, 'F': 402, 'G': 521, '?': 7, 'S': 1748, 'b': 129, 'h': 417, ' ': 82, 'g': 260, 'L': 1120, 'M': 536, 'D': 548, 'd': 367, '$': 57, '5': 100, '4': 53, '.': 132, 'w': 136, 'B': 468, '1': 228, '7': 60, '8': 51, 'V': 224, 'y': 231, 'K': 253, '!': 65, '9': 76, 'z': 14, ';': 3, '#': 24, 'j': 19, "'": 70, 'J': 100, ':': 24, 'x': 38, '%': 42, '/': 29, 'q': 3, 'Q': 28, '(': 7, ')': 5, '\\': 8, '"': 8, '´': 3, 'Z': 36, '&': 15, 'É': 2, '@': 9, '=': 1, '+': 2, 'é': 1}

上方代码中,lbl_cnt_map 为字符出现次数的统计字典,后面还会用于建立字符及其id映射关系。从数据集统计结果来看,测试集含有训练集没有出现过的字符,例如测试集中包含1个’é’未曾在训练集出现。这种情况数量不多,应该问题不大,所以此处未对数据集进行额外处理(但是有意识的进行这种训练集和测试集是否存在diff的检查是必要的)。

1.2char和id的映射字典构建

# 构造label中 字符--id 之间的映射
print("构造label中 字符--id之间的映射:")lbl2id_map = dict()
lbl2id_map['☯'] = 0 # padding标识符
lbl2id_map['■'] = 1 # 句子起始符
lbl2id_map['□'] = 2 # 句子结束符
#生成其余字符的id映射关系
cur_id = 3
for lbl in lbl_cnt_map.keys():lbl2id_map[lbl] = cur_idcur_id += 1#保存 字符--id 之间的映射 到txt文件with open(lbl2id_map_path, 'w', encoding='utf-8') as writer:for lbl in lbl2id_map.keys():cur_id = lbl2id_map[lbl]print (lbl, cur_id)line = lbl + '\t' + str(cur_id) + '\n'writer.write(line)
构造label中 字符--id之间的映射:
☯ 0
■ 1
□ 2
[ 3
0 4
6 5
] 6
2 7
- 8
3 9
C 10
a 11
r 12
p 13
k 14
E 15
X 16
I 17
T 18
R 19
f 20
u 21
s 22
i 23
o 24
n 25
l 26
e 27
v 28
A 29
U 30
O 31
N 32
c 33
t 34
m 35
W 36
H 37
Y 38
P 39
F 40
G 41
? 42
S 43
b 44
h 4546
g 47
L 48
M 49
D 50
d 51
$ 52
5 53
4 54
. 55
w 56
B 57
1 58
7 59
8 60
V 61
y 62
K 63
! 64
9 65
z 66
; 67
# 68
j 69
' 70
J 71
: 72
x 73
% 74
/ 75
q 76
Q 77
( 78
) 79
\ 80
" 81
´ 82
Z 83
& 84
É 85
@ 86
= 87
+ 88
é 89
def load_lbl2id_map(lbl2id_map_path):"""读取 字符-id 映射关系记录的txt文件,并返回 lbl->id 和 id->lbl 映射字典lbl2id_map_path : 字符-id 映射关系记录的txt文件路径"""lbl2id_map = dict()id2lbl_map = dict()with open(lbl2id_map_path, 'r',encoding = 'utf-8') as reader:for line in reader:items = line.rstrip().split('\t')label = items[0]cur_id = int(items[1])lbl2id_map[label] = cur_idid2lbl_map[cur_id] = labelreturn lbl2id_map, id2lbl_map

1.3数据集图像尺寸分析

# 分析数据集图片尺寸
print("分析数据集图片尺寸:")
# 初始化参数
min_h = 1e10
min_w = 1e10
max_h = -1
max_w = -1
min_ratio = 1e10
max_ratio = 0
# 遍历数据集计算尺寸信息
for img_name in os.listdir(train_img_dir):img_path = os.path.join(train_img_dir,img_name)img = cv2.imread(img_path)h, w = img.shape[:2]ratio = w / h #高宽比min_h = min_h if min_h <= h else h # 最小图片高度max_h = max_h if max_h >= h else h # 最大图片高度min_w = min_w if min_w <= w else w # 最小图片宽度max_w = max_w if max_w >= w else w # 最大图片宽度min_ratio = min_ratio if min_ratio <= ratio else ratio # 最小图片高宽比max_ratio = max_ratio if max_ratio >= ratio else ratio # 最大图片高宽比# 输出信息
print('min_h:', min_h)
print('max_h:', max_h)
print('min_w:', min_w)
print('max_w:', max_w)
print('min_ratio:', min_ratio)
print('max_ratio:', max_ratio)
分析数据集图片尺寸:
min_h: 9
max_h: 295
min_w: 16
max_w: 628
min_ratio: 0.6666666666666666
max_ratio: 8.619047619047619

2.将transformer引入OCR

2.1准备工作

import os
import time
import copy
import numpy as np
from PIL import Image
# torch相关包
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms
# 导入工具类包
# from analysis_recognition_dataset import load_lbl2id_map, statistics_max_len_label
from transformer import *
from train_utils import *
base_data_dir = './ICDAR_2015/' # 数据集根目录,请将数据下载到此位置
device = torch.device('cuda') # 'cpu'或者'cuda'
nrof_epochs = 1500 # 迭代次数,1500,根据需求进行修正
batch_size = 16 # 批量大小,32,根据需求进行修正
model_save_path = './log/ex1_ocr_model.pth'
# 读取label-id映射关系记录文件
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')
lbl2id_map, id2lbl_map = load_lbl2id_map(lbl2id_map_path)
# 统计数据集中出现的所有的label中包含字符最多的有多少字符,数据集构造gt(ground truth)信息需要用到
train_lbl_path = os.path.join(base_data_dir, 'train_gt.txt')
valid_lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')
train_max_label_len = statistics_max_len_label(train_lbl_path)
valid_max_label_len = statistics_max_len_label(valid_lbl_path)
# 数据集中字符数最多的一个case作为制作的gt的sequence_len
sequence_len = max(train_max_label_len, valid_max_label_len)

2.2数据集创建

class Recognition_Dataset(object):def __init__(self, dataset_root_dir, lbl2id_map, sequence_len, max_ratio, phase='train', pad=0):if phase == 'train':self.img_dir = os.path.join(base_data_dir, 'train')self.lbl_path = os.path.join(base_data_dir, 'train_gt.txt')else:self.img_dir = os.path.join(base_data_dir, 'valid')self.lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')self.lbl2id_map = lbl2id_mapself.pad = pad   # padding标识符的id,默认0self.sequence_len = sequence_len    # 序列长度self.max_ratio = max_ratio * 3      # 将宽拉长3倍self.imgs_list = []self.lbls_list = []with open(self.lbl_path, 'r',encoding = 'utf-8') as reader:for line in reader:items = line.rstrip().split(',')img_name = items[0]lbl_str = items[1].strip()[1:-1]self.imgs_list.append(img_name)self.lbls_list.append(lbl_str)# 定义随机颜色变换self.color_trans = transforms.ColorJitter(0.1, 0.1, 0.1)# 定义 Normalizeself.trans_Normalize = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),]) def __getitem__(self, index):""" 获取对应index的图像和ground truth label,并视情况进行数据增强"""img_name = self.imgs_list[index]img_path = os.path.join(self.img_dir, img_name)lbl_str = self.lbls_list[index]# ----------------# 图片预处理# ----------------# load imageimg = Image.open(img_path).convert('RGB')# 对图片进行大致等比例的缩放# 将高缩放到32,宽大致等比例缩放,但要被32整除w, h = img.sizeratio = round((w / h) * 3)   # 将宽拉长3倍,然后四舍五入if ratio == 0:ratio = 1if ratio > self.max_ratio:ratio = self.max_ratioh_new = 32w_new = h_new * ratioimg_resize = img.resize((w_new, h_new), Image.BILINEAR)# 对图片右半边进行padding,使得宽/高比例固定=self.max_ratioimg_padd = Image.new('RGB', (32*self.max_ratio, 32), (0,0,0))img_padd.paste(img_resize, (0, 0))# 随机颜色变换img_input = self.color_trans(img_padd)# Normalizeimg_input = self.trans_Normalize(img_input)# ----------------# label处理# ----------------# 构造encoder的maskencode_mask = [1] * ratio + [0] * (self.max_ratio - ratio)encode_mask = torch.tensor(encode_mask)encode_mask = (encode_mask != 0).unsqueeze(0)# 构造ground truth labelgt = []gt.append(1)    # 先添加句子起始符for lbl in lbl_str:gt.append(self.lbl2id_map[lbl])gt.append(2)for i in range(len(lbl_str), self.sequence_len):   # 除去起始符终止符,lbl长度为sequence_len,剩下的paddinggt.append(0)# 截断为预设的最大序列长度gt = gt[:self.sequence_len]# decoder的输入decode_in = gt[:-1]decode_in = torch.tensor(decode_in)# decoder的输出decode_out = gt[1:]decode_out = torch.tensor(decode_out)# decoder的maskdecode_mask = self.make_std_mask(decode_in, self.pad)# 有效tokens数ntokens = (decode_out != self.pad).data.sum()return img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens @staticmethoddef make_std_mask(tgt, pad):"""Create a mask to hide padding and future words.padd 和 future words 均在mask中用0表示"""tgt_mask = (tgt != pad)tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))tgt_mask = tgt_mask.squeeze(0)   # subsequent返回值的shape是(1, N, N)return tgt_maskdef __len__(self):return len(self.imgs_list)

以上是构建Dataset的所有细节,进而我们可以构建出DataLoader供训练使用

# 构造 dataloader
max_ratio = 8    # 图片预处理时 宽/高的最大值,不超过就保比例resize,超过会强行压缩
train_dataset = Recognition_Dataset(base_data_dir, lbl2id_map, sequence_len, max_ratio, 'train', pad=0)
valid_dataset = Recognition_Dataset(base_data_dir, lbl2id_map, sequence_len, max_ratio, 'valid', pad=0)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size,shuffle=False,num_workers=0)

3.模型构建

代码通过 make_ocr_modelOCR_EncoderDecoder 类完成模型结构搭建。

make_ocr_model 这个函数看起,该函数首先调用了pytorch中预训练的Resnet-18作为backbone以提取图像特征,此处也可以根据自己需要调整为其他的网络,但需要重点关注的是网络的下采样倍数,以及最后一层特征图的channel_num,相关模块的参数需要同步调整。之后调用了OCR_EncoderDecoder 类完成transformer的搭建。最后对模型参数进行初始化。

OCR_EncoderDecoder 类中,该类相当于是一个transformer各基础组件的拼装线,包括 encoder和 decoder 等,其初始参数是已存在的基本组件,其基本组件代码都在transformer.py文件中,本文不过多赘述。

来回顾一下,图片经过backbone后,如何构造为Transformer的输入:
图片经过backbone后将输出一个维度为 [batch_size, 512, 1, 24] 的特征图,在不关注batch_size的前提下,每一张图像都会得到如下所示具有512个通道的1×24的特征图,如图中红色框标注所示,将不同通道相同位置的特征值拼接组成一个新的向量,并作为一个时间步的输入,此时变构造出了维度为[batch_size, 24, 512] 的输入,满足Transformer的输入要求。
下面看完整的构造模型部分的代码:

# Model Architecture
class OCR_EncoderDecoder(nn.Module):"""A standard Encoder-Decoder architecture. Base for this and many other models."""def __init__(self, encoder, decoder, src_embed, src_position, tgt_embed, generator):super(OCR_EncoderDecoder, self).__init__()self.encoder = encoderself.decoder = decoderself.src_embed = src_embed    # input embedding module(input embedding + positional encode)self.src_position = src_positionself.tgt_embed = tgt_embed    # ouput embedding moduleself.generator = generator    # output generation moduledef forward(self, src, tgt, src_mask, tgt_mask):"Take in and process masked src and target sequences."memory = self.encode(src, src_mask)res = self.decode(memory, src_mask, tgt, tgt_mask)return resdef encode(self, src, src_mask):# feature extractsrc_embedds = self.src_embed(src)# 将src_embedds由shape(bs, model_dim, 1, max_ratio) 处理为transformer期望的输入shape(bs, 时间步, model_dim)src_embedds = src_embedds.squeeze(-2)src_embedds = src_embedds.permute(0, 2, 1)# position encodesrc_embedds = self.src_position(src_embedds)return self.encoder(src_embedds, src_mask)def decode(self, memory, src_mask, tgt, tgt_mask):target_embedds = self.tgt_embed(tgt)return self.decoder(target_embedds, memory, src_mask, tgt_mask)def make_ocr_model(tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):"""构建模型params:tgt_vocab: 输出的词典大小(82)N: 编码器和解码器堆叠基础模块的个数d_model: 模型中embedding的size,默认512d_ff: FeedForward Layer层中embedding的size,默认2048h: MultiHeadAttention中多头的个数,必须被d_model整除dropout:"""c = copy.deepcopybackbone = models.resnet18(pretrained=True)backbone = nn.Sequential(*list(backbone.children())[:-2])    # 去掉最后两个层 (global average pooling and fc layer)attn = MultiHeadedAttention(h, d_model)ff = PositionwiseFeedForward(d_model, d_ff, dropout)position = PositionalEncoding(d_model, dropout)model = OCR_EncoderDecoder(Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),backbone,c(position),nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),Generator(d_model, tgt_vocab))# Initialize parameters with Glorot / fan_avg.for child in model.children():if child is backbone:# 将backbone的权重设为不计算梯度for param in child.parameters():param.requires_grad = False# 预训练好的backbone不进行随机初始化,其余模块进行随机初始化continuefor p in child.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)return model

构建transformer模型:

# build model
# use transformer as ocr recognize model
tgt_vocab = len(lbl2id_map.keys())
d_model = 512
ocr_model = make_ocr_model(tgt_vocab, N=5, d_model=d_model, d_ff=2048, h=8, dropout=0.1)
ocr_model.to(device)

4.模型训练

模型训练之前,还需要定义模型评判准则、迭代优化器等。本实验在训练时,使用了标签平滑(label smoothing)、网络训练热身(warmup)等策略,以上策略的调用函数均在train_utils.py 文件中,此处不涉及以上两种方法的原理及代码实现。

label smoothing可以将原始的硬标签转化为软标签,从而增加模型的容错率,提升模型泛化能力。代码中 LabelSmoothing() 函数实现了label smoothing,同时内部使用了相对熵函数计算了预测值与真实值之间的损失。

warmup策略能够有效控制模型训练过程中的优化器学习率,自动化的实现模型学习率由小增大再逐渐下降的控制,帮助模型在训练时更加稳定,实现损失的快速收敛。代码中 NoamOpt() 函数实现了warmup控制,采用的Adam优化器,实现学习率随迭代次数的自动调整。

# train prepare
criterion = LabelSmoothing(size=tgt_vocab, padding_idx=0, smoothing=0.0)
#optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ocr_model.parameters()),
#                            lr=0,
#                            betas=(0.9, 0.98),
#                            eps=1e-9)
optimizer = torch.optim.Adam(ocr_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
model_opt = NoamOpt(d_model, 1, 400, optimizer)

SimpleLossCompute() 类实现了transformer输出结果的loss计算。在使用该类直接计算时,类需要接收(x, y, norm) 三个参数, x 为decoder输出的结果, y 为标签数据, norm 为loss的归一化系数,用batch中所有有效token数即可。由此可见,此处才正完成transformer所有网络的构建,实现数据计算
流的流通。

class SimpleLossCompute:"A simple loss compute and train function."def __init__(self, generator, criterion, opt=None):self.generator = generatorself.criterion = criterionself.opt = optdef __call__(self, x, y, norm):"""norm: loss的归一化系数,用batch中所有有效token数即可"""x = self.generator(x)x_ = x.contiguous().view(-1, x.size(-1))y_ = y.contiguous().view(-1)loss = self.criterion(x_, y_)loss /= normloss.backward()if self.opt is not None:self.opt.step()self.opt.optimizer.zero_grad()#return loss.data[0] * norm  # TODOreturn loss.item() * norm

模型训练过程的代码如下所示,每训练10个epoch便进行一次验证,单个epoch的计算过程封装在run_epoch() 函数中。

def run_epoch(data_loader, model, loss_compute, device=None):"Standard Training and Logging Function"start = time.time()total_tokens = 0total_loss = 0tokens = 0for i, batch in enumerate(data_loader):#if device == "cuda":#    batch.to_device(device)img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batchimg_input = img_input.to(device)                        encode_mask = encode_mask.to(device)                                decode_in = decode_in.to(device)                                decode_out = decode_out.to(device)                    decode_mask = decode_mask.to(device)ntokens = torch.sum(ntokens).to(device)out = model.forward(img_input, decode_in, encode_mask, decode_mask)loss = loss_compute(out, decode_out, ntokens)total_loss += losstotal_tokens += ntokenstokens += ntokensif i % 50 == 1:elapsed = time.time() - startprint("Epoch Step: %d Loss: %f Tokens per Sec: %f" %(i, loss / ntokens, tokens / elapsed))start = time.time()tokens = 0return total_loss / total_tokens
for epoch in range(nrof_epochs):print(f"\nepoch{epoch}")print("train...")ocr_model.train()loss_compute = SimpleLossCompute(ocr_model.generator, criterion, model_opt)train_mean_loss = run_epoch(train_loader, ocr_model, loss_compute, device)if epoch % 10 == 0:print("valid...")ocr_model.eval()valid_loss_compute = SimpleLossCompute(ocr_model.generator, criterion, None)valid_mean_loss = run_epoch(valid_loader, ocr_model, valid_loss_compute, device)print(f"valid loss:{valid_mean_loss}")
epoch 0
train...
Epoch Step: 1 Loss: 4.756953 Tokens per Sec: 74.231010
Epoch Step: 51 Loss: 3.345229 Tokens per Sec: 227.936249
Epoch Step: 101 Loss: 3.164185 Tokens per Sec: 217.443558
Epoch Step: 151 Loss: 2.884049 Tokens per Sec: 198.306046
Epoch Step: 201 Loss: 2.918671 Tokens per Sec: 204.400925
Epoch Step: 251 Loss: 3.152167 Tokens per Sec: 205.873077
valid...
Epoch Step: 1 Loss: 2.701383 Tokens per Sec: 275.244385
Epoch Step: 51 Loss: 2.951396 Tokens per Sec: 240.986679
Epoch Step: 101 Loss: 2.714810 Tokens per Sec: 261.232330
valid loss: 2.839085102081299epoch 1
train...
Epoch Step: 1 Loss: 3.549314 Tokens per Sec: 193.934494
Epoch Step: 51 Loss: 2.953091 Tokens per Sec: 198.670242
Epoch Step: 101 Loss: 2.828863 Tokens per Sec: 214.964783
Epoch Step: 151 Loss: 2.756577 Tokens per Sec: 208.429001

5.贪心解码

我们使用最简单的贪心解码直接进行OCR结果预测。因为模型每一次只会产生一个输出,我们选择输出的概率分布中的最高概率对应的字符为本次预测的结果,然后预测下一个字符,这就是所谓的贪心解码,见代码中 greedy_decode() 函数。
实验中分别将每一张图像作为模型的输入,逐张进行贪心解码统计正确率,并最终给出了训练集和验证集各自的预测准确率。

# 训练结束,使用贪心的解码方式推理训练集和验证集,统计正确率
ocr_model.eval()
print("\n------------------------------------------------")
print("greedy decode trainset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(train_loader):img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batchimg_input = img_input.to(device)                        encode_mask = encode_mask.to(device)                                bs = img_input.shape[0]for i in range(bs):cur_img_input = img_input[i].unsqueeze(0)cur_encode_mask = encode_mask[i].unsqueeze(0)cur_decode_out = decode_out[i]pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)pred_result = pred_result.cpu()is_correct = judge_is_correct(pred_result, cur_decode_out)total_correct_num += is_correcttotal_img_num += 1if not is_correct:# 预测错误的case进行打印print("----")print(cur_decode_out)print(pred_result)
total_correct_rate = total_correct_num / total_img_num * 100
print(f"total correct rate of trainset:{total_correct_rate}%")print("\n------------------------------------------------")
print("greedy decode validset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(valid_loader):img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batchimg_input = img_input.to(device)                        encode_mask = encode_mask.to(device)                                bs = img_input.shape[0]for i in range(bs):cur_img_input = img_input[i].unsqueeze(0)cur_encode_mask = encode_mask[i].unsqueeze(0)cur_decode_out = decode_out[i]pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)pred_result = pred_result.cpu()is_correct = judge_is_correct(pred_result, cur_decode_out)total_correct_num += is_correcttotal_img_num += 1if not is_correct:# 预测错误的case进行打印print("----")print(cur_decode_out)print(pred_result)
total_correct_rate = total_correct_num / total_img_num * 100
print(f"total correct rate of validset:{total_correct_rate}%")

greedy_decode() 函数实现如下:

# greedy decode
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol):memory = model.encode(src, src_mask)# ys代表目前已生成的序列,最初为仅包含一个起始符的序列,不断将预测结果追加到序列最后ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data).long()for i in range(max_len-1):out = model.decode(memory, src_mask, Variable(ys), Variable(subsequent_mask(ys.size(1)).type_as(src.data)))prob = model.generator(out[:, -1])_, next_word = torch.max(prob, dim = 1)next_word = next_word.data[0]next_word = torch.ones(1, 1).type_as(src.data).fill_(next_word).long()ys = torch.cat([ys, next_word], dim=1)next_word = int(next_word)if next_word == end_symbol:break#ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)ys = ys[0, 1:]return ysdef judge_is_correct(pred, label):# 判断模型预测结果和label是否一致pred_len = pred.shape[0]label = label[:pred_len]is_correct = 1 if label.equal(pred) else 0return is_correct
Epoch Step: 1 Loss: 5.315293 Tokens per Sec: 2073.354492
valid...
Epoch Step: 1 Loss: 3.870697 Tokens per Sec: 2173.835449
valid loss: 3.8293662071228027epoch 1
train...
Epoch Step: 1 Loss: 3.892932 Tokens per Sec: 2160.098633epoch 2
train...
Epoch Step: 1 Loss: 3.594534 Tokens per Sec: 2163.552490tensor([56, 56, 56, 55, 62, 56, 47, 12, 24, 21, 13, 55, 33, 24, 35, 55, 22, 47,2,  0,  0,  0])
tensor([56, 56, 56, 55, 62, 56, 47, 12, 24, 21, 13, 55, 33, 24, 35, 55, 33, 24,34, 55])tensor([15, 16, 10, 15, 39, 18,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,0,  0,  0,  0])
tensor([18, 31, 39,  2])total correct rate of validset: 95.78313253012048%

6.总结

本文首先介绍了所使用的ICDAR2015中的一个单词识别任务数据集,然后对数据的特点进行了简单分析,并构建了识别用的字符映射关系表。之后,重点介绍了将transformer引入来解决OCR任务的动机与思路,并结合代码详细介绍了细节,最后大致过了一些训练相关的逻辑和代码。

【深度学习】使用transformer实现OCR字符识别相关推荐

  1. 【深度学习】Transformer在语义分割上的应用探索

    [深度学习]Transformer在语义分割上的应用探索 文章目录 1 Segmenter 2 Swin-Unet:Unet形状的纯Transformer的医学图像分割 3 复旦大学提出SETR:基于 ...

  2. 用Transformer实现OCR字符识别!

    Datawhale干货 作者:安晟.袁明坤,Datawhale成员 在CV领域中,transformer除了分类还能做什么?本文将采用一个单词识别任务数据集,讲解如何使用transformer实现一个 ...

  3. 基于深度学习OpenCV与python进行字符识别

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 当我们在处理图像数据集时,总是会想有没有什么办法以简单的文本格式检 ...

  4. 【深度学习】Transformer 向轻量型迈进!微软与中科院提出两路并行的 Mobile-Former...

    作者丨happy 编辑丨极市平台 导读 本文创造性的将MobileNet与Transformer进行了两路并行设计,穿插着全局与特征的双向融合,同时利用卷积与Transformer两者的优势达到&qu ...

  5. 【深度学习】Transformer长大了,它的兄弟姐妹们呢?(含Transformers超细节知识点)...

    最近复旦放出了一篇各种Transformer的变体的综述(重心放在对Transformer结构(模块级别和架构级别)改良模型的介绍),打算在空闲时间把这篇文章梳理一下: 知乎:https://zhua ...

  6. 【深度学习】transformer 真的快要取代计算机视觉中的 CNN 吗?

    我相信你肯定已经在自然语言领域中听说过 transformer 这种结构,因为它在 2020 年的 GPT3 上引起了巨大轰动.Transformer 不仅仅可以用于NLP,在许多其他领域表现依然非常 ...

  7. opencv threshold_基于深度学习OpenCV与python进行字符识别

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 当我们在处理图像数据集时,总是会想有没有什么办法以简单的文本格式检 ...

  8. 【深度学习】Transformer温故知新

    这是之前学习paddle时候的笔记,对Transformer框架进行了拆解,附图解和代码,希望对大家有帮助  写在前面 最近在学习paddle相关内容,质量比较高的参考资料好像就paddle官方文档[ ...

  9. 李宏毅《深度学习》- Transformer

    一.Seq2seq 1. 简介 Transformer 就是一个 Seq2seq (Sequence-to-sequence) 的模型 输入一个序列,输出长度由模型决定.例如语音识别,输入的语音信号就 ...

最新文章

  1. vs工程移植报错:缺少MSVCP140D.dll ,CONCRT140D.dll ucrtbased.dll vcruntime140d.dll错误。
  2. 【Python】嫦娥探月数据(PDS)处理与可视化
  3. 联机分析的列式数据库 clickHouse
  4. 计算机网络项目——最小网元设计(前情提要和项目概述)
  5. React 深度学习:ReactFiberRoot
  6. 全网首发:JProfiler11运行时找不到库的解决办法
  7. php 继承 父类使用子类,PHP父类调用子类方法实例
  8. 肖忠付武汉大学计算机学院,丁立新(武汉大学计算机学院教授)_百度百科
  9. 实现python源代码加密
  10. 浏览器兼容性测试及常见问题
  11. edvac是商用计算机吗,EDVAC(eniac与edvac的区别)
  12. 四百左右音质好的蓝牙耳机有哪些?2023公认音质最好的蓝牙耳机排行
  13. 图像增强(拉普拉斯锐化增强)
  14. 如何还原sqlserver数据库或还原bak文件
  15. 二进制老鼠毒药c语言,老鼠试药  二进制问题
  16. android导航地图,地图导航-Android平台-开发指南-高德地图车机版 | 高德地图API
  17. 学建模的快速方法【快捷键】
  18. (附源码)计算机毕业设计ssm超市商品管理系统
  19. ICLR 2017精选论文
  20. 【Hexo搭建个人博客】(五)第三方主题(Next)的基本配置

热门文章

  1. 2022年双十一蓝牙耳机选哪款?盘点学生平价蓝牙耳机推荐
  2. python论文摘要_Python实现提取文章摘要的方法
  3. java 从已知日期计算干支纪日_干支纪日-干支纪日是如何计算的如何确定某一天的干支顺序? 爱问知识人...
  4. 亲情在前面幸福就在后面
  5. 扫雷--优化版实现(可以自动展开、标记雷、取消标记,并加入了标记出所有的雷直接获胜、自动清屏)
  6. 6.3数据粒度的转换
  7. 过去一年中国智慧物流行业发展得如何?十分钟让你知晓!
  8. 《机器学习基石》作业一
  9. jQuery表单控件操作
  10. 2018.10.23 第2周的第1次小组讨论