Datawhale 零基础入门CV赛事-Task4 模型训练与验证
文章目录
- 1.构造验证集
- 2.模型训练与验证
1.构造验证集
在机器学习模型(特别是深度学习模型)的训练过程中,模型是非常容易过拟合的。深度学习模型在不断的训练过程中训练误差会逐渐降低,但测试误差的走势则不一定。
在模型的训练过程中,模型只能利用训练数据来进行训练,模型并不能接触到测试集上的样本。因此模型如果将训练集学的过好,模型就会记住训练样本的细节,导致模型在测试集的泛化效果较差,这种现象称为过拟合(Overfitting)。与过拟合相对应的是欠拟合(Underfitting),即模型在训练集上的拟合效果较差。
导致模型过拟合的情况有很多种原因,其中最为常见的情况是模型复杂度(Model Complexity )太高,导致模型学习到了训练数据的方方面面,学习到了一些细枝末节的规律。
解决上述问题最好的解决方法:构建一个与测试集尽可能分布一致的样本集(可称为验证集),在训练过程中不断验证模型在验证集上的精度,并以此控制模型的训练。
在给定赛题后,赛题方会给定训练集和测试集两部分数据。参赛者需要在训练集上面构建模型,并在测试集上面验证模型的泛化能力。因此参赛者可以通过提交模型对测试集的预测结果,来验证自己模型的泛化能力。同时参赛方也会限制一些提交的次数限制,以此避免参赛选手“刷分”。
在一般情况下,参赛选手也可以自己在本地划分出一个验证集出来,进行本地验证。训练集、验证集和测试集分别有不同的作用:
- 训练集(Train Set):模型用于训练和调整模型参数;
- 验证集(Validation Set):用来验证模型精度和调整模型超参数;
- 测试集(Test Set):验证模型的泛化能力。
因为训练集和验证集是分开的,所以模型在验证集上面的精度在一定程度上可以反映模型的泛化能力。在划分验证集的时候,需要注意验证集的分布应该与测试集尽量保持一致,不然模型在验证集上的精度就失去了指导意义。
既然验证集这么重要,那么如何划分本地验证集呢。在一些比赛中,赛题方会给定验证集;如果赛题方没有给定验证集,那么参赛选手就需要从训练集中拆分一部分得到验证集。验证集的划分有如下几种方式:
留出法(Hold-Out)
直接将训练集划分成两部分,新的训练集和验证集。这种划分方式的优点是最为直接简单;缺点是只得到了一份验证集,有可能导致模型在验证集上过拟合。留出法应用场景是数据量比较大的情况。交叉验证法(Cross Validation,CV)
将训练集划分成K份,将其中的K-1份作为训练集,剩余的1份作为验证集,循环K训练。这种划分方式是所有的训练集都是验证集,最终模型验证精度是K份平均得到。这种方式的优点是验证集精度比较可靠,训练K次可以得到K个有多样性差异的模型;CV验证的缺点是需要训练K次,不适合数据量很大的情况。自助采样法(BootStrap)
通过有放回的采样方式得到新的训练集和验证集,每次的训练集和验证集都是有区别的。这种划分方式一般适用于数据量较小的情况。
在本次赛题中已经划分为验证集,因此选手可以直接使用训练集进行训练,并使用验证集进行验证精度(当然你也可以合并训练集和验证集,自行划分验证集)。
当然这些划分方法是从数据划分方式的角度来讲的,在现有的数据比赛中一般采用的划分方法是留出法和交叉验证法。如果数据量比较大,留出法还是比较合适的。当然任何的验证集的划分得到的验证集都是要保证训练集-验证集-测试集的分布是一致的,所以如果不管划分何种的划分方式都是需要注意的。
这里的分布一般指的是与标签相关的统计分布,比如在分类任务中“分布”指的是标签的类别分布,训练集-验证集-测试集的类别分布情况应该大体一致;如果标签是带有时序信息,则验证集和测试集的时间间隔应该保持一致。
2.模型训练与验证
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=10, shuffle=True, num_workers=10,
)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=10, shuffle=False, num_workers=10,
)model = SVHN_Model1()
criterion = nn.CrossEntropyLoss (size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
for epoch in range(20):print('Epoch: ', epoch)train(train_loader, model, criterion, optimizer, epoch)val_loss = validate(val_loader, model, criterion)# 记录下验证集精度if val_loss < best_loss:best_loss = val_losstorch.save(model.state_dict(), './model.pt')
def train(train_loader, model, criterion, optimizer, epoch):# 切换模型为训练模式model.train()for i, (input, target) in enumerate(train_loader):c0, c1, c2, c3, c4, c5 = model(data[0])loss = criterion(c0, data[1][:, 0]) + \criterion(c1, data[1][:, 1]) + \criterion(c2, data[1][:, 2]) + \criterion(c3, data[1][:, 3]) + \criterion(c4, data[1][:, 4]) + \criterion(c5, data[1][:, 5])loss /= 6optimizer.zero_grad()loss.backward()optimizer.step()
def validate(val_loader, model, criterion):# 切换模型为预测模型model.eval()val_loss = []# 不记录模型梯度信息with torch.no_grad():for i, (input, target) in enumerate(val_loader):c0, c1, c2, c3, c4, c5 = model(data[0])loss = criterion(c0, data[1][:, 0]) + \criterion(c1, data[1][:, 1]) + \criterion(c2, data[1][:, 2]) + \criterion(c3, data[1][:, 3]) + \criterion(c4, data[1][:, 4]) + \criterion(c5, data[1][:, 5])loss /= 6val_loss.append(loss.item())return np.mean(val_loss)
Datawhale 零基础入门CV赛事-Task4 模型训练与验证相关推荐
- 阿里云天池竞赛-零基础入门CV赛事-Task4 模型训练与验证
在上一章节我们构建了一个简单的CNN进行训练,并可视化了训练过程中的误差损失和第一个字符预测准确率,但这些还远远不够.一个成熟合格的深度学习训练流程至少具备以下功能: 在训练集上进行训练,并在验证集上 ...
- Datawhale 零基础入门CV赛事-Task5 模型集成
这里写目录标题 1.集成学习方法 2.深度学习中的集成学习 Dropout TTA Snapshot 1.集成学习方法 在机器学习中的集成学习可以在一定程度上提高预测精度,常见的集成学习方法有Stac ...
- 零基础入门语义分割-Task5 模型训练与验证
使用Pytorch来完成CNN的训练和验证过程,逻辑结构如下: 构造训练集和验证集: 每轮进行训练和验证,并根据最优验证集精度保存模型. train_loader = torch.utils.data ...
- Datawhale 零基础入门CV赛事-Task3 字符识别模型
文章目录 1.CNN实现 2.Pytorch实现CNN 3.使用ImangeNet预训练模型 1.CNN实现 CNN基础 2.Pytorch实现CNN 构建一个简单的CNN模型和训练过程 import ...
- Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增
文章目录 数据读取 图像读取 1.pillow 2.opencv 数据读取 数据扩增 数据读取 导入需要的包以及文件路径 import json, glob import numpy as np fr ...
- 零基础入门CV赛事,理论结合实践
Datawhale干货 作者:阿水,Datawhale成员 本次分享的背景是,Datawhle联合天池发布的学习赛:零基础入门CV赛事之街景字符识别.本文以该比赛为例,对计算机视觉赛事中,赛事理解和B ...
- 零基础入门CV赛事- 街景字符编码识别
零基础入门CV赛事- 街景字符编码识别 Task01 学习目标 数据介绍 Task01任务内容 数据读取 解题思路 学习目标 熟悉计算机视觉赛事 完成典型的字符识别问题 掌握CV领域赛事的编程和解题思 ...
- Datawhale零基础入门NLP赛事 - Task5 基于深度学习的文本分类2
在上一章节,我们通过FastText快速实现了基于深度学习的文本分类模型,但是这个模型并不是最优的.在本章我们将继续深入. 基于深度学习的文本分类 本章将继续学习基于深度学习的文本分类. 学习目标 学 ...
- 零基础入门CV赛事-Task1 赛题理解
文章目录 赛题介绍 解题思路 1. 定长字符识别 2.不定长字符识别 3. 检测再识别 赛题介绍 赛题以街道字符为为赛题数据(比赛地址),该数据来自收集的SVHN街道字符,训练集数据包括3W张照片,验 ...
最新文章
- PTA基础编程题目集-6-12 判断奇偶性
- I/O多路复用是什么?(I/O multiplexing)
- java编译不了testpad,java – Gradle编译但不运行TestNG测试
- Apache Cassandra和低延迟应用程序
- linux下,.ko,.o , .so , .a ,.la
- 【javascript】ajax 基础(转)
- 各大杀软免费救急光盘合集——这个可以收藏备用了
- 基于matlab的小波去噪方法研究,基于matlab的小波去噪分析毕业论文.doc
- 观周教授新冠报告而作
- Java最新面试题大全
- 2018年中国互联网企业百强榜单揭晓
- opencv cvtcolor函数中断异常
- android lame,Android 采用Lame编码器编码mp3文件
- cad直线和圆弧倒角不相切_数控加工中心如何使用任意角度倒角C和倒圆角R功能的编程...
- hive学习第五章:查询
- SwiftUI学习笔记之@State, @Binding
- php文件上传漏洞攻击与防御
- 动手深度学习PyTorch(十二)word2vec
- 构建 Darknet 分类器 (Tiny Darknet) 训练数据集 (color recognition 颜色识别/color classification 颜色分类)
- 【​观察】“数字广东”背后的力量 腾讯云创新政务服务新模式
热门文章
- 《TensorFlow技术解析与实战》——1.2 什么是深度学习
- mysql主从数据库不同步的2种解决方法(转)
- Matlab 数字滤波器设计大报告(数字信号处理课程设计)附代码
- 华为机试HJ23:删除字符串中出现次数最少的字符
- matlab移相变压器,18脉移相变压器+三相不可控桥式整流的MATLAB仿真
- java九九成表发_用EXCEL可多种办法生成99乘法表
- 视频的播放的用例设计点
- java 月度相减_java根据日期获取月龄,按照减法原理,先day相减,不够向month借;然后month相减,不够向year借;最后year相减。...
- 单链表的创建、测长、打印、插入、删除、排序及逆置
- 亲密关系沟通-【表达情绪】如何说出感受却不伤人