赛题地址

零基础入门CV赛事- 街景字符编码识别

前期环境

运行环境及安装
运行环境

  • python3.7

  • pytorch1.3.1

  • 有GPU

首先在Anaconda中创建一个专门用于本次练习赛的虚拟环境。

$conda create -n pytorch_gpu python=3.7

激活环境,并安装pytorch1.3.1

$source activate pytorch_gpu
$conda install pytorch=1.3.1 torchvision cudatoolkit=10.0

一键安装所需其它依赖库

$pip install jupyter tqdm opencv-python matplotlib pandas

预训练

首先导入常用的包

import os, sys, glob, shutil, json
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import cv2from PIL import Image
import numpy as npfrom tqdm import tqdm, tqdm_notebookimport torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = Trueimport torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset

步骤1:定义好读取图像的Dataset

class SVHNDataset(Dataset):def __init__(self, img_path, img_label, transform=None):self.img_path = img_pathself.img_label = img_label if transform is not None:self.transform = transformelse:self.transform = Nonedef __getitem__(self, index):img = Image.open(self.img_path[index]).convert('RGB')if self.transform is not None:img = self.transform(img)# 设置最长的字符长度为5个lbl = np.array(self.img_label[index], dtype=np.int)lbl = list(lbl)  + (5 - len(lbl)) * [10]return img, torch.from_numpy(np.array(lbl[:5]))def __len__(self):return len(self.img_path)

步骤2:定义好训练数据和验证数据的Dataset

train_path = glob.glob('E:\python-project\deep-learning\cv-stree\mchar_val/*.png')
train_path.sort()
train_json = json.load(open('E:\python-project\deep-learning\cv-stree\train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(train_label))train_loader = torch.utils.data.DataLoader(SVHNDataset(train_path, train_label,transforms.Compose([transforms.Resize((64, 128)),transforms.RandomCrop((60, 120)),transforms.ColorJitter(0.3, 0.3, 0.2),transforms.RandomRotation(5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=True, num_workers=10,
)val_path = glob.glob('E:\python-project\deep-learning\cv-stree\mchar_val/*.png')
val_path.sort()
val_json = json.load(open('E:\python-project\deep-learning\cv-stree\val.json'))
val_label = [val_json[x]['label'] for x in val_json]
print(len(val_path), len(val_label))val_loader = torch.utils.data.DataLoader(SVHNDataset(val_path, val_label,transforms.Compose([transforms.Resize((60, 120)),# transforms.ColorJitter(0.3, 0.3, 0.2),# transforms.RandomRotation(5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=False, num_workers=10,
)

步骤3:定义好字符分类模型,使用renset18的模型作为特征提取模块

class SVHN_Model1(nn.Module):def __init__(self):super(SVHN_Model1, self).__init__()model_conv = models.resnet18(pretrained=True)model_conv.avgpool = nn.AdaptiveAvgPool2d(1)model_conv = nn.Sequential(*list(model_conv.children())[:-1])self.cnn = model_convself.fc1 = nn.Linear(512, 11)self.fc2 = nn.Linear(512, 11)self.fc3 = nn.Linear(512, 11)self.fc4 = nn.Linear(512, 11)self.fc5 = nn.Linear(512, 11)def forward(self, img):        feat = self.cnn(img)# print(feat.shape)feat = feat.view(feat.shape[0], -1)c1 = self.fc1(feat)c2 = self.fc2(feat)c3 = self.fc3(feat)c4 = self.fc4(feat)c5 = self.fc5(feat)return c1, c2, c3, c4, c5

步骤4:定义好训练、验证和预测模块

def train(train_loader, model, criterion, optimizer):# 切换模型为训练模式model.train()train_loss = []for i, (input, target) in enumerate(train_loader):if use_cuda:input = input.cuda()target = target.cuda()c0, c1, c2, c3, c4 = model(input)loss = criterion(c0, target[:, 0]) + \criterion(c1, target[:, 1]) + \criterion(c2, target[:, 2]) + \criterion(c3, target[:, 3]) + \criterion(c4, target[:, 4])# loss /= 6optimizer.zero_grad()loss.backward()optimizer.step()if i % 100 == 0:print(loss.item())train_loss.append(loss.item())return np.mean(train_loss)def validate(val_loader, model, criterion):# 切换模型为预测模型model.eval()val_loss = []# 不记录模型梯度信息with torch.no_grad():for i, (input, target) in enumerate(val_loader):if use_cuda:input = input.cuda()target = target.cuda()c0, c1, c2, c3, c4 = model(input)loss = criterion(c0, target[:, 0]) + \criterion(c1, target[:, 1]) + \criterion(c2, target[:, 2]) + \criterion(c3, target[:, 3]) + \criterion(c4, target[:, 4])# loss /= 6val_loss.append(loss.item())return np.mean(val_loss)def predict(test_loader, model, tta=10):model.eval()test_pred_tta = None# TTA 次数for _ in range(tta):test_pred = []with torch.no_grad():for i, (input, target) in enumerate(test_loader):if use_cuda:input = input.cuda()c0, c1, c2, c3, c4 = model(input)output = np.concatenate([c0.data.numpy(), c1.data.numpy(),c2.data.numpy(), c3.data.numpy(),c4.data.numpy()], axis=1)test_pred.append(output)test_pred = np.vstack(test_pred)if test_pred_tta is None:test_pred_tta = test_predelse:test_pred_tta += test_predreturn test_pred_tta

步骤5:迭代训练和验证模型

model = SVHN_Model1()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0use_cuda = False
if use_cuda:model = model.cuda()for epoch in range(2):train_loss = train(train_loader, model, criterion, optimizer, epoch)val_loss = validate(val_loader, model, criterion)val_label = [''.join(map(str, x)) for x in val_loader.dataset.img_label]val_predict_label = predict(val_loader, model, 1)val_predict_label = np.vstack([val_predict_label[:, :11].argmax(1),val_predict_label[:, 11:22].argmax(1),val_predict_label[:, 22:33].argmax(1),val_predict_label[:, 33:44].argmax(1),val_predict_label[:, 44:55].argmax(1),]).Tval_label_pred = []for x in val_predict_label:val_label_pred.append(''.join(map(str, x[x!=10])))val_char_acc = np.mean(np.array(val_label_pred) == np.array(val_label))print('Epoch: {0}, Train loss: {1} \t Val loss: {2}'.format(epoch, train_loss, val_loss))print(val_char_acc)# 记录下验证集精度if val_loss < best_loss:best_loss = val_losstorch.save(model.state_dict(), './model.pt')

步骤6:对测试集样本进行预测,生成提交文件

test_path = glob.glob('../input/test_a/*.png')
test_path.sort()
test_label = [[1]] * len(test_path)
print(len(val_path), len(val_label))test_loader = torch.utils.data.DataLoader(SVHNDataset(test_path, test_label,transforms.Compose([transforms.Resize((64, 128)),transforms.RandomCrop((60, 120)),# transforms.ColorJitter(0.3, 0.3, 0.2),# transforms.RandomRotation(5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=False, num_workers=10,
)test_predict_label = predict(test_loader, model, 1)test_label = [''.join(map(str, x)) for x in test_loader.dataset.img_label]
test_predict_label = np.vstack([test_predict_label[:, :11].argmax(1),test_predict_label[:, 11:22].argmax(1),test_predict_label[:, 22:33].argmax(1),test_predict_label[:, 33:44].argmax(1),test_predict_label[:, 44:55].argmax(1),
]).Ttest_label_pred = []
for x in test_predict_label:test_label_pred.append(''.join(map(str, x[x!=10])))import pandas as pd
df_submit = pd.read_csv('../input/test_A_sample_submit.csv')
df_submit['file_code'] = test_label_pred
df_submit.to_csv('renset18.csv', index=None)

赛题理解

赛题数据
赛题以街道字符为为赛题数据,数据集报名后可见并可下载,该数据来自收集的SVHN街道字符,并进行了匿名采样处理。

训练集数据包括3W张照片,验证集数据包括1W张照片,每张照片包括颜色图像和对应的编码类别和具体位置;为了保证比赛的公平性,测试集A包括4W张照片,测试集B包括4W张照片。

数据标签
对于训练数据每张图片将给出对于的编码标签,和具体的字符框的位置(训练集、验证集都给出字符位置),可用于模型训练:

Field    Description
top     左上角坐标X
height  字符高度
left    左上角坐标Y
width   字符宽度
label   字符编码

字符的坐标具体如下所示:

在比赛数据(训练集和验证集)中,同一张图片中可能包括一个或者多个字符,因此在比赛数据的JSON标注中,会有两个字符的边框信息:

评测指标
选手提交结果与实际图片的编码进行对比,以编码整体识别准确率为评价指标。任何一个字符错误都为错误,最终评测指标结果越大越好,具体计算公式如下:
Score=编码识别正确的数量/测试集图片数量

读取数据
JSON中标签的读取方式:

import json
train_json = json.load(open('../input/train.json'))# 数据标注处理
def parse_json(d):arr = np.array([d['top'], d['height'], d['left'],  d['width'], d['label']])arr = arr.astype(int)return arrimg = cv2.imread('../input/train/000000.png')
arr = parse_json(train_json['000000.png'])plt.figure(figsize=(10, 10))
plt.subplot(1, arr.shape[1]+1, 1)
plt.imshow(img)
plt.xticks([]); plt.yticks([])for idx in range(arr.shape[1]):plt.subplot(1, arr.shape[1]+1, idx+2)plt.imshow(img[arr[0, idx]:arr[0, idx]+arr[1, idx],arr[2, idx]:arr[2, idx]+arr[3, idx]])plt.title(arr[4, idx])plt.xticks([]); plt.yticks([])


解题思路
赛题思路分析:赛题本质是分类问题,需要对图片的字符进行识别。但赛题给定的数据图片中不同图片中包含的字符数量不等,如下图所示。有的图片的字符个数为2,有的图片字符个数为3,有的图片字符个数为4。

字符属性 图片
字符:42 字符个数:2
字符:241 字符个数:3
字符:7358 字符个数:4

因此本次赛题的难点是需要对不定长的字符进行识别,与传统的图像分类任务有所不同。

  • 简单入门思路:定长字符识别

可以将赛题抽象为一个定长字符识别问题,在赛题数据集中大部分图像中字符个数为2-4个,最多的字符 个数为6个。
因此可以对于所有的图像都抽象为6个字符的识别问题,字符23填充为23XXXX,字符231填充为231XXX。

经过填充之后,原始的赛题可以简化了6个字符的分类问题。在每个字符的分类中会进行11个类别的分类,假如分类为填充字符,则表明该字符为空。

  • 专业字符识别思路:不定长字符识别

在字符识别研究中,有特定的方法来解决此种不定长的字符识别问题,比较典型的有CRNN字符识别模型。
在本次赛题中给定的图像数据都比较规整,可以视为一个单词或者一个句子。

  • 专业分类思路:检测再识别
    在赛题数据中已经给出了训练集、验证集中所有图片中字符的位置,因此可以首先将字符的位置进行识别,利用物体检测的思路完成。

此种思路需要参赛选手构建字符检测模型,对测试集中的字符进行识别。选手可以参考物体检测模型SSD或者YOLO来完成。

天池-街景字符编码识别1-赛题理解相关推荐

  1. 天池-街景字符编码识别2-数据读取与数据扩增

    本此使用[定长字符识别]思路来构建模型 赛题地址 零基础入门CV赛事- 街景字符编码识别 关于更详细的数据预处理可=可以参考我的另一篇博文: 卷积神经网络性能优化(提高准确率) 2 数据读取与数据扩增 ...

  2. 天池-街景字符编码识别4-模型训练与验证

    4 模型训练与验证 构造验证集 在机器学习模型(特别是深度学习模型)的训练过程中,模型是非常容易过拟合的.深度学习模型在不断的训练过程中训练误差会逐渐降低,但测试误差的走势则不一定. 在模型的训练过程 ...

  3. 天池-街景字符编码识别5-模型训练与验证

    模型集成 包括:集成学习方法.深度学习中的集成学习和结果后处理思路. 集成学习方法 在机器学习中的集成学习可以在一定程度上提高预测精度,常见的集成学习方法有Stacking.Bagging和Boost ...

  4. 天池学习赛——街景字符编码识别(得分上0.93)

    项目代码已上传至github需要的可以自行下载 目录 1 比赛介绍 2 解题思路 3 比赛数据集 4 模型训练 5 更改detect.py文件 6 上传文件 1 比赛介绍 项目链接:零基础入门CV - ...

  5. 天池大赛:街景字符编码识别——Part2:数据读取与数据扩增

    街景字符编码识别 更新流程↓ Task01:赛题理解 Task02:数据读取与数据扩增 Task03:字符识别模型 Task04:模型训练与验证 Task05:模型集成 底到镜一 比赛链接 Part2 ...

  6. 阿里天池比赛——街景字符编码识别

    文章目录 前言 一.街景字符编码识别 1. 目标 2. 数据集 3. 指标 总结 前言 之前参加阿里天池比赛,好久了,一直没有时间整理,现在临近毕业,趁论文外审期间,赶紧把东西整理了,5月底学校就要让 ...

  7. 零基础入门CV赛事—街景字符编码识别—task2数据读取与扩增

    数据读取与扩增 上节学习了街景字符编码识别的解题思路,让我们对本赛题有了基本的idea,这节在定长字符编码的思路基础上学习读取数据和数据扩增. 图像数据读取 由于赛题数据是图像数据,赛题的任务是识别图 ...

  8. 计算机视觉实践(街景字符编码识别)-Task2 数据读取与数据扩增

    计算机视觉实践(街景字符编码识别)-Task2 数据读取与数据扩增 2.1.数据读取与数据扩增 本章主要内容为数据读取.数据扩增方法和Pytorch读取赛题数据三个部分组成. 2.1 学习目标 学习P ...

  9. 零基础入门CV赛事- 街景字符编码识别

    零基础入门CV赛事- 街景字符编码识别 Task01 学习目标 数据介绍 Task01任务内容 数据读取 解题思路 学习目标 熟悉计算机视觉赛事 完成典型的字符识别问题 掌握CV领域赛事的编程和解题思 ...

最新文章

  1. UI_UITableView_搭建
  2. 如何解决头文件重复包含和宏的重复定义问题:用#ifndef 、#define、#endif
  3. 计算机病毒按破坏性分为哪两类,计算机导论复习要点.doc
  4. 【SSH】禁用root远程、修改ssh端口
  5. nicetool好工具_N个办公辅助好工具,无需下载,简单实用
  6. python基础-牛逼的三层循环,实现想在那里退出,就在那里退出。
  7. Perspective Mockups mac(PS透视模型动作插件)支持ps2021
  8. 携程高管解读Q3财报:海外市场Trip.com流量恢复到去年同期70%
  9. python PDF文件转JPG
  10. php包含大马执行,对于某个PHP大马的分析
  11. 3dsmax展uv_TexTools|3dmax展UV插件(TexTools for 3ds Max)下载v4.10免费版 - 欧普软件下载
  12. 面向AMD64的文件xxx与项目的目标平台x86不兼容
  13. python 字符串分割方法_Python字符串分割方法总结
  14. 汇报措辞:你懂得如何向领导汇报吗(审阅、审批、审阅、批示、查阅)?
  15. 如何在iOS手机上进行自动化测试
  16. python 爬取网页数据到csv
  17. 史上最全Unity3D游戏开发教程,从入门到精通(含学习路线图)
  18. java的像素与dpi_DPI与像素的关系
  19. hadoop2.6伪分布+pig0.15+zookeeper3.4.6安装
  20. 关于T—SQL与SQL企业管理器

热门文章

  1. ZooKeeper--Recipes和解决方案
  2. JDK动态代理底层剖析
  3. python 生成 和 加载 requirements.txt
  4. chrome91 后 SameSite by default cookies 不对外开放 解决方案
  5. 使用Java Swing实现简单计算器
  6. C#LeetCode刷题之#112-路径总和​​​​​​​(Path Sum)
  7. C#LeetCode刷题之#232-用栈实现队列​​​​​​​​​​​​​​(Implement Queue using Stacks)
  8. sql视图语句_SQL视图:Replace View语句的示例语法
  9. fitbit手表中文说明书_我如何分析FitBit中的数据以改善整体健康状况
  10. 不要解决:如何将JavaScript集合与目标相匹配