前言

最近在做OCR相关的任务,用到了阿里天池一个街景字符识别比赛的数据集,索性就分享一下相关方案,我采用YOLO5模型,最终在平台提交分数也做到了0.924,没有经过任何优化,可以看出YOLO5的效果还是非常不错的。

比赛地址链接:https://tianchi.aliyun.com/competition/entrance/531795/introduction?spm=5176.12281973.1005.7.3dd52448VtZc6t

下载YOLO5模型

YOLO5下载:https://github.com/ultralytics/yolov5

下载压缩包,然后放到自己文件夹进行解压。

在yolo5-master中打开命令行,键入以下命令安装相关包:

pip install -r requirements.txt

注意:安装包的时候可能会有各种各样的报错,特别是安装pycocotools的时候,不用慌,把报错复制粘贴到百度上面,都能解决!

获取数据集

YOLO已经准备好了,现在把比赛数据集拿出来,解析数据我就不自己写了,直接采用另外一位论坛上面老哥的代码,这里是训练集的处理,验证集也是一样的:

import os
import cv2
import json
train_json = json.load(open('mchar_train.json'))
for x in train_json:img=cv2.imread("images/train/"+x)width=img.shape[1]height=img.shape[0]train_label =list(map(int,train_json[x]['label']))train_height=list(map(int,train_json[x]['height']))train_left=list(map(int,train_json[x]['left']))train_width=list(map(int,train_json[x]['width']))train_top=list(map(int,train_json[x]['top']))loc_pic="labels/train/"+x.split('.')[0]+'.txt' pic=open(loc_pic,"w")for i in range(len(train_label)):pic_label=train_label[i]pic_x=(train_left[i]+train_width[i]/2)/widthpic_y=(train_top[i]+train_height[i]/2)/heightpic_width=train_width[i]/widthpic_height=train_height[i]/height            pic.write(str(pic_label)+" "+str(pic_x)+" "+str(pic_y)+" "+str(pic_width)+" "+str(pic_height))pic.write("\n")pic.close()

解析后的数据都是txt格式的,因为YOLO模型输入格式要求也是这样

数据处理好了以后,我们在yolo5-master中创建一个名为tianchi的文件夹,文件夹结构如下:

文件夹创建好后,把对应的数据拷贝进相应文件夹中就行了。

模型训练

我们把models文件夹中yolo5s.yaml文件复制一份到tianchi文件夹中,同时把data文件夹中coco128.yaml文件也复制一份到tianchi文件夹中,并且把yolo5s.yaml改名为street_yolo5s.yaml,把coco128.yaml改名为street_yolo.yaml,改好后如下图:

然后我们再将这两个文件中的内容进行修改,首先修改street_yolo5s.yaml:将nc的值改为10

然后修改street_yolo.yaml文件,只需要修改train和val的路径,还有nc和names就行了,然后把path那一行注释掉,修改后如下:

改完之后我们就可以进行模型训练了!!!
在yolo5-master中打开命令行,执行以下命令(这里我只设置了20个epoch作为示例,我自己是100个epoch训练后才是0.924):

python train.py --data tianchi/streat_yolo.yaml --cfg tianchi/street_yolo5s.yaml --epochs 100


训练会花费很多时间,我训练了21个小时,电脑太垃圾了(GTX1050)!

测试数据预测

将test图片数据放入images文件夹中:

然后执行如下命令即可:

python detect.py --weights runs/train/exp/weights/best.pt --source  tianchi/images/test/ --save-txt

预测完成后在runs/detect/exp中可以可以看到训练后的结果:

提交结果

预测出来的labels格式不是最终提交结果,我们要按照比赛要求的提交结果来,所以还要对结果进行一点处理:

import pandas as pd
import glob
import os
def get(elem):return elem[1]
label_path=glob.glob('labels/*.txt')
label_path.sort()
df_submit = pd.read_csv('mchar_sample_submit_A.csv')
df_submit.set_index('file_name')
for x in label_path:text=open(x,'r')result_list=[]for line in text.readlines():result_list.append((line.split(' ')[0],line.split(' ')[1]))result_list.sort(key=get)result=''for j in result_list:result+=j[0]label_path=x.split('\\')[-1].split('.')[0]+'.png'df_submit['file_code'][df_submit['file_name']==label_path]=resulttext.close()
df_submit.to_csv('content/submit.csv', index=None)

将submit.csv提交到天池平台上面:

写在最后

上面的成绩只是单纯的调用了模型,没有进行任何调优和融合,可以看出YOLO5的效果还是很好的,我们也可以采用不同的YOLO权重参数进行训练尝试,效果会更好。同时也感谢论坛各位大佬提供的各种想法和代码,我也是学到了很多,本人才疏学浅,如果有不对的地方希望指正!

阿里天池街景字符编码YOLO5方案相关推荐

  1. 天池-街景字符编码识别2-数据读取与数据扩增

    本此使用[定长字符识别]思路来构建模型 赛题地址 零基础入门CV赛事- 街景字符编码识别 关于更详细的数据预处理可=可以参考我的另一篇博文: 卷积神经网络性能优化(提高准确率) 2 数据读取与数据扩增 ...

  2. 天池-街景字符编码识别1-赛题理解

    赛题地址 零基础入门CV赛事- 街景字符编码识别 前期环境 运行环境及安装 运行环境 python3.7 pytorch1.3.1 有GPU 首先在Anaconda中创建一个专门用于本次练习赛的虚拟环 ...

  3. 天池-街景字符编码识别5-模型训练与验证

    模型集成 包括:集成学习方法.深度学习中的集成学习和结果后处理思路. 集成学习方法 在机器学习中的集成学习可以在一定程度上提高预测精度,常见的集成学习方法有Stacking.Bagging和Boost ...

  4. 天池-街景字符编码识别4-模型训练与验证

    4 模型训练与验证 构造验证集 在机器学习模型(特别是深度学习模型)的训练过程中,模型是非常容易过拟合的.深度学习模型在不断的训练过程中训练误差会逐渐降低,但测试误差的走势则不一定. 在模型的训练过程 ...

  5. 阿里天池比赛——街景字符编码识别

    文章目录 前言 一.街景字符编码识别 1. 目标 2. 数据集 3. 指标 总结 前言 之前参加阿里天池比赛,好久了,一直没有时间整理,现在临近毕业,趁论文外审期间,赶紧把东西整理了,5月底学校就要让 ...

  6. 天池学习赛——街景字符编码识别(得分上0.93)

    项目代码已上传至github需要的可以自行下载 目录 1 比赛介绍 2 解题思路 3 比赛数据集 4 模型训练 5 更改detect.py文件 6 上传文件 1 比赛介绍 项目链接:零基础入门CV - ...

  7. 天池大赛:街景字符编码识别——Part2:数据读取与数据扩增

    街景字符编码识别 更新流程↓ Task01:赛题理解 Task02:数据读取与数据扩增 Task03:字符识别模型 Task04:模型训练与验证 Task05:模型集成 底到镜一 比赛链接 Part2 ...

  8. 零基础入门CV赛事- 街景字符编码识别

    零基础入门CV赛事- 街景字符编码识别 Task01 学习目标 数据介绍 Task01任务内容 数据读取 解题思路 学习目标 熟悉计算机视觉赛事 完成典型的字符识别问题 掌握CV领域赛事的编程和解题思 ...

  9. 零基础入门CV赛事—街景字符编码识别—task2数据读取与扩增

    数据读取与扩增 上节学习了街景字符编码识别的解题思路,让我们对本赛题有了基本的idea,这节在定长字符编码的思路基础上学习读取数据和数据扩增. 图像数据读取 由于赛题数据是图像数据,赛题的任务是识别图 ...

最新文章

  1. springboot实现SSE服务端主动向客户端推送数据,java服务端向客户端推送数据,kotlin模拟客户端向服务端推送数据
  2. 6大最流行、最有用的自然语言处理库对比
  3. 操作系统:第二章 进程管理1 - 进程、线程
  4. 卸载 流程_一款适合于windows端的卸载神器 彻底清理残留软件
  5. 千万商家的智能决策引擎--AnalyticDB如何助力生意参谋双十一
  6. win10家庭版无法安装mysql_大师处置win10系统家庭版安装MySQL server 5.7.19失败的详细办法...
  7. vagrant 的安装与使用
  8. 让Cocos2dx中的TestCPP中的Box2dTest运行起来
  9. FA:萤火虫算法的改进及Python实现
  10. ping C语言实现
  11. C/C++编程:#pragma once用法总结
  12. isFinite() 如果参数是 NaN,正无穷大或者负无穷大,会返回 false,其他返回 true
  13. oracle 抽样_[转载]利用ORACLE实现数据抽样
  14. MacOS Monterey 12.4 (21F79) OC 0.8.0 / Cl 5146 / PE 三分区原版黑苹果镜像
  15. 七夕h5开发就找TOM小游戏
  16. 手机电路板文件_ORICO移动硬盘盒玩出新花样,变身手机备份宝
  17. 【数据库】什么是 PostgreSQL?开源数据库系统
  18. 2020互联网大厂中秋礼盒PK!阿里走情怀,腾讯更复古,最走心的是...(文末有福利)
  19. 海苔和紫菜有什么区别?
  20. 华南理工大学 数据库实验一 实验截图

热门文章

  1. JAVA 数组降序排列思路
  2. 极具发展潜力的20项油气勘探开发新技术
  3. Openvino 模型文件部署推理
  4. cmd批处理的/d/l/r/f
  5. EKS日志收集方案-PLG(Promtail+Loki+Grafana)
  6. python表示差值_Python算法之差值查找-Testfan打卡学测开0116
  7. poi设置单元格下拉下表
  8. IP协议详解【IP报文头部结构、IP分片、IP路由、IP转发】
  9. SpringBoot项目的Liunx服务器部署(一)
  10. 纽约大学理工学院:MULTIMEDIA SIGNAL COMPRESSION: SPEECH AND