深度学习训练tricks整理1

环境:pytorch1.4.0 + Ubuntu16.04

参考:

数据增强策略(一)​mp.weixin.qq.com

https://zhuanlan.zhihu.com/p/104992391​zhuanlan.zhihu.com

深度神经网络模型训练中的 tricks(原理与代码汇总)​mp.weixin.qq.com

一、data_augmentation

基本的数据增强调用torchvision.transforms库中的就可以了,我整理一下其他的。

参考:

Pytorch 中的数据增强方式最全解释​cloud.tencent.com

1.1 单图操作(图像遮挡)

1.Cutout

对CNN 第一层的输入使用剪切方块Mask

论文参考:

Improved Regularization of Convolutional Neural Networks with Cutout​arxiv.org

代码链接:

https://github.com/uoguelph-mlrg/Cutout​github.com

Cutout示意图

2.Random Erasing

用随机值或训练集的平均像素值替换图像的区域

论文参考:

https://arxiv.org/abs/1708.04896​arxiv.org

代码参考:

https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py​github.com

Random Erasing示意图

3.Hide-and-Seek

图像分割成一个由 SxS 图像补丁组成的网格,根据概率设置随机隐藏一些补丁,从而让模型学习整个对象的样子,而不是单独一块,比如不单独依赖动物的脸做识别。

论文参考:

Hide-and-Seek: Forcing a Network to be Meticulous for Weakly-supervised Object and Action Localization​arxiv.org

代码参考:

https://github.com/kkanshul/Hide-and-Seek/blob/master/hide_patch.py​github.com

Hide-and-Seek示意图

4.GridMask

将图像的区域隐藏在网格中,作用也是为了让模型学习对象的整个组成部分

论文参考:

https://arxiv.org/pdf/2001.04086.pdf​arxiv.org

代码参考:

https://github.com/Jia-Research-Lab/GridMask/blob/master/imagenet_grid/utils/grid.py​github.com

GridMask示意图

1.2 多图组合

1.Mixup

通过线性叠加两张图片生成新的图片,对应label也进行线性叠加用以训练

论文参考:

https://arxiv.org/abs/1710.09412​arxiv.org

理解与代码参考:

目标检测中图像增强,mixup 如何操作?​www.zhihu.com

Mixup 示意图

2.Cutmix

将另一个图像中的剪切部分粘贴到当前图像来进行图像增强,图像的剪切迫使模型学会根据大量的特征进行预测。

论文参考:

https://arxiv.org/abs/1905.04899​arxiv.org

代码参考:

https://github.com/clovaai/CutMix-PyTorch/blob/master/train.py​github.com

代码理解:

模型训练技巧--CutMix_Guo_Python的博客-CSDN博客_cutmix loss​blog.csdn.net

Cutmix示意图

3.Mosaic data augmentation(用于检测)

Cutmix中组合了两张图像,而在 Mosaic中使用四张训练图像按一定比例组合成一张图像,使模型学会在更小的范围内识别对象。其次还有助于显著减少对batch-size的需求。

代码参考:

https://zhuanlan.zhihu.com/p/163356279​zhuanlan.zhihu.com

Mosaic data augmentation示意图

二、Label Smoothing

  1. label smoothing

参考论文:

https://arxiv.org/pdf/1812.01187.pdf​arxiv.org

参考理解:

SoftMax原理介绍 及其 LabelSmooth优化​blog.csdn.net

标签平滑Label Smoothing​blog.csdn.net

https://zhuanlan.zhihu.com/p/148487894​zhuanlan.zhihu.com

在多分类训练任务中,输入图片经过神经网络的计算,会得到当前输入图片对应于各个类别的置信度分数,这些分数会被softmax进行归一化处理,最终得到当前输入图片属于每个类别的概率,最终在训练网络时,最小化预测概率和标签真实概率的交叉熵,从而得到最优的预测概率分布.

网络会驱使自身往正确标签和错误标签差值大的方向学习,在训练数据不足以表征所以的样本特征的情况下,这就会导致网络过拟合。label smoothing的提出就是为了解决上述问题。最早是在Inception v2中被提出,是一种正则化的策略。其通过"软化"传统的one-hot类型标签,使得在计算损失值时能够有效抑制过拟合现象。

代码:

class LabelSmoothCEloss(nn.Module):def __init__(self):super().__init__()def forward(self,  pred,  label,  smoothing=0.1):pred = F.softmax(pred,  dim=1)one_hot_label = F.one_hot(label, pred.size(1)).float()smoothed_one_hot_label = (1.0 - smoothing)  *  one_hot_label + smoothing / pred.size(1)loss = (-torch.log(pred))  *  smoothed_one_hot_labelloss = loss.sum(axis=1,  keepdim=False)loss = loss.mean()return loss
----------------------------------------------------------------------------------------------
调用时criterion = nn.CrossEntropyLoss()
改为criterion = LabelSmoothCELoss()

三、学习率调整

warm up最早来自于这篇文章:https://arxiv.org/pdf/1706.02677.pdf 。根据这篇文章,我们一般只在前5个epoch使用warm up。consine learning rate来自于这篇文章:https://arxiv.org/pdf/1812.01187.pdf 。通常情况下,把warm up和consine learning rate一起使用会达到更好的效果。 代码实现:

class WarmUpLR(_LRScheduler):"""warmup_training learning rate schedulerArgs:optimizer: optimzier(e.g. SGD)total_iters: totoal_iters of warmup phase"""def __init__(self, optimizer, total_iters, last_epoch=-1):self.total_iters = total_iterssuper().__init__(optimizer, last_epoch)def get_lr(self):"""we will use the first m batches, and set the learningrate to base_lr * m / total_iters"""return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]# MultiStepLR without warm up
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1)# warm_up_with_multistep_lr
warm_up_with_multistep_lr = lambda epoch: epoch / args.warm_up_epochs if epoch <= args.warm_up_epochs else 0.1**len([m for m in args.milestones if m <= epoch])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_multistep_lr)# warm_up_with_cosine_lr
warm_up_with_cosine_lr = lambda epoch: epoch / args.warm_up_epochs if epoch <= args.warm_up_epochs else 0.5 * ( math.cos((epoch - args.warm_up_epochs) /(args.epochs - args.warm_up_epochs) * math.pi) + 1)
scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=warm_up_with_cosine_lr)

四、蒸馏(distillation)

4.1 传统蒸馏

论文参考:

https://arxiv.org/pdf/1503.02531.pdf​arxiv.org

理解参考:

深度学习方法(十五):知识蒸馏(Distilling the Knowledge in a Neural Network),在线蒸馏​blog.csdn.net

知识蒸馏(Distilling Knowledge )的核心思想​blog.csdn.net

传统蒸馏示意图
训练的过程采用以下的步骤:
先用硬标签训练大型复杂网络(Teacher Net);
采用值大的T,经训练好的 TN 进行前向传播获得软标签;
分别采用值大的 T 和 T=1 两种情况,让小型网络(Student Net)获得两种不同的输出,加权计算两种交叉熵损失,训练SN;
采用训练好的 SN 预测类别。

2. 新的蒸馏方式:通道蒸馏

论文参考:

Channel Distillation: Channel-Wise Attention for Knowledge Distillation​arxiv.org

代码参考:

https://github.com/zhouzaida/channel-distillation​github.com

通道蒸馏示意图

训练深度学习_深度学习训练tricks整理1相关推荐

  1. 深度学习 训练吃显卡_深度学习训练如何更快些?GPU性能的I/O优化你试过吗?...

    原本,有多少人已经准备好最新显卡,足够的硬盘空间,甚至请好年假,只为十天后去那个仰慕已久的赛博朋克世界里体验一番-- 结果他们又发了一张「黄色背景图」,告诉大家要跳票--再一次-- 好吧,你有了大量闲 ...

  2. 贝叶斯深度神经网络_深度学习为何胜过贝叶斯神经网络

    贝叶斯深度神经网络 Recently I came across an interesting Paper named, "Deep Ensembles: A Loss Landscape ...

  3. 深度强化学习_深度学习理论与应用第8课 | 深度强化学习

    本文是博雅大数据学院"深度学习理论与应用课程"第八章的内容整理.我们将部分课程视频.课件和讲授稿进行发布.在线学习完整内容请登录www.cookdata.cn 深度强化学习是一种将 ...

  4. 深度强化学习和强化学习_深度强化学习:从哪里开始

    深度强化学习和强化学习 by Jannes Klaas 简尼斯·克拉斯(Jannes Klaas) 深度强化学习:从哪里开始 (Deep reinforcement learning: where t ...

  5. 保证为正数 深度学习_深度学习:让数学课堂学习真正发生

    在21世纪核心素养中,深度学习能力是公民必须具备的生活和工作能力,发展深度学习是当代学习科学的重要举措,是深度加工知识信息.提高学习效率的有效途径.深度学习也称深层学习,是美国学者Ference Ma ...

  6. 判断过拟合 深度学习_深度学习—过拟合问题

    1.过拟合问题 欠拟合:根本原因是特征维度过少,模型过于简单,导致拟合的函数无法满足训练集,误差较大:  解决方法:增加特征维度,增加训练数据:  过拟合:根本原因是特征维度过多,模型假设过于复杂,参 ...

  7. 元学习 迁移学习_元学习就是您所需要的

    元学习 迁移学习 Update: This post is part of a blog series on Meta-Learning that I'm working on. Check out ...

  8. 分类 迁移学习_迁移学习时间序列分类

    迁移学习时间序列分类 题目: Transfer learning for time series classification 作者: Hassan Ismail Fawaz, Germain For ...

  9. 度量学习 流形学习_流形学习2

    度量学习 流形学习 潜图深度学习 (Deep learning with latent graphs) TL;DR: Graph neural networks exploit relational ...

最新文章

  1. 自定义控件:SlidingMenu,侧边栏,侧滑菜单
  2. 实战SSM_O2O商铺_29【商品】商品添加之Service层的实现及重构
  3. C++之临时对象的构造与析构
  4. 基本数据类型float和double的区别
  5. 匈牙利算法java实现_匈牙利算法(Hungarian Algorithm)
  6. 使用 Drone 构建 Coding 项目
  7. switch字符串jdk_从JDK 12删除原始字符串文字
  8. #把函数当作参数传给另一个函数
  9. Python将类对象转换为json
  10. GLKVector3参考
  11. 【BZOJ3144】[Hnoi2013]切糕 最小割
  12. 操作系统概念:系统引导过程、引导程序、固件
  13. “知识地图”助员工岗位成才
  14. java矩形面积_Java编程求矩形的面积
  15. 花与剑尚未获取服务器信息,花与剑澄心无忆攻略,触发条件及完成方式介绍
  16. PHP 版 微信小程序商城 源码和搭建
  17. 【Spring Boot】使用mockMvc模拟请求以及遇到的问题
  18. 云原生之下,百度智能云Palo如何驰骋大数据疆场?
  19. Nginx Error: socket() [::]:80 failed (97: Address family not supported by protocol)
  20. 中国无烟尼古丁袋市场深度研究分析报告(2021)

热门文章

  1. LeetCode 938. 二叉搜索树的范围和(二叉树遍历+搜索剪枝)
  2. android c 11 编译,Android NDK r9b和编译C 11
  3. 大型网站电商网站架构案例和技术架构的示例
  4. centos 7 ssh 安装mysql,Linux服务器远程ssh为centos7安装MySQL
  5. php mqtt qos,Mqtt Qos 深度解读
  6. 单反录像按钮在哪_单反与微单到底哪不同
  7. log4net异步写入日志_微信支付万亿日志在Hermes中的实践
  8. 漆桂林 | 人工智能的浪潮中,知识图谱何去何从?
  9. 十分钟搞定特征值和特征向量
  10. 【分布式训练】单机多卡—PyTorch