项目信息Program Info

项目要求

基于MindSpore的实现在线手写汉字识别,主要包括手写汉字检测和手写汉字识别,能较准确的对标准字体的手写文字进行识别,识别后通过人工干预对文本进行适当修正。需要有一定的创新特性,代码达到合入社区标准及规范。

项目方案

在自然场景下进行手写汉字识别主要分为两个步骤:手写汉字检测和手写汉字识别。如下图所示,其中,手写汉字检测主要目标是从自然场景的输入图像中寻找到手写汉字区域,并将手写汉字区域从原始图像中分离出来;手写汉字识别的主要目标是从分离出来的图像中准确地识别出该手写汉字的含义。

对于手写汉字检测,考虑采用CTPN算法,CTPN是在ECCV 2016中论文 Detecting Text in Natural Image with Connectionist Text Proposal Network (https://arxiv.org/abs/1609.03605)

中提出的一种文字检测算法。CTPN是在Faster RCNN的基础上结合CNN与LSTM深度网络,能有效的检测出复杂场景的横向分布的文字。对于手写汉字识别考虑使用CNN+RNN+CTC(CRNN+CTC)方法进行识别。CNN用于提取图像特征,RNN使用的是双向的LSTM网络(BiLSTM),用于在卷积特征的基础上继续提取文字序列特征。使用CTCLoss可以解决输出和label长度不一致的问 题,而不用手动去严格对齐。

进度安排

按照时间进度安排,圆满完成任务,通过CTPN网络实现文本检测,CRNN+CTCLoss实现文本识别,并将两者结合, 实现端到端的汉字识别。最后并将代码合并到社区。

准备工作

由于之前深度学习框架只接触Pytorch和Tensorflow,而项目需要使用MindSpore进行搭建,因此我首先对其进行了解。

MindSpore简介

MindSpore是一种适用于端边云场景的新型开源深度学习训练/推理框架。MindSpore提供了友好的设计和高效的执行,旨在提升数据科学家和算法工程师的开发体验,并为Ascend AI处理器提供原生支持,以及软硬件协同优化。具有易开发、高效执行、全场景覆盖三大特性,其中易开发表现为API友好、调试难度低,高效执行包括计算效率、数据预处理效率和分布式训练效率,全场景则指框架同时支持云、边缘以及端侧场景。

同时,MindSpore作为全球AI开源社区,致力于进一步开发和丰富AI软硬件应用生态。

此外对于使用其他深度学习框架的学习者,官方文档中也给出了Pytorch和Tensorflow与Mindspore的算子映射表

(https://www.mindspore.cn/docs/migration_guide/zh-CN/master/api_mapping.html ),

我们可以极其便捷的进行算法迁移。

数据集处理

由于我们任务分为文本检测和文本识别两个部分,因此我们需要数据集必须能同时满足两种任务需求,经过考察后我们选择由中科院自动化所模式识别国家重点实验室搭建的

(CASIA-HWDB)汉字识别数据集

(http://www.nlpr.ia.ac.cn/databases/handwriting/Home.html)手写样本由1020名作者在纸上书写,主要包括独立的字符和连续汉字。离线数据集由6个子数据集组成,3个为手写的独立字符(DB1.0– 1.2),3个为手写汉字(DB2.0–2.2),独立字符中包含3.9M个样本,分为7356类,其中有7185个汉字和171个符号,手写文本共有5090页和1.35M个汉字。手写汉字样本如下图所示。

脱机文本的数据格式为 .dgrl需要对该文本进行解析,转换为训练需要的图片格式和label格式。此外,由于该数据集保存的是每一行的文本图片,为了进行文本检测任务我们需要将每一行拼接成一页的文本图片。

因此数据集处理过程可以分为两步:转换数据格式获得文本行数据,拼接文本行获得文本页数据

转换数据格式

.dgrl按照如下图所示进行存储,每一张图对应一个DGRL文件,大部分内容都有固定的长度,部分内容长度不固定 但是也能通过其他数据推导出来,我们可以通过访问文件特定位置的数据得到我们需要的内容:行文本标注,行图像。获取到重要的文本和图像信息即可。

其中通过文件头部分,可以得到文件头部长度和单个字符的长度,因为文件头部中任意长度的信息,可以通过文件头   部长度直接跳转到图像信息部分,通过单个字符长度可以读取文本信息。图像信息中,可以得到图像的高度、宽度和   行数量,均是固定长度。之后在行文本信息中获取行内字符数量,行文本标注,行图像高度、宽度和图像像素。

使用二进制方式打开文件进读取.dgrl 文件

f = open(dgrl, 'rb')

之后使用numpy进行依次读取,注意是一个一个Byte依次读取,需要指定读取的格式和数量

import numpy as np
np.fromfile(f, dtype='uint8', count=4)

一般 dtype 都选择 uint8,count 需要根据上图结构中的长度 Length 做相应变化。

要注意的一个地方是:行文本标注读取出来以后,是一个 int 列表,要把它还原成汉字,一个汉字占用两个字节(具体由 code_length 决定),使用 struct 将其还原:

struct.pack('I', i).decode('gbk', 'ignore')[0]

上面的 i 就是提取出来的汉字编码,解码格式为 gbk,有些行文本会有空格,解码可能会出错,使用 ignore 忽略。

得到的坐标为每个矩形框的坐标。保存文件为 .jpg的图像格式

详细代码如下所示

import struct
import os
import cv2 as cv
import numpy as np
def read_from_dgrl(dgrl):if not os.path.exists(dgrl):print('DGRL not exis!')returndir_name,base_name = os.path.split(dgrl)label_dir = dir_name+'_label'image_dir = dir_name+'_images'if not os.path.exists(label_dir):os.makedirs(label_dir)if not os.path.exists(image_dir):os.makedirs(image_dir)with open(dgrl, 'rb') as f:# 读取表头尺寸header_size = np.fromfile(f, dtype='uint8', count=4)header_size = sum([j<<(i*8) for i,j in enumerate(header_size)])# print(header_size)# 读取表头剩下内容,提取 code_lengthheader = np.fromfile(f, dtype='uint8', count=header_size-4)code_length = sum([j<<(i*8) for i,j in enumerate(header[-4:-2])])# print(code_length)# 读取图像尺寸信息,提取图像中行数量image_record = np.fromfile(f, dtype='uint8', count=12)height = sum([j<<(i*8) for i,j in enumerate(image_record[:4])])width = sum([j<<(i*8) for i,j in enumerate(image_record[4:8])])line_num = sum([j<<(i*8) for i,j in enumerate(image_record[8:])])print('图像尺寸:')print(height, width, line_num)# 读取每一行的信息for k in range(line_num):print(k+1)# 读取该行的字符数量char_num = np.fromfile(f, dtype='uint8', count=4)char_num = sum([j<<(i*8) for i,j in enumerate(char_num)])print('字符数量:', char_num)# 读取该行的标注信息label = np.fromfile(f, dtype='uint8', count=code_length*char_num)label = [label[i]<<(8*(i%code_length)) for i in range(code_length*char_num)]label = [sum(label[i*code_length:(i+1)*code_length]) for i in range(char_num)]label = [struct.pack('I', i).decode('gbk', 'ignore')[0] for i in label]print('合并前:', label)label = ''.join(label)label = ''.join(label.split(b'\x00'.decode()))  # 去掉不可见字符 \x00,这一步不加的话后面保存的内容会出现看不见的问题print('合并后:', label)# 读取该行的位置和尺寸pos_size = np.fromfile(f, dtype='uint8', count=16)y = sum([j<<(i*8) for i,j in enumerate(pos_size[:4])])x = sum([j<<(i*8) for i,j in enumerate(pos_size[4:8])])h = sum([j<<(i*8) for i,j in enumerate(pos_size[8:12])])w = sum([j<<(i*8) for i,j in enumerate(pos_size[12:])])# print(x, y, w, h)# 读取该行的图片bitmap = np.fromfile(f, dtype='uint8', count=h*w)bitmap = np.array(bitmap).reshape(h, w)# 保存信息label_file = os.path.join(label_dir, base_name.replace('.dgrl', '_'+str(k)+'.txt'))with open(label_file, 'w') as f1:f1.write(label)bitmap_file = os.path.join(image_dir, base_name.replace('.dgrl', '_'+str(k)+'.jpg'))cv.imwrite(bitmap_file, bitmap)

结果如下图所示

可以发现每张图仅为一行文本数据,这样无法进行文本识别,因此需要将文件进行拼接为整页的文本格式。

拼接文本行数据

根据上一步得到的数据可以发现,每一完整页的汉字前缀为 006-P16_*的形式,下划线后表示每行,如果需要拼接为整页只要将相同前缀按照顺序拼接即可。下面对于图片和label的拼接和生成进行分别说明。

首先对图片拼接进行说明。上一步得到的每个行图片height和width不同,在拼接时需要进行调整。对width而言,由于是从上到下拼接,width需要保持一致,因此,取每个行图片width的最大值,其他小于max_width的图片到扩充到最大值,均扩充为白色。对height而言,由于段首和段位的长度明显要小于段间的长度,如果都pad到行图片的前端或后端显然不合适,这时候做一个简单的判断,如果是开头就pad到行图片的前端,如果是结尾或段中就pad到行图片的后端。最后将pad成整页的图片在外围在pad上白色。

对于label进行生成,由行级别的bbox坐标和字符两个部分组成。先生成bbox的坐标再将每个行图片的label读取写入新的page level的label中。bbox 的坐标为每个矩形框四个点的坐标。最终生成的结果如下所示

手写汉字图片

标签

628,1000,2519,1000,2519,1085,628,1085,2006年8月,国际天文学联合会大会正式通过决议,将冥王星降级,

500,1085,2500,1085,2500,1175,500,1175,与其他类似的一些星体统一定义为“矮行星”。当时,天文学家认为冥

500,1175,2519,1175,2519,1269,500,1269,王星应该是矮行星中的“老大”。而最新的天文观测证实,冥王星的“老

500,1269,1010,1269,1010,1343,500,1343,大”头衔也将不保。

650,1343,2519,1343,2519,1445,650,1445,美国加利福尼亚理工学院天文学家迈克尔·布朗等人定于15日出版的

500,1445,2508,1445,2508,1557,500,1557,美国《科学》杂志上报告说,他们在研究矮行星厄里斯的卫星“迪丝诺美

500,1557,2516,1557,2516,1666,500,1666,亚”时,利用设在美国夏威夷的凯克大型望远镜和太空中的哈勃太空望

500,1666,2509,1666,2509,1766,500,1766,远镜,计算出了这颗卫星的运动轨迹,并借助这一信息,进一步计算得

500,1766,2513,1766,2513,1859,500,1859,到厄里斯的最新密度及轨道数据。结果发现,厄里斯的质量大约

500,1859,1720,1859,1720,1941,500,1941,比冥王星大27%,是目前已知最大的矮行星。

最终处理代码如下:

import numpy as np
import cv2
import os
from glob import glob
import re
from tqdm import tqdmdef get_char_nums(segments):nums = []chars = []for seg in segments:label_head = seg.split('.')[0]label_name = label_head + '.txt'with open(os.path.join(label_root,label_name), 'r', encoding='utf-8') as f:lines = f.readlines()nums.append(len(lines[0]))chars.append(lines[0])return nums, chars
def addZeros(s_):head, tail = s_.split('_')num = ''.join(re.findall(r'\d',tail))head_num = '0'*(4-len(num)) + numreturn head + '_' + head_num + '.jpg'def strsort(alist):alist.sort(key=lambda i:addZeros(i))return alist
def pad(img, headpad, padding):assert padding>=0if padding>0:logi_matrix = np.where(img > 255*0.95, np.ones_like(img), np.zeros_like(img))ids = np.where(np.sum(logi_matrix, 0) == img.shape[0])if ids[0].tolist() != []:pad_array = np.tile(img[:,ids[0].tolist()[-1],:], (1, padding)).reshape((img.shape[0],-1,3))else:pad_array = np.tile(np.ones_like(img[:, 0, :]) * 255, (1, padding)).reshape((img.shape[0], -1, 3))if headpad:return np.hstack((pad_array, img))else:return np.hstack((img, pad_array))else:return img
def pad_peripheral(img, pad_size):assert isinstance(pad_size,tuple)w, h = pad_sizeresult = cv2.copyMakeBorder(img, h, h, w, w, cv2.BORDER_CONSTANT, value=[255, 255, 255])return result
if __name__ == '__main__':label_roots = ['./labels']label_dets = ['./fulllabels']pages_roots = ['./images']pages_dets = ['./fullimages']for label_root, label_det, pages_root, pages_det in zip(label_roots, label_dets, pages_roots, pages_dets):os.makedirs(label_det, exist_ok=True)os.makedirs(pages_det, exist_ok=True)pages_for_set = os.listdir(pages_root)pages_set = set([pfs.split('_')[0] for pfs in pages_for_set])for ds in tqdm(pages_set):boxes = []pages = []seg_sorted = strsort([d for d in pages_for_set if ds in d])widths = [cv.imread(os.path.join(pages_root, d)).shape[1] for d in seg_sorted]heights = [cv.imread(os.path.join(pages_root, d)).shape[0] for d in seg_sorted]max_width = max(widths)seg_nums, chars = get_char_nums(seg_sorted)pad_size = (500, 1000)w, h = pad_sizelabel_name = ds + '.txt'with open(os.path.join(label_det, label_name), 'w') as f:for i, pg in enumerate(seg_sorted):headpad = True if i == 0 else True if seg_nums[i] - seg_nums[i - 1] > 5 else Falsepg_read = cv.imread(os.path.join(pages_root, pg))padding = max_width - pg_read.shape[1]page_new = pad(pg_read, headpad, padding)pages.append(page_new)if headpad:x1 = str(w + padding)x2 = str(w + max_width)y1 = str(h + sum(heights[:i + 1]) - heights[i])y2 = str(h + sum(heights[:i + 1]))box = np.array([int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)])else:x1 = str(w)x2 = str(w + max_width - padding)y1 = str(h + sum(heights[:i + 1]) - heights[i])y2 = str(h + sum(heights[:i + 1]))box = np.array([int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)])boxes.append(box.reshape((4, 2)))char = chars[i]f.writelines(x1 + ',' + y1 + ',' + x2 + ',' + y1 + ',' + x2 + ',' + y2 + ',' + x1 + ',' + y2 + ',' + char + '\n')pages_array = np.vstack(pages)pages_array = pad_peripheral(pages_array, pad_size)pages_name = ds + '.jpg'# cv.polylines(pages_array, [box.astype('int32') for box in boxes], True, (0, 0, 255))cv.imwrite(os.path.join(pages_det, pages_name), pages_array)

做完了以上准备工作,下面开始分别对手写汉字进行文本检测和文本识别的网络进行搭建和训练。

文本检测

对于手写汉字检测,考虑采用CTPN算法,CTPN是在ECCV 2016中论文Detecting Text in Natural Image with Connectionist Text Proposal Network

(https://arxiv.org/abs/1609.03605)中提出的一种文字检测算法。CTPN是在Faster RCNN的基础上结合CNN与LSTM深度网络,能有效的检测出复杂场景的横向分布的文字。CTPN算法只能检测出横向排列的文字,其结构与Faster R-CNN基本类似,但是加入了LSTM层,网络结构如下图所示。

图片来源:Detecting Text in Natural Image with Connectionist Text Proposal Network

(https://arxiv.org/abs/1609.03605)

代码

CTPN网络代码如下

class CTPN(nn.Cell):"""Define CTPN networkArgs:input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height forcaptcha images.batch_size(int): batch size of input data, default is 64hidden_size(int): the hidden size in LSTM layers, default is 512"""def __init__(self, config, is_training=True):super(CTPN, self).__init__()self.config = configself.is_training = is_trainingself.num_step = config.num_stepself.input_size = config.input_sizeself.batch_size = config.batch_sizeself.hidden_size = config.hidden_sizeself.vgg16_feature_extractor = VGG16FeatureExtraction()self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16)self.reshape = P.Reshape()self.transpose = P.Transpose()self.cast = P.Cast()# rpn blockself.rpn_with_loss = RPN(config,self.batch_size,config.rpn_in_channels,config.rpn_feat_channels,config.num_anchors,config.rpn_cls_out_channels)self.anchor_generator = AnchorGenerator(config)self.featmap_size = config.feature_shapesself.anchor_list = self.get_anchors(self.featmap_size)self.proposal_generator_test = Proposal(config,config.test_batch_size,config.activate_num_classes,config.use_sigmoid_cls)self.proposal_generator_test.set_train_local(config, False)def construct(self, img_data, gt_bboxes, gt_labels, gt_valids, img_metas=None):x = self.vgg16_feature_extractor(img_data)x = self.conv(x)x = self.cast(x, mstype.float16)x = self.transpose(x, (0, 2, 1, 3))x = self.reshape(x, (-1, self.input_size, self.num_step))x = self.transpose(x, (2, 0, 1))x = self.rnn(x)rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x,                                                                                       gt_valids)if self.training:return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_lossproposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)return proposal, proposal_maskdef get_anchors(self, featmap_size):anchors = self.anchor_generator.grid_anchors(featmap_size)return Tensor(anchors, mstype.float16)
class CTPN_Infer(nn.Cell):def __init__(self, config):super(CTPN_Infer, self).__init__()self.network = CTPN(config, is_training=False)self.network.set_train(False)def construct(self, img_data):output = self.network(img_data, None, None, None, None)return output

与一般目标检测框架不同,为了能够检测连续文本,加入LSTM结构。因为CNN学习的是感受野内的空间信息, LSTM学习的是序列特征。对于文本序列检测,显然既需要CNN抽象空间特征,也需要序列特征(毕竟文字是连续的)。

代码实现如下,代码在 ./ctpn.py

class BiLSTM(nn.Cell):"""Define a BiLSTM network which contains two LSTM layersArgs:input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height forcaptcha images.batch_size(int): batch size of input data, default is 64hidden_size(int): the hidden size in LSTM layers, default is 512"""def __init__(self, config, is_training=True):super(BiLSTM, self).__init__()self.is_training = is_trainingself.batch_size = config.batch_size * config.rnn_batch_sizeprint("batch size is {} ".format(self.batch_size))self.input_size = config.input_sizeself.hidden_size = config.hidden_sizeself.num_step = config.num_stepself.reshape = P.Reshape()self.cast = P.Cast()k = (1 / self.hidden_size) ** 0.5self.rnn1 = P.DynamicRNN(forget_bias=0.0)self.rnn_bw = P.DynamicRNN(forget_bias=0.0)self.w1 = Parameter(np.random.uniform(-k, k, \(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1")self.w1_bw = Parameter(np.random.uniform(-k, k, \(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw")self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.reverse_seq = P.ReverseV2(axis=[0])self.concat = P.Concat()self.transpose = P.Transpose()self.concat1 = P.Concat(axis=2)self.dropout = nn.Dropout(0.7)self.use_dropout = config.use_dropoutself.reshape = P.Reshape()self.transpose = P.Transpose()def construct(self, x):if self.use_dropout:x = self.dropout(x)x = self.cast(x, mstype.float16)bw_x = self.reverse_seq(x)y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)y1_bw = self.reverse_seq(y1_bw)output = self.concat1((y1, y1_bw))return output

RPN与Faster-RCNN类似,便不再赘述

实验结果

最后finetune的loss结果如下

epoch: 100 step: 1467, rpn_loss: 0.02794, rpn_cls_loss: 0.01963, rpn_reg_loss: 0.01110

某一样本检测结果如图所示,能够通过检测框较为准确的框出文本

训练loss变化如下图所示

文本识别

对于手写汉字识别考虑使用CNN+RNN+CTC(CRNN+CTC)方法进行识别。CNN用于提取图像特征,RNN使用的是 双向的LSTM网络(BiLSTM),用于在卷积特征的基础上继续提取文字序列特征。使用CTCLoss可以解决输出和label 长度不一致的问题,而不用手动去严格对齐。

整个CRNN网络分为三个部分,网络结构如下图所示。

图片来源:CRNN文本识别论文

An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition

https://arxiv.org/pdf/1507.05717.pdf

代码

CRNN部分代码构建如下

"""crnn_ctc network define"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.initializer import TruncatedNormal
def _bn(channel):return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, moving_mean_init=0,moving_var_init=1)
class Conv(nn.Cell):def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, use_bn=False, pad_mode='same'):super(Conv, self).__init__()self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,padding=0, pad_mode=pad_mode, weight_init=TruncatedNormal(0.02))self.bn = _bn(out_channel)self.Relu = nn.ReLU()self.use_bn = use_bndef construct(self, x):out = self.conv(x)if self.use_bn:out = self.bn(out)out = self.Relu(out)return out
class VGG(nn.Cell):"""VGG Network structure"""def __init__(self, is_training=True):super(VGG, self).__init__()self.conv1 = Conv(3, 64, use_bn=True)self.conv2 = Conv(64, 128, use_bn=True)self.conv3 = Conv(128, 256, use_bn=True)self.conv4 = Conv(256, 256, use_bn=True)self.conv5 = Conv(256, 512, use_bn=True)self.conv6 = Conv(512, 512, use_bn=True)self.conv7 = Conv(512, 512, kernel_size=2, pad_mode='valid', use_bn=True)self.maxpool2d1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')self.maxpool2d2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1), pad_mode='same')# self.maxpool2d2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(2, 1), pad_mode='same')self.bn1 = _bn(512)def construct(self, x):x = self.conv1(x)x = self.maxpool2d1(x)x = self.conv2(x)x = self.maxpool2d1(x)x = self.conv3(x)x = self.conv4(x)x = self.maxpool2d2(x)x = self.conv5(x)x = self.conv6(x)x = self.maxpool2d2(x)x = self.conv7(x)return x
class CRNN(nn.Cell):"""Define a CRNN network which contains Bidirectional LSTM layers and vgg layer.Args:input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height fortext images.batch_size(int): batch size of input data, default is 64hidden_size(int): the hidden size in LSTM layers, default is 512"""def __init__(self, config):super(CRNN, self).__init__()self.batch_size = config.batch_sizeself.input_size = config.input_sizeself.hidden_size = config.hidden_sizeself.num_classes = config.class_numself.reshape = P.Reshape()self.cast = P.Cast()k = (1 / self.hidden_size) ** 0.5self.rnn1 = P.DynamicRNN(forget_bias=0.0)self.rnn1_bw = P.DynamicRNN(forget_bias=0.0)self.rnn2 = P.DynamicRNN(forget_bias=0.0)self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))self.w1 = Parameter(w1.astype(np.float32), name="w1")w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))self.w2 = Parameter(w2.astype(np.float32), name="w2")w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))self.w1_bw = Parameter(w1_bw.astype(np.float32), name="w1_bw")w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))self.w2_bw = Parameter(w2_bw.astype(np.float32), name="w2_bw")self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2")self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2_bw")self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)self.fc = nn.Dense(in_channels=self.hidden_size, out_channels=self.num_classes,weight_init=Tensor(self.fc_weight), bias_init=Tensor(self.fc_bias))self.fc.to_float(mstype.float32)self.expand_dims = P.ExpandDims()self.concat = P.Concat()self.transpose = P.Transpose()self.squeeze = P.Squeeze(axis=0)self.vgg = VGG()self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)self.concat1 = P.Concat(axis=2)self.dropout = nn.Dropout(0.5)self.rnn_dropout = nn.Dropout(0.9)self.use_dropout = config.use_dropoutdef construct(self, x):x = self.vgg(x)shape1 = x.shapex = self.reshape(x, (self.batch_size, self.input_size, -1))x = self.transpose(x, (2, 0, 1))bw_x = self.reverse_seq1(x, self.seq_length)y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)y1_bw = self.reverse_seq2(y1_bw, self.seq_length)y1_out = self.concat1((y1, y1_bw))if self.use_dropout:y1_out = self.rnn_dropout(y1_out)y2, _, _, _, _, _, _, _ = self.rnn2(y1_out, self.w2, self.b2, None, self.h2, self.c2)bw_y = self.reverse_seq3(y1_out, self.seq_length)y2_bw, _, _, _, _, _, _, _ = self.rnn2(bw_y, self.w2_bw, self.b2_bw, None, self.h2_bw, self.c2_bw)y2_bw = self.reverse_seq4(y2_bw, self.seq_length)y2_out = self.concat1((y2, y2_bw))if self.use_dropout:y2_out = self.dropout(y2_out)output = ()for i in range(F.shape(y2_out)[0]):y2_after_fc = self.fc(self.squeeze(y2[i:i+1:1]))y2_after_fc = self.expand_dims(y2_after_fc, 0)output += (y2_after_fc,)output = self.concat(output)return output
#         return output, shape1, x.shape, y1_out.shape, y2_out.shape, y2_after_fc.shape
def crnn(config, full_precision=False):"""Create a CRNN network with mixed_precision or full_precision"""net = CRNN(config)if not full_precision:net = net.to_float(mstype.float16)return net

由于使用CTCLoss,需要加入blank label,用于分隔文本字符,因此识别的文本类别数需要加1。实现代码如下:

'''
Date: 2021-09-05 14:53:34
LastEditors: xgy
LastEditTime: 2021-09-25 22:45:07
FilePath: \code\crnn_ctc\src\loss.py
'''"""CTC Loss."""
import numpy as np
from mindspore.nn.loss.loss import _Loss
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as Pclass CTCLoss(_Loss):"""CTCLoss definitionArgs:max_sequence_length(int): max number of sequence length. For text images, the value is equal to image widthmax_label_length(int): max number of label length for each input.batch_size(int): batch size of input logits"""def __init__(self, max_sequence_length, max_label_length, batch_size):super(CTCLoss, self).__init__()self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32),name="sequence_length")labels_indices = []for i in range(batch_size):for j in range(max_label_length):labels_indices.append([i, j])self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices")self.reshape = P.Reshape()self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True)def construct(self, logit, label):labels_values = self.reshape(label, (-1,))loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)return loss

实验结果

识别结果如图所示,能基本正确识别图像中的文本( :: 左边为label,右边为预测结果)

训练loss图像如下所示

评估准则分为两种,分别是字层次的精度和句子层次的精度,得到的精度为:

correct num: 8247 , total num: 10449
Accracy in word: 0.879359924968553
Accracy in sentence: 0.7892621303474017
result: {'CRNNAccuracy': 0.7892621303474017}

问题与解决

问题描述

在训练CRNN+CTCLoss框架进行文本识别时,发现多次出现Loss为0的情况,甚至在前100此迭代中都会出现为0的情 况,这显然是极其不合理的。

根据以往Debug经验和深度学习相关知识,整体解决思路主要分为3部分进行。

  • 查看官方文档、手册及论坛,检查对应函数和接口存在什么限制,检查自己编写代码过程中是否已经满足。

  • 查看Github和Google中别人是否存在类似问题,进行参考,看能否解决可能存在的问题。

  • 编写简单的案例进行尝试,考虑所有可能的情况,找出在什么条件下会出现类似的BUG。

  • 对自己代码由浅到深一步步进行检查,可以采用二分法逐步缩小问题范围。

一般只要按照如上思路进行,最终都能解决。

解决过程

我遵循以上思路,一步步排查问题所在。

查看官网手册中MindSporeCTCLoss

(https://www.mindspore.cn/doc/api_python/zh-CN/r1.2/mindspore/ops/mindspore.ops.CTCLoss.html?highlight=ctcloss#mindspore.ops.CTCLoss)

的说明,发现条件均满足。之后,搜索Google、Github和Mindspore论坛,发现并没有出现类似的情况。因此,我觉得自己先编写简单的案例,进行尝试。

由于数据过大,因此我先定义简单的张量进行尝试,测试什么情况下会出现loss为0的情况。

经过多组尝试,发现当 labels_values  异常时会导致loss取值为0,因为CTCLoss在merge_repeated=true的情况下,不可能出现1411这种情况,两个连续的11中间必定有一个blank label ,至少需要5个logits,例如14141才能有1411这种label,而例子中的输出max time为4,这种label是构造不出来的。

这种情况在手册中很难描述出来,我们必须深入理解CTCLoss的内涵才能发现。

找到了出现Loss为0的情况,则可以把问题定位在label 部分,因此我决定对代码的label进行深入检查。

我使用PYNative模型对代码进行调试,将batch_size设置为1,查看当Loss为0时,输入、输出和label各是什么,发现在label在句子中间部分出现了blank label的情况,但在实际情况是不可能的,说明是数据集label构造出了问题。

再检查数据集代码发现,字典集合中缺失了部分汉字,导致无法正确转为字符label,从而出错。

总结与反思

在实验过程中,会出现各种意想不到的问题,我们不应该惊慌失措,冷静下来,按照步骤一步步确定问题所在,最终一定能够成功Debug。其实很多时候的BUG都是由于自己平时的粗心和代码编写不规范导致的,只要自己养成良好的代码编写习惯,仔细分析问题,就能减少BUG出现频率。

致谢

这次活动促进了开源软件的发展和优秀开源软件社区建设,增加开源项目的活跃度,推进开源生态的发展;感谢开源之夏主办方为这次活动提供的平台与机会。大大提高了代码编写能力,真的受益匪浅!

MindSpore官方资料

GitHub : https://github.com/mindspore-ai/mindspore

Gitee : https : //gitee.com/mindspore/mindspore

官方QQ群 : 871543426

项目经验分享:基于昇思MindSpore实现手写汉字识别相关推荐

  1. 基于昇思MindSpore Quantum,实现量子虚时演化算法

    01.关于昇思MindSpore项目介绍 1.项目名称 基于昇思MindSpore Quantum,实现量子虚时演化算法 2.项目链接 https://summer-ospp.ac.cn/#/org/ ...

  2. 基于昇思MindSpore的同元软控AI系列工具箱正式发布,大幅度降低产品研发成本

    随着智能时代的到来,同元软控与华为携手合作,以昇思MindSpore为框架底座,打造了MWORKS AI工具箱,并于2023年1月8号正式对外发布.基于昇思MindSpore的MWORKS AI工具箱 ...

  3. 项目经验分享:基于昇思MindSpore,使用DFCNN和CTC损失函数的声学模型实现

    本期分享来自 MindSpore 社区的龙泳旭同学带来的项目经验:基于MindSpore,使用DFCNN和CTC损失函数的声学模型实现. 项目信息 项目名称 <基于MindSpore,使用DFC ...

  4. 致AI开发者,昇思MindSpore发来“成长”邀请

    撰文 / 张贺飞 编辑 / 沈洁 2020年,应届毕业的蒋倩成了一名算法工程师,因为工作的原因,蒋倩接触到了刚刚开源的昇思MindSpore. 和许多开发者一样,蒋倩对人工智能和开源社区充满了好奇心, ...

  5. 昇思MindSpore全场景AI框架 1.6版本,更高的开发效率,更好地服务开发者

    本文分享自华为云社区<昇思MindSpore全场景AI框架 1.6版本,更高的开发效率,更好地服务开发者>,作者: 技术火炬手. 全新的昇思MindSpore全场景AI框架1.6版本已发布 ...

  6. 智能基座昇腾高校行 | 昇思MindSpore携手清华大学共同培养新时代科技人才

    智能基座昇腾高校行 计算是人类永恒的需求,是智能世界的源动力,人工智能的发展需要数字化人才的助力.在新一轮科技革命背景下,人工智能扮演着重要角色,伴随着产业的快速升级,也对教育的人才培养提出了新要求. ...

  7. AI科学计算领域的再突破,昇思MindSpore做“基石”的决心有多强?

    过去的十多年,人工智能技术越来越深刻地影响了人类社会,越来越多成熟的人工智能产品逐渐渗透到每一个人的生活.就在大家享受着人工智能带来各种便利的同时,AI也不断影响着最前沿的科学研究领域.过去的数百年来 ...

  8. 昇思MindSpore AI框架在知名度与使用率市场份额上处于第一梯队

    2023年2月6日,行业研究机构Omdia(Informa tech集团旗下国际信息与通信技术研究机构)发布了<中国人工智能框架市场调研报告>,深入分析了中国人工智能框架市场的竞争格局,产 ...

  9. 如何加速大模型开发?技术方案拆解来了:昇思MindSpore技术一览

    随着人工智能爆火出圈,狂飙之势从22年底持续到23年初,与以往的技术突破不同的是,此次的大模型不仅被技术界关注,而且备受投资界.产业界和大众消费者的追捧,使它成为历史上最快月活过亿的现象级应用,继而引 ...

最新文章

  1. chrome 浏览器打开静态html 获取json文件失败 解决方法
  2. XP与Windows 7(Win7)等操作系统Ghost备份
  3. 关于Uri.Segments 属性的理解
  4. 禁止Win7系统自动安装驱动程序
  5. AttributeError: module ‘tensorflow‘ has no attribute ‘app‘
  6. 寻找不合群的数据(异常值)
  7. spring框架_一篇文章带你理解Spring框架
  8. 【数据结构与算法】之旋转图像的求解算法
  9. CSLA.Net 3.0.5 项目管理示例 业务集合基类(ProjectResources.cs,ProjectResource.cs)
  10. 模拟退火算法解决np_P和NP问题与解决方案| 演算法
  11. 前端学习(3213):setstate的一个使用
  12. 蓝牙基础知识进阶——Physical channel
  13. 城市公交网建设问题(信息学奥赛一本通-T1348)
  14. 嵌入式C语言面试题剖析100,嵌入式c语言面试题汇总超.docx
  15. 使用 Charles 对 Android 设备进行 Https 抓包
  16. 计算机室管理员考核细则,宿舍管理员量化考核细则
  17. 揭秘“菲住布渴”中运用的黑科技:除了check in、坐电梯、开门...全部刷脸之外,还有什么?
  18. 用户计算机脱域了如何处理,AD域计算机经常脱域
  19. Epic League 推出支持 Free to Earn 的 RPG 游戏 Dark Throne
  20. python turtle画猫_Turtle库画小猫咪

热门文章

  1. 使用map方式获取iris请求中的json请求数据
  2. 国内外差价悬殊,催火“代购一族”
  3. Linux是什么?大牛十年Linux心得文档给你答案
  4. HTML列表的简单使用以及在我们网页编程中的单位你了解多少??CSS中的字体样式你又了解多少,进来康康!!HTML、CSS(三)
  5. 全球虚拟运营商发展现状与探索
  6. Category In Objective-C
  7. NLP实战一 利用OpenAI Codex实现中文转python代码
  8. android高仿京东快报(垂直循环滚动新闻栏)
  9. 日志系统新贵 Loki,确实比笨重的ELK轻
  10. oracle sql 历史 监控,ORACLE 管理,SQL 篇--监控