每天给你送来NLP技术干货!


作者 | BetterBench

出品 | 对白的算法屋

编者寄语:

搞懂Dropout和R-Drop看这篇就够了。

大家好,我是对白。

既上一篇R-Drop:提升有监督任务性能最简单的方法,很多小伙伴们都私信我说,让我介绍一下Dropout和R-Drop之间的区别。相信大家看完这篇后,当面试官再问时,就可以轻松应对啦!

1、引言

在ML中存在两类严重的问题:过拟合和学习时间开销大

当过拟合时,得到的模型会在训练集上有非常好的表现,但是对新数据的预测结果会非常的不理想。为了解决过拟合问题,通常会采用训练多个模型来解决单模过拟合的问题。但又会带来时间开销大的问题。Dropout就很好的解决了这个问题,在单模内防止过拟合。对于时间开销大的地方是梯度下降,学习率衰减可以解决梯度下降中时间开销的问题。

Dropout是在训练过程中,随机地忽略部分神经元,即是在正向传播的过程中,这些被忽略的神经元对下游神经元的贡献暂时消失,在反向传播时,这些神经元也不会有任何权重的更新。

2、Dropout使用技巧

(1)经过验证,隐含节点Dropout率等于0.5的时候最佳,此时Dropout随机生成的网络结构最多。Dropout也可以用在输入层,作为一种添加噪声的方法。输入层设为更接近1时,使得输入变化不会太大,比如0.8。

(2)通常在网络中Dropout率为0.2~0.5。0.2是一个很好的起点,太低的概率产生的作用有限,太高的概率可能导致网络的训练不充分。

(3)当在较大的网络上使用Dropout时,可能会获得更好的表现,因为Dropout降低了模型训练过程中的干扰

(4)在输入层和隐藏层上使用Dropout。或者在网络的每一层都使用Dropout能有更佳的效果。

(5)使用较高的学习率,使用学习率衰减和设置较大的动量值,将学习率提高10~100倍,且使用0.9或0.99的动量值。

Keras中,momentum就是动量值
sgd = SGD(lr=0.1,momentum=0.8,decay=0.0,nesterov=False)

(6)限制网络权重的大小,打的学习率可能导致非常大的网络权重,对网络权重大小进行约束,例如大小为4或5的最大范数正则化(Max-norm Regularizationi)。

Keras中,通过指定Dense中的kernel_constrain=maxnorm(x)来限制网络权重

参考资料:Dropout: A Simple Way to Prevent Neural Networks from Overfitting

3、Dropout的拓展R-Dropout

3.1 简介

简单来说就是模型中加入dropout,在训练阶段的预测阶段,用同样的数据预测两次,去追求两次的结果尽可能接近,这种接近体现在损失函数上。虽然是同样的数据,但是因为模型中Dropout是随机丢弃神经元,会导致两次丢弃的神经元不一样,从而预测的结果也会不一样。R-Dropout思想就是去实现控制两次预测尽量保持一致,从而去优化模型。除了在NLP领域,其他的NLU、NLG、CV的分类等多种任务上都对R-Drop做了对比实验,大部分实验效果都称得上“明显提升”。

3.2 使用方法

和普通的Dropout方法不同,有封装的API可以一行代码使用。R-Dropout的使用需要自定义模型的输入和损失函数。举例如下,参考NLP 中Pytorch 实现R-Dropout

# define your task model, which outputs the classifier logits
model = TaskModel()def compute_kl_loss(self, p, q pad_mask=None):p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')# pad_mask is for seq-level tasksif pad_mask is not None:p_loss.masked_fill_(pad_mask, 0.)q_loss.masked_fill_(pad_mask, 0.)# You can choose whether to use function "sum" and "mean" depending on your taskp_loss = p_loss.sum()q_loss = q_loss.sum()loss = (p_loss + q_loss) / 2return loss# keep dropout and forward twice
logits = model(x)logits2 = model(x)# cross entropy loss for classifier
ce_loss = 0.5 * (cross_entropy_loss(logits, label) + cross_entropy_loss(logits2, label))kl_loss = compute_kl_loss(logits, logits2)# carefully choose hyper-parameters
loss = ce_loss + α * kl_loss

投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

整理不易,还望给个在看!

算法工程师面试必考点:Dropout和R-Dropout的使用技巧相关推荐

  1. 【深度学习】算法工程师面试必考点:Dropout和R-Dropout的使用技巧

    作者 | BetterBench 出品 | 对白的算法屋 编者寄语: 搞懂Dropout和R-Drop看这篇就够了. 上一篇R-Drop:提升有监督任务性能最简单的方法,很多小伙伴们都私信我说,让我介 ...

  2. 算法工程师面试必考项:二叉树

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 1 二叉树简介 二叉树是最基本的数据结构之一,二叉树(Binary ...

  3. 算法工程师面试必考项——链表

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 1 知识点 1.1 什么是链表 提到链表,我们大家都不陌生,在平时 ...

  4. 算法工程师面试问题及相关资料集锦(附链接)

    来源:专知 本文约9800字,建议阅读20分钟. 本文为你介绍算法工程师面试问题及相关资料集锦,相当全面,值得收藏. 目录 算法工程师 Github.牛客网.知乎.个人博客.微信公众号.其他 机器学习 ...

  5. 拉勾网《32个Java面试必考点》学习笔记之一------Java职业发展路径

    本文为拉勾网<32个Java面试必考点>学习笔记.只是对视频内容进行简单整理,详细内容还请自行观看视频<32个Java面试必考点>.若本文侵犯了相关所有者的权益,请联系:txz ...

  6. 算法工程师面试问题及资料超详细合集(多家公司算法岗面经/代码实战/网课/竞赛等)

    这里是算法江湖,传授AI武林秘籍. 资源目录: 一.算法工程师 Github.牛客网.知乎.个人博客.微信公众号.其他 二.机器学习 面试问题.资料.代码实战 三.深度学习 面试.资料.代码实战Pyt ...

  7. 拉勾网《32个Java面试必考点》学习笔记之二------操作系统与网络知识

    本文为拉勾网<32个Java面试必考点>学习笔记.只是对视频内容进行简单整理,详细内容还请自行观看视频<32个Java面试必考点>.若本文侵犯了相关所有者的权益,请联系:txz ...

  8. 机器学习算法工程师面试知识点汇总

    机器学习算法工程师面试知识点汇总 机器学习 梯度下降 k-means 1 × 1卷积核 模型 SVM Bagging & Boosting 随机森林 激活函数 Sigmod tanh ReLU ...

  9. 深度学习算法工程师面试知识点总结(四)

    这是算法工程师面试知识点总结的第四篇,有兴趣的朋友可以看看前三篇的内容: 深度学习算法工程师面试知识点总结(一) 深度学习算法工程师面试知识点总结(二) 深度学习算法工程师面试知识点总结(三) 基于t ...

  10. 决战春招!算法工程师面试问题及资料超详细合集(算法岗面经/代码实战/网课/竞赛等)...

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! Awesome-AI-algorithm 目录 算法面试 1. Github 20 ...

最新文章

  1. Windows 7 部署(一):安装和部署简述
  2. 在使用import语句时
  3. CentOS系统Nginx配置免费https证书
  4. 阿里 mysql 架构_阿里java架构教你怎么用mysql怒怼面试官
  5. 关卡2-1 简单的模拟 1540 机器翻译
  6. 一个快速实现彩屏应用的跨平台快速原型开发工具平台,最重要的是还免费!8ms.xyz平台原以为是单片机版墨刀,今天上去玩了才知道平台厉害的很,基于WEB端免搭建开发环境,跑的还是C代码编译出来的程序!
  7. Java 函数式编程入门
  8. 网易2019实习生Java编程题
  9. HDU5794 - A Simple Chess
  10. linux网络编程应用于生活,[Linux网络编程]应用实例--获取网络时间
  11. html+cs入门实例,CS50 HTML和CSS基础(介绍最简单的HTML和CSS)
  12. 大型任务处理:为虚拟现实游戏施展混合现实魔法
  13. nhibernate GetType
  14. Xilisoft iPad Magic Platinum for Mac如何制作铃声?将联系人传输到计算机/设备?
  15. 微信公众号开发详细笔记
  16. 我建议你自己写一个疫情数据监控
  17. 新操作系统有哪些新功能?一起来看看吧!
  18. 大数据MBA 通过大数据实现与分析驱动企业决策与转型
  19. SEM数据分析之做好关键词报告
  20. python:select interpreter resulted in an error python.setINterpreter not found

热门文章

  1. Object.create()和深拷贝
  2. 四、bootstrap-Table
  3. Python虚拟环境的搭建
  4. 关于数据分析用到的统计学知识
  5. HTML入门之003
  6. 用HTML5 Canvas为Web图形创建特效
  7. 在Migration中操作新添加的字段
  8. Ext 介绍入门之 Templates(模板)
  9. vue-awesome-swiper 的安装和使用
  10. #JS 窗口resize避免触发多次