目录

  • 赛题链接
  • 赛题背景
  • 数据集探索
    • 合并多个类别CSV数据集
  • 数据建模 (pytorch)

赛题链接

https://www.kaggle.com/competitions/quickdraw-doodle-recognition/overview/evaluation
数据集从上述链接中找

赛题背景

'Quick,Draw!'作为实验性游戏发布,以有趣的方式向公众宣传 AI 的工作原理。游戏提示用户绘制描绘特定类别的图像,例如“香蕉”、“桌子”等。游戏生成了超过 1B 幅图画,其中的一个子集被公开发布,作为本次比赛训练集的基础。该子集包含 5000 万张图纸,涵盖 340 个标签类别。

听起来很有趣,对吧?挑战在于:由于训练数据来自游戏本身,绘图可能不完整或可能与标签不匹配。您需要构建一个识别器,它可以有效地从这些嘈杂的数据中学习,并在来自不同分布的手动标记的测试集上表现良好。

您的任务是为现有的 Quick, Draw! 构建一个更好的分类器。数据集。通过在此数据集上推进模型,Kagglers 可以更广泛地改进模式识别解决方案。这将对手写识别及其在 OCR(光学字符识别)、ASR(自动语音识别)和 NLP(自然语言处理)等领域的稳健应用产生直接影响。

属于多分类问题

数据集探索

字段解释

Key Type Description
key_id 64 位无符号整数 所有图纸的唯一标识符。
word string 玩家绘制的类别。
recognized boolean 该词是否被游戏识别。
timestamp datetime 创建绘图时间
countrycode string 玩家所在位置的两个字母国家代码
drawing string 表示矢量绘图的 JSON 数组

example:

根据矢量绘图的JSON数组画图

def show_imale(n,owls,drawing):fig,axs = plt.subplots(nrows=n,ncols=n,sharex=True,sharey=True,figsize = (16.10))for i , drawing in enumerate(owls,drawing):ax = axs[i//n,i%n]for x,y in drawing:ax.plot(x,-np.array(y),lw=3)fig.savefig('owls.png',dpi=200)plt.show();

赛题建模思路

  1. 读取数据并转化为图像
  2. 构建分类模型
  3. 确定训练细节和数据扩增方法;
  4. 对测试集完成预测并完成模型集成

数据集的文件结构:

每一种类型的数据图片,都放在一个单独的csv中,下面要对整个数据集进行处理。

合并多个类别CSV数据集

import os, sys, codecs, glob
import numpy as np
import pandas as pd
import cv2
import timmfrom sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split# 读取单个csv文件
def read_df(path, nrows):print('Reading...', path)if nrows.isdigit():return pd.read_csv(path, nrows=int(nrows), parse_dates=['timestamp'])else:return pd.read_csv(path, parse_dates=['timestamp'])# 读取多个csv文件
def contcat_df(paths, nrows):dfs = []for path in paths:dfs.append(read_df(path, nrows))return pd.concat(dfs, axis=0, ignore_index=True)def main():if not os.path.exists('./data'):os.mkdir('./data')CLASSES_CSV = glob.glob('../input/train_simplified/*.csv')CLASSES = [x.split('/')[-1][:-4] for x in CLASSES_CSV]print('Reading data...')# 读取指定行数的csv文本,并进行拼接df = contcat_df(CLASSES_CSV, number)# 数据打乱df = df.reindex(np.random.permutation(df.index))lbl = LabelEncoder().fit(df['word'])df['word'] = lbl.transform(df['word'])if df.shape[0] * 0.05 < 120000:df_train, df_val = train_test_split(df, test_size=0.05)else:df_train, df_val = df.iloc[:-500000], df.iloc[-500000:]print('Train:', df_train.shape[0], 'Val', df_val.shape[0])print('Save data...')df_train.to_pickle(os.path.join('./data', 'train_' + str(number) + '.pkl'))df_val.to_pickle(os.path.join('./data', 'val_' + str(number) + '.pkl'))# python 1_save2df.py 50000
# python 1_save2df.py all
if __name__ == "__main__":number = str(sys.argv[1])main()

其中glob的作用如下注释所示

import glob#获取指定目录下的所有图片
print (glob.glob(r"/home/qiaoyunhao/*/*.png"),"\n")#加上r让字符串不转义#获取上级目录的所有.py文件
print (glob.glob(r'../*.py')) #相对路径

得到的结果如下所示:
32300个训练集,1700个测试集
这里我们是先采用少量数据集训练,试一下数据是否拟合,若拟合

数据建模 (pytorch)

导入所需库

import os, sys, codecs, glob
from PIL import Image, ImageDrawimport numpy as np
import pandas as pd
import cv2import torch
torch.backends.cudnn.benchmark = False
import timmimport torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Datasetimport logging
logging.basicConfig(level=logging.DEBUG, filename='example.log',format='%(asctime)s - %(filename)s[line:%(lineno)d]: %(message)s')

将绘图的轨迹转变为图片
这里用的是opencv,cv的处理速度大于pillow

def draw_cv2(raw_strokes, size=256, lw=6, time_color=True):BASE_SIZE = 299img = np.zeros((BASE_SIZE, BASE_SIZE), np.uint8)for t, stroke in enumerate(eval(raw_strokes)):str_len = len(stroke[0])for i in range(len(stroke[0]) - 1):# 数据集随机丢弃一些像素,属于数据集的drop out,防止过拟合if np.random.uniform() > 0.95:continuecolor = 255 - min(t, 10) * 13 if time_color else 255_ = cv2.line(img, (stroke[0][i] + 22, stroke[1][i]  + 22),(stroke[0][i + 1] + 22, stroke[1][i + 1] + 22), color, lw)if size != BASE_SIZE:return cv2.resize(img, (size, size))else:return img

计算topk准确率

def accuracy(output, target, topk=(1,)):with torch.no_grad():maxk = max(topk)batch_size = target.size(0)_, pred = output.topk(maxk, 1, True, True)pred = pred.t()correct = pred.eq(target.view(1, -1).expand_as(pred))# print(correct.shape)res = []for k in topk:# print(correct[:k].shape)correct_k = correct[:k].float().sum()res.append(correct_k.mul_(100.0 / batch_size))# print(res)return res

数据扩展

class QRDataset(Dataset):def __init__(self, img_drawing, img_label, img_size, transform=None):self.img_drawing = img_drawingself.img_label = img_labelself.img_size = img_sizeself.transform = transformdef __getitem__(self, index):img = np.zeros((self.img_size, self.img_size, 3))img[:, :, 0] = draw_cv2(self.img_drawing[index], self.img_size)img[:, :, 1] = img[:, :, 0]img[:, :, 2] = img[:, :, 0]img = Image.fromarray(np.uint8(img))if self.transform is not None:img = self.transform(img)label = torch.from_numpy(np.array([self.img_label[index]]))return img, labeldef __len__(self):return len(self.img_drawing)

载入模型

def get_resnet18():model = models.resnet18(True)model.avgpool = nn.AdaptiveAvgPool2d(1) # 匹配不固定的输入尺寸model.fc = nn.Linear(512, 340)return modeldef get_resnet34():model = models.resnet34(True)model.avgpool = nn.AdaptiveAvgPool2d(1)model.fc = nn.Linear(512, 340)return modeldef get_resnet50():model = models.resnet50(True)model.avgpool = nn.AdaptiveAvgPool2d(1)model.fc = nn.Linear(2048, 340)return modeldef get_resnet101():model = models.resnet101(True)model.avgpool = nn.AdaptiveAvgPool2d(1)model.fc = nn.Linear(2048, 340)

图片mixup操作

def mixup_data(x, y, alpha=1.0, use_cuda=True):'''Returns mixed inputs, pairs of targets, and lambda'''if alpha > 0:lam = np.random.beta(alpha, alpha)else:lam = 1batch_size = x.size()[0]if use_cuda:index = torch.randperm(batch_size).cuda()else:index = torch.randperm(batch_size)# x 是一个batch 一批的输入mixed_x = lam * x + (1 - lam) * x[index, :]y_a, y_b = y, y[index]return mixed_x, y_a, y_b, lamdef mixup_criterion(criterion, pred, y_a, y_b, lam):return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

主函数

  1. 数据扩展
def main():df_train = pd.read_pickle(os.path.join('./data', 'train_' + dataset + '.pkl'))# df_train = df_train.reindex(np.random.permutation(df_train.index))df_val = pd.read_pickle(os.path.join('./data', 'val_' + dataset + '.pkl'))train_loader = torch.utils.data.DataLoader(QRDataset(df_train['drawing'].values, df_train['word'].values, imgsize,transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),# transforms.RandomAffine(5, scale=[0.95, 1.05]),transforms.ToTensor(),# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])),batch_size=200, shuffle=True, num_workers=5,)val_loader = torch.utils.data.DataLoader(QRDataset(df_val['drawing'].values, df_val['word'].values, imgsize,transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ToTensor(),# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])),batch_size=200, shuffle=False, num_workers=5,)

载入模型

if modelname == 'resnet18':model = get_resnet18()elif modelname == 'resnet34':model = get_resnet34()elif modelname == 'resnet50':model = get_resnet50()elif modelname == 'resnet101':model = get_resnet101()else:model = timm.create_model(modelname, num_classes=340, pretrained=True, in_chans=3)

设置优化器等损失函数

# model = nn.DataParallel(model).cuda()# nvismodel.load_state_dict(torch.load('./resnet50_64_7_0.pt'))# model.load_state_dict(torch.load('./data/resnet18_64_16_110.pt'))model = model.cuda()loss_fn = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.01)# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2, 3, 5, 7, 8], gamma=0.1)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=len(train_loader) / 10, gamma=0.95)print('Train:', df_train.shape[0], 'Val', df_val.shape[0])print('Epoch/Batch\t\tTrain: loss/Top1/Top3\t\tTest: loss/Top1/Top3')

训练50次

for epoch in range(50):train_losss, train_acc1s, train_acc5s = [], [], []for i, data in enumerate(train_loader):scheduler.step()model = model.train()train_img, train_label = dataoptimizer.zero_grad()# TODO: data paraell# train_img = Variable(train_img).cuda(async=True)# train_label = Variable(train_label.view(-1)).cuda()train_img = Variable(train_img).cuda()train_label = Variable(train_label.view(-1)).cuda()# 加入mixupif np.random.randint(1, 10) >= 5:mixed_x, y_a, y_b, lam = mixup_data(train_img, train_label)output = model(mixed_x)train_loss = mixup_criterion(loss_fn, output, y_a, y_b, lam)else:output = model(train_img)train_loss = loss_fn(output, train_label)# output = model(train_img)# train_loss = loss_fn(output, train_label)train_loss.backward()optimizer.step()train_losss.append(train_loss.item())if i % 5 == 0:logging.info('{0}/{1}:\t{2}\t{3}.'.format(epoch, i, optimizer.param_groups[0]['lr'], train_losss[-1]))if i % int(10) == 0:val_losss, val_acc1s, val_acc5s = [], [], []with torch.no_grad():train_acc1, train_acc3 = accuracy(output, train_label, topk=(1, 3))train_acc1s.append(train_acc1.data.item())train_acc5s.append(train_acc3.item())for data in val_loader:val_images, val_labels = data# val_images = Variable(val_images).cuda(async=True)# val_labels = Variable(val_labels.view(-1)).cuda()val_images = Variable(val_images).cuda()val_labels = Variable(val_labels.view(-1)).cuda() output = model(val_images)val_loss = loss_fn(output, val_labels)val_acc1, val_acc3 = accuracy(output, val_labels, topk=(1, 3))val_losss.append(val_loss.item())val_acc1s.append(val_acc1.item())val_acc5s.append(val_acc3.item())logstr = '{0:2s}/{1:6s}\t\t{2:.4f}/{3:.4f}/{4:.4f}\t\t{5:.4f}/{6:.4f}/{7:.4f}'.format(str(epoch), str(i),np.mean(train_losss, 0), np.mean(train_acc1s, 0), np.mean(train_acc5s, 0),np.mean(val_losss, 0), np.mean(val_acc1s, 0), np.mean(val_acc5s, 0),)torch.save(model.state_dict(), './data/{0}_{1}_{2}_{3}.pt'.format(modelname, imgsize, epoch, i))print(logstr)

运行

# python 2_train.py 模型 数量 图片尺寸
# python 2_train.py resnet18 5000 64
if __name__ == "__main__":modelname = str(sys.argv[1]) # 模型名字dataset = str(sys.argv[2]) # 数据集规模imgsize = int(sys.argv[3]) # 图片的尺寸main()

kaggle竞赛 | 计算机视觉 | Doodle Recognition Challenge相关推荐

  1. Python视觉深度学习系列教程 第三卷 第9章 Kaggle竞赛:情绪识别

            第三卷 第九章 Kaggle竞赛:情绪识别 在本章中,我们将解决Kaggle的面部表情识别挑战.为了完成这项任务,我们将在训练数据上从头开始训练一个类似VGG的网络,同时考虑到我们的网 ...

  2. 图像分类:从13个Kaggle竞赛中总结技巧

    原文:https://neptune.ai/blog/image-classification-tips-and-tricks-from-13-kaggle-competitions 任何领域的成功都 ...

  3. 梳理十年Kaggle竞赛,看自然语言处理的变迁史

    自2010年创办以来,Kaggle作为著名的数据科学竞赛平台,一直都是机器学习领域发展趋势的风向标,许多重大突破都在该平台发生,数以千计的从业人员参与其中,每天在Kaggle论坛上都有着无数的讨论. ...

  4. Dataset之HiggsBoson:Higgs Boson(Kaggle竞赛)数据集的简介、下载、案例应用之详细攻略

    Dataset之HiggsBoson:Higgs Boson(Kaggle竞赛)数据集的简介.下载.案例应用之详细攻略 目录 Higgs Boson比赛简介 Higgs Boson数据集的下载 Hig ...

  5. 【数据竞赛】kaggle竞赛宝典-多分类相关指标优化​

    ↑↑↑关注后"星标"kaggle竞赛宝典 kaggle竞赛宝典 作者: 尘沙杰少.谢嘉嘉.DOTA.有夕 赛题理解,分析,规划之多分类相关指标优化 这是一个系列篇,后续我们会按照我 ...

  6. 从零开始拿到了Kaggle竞赛冠军

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 因 ...

  7. 关于Kaggle竞赛

    这次酝酿了很久想给大家讲一些关于Kaggle那点儿事,帮助对数据科学(Data Science)有兴趣的同学们更好的了解这个项目,最好能亲身参与进来,体会一下学校所学的东西和想要解决一个实际的问题所需 ...

  8. 试试kaggle竞赛:辨别猫狗

    在上一篇文章<深度学习中超大规模数据集的处理>中讲到采用HDF5文件处理大规模数据集.有朋友问到:HDF5文件是一次性读入内存中,然后通过键进行访问吗?答案当然不是,在前面的文章中也提到过 ...

  9. Kaggle竞赛实战-手写数字识别器实战

    算法实战--Kaggle竞赛实战 文章目录-微信公众号:AI研习图书馆 Kaggle竞赛实战系列 一.介绍 二.数据准备 2.1.数据加载 2.2.数据可视化 2.3.数据清洗 2.4.归一化 2.5 ...

最新文章

  1. 【HNOI2007】紧急疏散
  2. 客户端命令(docker)
  3. codeforces1167 E. Range Deleting(双指针)
  4. Java System类loadLibrary()方法与示例
  5. http://ilinuxkernel.com/?p=1328
  6. 随机森林原始论文_推荐一个神器画出论文中酷炫的机器学习图
  7. php openssl 处理pkcs8,【转载】OpenSSL命令---pkcs8
  8. linux升级libpng,在Linux中安装libpng-dev以解决pngquant构建失败的问题
  9. zemax设置 像方远心_ZEMAX|如何翻转整个光学系统
  10. 视频教程-Cisco CCNP路由实验专题讲解视频课程--路由重分发篇-思科认证
  11. SAP 库存报表查询
  12. 计算机wifi共享怎么设置,笔者教你win7如何设置wifi共享
  13. 利用计算机网络实现OA的功能,中小企业oa办公系统解决方案怎么做?
  14. css渐变背景色与切角
  15. php获取指定日期的节假日信息
  16. 组战队,赢iPhone啦!
  17. 【个人喜好诗词之一】雨巷
  18. e339 java_java-在Spring Mongo中从文档数组中删除项目
  19. 刘振飞BugFree管理系统的功能与使用(一)
  20. 揭秘:宜信科技中心如何支持公司史上最大规模全员远程办公|上篇

热门文章

  1. 《塔木德智慧全书》摘要(之二)
  2. VirtualBox 安装 Ubuntu16.04服务器版系统
  3. 卧槽,这竟然不是阿汤哥?这个「真的吓人」视频火爆全网
  4. GlusterFS技术概要分析(转自oschina)
  5. 动漫头发基础画法,正面动漫头发画法
  6. scanf 之 %2s 与 %2d
  7. MPAndroidChart使用(个人笔记)
  8. RabbitMQ-简单模式/工作模式(分发、应答、持久化、不公平分发、发布确认)
  9. 变量的作用功能、作用域和作用形态
  10. 学点编码知识又不会死:Unicode的流言终结者和编码大揭秘