4 模型训练与验证

构造验证集
在机器学习模型(特别是深度学习模型)的训练过程中,模型是非常容易过拟合的。深度学习模型在不断的训练过程中训练误差会逐渐降低,但测试误差的走势则不一定。

在模型的训练过程中,模型只能利用训练数据来进行训练,模型并不能接触到测试集上的样本。因此模型如果将训练集学的过好,模型就会记住训练样本的细节,导致模型在测试集的泛化效果较差,这种现象称为过拟合(Overfitting)。与过拟合相对应的是欠拟合(Underfitting),即模型在训练集上的拟合效果较差。
如图所示:随着模型复杂度和模型训练轮数的增加,CNN模型在训练集上的误差会降低,但在测试集上的误差会逐渐降低,然后逐渐升高,而我们为了追求的是模型在测试集上的精度越高越好。

导致模型过拟合的情况有很多种原因,其中最为常见的情况是模型复杂度(Model Complexity )太高,导致模型学习到了训练数据的方方面面,学习到了一些细枝末节的规律。

解决上述问题最好的解决方法:构建一个与测试集尽可能分布一致的样本集(可称为验证集),在训练过程中不断验证模型在验证集上的精度,并以此控制模型的训练。

在给定赛题后,赛题方会给定训练集和测试集两部分数据。参赛者需要在训练集上面构建模型,并在测试集上面验证模型的泛化能力。因此参赛者可以通过提交模型对测试集的预测结果,来验证自己模型的泛化能力。同时参赛方也会限制一些提交的次数限制,以此避免参赛选手“刷分”。

在一般情况下,参赛选手也可以自己在本地划分出一个验证集出来,进行本地验证。训练集、验证集和测试集分别有不同的作用:

  • 训练集(Train Set):模型用于训练和调整模型参数;
  • 验证集(Validation Set):用来验证模型精度和调整模型超参数;
  • 测试集(Test Set):验证模型的泛化能力。

因为训练集和验证集是分开的,所以模型在验证集上面的精度在一定程度上可以反映模型的泛化能力。在划分验证集的时候,需要注意验证集的分布应该与测试集尽量保持一致,不然模型在验证集上的精度就失去了指导意义。

既然验证集这么重要,那么如何划分本地验证集呢。在一些比赛中,赛题方会给定验证集;如果赛题方没有给定验证集,那么参赛选手就需要从训练集中拆分一部分得到验证集。验证集的划分有如下几种方式:

  • 留出法(Hold-Out)

直接将训练集划分成两部分,新的训练集和验证集。这种划分方式的优点是最为直接简单;缺点是只得到了一份验证集,有可能导致模型在验证集上过拟合。留出法应用场景是数据量比较大的情况。

  • 交叉验证法(Cross Validation,CV)

将训练集划分成K份,将其中的K-1份作为训练集,剩余的1份作为验证集,循环K训练。这种划分方式是所有的训练集都是验证集,最终模型验证精度是K份平均得到。这种方式的优点是验证集精度比较可靠,训练K次可以得到K个有多样性差异的模型;CV验证的缺点是需要训练K次,不适合数据量很大的情况。

  • 自助采样法(BootStrap)

通过有放回的采样方式得到新的训练集和验证集,每次的训练集和验证集都是有区别的。这种划分方式一般适用于数据量较小的情况。

在本次赛题中已经划分为验证集,因此选手可以直接使用训练集进行训练,并使用验证集进行验证精度(当然你也可以合并训练集和验证集,自行划分验证集)。

当然这些划分方法是从数据划分方式的角度来讲的,在现有的数据比赛中一般采用的划分方法是留出法和交叉验证法。如果数据量比较大,留出法还是比较合适的。当然任何的验证集的划分得到的验证集都是要保证训练集-验证集-测试集的分布是一致的,所以如果不管划分何种的划分方式都是需要注意的。

这里的分布一般指的是与标签相关的统计分布,比如在分类任务中“分布”指的是标签的类别分布,训练集-验证集-测试集的类别分布情况应该大体一致;如果标签是带有时序信息,则验证集和测试集的时间间隔应该保持一致。

4.3 模型训练与验证
在本节我们目标使用Pytorch来完成CNN的训练和验证过程,CNN网络结构与之前的章节中保持一致。我们需要完成的逻辑结构如下:

  • 构造训练集和验证集;
  • 每轮进行训练和验证,并根据最优验证集精度保存模型。
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=10, shuffle=True, num_workers=10,
)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=10, shuffle=False, num_workers=10,
)model = SVHN_Model1()
criterion = nn.CrossEntropyLoss (size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
for epoch in range(20):print('Epoch: ', epoch)train(train_loader, model, criterion, optimizer, epoch)val_loss = validate(val_loader, model, criterion)# 记录下验证集精度if val_loss < best_loss:best_loss = val_losstorch.save(model.state_dict(), './model.pt')

其中每个Epoch的训练代码如下:

def train(train_loader, model, criterion, optimizer, epoch):# 切换模型为训练模式model.train()for i, (input, target) in enumerate(train_loader):c0, c1, c2, c3, c4, c5 = model(data[0])loss = criterion(c0, data[1][:, 0]) + \criterion(c1, data[1][:, 1]) + \criterion(c2, data[1][:, 2]) + \criterion(c3, data[1][:, 3]) + \criterion(c4, data[1][:, 4]) + \criterion(c5, data[1][:, 5])loss /= 6optimizer.zero_grad()loss.backward()optimizer.step()

其中每个Epoch的验证代码如下:

def validate(val_loader, model, criterion):# 切换模型为预测模型model.eval()val_loss = []# 不记录模型梯度信息with torch.no_grad():for i, (input, target) in enumerate(val_loader):c0, c1, c2, c3, c4, c5 = model(data[0])loss = criterion(c0, data[1][:, 0]) + \criterion(c1, data[1][:, 1]) + \criterion(c2, data[1][:, 2]) + \criterion(c3, data[1][:, 3]) + \criterion(c4, data[1][:, 4]) + \criterion(c5, data[1][:, 5])loss /= 6val_loss.append(loss.item())return np.mean(val_loss)

4.4 模型保存与加载
在Pytorch中模型的保存和加载非常简单,比较常见的做法是保存和加载模型参数:

torch.save(model_object.state_dict(), 'model.pt')
model.load_state_dict(torch.load(' model.pt'))
  • 4.5 模型调参流程

深度学习原理少但实践性非常强,基本上很多的模型的验证只能通过训练来完成。同时深度学习有众多的网络结构和超参数,因此需要反复尝试。训练深度学习模型需要GPU的硬件支持,也需要较多的训练时间,如何有效的训练深度学习模型逐渐成为了一门学问。

深度学习有众多的训练技巧,比较推荐的阅读链接有:

  • http://lamda.nju.edu.cn/weixs/project/CNNTricks/CNNTricks.html
  • http://karpathy.github.io/2019/04/25/recipe/

本节挑选了常见的一些技巧来讲解,并针对本次赛题进行具体分析。与传统的机器学习模型不同,深度学习模型的精度与模型的复杂度、数据量、正则化、数据扩增等因素直接相关。所以当深度学习模型处于不同的阶段(欠拟合、过拟合和完美拟合)的情况下,大家可以知道可以什么角度来继续优化模型。

在参加本次比赛的过程中,我建议大家以如下逻辑完成:

  • 1.初步构建简单的CNN模型,不用特别复杂,跑通训练、验证和预测的流程;
  • 2.简单CNN模型的损失会比较大,尝试增加模型复杂度,并观察验证集精度;
  • 3.在增加模型复杂度的同时增加数据扩增方法,直至验证集精度不变。

天池-街景字符编码识别4-模型训练与验证相关推荐

  1. 天池-街景字符编码识别5-模型训练与验证

    模型集成 包括:集成学习方法.深度学习中的集成学习和结果后处理思路. 集成学习方法 在机器学习中的集成学习可以在一定程度上提高预测精度,常见的集成学习方法有Stacking.Bagging和Boost ...

  2. 天池-街景字符编码识别2-数据读取与数据扩增

    本此使用[定长字符识别]思路来构建模型 赛题地址 零基础入门CV赛事- 街景字符编码识别 关于更详细的数据预处理可=可以参考我的另一篇博文: 卷积神经网络性能优化(提高准确率) 2 数据读取与数据扩增 ...

  3. 天池-街景字符编码识别1-赛题理解

    赛题地址 零基础入门CV赛事- 街景字符编码识别 前期环境 运行环境及安装 运行环境 python3.7 pytorch1.3.1 有GPU 首先在Anaconda中创建一个专门用于本次练习赛的虚拟环 ...

  4. 街景字符编码识别之模型集成

    点赞再看,养成习惯!觉得不过瘾的童鞋,欢迎关注公众号<机器学习算法工程师>,有非常多大神的干货文章可供学习噢- 目录 前言 正文 集成学习方法 深度学习中的集成学习 结果后处理 结语 参考 ...

  5. 天池学习赛——街景字符编码识别(得分上0.93)

    项目代码已上传至github需要的可以自行下载 目录 1 比赛介绍 2 解题思路 3 比赛数据集 4 模型训练 5 更改detect.py文件 6 上传文件 1 比赛介绍 项目链接:零基础入门CV - ...

  6. 阿里天池比赛——街景字符编码识别

    文章目录 前言 一.街景字符编码识别 1. 目标 2. 数据集 3. 指标 总结 前言 之前参加阿里天池比赛,好久了,一直没有时间整理,现在临近毕业,趁论文外审期间,赶紧把东西整理了,5月底学校就要让 ...

  7. 天池大赛:街景字符编码识别——Part2:数据读取与数据扩增

    街景字符编码识别 更新流程↓ Task01:赛题理解 Task02:数据读取与数据扩增 Task03:字符识别模型 Task04:模型训练与验证 Task05:模型集成 底到镜一 比赛链接 Part2 ...

  8. 零基础入门CV赛事- 街景字符编码识别

    零基础入门CV赛事- 街景字符编码识别 Task01 学习目标 数据介绍 Task01任务内容 数据读取 解题思路 学习目标 熟悉计算机视觉赛事 完成典型的字符识别问题 掌握CV领域赛事的编程和解题思 ...

  9. 零基础入门CV赛事—街景字符编码识别—task2数据读取与扩增

    数据读取与扩增 上节学习了街景字符编码识别的解题思路,让我们对本赛题有了基本的idea,这节在定长字符编码的思路基础上学习读取数据和数据扩增. 图像数据读取 由于赛题数据是图像数据,赛题的任务是识别图 ...

最新文章

  1. 2022-2028年中国超声波探伤仪行业市场现状调研及发展前景分析报告
  2. python编程 语言-Python——最美丽的编程语言
  3. Android——Handler总结
  4. mpls标签分配原理——Vecloud
  5. 草履虫纳米机器人_激光驱动的机器人大军!Nature:机器人尺寸小于 0.1 毫米,4 英寸晶圆可容纳 100 万个...
  6. ITK:从文件读取转换
  7. python3读写excel文件_Python读写/追加excel文件Demo
  8. linux内核博大精深,Linux Kernel里的cpu_to_le32是干啥的?
  9. 【MyBatis框架】mapper配置文件-关于动态sql
  10. django 1.8 官方文档翻译: 9-1-1 国际化和本地化
  11. 用VC打开位图程序[转]
  12. go - struct
  13. 多区域OSPF基本配置
  14. 信息学奥赛GoC编程测试题题库
  15. PRML 1.5 决策论
  16. 手机话费充值页面HTMLcss3+html5模板
  17. Centos 7 开机一直转圈 提示failed to load SELinux policy freezing的解决方法
  18. Jmockit使用指南
  19. 如何调试上位机软件与串口进行通信
  20. 数据可视化:FineReport公司财务报表制作

热门文章

  1. 【Java】判断学生成绩等级
  2. 【Liunx】Linux 系统目录结构
  3. 【安卓开发】Webview简单使用
  4. C#LeetCode刷题之#112-路径总和​​​​​​​(Path Sum)
  5. 创新品牌体验团队_如何推动软件团队创新
  6. 101_Power Pivot DAX 累计至今,历史累计至今
  7. java lambda表达式详解_Java8新特性:Lambda表达式详解
  8. 数据结构(Java)-哈希表
  9. 记一次面试过程中的Python编程题
  10. Python教程:网络爬虫快速入门实战解析