深度学习实战项目:速算题目批改

  • 前言
  • 一、摘要
  • 二、项目框架
  • 三、项目步骤
    • 1. 数据处理
      • 1.1 数据收集
      • 1.2 数据打标
      • 1.3 数据预处理
    • 2. 模型训练
      • 2.1 目标检测
        • 2.1.1 模型介绍
        • 2.1.2 模型训练
        • 2.1.3 训练结果
      • 2.2 文本识别
        • 2.2.1 模型介绍
        • 2.2.2 模型训练
        • 2.2.3 训练结果
    • 3. 模型推理
      • 3.1 YOLO模型接口
      • 3.2 CRNN模型接口
    • 4. 模型部署
      • 4.1 上传图像
        • 4.1.1 代码实现
        • 4.1.2 页面效果
      • 4.2 文本检测
        • 4.2.1 代码实现
        • 4.2.2 效果展示
      • 4.3 文本识别
        • 4.3.1 代码实现
        • 4.3.2 效果展示
      • 4.4 算式批改
        • 4.4.1 代码实现
        • 4.4.1 效果展示
      • 4.5 结果反馈
        • 4.5.1 推理时间
        • 4.5.2 用户答题情况
    • 5. 模型压缩
      • 5.1 核心代码
      • 5.2 压缩效果
    • 6. 项目优化
      • 6.1 书的部分拱起
      • 6.2 模型泛化能力
      • 6.3 模型压缩
      • 6.4 模型推广
  • 三、项目演示
  • 四、项目总结
  • 五、项目地址
  • 六、参考资料

前言

这个项目是笔者在《深度学习实践与应用》这门课的期末大作业,可以算得上是我的深度学习启蒙项目。当时这个项目花了我很多精力去完成的,最后也取得了不错的结果,收获满满。我第一次看到这个项目时,是有点恐惧的,担心自己无法完成这个项目,但当我真正上手去做这个项目的时候,才发现它没有想象中的那么难,我只需要把项目进行拆解,分解成若干个子任务,然后各个击破就可以了。所以有时候不要畏惧挑战,干就完事了哈哈哈。而做完这个项目之后,不瞒你说贼有成就感,也因为这个项目我对人工智能应用也更加感兴趣,未来我还会继续做更多有意思的项目,敬请期待吧!废话不多说,让我们直接开始吧!
友情提示:全文篇幅有点长,建议阅读时间30分钟,可以先收藏后慢慢食用哈

一、摘要

针对AI+教育行业的应用,以小学速算作业批改为原型,我们运用了OCR(光学字符识别)中经典的目标检测搭配文本识别来实现自动批改任务。首先,我们对数据进行人工标注,分为YOLO和CRNN的标记,分别为equation和题目字符内容;然后是目标检测和文本识别两个任务,分别使用的是训练好的YOLO和CRNN模型,得到了不错的识别效果,最终使用逆波兰式对用户的答案进行批改。其中,我们还对系统进行了优化,对于数据标记,我们利用了的图像处理的方法进行自动数据标注;对于用户上传的低质量图片(模糊/倾斜/含阴影),我们同样利用了一些图像处理的方法进行修正;对于YOLO模型参数过大的,我们使用了模型剪枝方法对其进行压缩。最后,我们的展望是能将模型适用于对存在拱起区域的图片,以及将模型剪枝搭配参数量化方法进一步压缩我们的模型,进一步提升模型的推理速度。

二、项目框架

三、项目步骤

本次项目可以分为以下几步:

  1. 数据处理
  2. 模型训练
  3. 模型推理
  4. 模型部署
  5. 模型压缩
  6. 项目优化

由于本次项目是应用为主,这里我就不会过多介绍具体的算法实现,而会偏向于介绍工程上的实现,包括数据集制作、模型训练心得、模型优化等等。

1. 数据处理

数据处理又可以分为数据收集、数据打标和数据预处理。下面我就来分别介绍下。

1.1 数据收集

有多少人工就有多少智能,深度学习项目成功与否很大程度是跟数据有关,所以一开始的重中之重就是收集数据。这里我们采取分工协作的方式来创建数据集,每个小组写两大本小学速算题目并将照片拍照收集起来,(这个过程蛮有意思,如果你发现某个大学生上课在做小学计算题,请不要嘲笑他哈哈)。最后我们通力合作收集了601张照片。

1.2 数据打标

数据收集好后,为了模型能训练准确识别算式的位置和内容,所以我们需要给图片都进行标注,这里我们用的是labelimg软件来进行数据打标,操作十分简单,只需要拖动鼠标划出矩形框,对于YOLO模型,标记统一设定为“equation”,如下图所示:


对于CRNN模型,标记设定为等式框中的字符串内容:

最后数据标注的结果会生成对应图片的XML文件,XML文件保存着对应图片中算式的位置坐标和标注结果信息。如下图所示:(关于XML文件的具体介绍,我后续会继续完善和补充,先占个位哈哈)

1.3 数据预处理

YOLO模型训练的数据就是整张照片,而CRNN模型训练的数据是一个个算式式子,因此我们利用脚本将YOLO模型训练数据进行处理。具体原理是根据YOLO数据打标获得的XML文件获取每张图片每个等式在图片的具体位置,然后利用CV2库将其裁剪为一个个小图片。核心代码如下:

# 裁剪,只适用标签文件为xml的情况,其他情况可相应地修改代码
for img_file in os.listdir(img_path):    # 遍历图片文件夹img_filename = os.path.join(img_path, img_file)  #将 图片路径与图片名进行拼接img_cv = cv2.imread(img_filename)  #读取图片img_name = (os.path.splitext(img_file)[0])  # 分割出图片名,如“000.png” 图片名为“000”xml_name = xml_path + '\\' + '%s.xml'%img_name  #利 用标签路径、图片名、xml后缀拼接出完整的标签路径名root = ET.parse(xml_name).getroot() # 利用ET读取xml文件for obj in root.iter('object'):  # 遍历所有目标框name = obj.find('name').text   # 获取目标框名称,即label名xmlbox = obj.find('bndbox')   # 找到框目标x0 = xmlbox.find('xmin').text  # 将框目标的四个顶点坐标取出y0 = xmlbox.find('ymin').textx1 = xmlbox.find('xmax').texty1 = xmlbox.find('ymax').textobj_img = img_cv[int(y0):int(y1), int(x0):int(x1)]  # cv2裁剪出目标框中的图片obj_img_name = obj_img_path + '\\' + '%s_%s'%(img_name, name) + '.jpg'  # 裁剪图片的名字cv2.imencode('.jpg', obj_img)[1].tofile(obj_img_name)  # 写入print("Finished.")

最后获得CRNN训练数据长这样:

同时也生成对应的txt文件,txt文件里包含图片的名字和图片的算式内容

2. 模型训练

2.1 目标检测

我参考的代码是这个:yolo3-pytorch
(占个坑哈,后面会补上YOLO模型训练的具体过程,敬请期待)

2.1.1 模型介绍

这里我采用的目标检测算法是YOLO算法 , 由于YOLO算法采用了残差网络这种跳层连接的方式,性能完全比ResNet-152和ResNet-101深层网络更好,无论是准确率还是计算效率都更佳。相比于RCNN系列的目标检测方法,YOLO的识别物体位置精准性较差,召回率低。

2.1.2 模型训练

这里我用的是
(1) 数据集
样本总量共601张,预处理前的图片平均尺寸为(1452.0, 1815.6),将样本划分为训练集:验证集:测试集=0.81:0.09:0.1
(2) 参数调整
输入图片放缩尺寸至416*416,通道数为3.
冻结阶段:epochs为20、batch_size为8、lr为1e-3。
解冻阶段:epochs为50,batch_size为4、lr为1e-4。
预测概率阈值为0.5(只有预测概率大于0.5的预测框才会保留)

2.1.3 训练结果

2.2 文本识别

我参考的代码是这个:使用pytorch训练自己的文字识别模型
关于文本识别的训练过程蛮有意思的可以分享下哈哈:当时我在使用大佬模型训练过程中一直出现问题,所以我通过B站蹲点联系上了大佬并加上了大佬的微信,大佬也很热情地帮我解决了问题。这告诉我们:办法总是比困难多的(B站真的是个学习的地方 )
(再占个坑哈,到时候会补上CRNN模型训练的具体过程,敬请期待)

2.2.1 模型介绍

文本识别我采用的是CRNN模型 ,文字识别可以认为是对序列的预测方法,所以采用了对序列预测的RNN网络。通过CNN将图片的特征提取出来后采用RNN对序列进行预测,最后通过一个CTC的翻译层得到最终结果。简单来说就是CNN+RNN+CTC的结构。CRNN可以直接从序列标签(例如单词)学习,不需要详细的标注,虽然其对较大形变的手写字体的的识别准确率欠佳,但在速算识别的应用场景下的识别率较为稳定。

2.2.2 模型训练

(1) 数据集
样本总量共3284张,预处理前的图片平均尺寸为(266,65),将样本划分为训练集:验证集:测试集=0.75:0.2:0.05
(2) 参数调整
输入图片放缩尺寸至262*32,通道数为3.
训练轮数epochs为30、batch_size为256、lr为1e-3。
Val_epoch为1,即每轮都验证一次。

2.2.3 训练结果

3. 模型推理

3.1 YOLO模型接口

我们根据以下两个YOLO接口,将YOLO模型嵌入我们的系统中。

  1. yolo.detect_image(image)
  2. GetBoxedPic(img, boxes)

接口1:yolo.detect_image(image)
目的:该接口用于调用YOLO3模型对图片进行预测并返回相应的预测结果信息。
输入:原图image(PIL的image对象)。
输出:等式检测图、所有等式框的坐标信息和对应的置信度。

其中,等式检测图是在输入的原图上进行了等式框的绘制;等式框的坐标信息将作为后面CRNN文本识别的输入;等式框的置信度将在模型性能分析报告中体现。
如下图所示,该Detecting函数用于实现将等式检测的部分展现至web上,首先,根据st.subheader函数显示副标题(Detected Image),根据st.write函数提示用户等待信息。然后创建一个yolo对象,并输入用户上传的图像,调用接1(yolo.detect_image)进行等式检测并返回结果信息,与此同时,st.progress函数将会显示识别的进度条,防止用户错认为网站卡顿。

# 进行yolo检测,呈现在web页面上
def Detecting(image):st.subheader("Detected Image")st.write("Just a second ...")yolo = YOLO()my_bar = st.progress(0)img = image.copy()start1 = time.time()r_image, boxes, top_conf = yolo.detect_image(image)end1 = time.time()# print(boxes)for percent_complete in range(100):my_bar.progress(percent_complete + 1)st.image(r_image, use_column_width=True)  # 展现检测结果# st.download_button(label="Download image", data=r_image, file_name='large_df.jpg', mime="image/jpg")st.subheader("Detection outcome Analysis")plt.scatter(np.arange(len(top_conf)), top_conf)plt.xlabel('detected rectangle')plt.ylabel('score')st.pyplot()# st.balloons()pics = GetBoxesPic(img, boxes)return boxes, start1, end1# st.image(pic, use_column_width=True)

接口2:GetBoxedPic(img, boxes)
目的:该接口用于存储检测出的所有等式图像,作为CRNN文本识别的输入图像。输入:原图img(PIL的image对象)、等式框坐标信息boxes。
输出:将所有截取的等式框图像保存。

其中,等式框的坐标信息需要转为整型,分别为top, left, bottom, right四个整型,代表等式框的上(y),左(x),下(y),右(x)坐标。在等式检测完毕后,返回(输出)了等式检测图像r_image、等式框坐标信息boxes、等式框检测置信度top_conf。我们使用st.image在web上显示出等式检测图像,并将等式框识别置信度以散点图的形式绘制在web上。除此之外,调用接口2用于存储boxes的信息,代码如下:

def get4pos(box, image):top, left, bottom, right = boxtop = max(0, np.floor(top).astype('int32'))left = max(0, np.floor(left).astype('int32'))bottom = min(image.size[1], np.floor(bottom).astype('int32'))right = min(image.size[0], np.floor(right).astype('int32'))return top, left, bottom, right# 返回yolo框出的区域,并将其等式图片存入对应文件夹中
def GetBoxesPic(image, boxes):pics = []shutil.rmtree('./yolo3/tmp_img')  # 清空操作os.mkdir('./yolo3/tmp_img')for i in range(len(boxes)):top, left, bottom, right = get4pos(boxes[i], image)pic = image.crop((left - 15, top, right + 40, bottom))pic.save('./yolo3/tmp_img/pic' + str(i).rjust(3, '0') + '.jpg')pics.append(pic)return pics

3.2 CRNN模型接口

  1. 接口1:parse_opt配置函数
  2. 接口2:main函数

接口1:parse_opt配置函数
目的:对CRNN模型提供参数以及路径,包括模型权重、预测图片路径、批量大小、结果存放路径等。
输出:返回模型的配置内容的对象。代码如下图所示:

def parse_opt():parser = argparse.ArgumentParser(description='detect')parser.add_argument('--weights', type=str, default='../crnn_master/weights/CPU.pt', help='权重的路径')parser.add_argument('--source', type=str, default='../YOLO/yolo3/tmp_img/', help='要用来推理图片的路径,可以是一张图片,也可以是一个目录')parser.add_argument('--batch_size', type=int, default=32, help='批次大小')parser.add_argument('--chinese', type=str, default='../crnn_master/data/formula.txt', help='字符集保存路径')parser.add_argument('--imgH', type=int, default=32)parser.add_argument('--nc', type=int, default=1)parser.add_argument('--nh', type=int, default=256)opt = parser.parse_args()return opt

接口2:main函数
目的:使用CRNN模型对输入的一系列等式图像进行识别。
输入:接口1返回的配置对象。
输出:CRNN的识别结果,即所有等式图像对应的等式字符串。

def main(opt):chinese = get_chinese(opt.chinese)converter = StrLabelConverter(chinese)nclass = len(chinese) + 1crnn = CRNN(opt.imgH, opt.nc, nclass, opt.nh)crnn.load_state_dict(torch.load(opt.weights))log_load_model(opt.weights)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#device = torch.device('cpu')log_device(device)crnn = crnn.to(device)equations = detect_(crnn, opt.source, device, converter)return equations

4. 模型部署

本次项目对训练好的YOLO3目标识别模型和CRNN文本检测模型,使用Streamlit轻量级机器学习部署工具呈现至web端,提供给用户良好的速算批改体验。这里简单介绍下Streamlit:它是第一个专门针对机器学习和数据科学团队的应用开发框架,它是开发自定义机器学习工具的最快的方法,它的目标是取代Flask在机器学习项目中的地位,可以帮助机器学习工程师快速开发用户交互工具。(

深度学习实战项目:速算题目批改相关推荐

  1. 最适合入门的100个深度学习实战项目

    前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...

  2. 【功能超全】基于OpenCV车牌识别停车场管理系统软件开发【含python源码+PyqtUI界面+功能详解】-车牌识别python 深度学习实战项目

    车牌识别基础功能演示 摘要:车牌识别系统(Vehicle License Plate Recognition,VLPR) 是指能够检测到受监控路面的车辆并自动提取车辆牌照信息(含汉字字符.英文字母.阿 ...

  3. Pytorch深度学习实战项目回顾

    1.前言 很久没有碰Pytorch了,准备以实战项目代码回顾的方式进行复习. 2.Pytorch安装 现在我又切回了ubuntu系统,里面没有Pytorch,所以顺便从Pytorch最新版安装开始讲起 ...

  4. [深度学习 - 实战项目] CRAFTCRNN_seq2seq图片文字提取

    图片文字提取项目 检测网络:CRAFT,基于字符区域感知的文本检测: CRAFT源码:https://github.com/clovaai/CRAFT-pytorch 识别网络:crnn+seq2se ...

  5. 深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了

    大家好,我是微学AI,今天给大家带来手写OCR识别的项目.手写的文稿在日常生活中较为常见,比如笔记.会议记录,合同签名.手写书信等,手写体的文字到处都有,所以针对手写体识别也是有较大的需求.目前手写体 ...

  6. 深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测

    文章目录 一.前期工作 导入库包 导入数据 主成分分析(PCA) 聚类分析(K-means) 二.神经网络模型建立 三.检验模型 大家好,我是微学AI,今天给大家带来一个利用卷积神经网络(pytorc ...

  7. 深度学习实战1-(keras框架)企业数据分析与预测

    大家好,我是微学AI,今天给大家带来深度学习框架keras的实战项目,用于基本的企业数据分析,预测企业净利润情况,并利用灰色预测函数GM11进行预测模型.我们拿到企业数据,这里参数抽象成x1-x9,y ...

  8. 深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类

    文章目录 一.前期工作 1. 设置GPU 2. 导入预处理词库类 二.导入预处理词库类 三.参数设定 四.创建模型 五.训练模型函数 六.测试模型函数 七.训练模型与预测 今天给大家带来一个简单的中文 ...

  9. 深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星

    大家好,我是微学AI,我们在日常生活中,经常会写一些文稿,比如:会议纪要,周报,日报,汇报材料,这些文稿里我们会发现有时候出现拼写.语法.标点等错误:其中拼写错误的错别字占大部分. 经过初步统计:在微 ...

最新文章

  1. Java项目:零食商城系统(java+SSM+jsp+MySQL+EasyUI)
  2. lucene-solr本地调试方法
  3. 【专访】会会创始人李翔昊:重新颠覆职业社交,盲目抄袭只有死路一条
  4. oracle 带有变量的语句_Oracle 动态SQL语句(2)之含变量的WHERE语句与日期变量
  5. 监听网页微信扫码支付成功_网付扫码点餐新福利,消费者点餐可获微信支付金币奖励...
  6. 腾讯天衍实验室招聘科研实习生
  7. 如何从900万张图片中对600类照片进行分类,附代码
  8. Android笔记 采用httpclient提交数据到服务器demo
  9. 区块链工作笔记0001---以太坊流程简介
  10. BTC 5分钟内涨幅1.08%,现价23903.52usdt
  11. python logging默认情况下打印_python logging日志打印过程解析
  12. javascript类功能代码集
  13. Visual Studio 2008破解激活升级方法【转】
  14. J2EE架构师路线脑图
  15. 模拟电子电路技术基础 | 常用半导体器件
  16. 京东商城逆势融资B2C成投资热土
  17. ASP.Net Core Web Api在Windows服务器上部署
  18. 使用Graphics将字符串居中绘制到图片上
  19. 计算机关闭账号用户控制好吗,Win10系统彻底关闭用户帐户控制的方法
  20. Vue知识点总结(一)

热门文章

  1. markdown空格缩进以及HTML空格实体
  2. 5.4.1 jmeter组件—逻辑控制器-简单控制器、IF控制器、事务控制器、循环控制器、交替控制器
  3. 同步调制国内论文学习
  4. RationalDMIS 2020 AeroTech转盘使用说明
  5. 前向传播、反向传播、更新梯度
  6. Unity_VRTK 3.2.1_UI手柄射线检测点击事件的问题
  7. 超200家上市企业布局!从千余条备案信息看区块链产业
  8. “互联网+制造业”:六大特点三大问题
  9. 垃圾回收器的作用?垃圾回收器可以马上回收内存吗?
  10. 微信公众平台开发消息回复总结