模型训练常用tricks
一、背景
- 背景:常见NLP模型训练tricks
- 目标群体:Trainer
- 技术应用场景:仅适用于深度学习(狭义)模型训练,未涉及机器学习模型
- 整体思路:按训练前、训练中、训练后三个阶段划分
二、模型训练常见tricks
假定数据、算力给定,如何提高模型的泛化性和鲁棒性?
2.1 模型训练前
1. 数据增强 (EDA,Easy Data Augmention)
- 定义:一个用于提高文本分类任务性能的简单数据增强技术
- 构成:同义词替换、随机插入、随机交换和随机删除
- 同义词替换(Synonyms Replace SR):随机从句子中抽取 n 个词 (抽取时不包括停用词),然后随机找出抽取这些词的同义词,用同义词将原词替换。例如将句子 “我比较喜欢猫” 替换成 “我有点喜好猫”。通过同义词替换后句子大概率还是会有相同的标签的;
- 随机插入(Randomly Insert RI):随机从句子中抽取 1 个词 (抽取时不包括停用词),然后随机选择一个该词的同义词,插入原来句子中的随机位置,重复这一过程 n 次。例如将句子 “我比较喜欢猫” 改为 “我比较喜欢猫有点”;
- 随机交换(Randomly Swap RS):在句子中,随机交换两个词的位置,重复这一过程 n 次。例如将句子 “我比较喜欢猫” 改为 “喜欢我猫比较”;
- 随机删除(Randomly Delete RD):对于句子的每一个单词,都有 p (=α) 的概率会被删除。例如将句子 “我比较喜欢猫” 改为 “我比较猫”。
- 参数:句子中单词修改比例 α,生成句子的个数 n_aug
- 收益:EDA 对于小数据集的结果提升明显;平均而言,仅使用≈50% 的训练集进行 EDA训练,便能达到使用全量数据进行正常训练相同的准度。
- 建议:控制样本数量,少量学习,不能扩充太多,因为EDA操作太过频繁可能会改变语义,从而降低模型性能。推荐的参数组合:
2. 数据增强 (mixup)
- 定义:mixup对两个样本-标签数据对按比例相加后生成新的样本-标签数据
- 构成:给定(x_i, y_i),(x_j, y_j),具体实现方法见附录一
x' = ∂x_i + (1 - ∂)x_j ,x为输入向量
y' = ∂y_i + (1 - ∂)y_j , y为标签的one-hot编码
∂ in [0, 1]是概率值,∂~Beta(a, a),即 ∂ 服从参数a的Beta分布
- 参数:a
- 收益:
- 泛化性:mixup在语音、文本、表格等数据上表现良好,平均而言,mixup能够取得1.2~1.5%的精确率收益;
- 鲁棒性:含噪声标签的数据和对抗样本攻击等场景皆适用;
- 建议:a = 0.2 ~ 2,大样本数据集下,a = 0.2 ~0.4;
- 进阶:使用mixup以后训练抖动会大一些,训练没有base稳定,改进:cutMix, manifold mixup,patchUp,puzzleMix, saliency Mix,fMix,co-Mix。
3. 长文本处理
- 常见方法:英文限定query max_sequence_length长度
- 截取,截取前510或后510或前128+后182(共记510)+ [cls]和[sep]补齐512;
- 滑动窗口(slide window),切分文本到若干重复段,分别作为输入,最后整合多个输出;
- 分段,段数=文本长度/510,对切分的各段做整合;常见整合方法:concat或mean_sqrt或max或加attention/lstm映射。
- 建议:
- 单句NLP任务,适当长度的截断文本的信息量足够涵盖文本语义;
2.2 模型训练中
1. focal loss
- 数据不平衡造成的模型性能问题,以二分类问题为例,损失函数可以写为
其中m为正样本个数,n为负样本个数,N为样本总数,m+n=N。
当样本分布失衡时,在损失函数L的分布也会发生倾斜,如m<<n时,负样本就会在损失函数占据主导地位。由于损失函数的倾斜,模型训练过程中会倾向于样本多的类别,造成模型对少样本类别的性能较差。
2. focal loss对交叉熵损失函数做优化,解决数据不平衡造成的模型性能问题 [3]:
可简写为:
- 原理:p_t反映与真实标签 y 的接近程度,越大说明越接近类别 y,即分类越准确(越易分);
- 对于分类准确(易分)的样本 p_t -> 1,(1 - p_t) ^ r 趋近于0;
- 对于分类不准确(难分)的样本 1-p_t -> 1,(1 - p_t) ^ r 趋近于1;
- 相比交叉熵损失,focal loss对于分类不准确(难分)的样本,损失没有改变;对于分类准确(易分)的样本,损失会变小。 整体而言,相当于增加分类不准确(难分)样本在损失函数中的权重,难分类样本占主导,因此学习过程更加聚焦难分类样本。
- 参数:r > 0
- 建议:r = 2 ~ 5
2. 对抗训练(Adversarial Training)
定义:原始输入样本 x 上加一个扰动 r_adv ,得到对抗样本后,用其进行训练。简单抽象为:
优化目标:
- 内部损失函数的最大化:找到worst-case的扰动(攻击),其中 L 为损失函数,r_adv 为扰动的范围空间;
- 外部经验风险的最小化:基于该攻击方式,找到最鲁棒的模型参数(防御),其中 D 是输入样本的分布。
扰动方法:
FGM(Fast Gradient Method)
PGD(Projected Gradient Descent)
参数:epsilon,alpha
收益:
- 平均而言,对抗训练能够取得2 ~ 3%的精确率收益,但通常提升不稳定且训练速度慢几倍;
推荐:epsilon=1 ~ 3(扰动因子),alpha=0.1 ~ 0.5(步长)
3. 对抗训练(Adversarial Training)
- 定义:类别/标签一致的前提下,使用训练好的模型权重来对模型进行初始化,然后只在新增数据集上训练,区别于加载开源预训练模型的checkpoint对全量数据做训练;
4. R-Drop
定义:同样的输入,同样的模型,分别走过两个 Dropout 得到的将是两个不同的分布,近似将这两个路径网络看作两个不同的模型网络
优化目标:
原理:假如目标类为第一个类别,预测结果是[0.5,0.2,0.3]或[0.5,0.3,0.2]
- 对交叉熵损失函数来说没区别;
- 对于KL散度项来说就不一样:每个类的得分都要参与计算,[0.5,0.2,0.3]或[0.5,0.3,0.2]有非零损失;
参数:a
收益:
- R-Drop帮助模型关注非目标类的稳定性,提高模型的鲁棒性,但会增加部分训练时间,不会增加推理时间;
推荐:a = 1 ~ 10
5. AMP混合精度
- 定义:混合精度通常指在训练期间在模型中同时使用16位和32位浮点类型,以使其运行更快并使用更少的内存;此外,也可在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的 。
- 支持:
- Tesla P100,Tesla V100、Tesla A100、GTX 20XX 和RTX 30XX等;
- TensorFlow ≥ 2.1,torch ≥ 1.7;
- 收益:损失精度可接受的范围内,AMP通常带来明显的内存节省和训练时长或推理时延收益
- 训练时显存占用减少≥20%,训练时长减少≥1%;
- 推理时显存占用减少≥25%,推理时延减少≥2%。
2.3 模型训练后
1. 模型平均(Stochastic Weight Averaging)
- 定义:SWA定义为对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能
- 原理:记训练过程第 i 个epoch的checkpoint为 w_i,
- 一般情况下我们会选择训练过程中最后的一个epoch的模型 w_n 或者在验证集上效果最好的一个模型 w_i_* 作为最终模型;
- SWA一般在最后采用较高的固定学习速率或者周期式学习速率额外训练一段时间,取多个checkpoints的平均值;
- 参数:SWA学习率lr_s
- 收益:SWA有助于模型收敛到loss平坦区域的中心,提升模型的泛化能力
- 推荐:lr_s = 0.05 ~ 0.1
2. 异构模型融合
- 原理:K折交叉,做模型预测的logits平均;
- 推荐:考虑集成模型的异构性;
3. 推理时内存按需增长
- 原理:加载tensorflow模型时,默认一次性加载打满显存,设置按需分配显存有效减少显存占用
physical_devices = tf.config.list_physical_devices('GPU')
try:tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:# Invalid device or cannot modify virtual devices once initialized.pass and go on inference
三、总结
1. 技术经验
- 视不同业务场景,不同tricks通常组合使用:如比较通用的组合有mixup+amp+swa;
- tricks选定时,往往基于经验和穷举排列组合,启发性占主导;
- 模型的效果实际上仍是数据占主导,tricks只是辅助,验证trick是否生效的过程需要反复做正反向验证;
2. 亟待提升
- 思考多个trick叠加是否会产生负作用;
- “拍脑门”式trick往往取得很好的效果,是否能在“拍脑门”决策时与实际业务场景契合;
-
模型训练常用tricks相关推荐
- K210模型训练(物体分类)
目录 一.打开Maix IDE 的官网找到需要训练模型的平台Maix Hub 模型训练的分类: 二.如何使用Maix Hub模型训练平台 数据集的采集 三.训练模型后的文件分类 一.打开Maix ID ...
- ML:模型训练/模型评估中常用的两种方法代码实现(留一法一次性切分训练和K折交叉验证训练)
ML:模型训练/模型评估中常用的两种方法代码实现(留一法一次性切分训练和K折交叉验证训练) 目录 模型训练评估中常用的两种方法代码实现 T1.留一法一次性切分训练 T2.K折交叉验证训 模型训练评估中 ...
- 用什么tricks能让模型训练得更快?先了解下这个问题的第一性原理
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨Horace He 来源丨机器之心 编辑丨极市平台 导读 深度 ...
- YOLOv5的Tricks | 【Trick11】在线模型训练可视化工具wandb(Weights Biases)
如有错误,恳请指出. 与其说是yolov5的训练技巧,这篇博客更多的记录如何使用wandb这个在线模型训练可视化工具,感受到了yolov5作者对其的充分喜爱. 所以下面内容更多的记录下如何最简单的使用 ...
- 深度学习炼丹-超参数设定和模型训练
前言 网络层内在参数 使用 3x3 卷积 使用 cbr 组合 尝试不同的权重初始化方法 图片尺寸与数据增强 batch size 设定 背景知识 batch size 定义 选择合适大小的 batch ...
- RNN模型训练经验总结
文章目录 RNN模型训练经验总结 数据准备 "look at your data"!! 小步试错. 搭建模型 设置端到端的训练评估框架 forward propagation设置 ...
- DeepSpeed超大规模模型训练工具
DeepSpeed超大规模模型训练工具 2021年 2 月份发布了 DeepSpeed.这是一个开源深度学习训练优化库,包含的一个新的显存优化技术-- ZeRO(零冗余优化器),通过扩大规模,提升速度 ...
- yolov5模型训练
本文将介绍yolov5从环境搭建到模型训练的整个过程.最后训练识别哆啦A梦的模型. 1.anconda环境搭建 2.yolov5下载 3.素材整理 4.模型训练 5.效果预测 - Anconda环境搭 ...
- 【深度学习】Keras加载权重更新模型训练的教程(MobileNet)
[深度学习]Keras加载权重更新模型训练的教程(MobileNet) 文章目录 1 重新训练 2 keras常用模块的简单介绍 3 使用预训练模型提取特征(口罩检测) 4 总结 1 重新训练 重新建 ...
最新文章
- 自定义状态栏notification布局
- 我,谷歌AI编舞师,能根据音乐来10种freestyle,想看霹雳还是爵士芭蕾?
- mysql模式匹配用什么关键字_MYSQL模式匹配:REGEXP和like用法
- 新手学C语言会踩到什么样的坑?
- [渝粤教育] 中国地质大学 微积分(一) 复习题 (2)
- java 垃圾回收 null_java方法中把对象置null,到底能不能加速垃圾回收
- 聊天室程序python_Python聊天室程序(基础版)
- ctc网络结构中接口服务器信号交换的方式有哪些,通信工程复习资料
- svn 命令的使用(在linux下)
- Caffe学习:Forward and Backward
- Eclipse使用:Eclipse安装中文语言包
- 树莓派做无线打印服务器,用树莓派和 CUPS 打印服务器将你的打印机变成网络打印机...
- 怎么把图片压缩到30K以下?如何用手机快速压缩图片?
- java后台模板_Java服务端后台常用模板
- 微信号码检测是什么意思
- 博客美化总结(持续更新)
- 杂七杂八(9): IDEA初始化配置 插件收集
- Haskell编程指南 | Lynda教程 中文字幕
- 突破性进展什么意思_宣布突破性发展2011
- 天津最新建筑施工八大员之(安全员)考试真题及答案解析
热门文章
- 【01】一起学ASP之《ASP.NET MVC企业级实战》
- leetcode1049. 最后一块石头的重量 II(java)
- 阿里巴巴首席客户服务官戴珊:客服不再只是接电话
- mysql unknow column_Python/MySQL查询错误:`Unknown column`
- 扫雷游戏【敢看完就敢教会你】------- C语言
- Chrome 浏览器性能对比测试报告
- 一个人的旅行(Dijkstra)
- 简易html视频播放器
- markdown编辑器使用方法(对数学公式的编写方法做了全面详细的说明)
- ​推荐一个免费AI算力平台:OpenI​
- K210模型训练(物体分类)