作者 | BetterBench

出品 | 对白的算法屋

编者寄语:

搞懂Dropout和R-Drop看这篇就够了。

上一篇R-Drop:提升有监督任务性能最简单的方法,很多小伙伴们都私信我说,让我介绍一下Dropout和R-Drop之间的区别。相信大家看完这篇后,当面试官再问时,就可以轻松应对啦!

1、引言

在ML中存在两类严重的问题:过拟合和学习时间开销大

当过拟合时,得到的模型会在训练集上有非常好的表现,但是对新数据的预测结果会非常的不理想。为了解决过拟合问题,通常会采用训练多个模型来解决单模过拟合的问题。但又会带来时间开销大的问题。Dropout就很好的解决了这个问题,在单模内防止过拟合。对于时间开销大的地方是梯度下降,学习率衰减可以解决梯度下降中时间开销的问题。

Dropout是在训练过程中,随机地忽略部分神经元,即是在正向传播的过程中,这些被忽略的神经元对下游神经元的贡献暂时消失,在反向传播时,这些神经元也不会有任何权重的更新。

2、Dropout使用技巧

(1)经过验证,隐含节点Dropout率等于0.5的时候最佳,此时Dropout随机生成的网络结构最多。Dropout也可以用在输入层,作为一种添加噪声的方法。输入层设为更接近1时,使得输入变化不会太大,比如0.8。

(2)通常在网络中Dropout率为0.2~0.5。0.2是一个很好的起点,太低的概率产生的作用有限,太高的概率可能导致网络的训练不充分。

(3)当在较大的网络上使用Dropout时,可能会获得更好的表现,因为Dropout降低了模型训练过程中的干扰

(4)在输入层和隐藏层上使用Dropout。或者在网络的每一层都使用Dropout能有更佳的效果。

(5)使用较高的学习率,使用学习率衰减和设置较大的动量值,将学习率提高10~100倍,且使用0.9或0.99的动量值。

Keras中,momentum就是动量值
sgd = SGD(lr=0.1,momentum=0.8,decay=0.0,nesterov=False)

(6)限制网络权重的大小,打的学习率可能导致非常大的网络权重,对网络权重大小进行约束,例如大小为4或5的最大范数正则化(Max-norm Regularizationi)。

Keras中,通过指定Dense中的kernel_constrain=maxnorm(x)来限制网络权重

参考资料:Dropout: A Simple Way to Prevent Neural Networks from Overfitting

3、Dropout的拓展R-Dropout

3.1 简介

简单来说就是模型中加入dropout,在训练阶段的预测阶段,用同样的数据预测两次,去追求两次的结果尽可能接近,这种接近体现在损失函数上。虽然是同样的数据,但是因为模型中Dropout是随机丢弃神经元,会导致两次丢弃的神经元不一样,从而预测的结果也会不一样。R-Dropout思想就是去实现控制两次预测尽量保持一致,从而去优化模型。除了在NLP领域,其他的NLU、NLG、CV的分类等多种任务上都对R-Drop做了对比实验,大部分实验效果都称得上“明显提升”。

3.2 使用方法

和普通的Dropout方法不同,有封装的API可以一行代码使用。R-Dropout的使用需要自定义模型的输入和损失函数。举例如下,参考NLP 中Pytorch 实现R-Dropout

# define your task model, which outputs the classifier logits
model = TaskModel()def compute_kl_loss(self, p, q pad_mask=None):p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')# pad_mask is for seq-level tasksif pad_mask is not None:p_loss.masked_fill_(pad_mask, 0.)q_loss.masked_fill_(pad_mask, 0.)# You can choose whether to use function "sum" and "mean" depending on your taskp_loss = p_loss.sum()q_loss = q_loss.sum()loss = (p_loss + q_loss) / 2return loss# keep dropout and forward twice
logits = model(x)logits2 = model(x)# cross entropy loss for classifier
ce_loss = 0.5 * (cross_entropy_loss(logits, label) + cross_entropy_loss(logits2, label))kl_loss = compute_kl_loss(logits, logits2)# carefully choose hyper-parameters
loss = ce_loss + α * kl_loss

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载黄海广老师《机器学习课程》视频课黄海广老师《机器学习课程》711页完整版课件

本站qq群851320808,加入微信群请扫码:

【深度学习】算法工程师面试必考点:Dropout和R-Dropout的使用技巧相关推荐

  1. 深度学习算法工程师面试知识点总结(四)

    这是算法工程师面试知识点总结的第四篇,有兴趣的朋友可以看看前三篇的内容: 深度学习算法工程师面试知识点总结(一) 深度学习算法工程师面试知识点总结(二) 深度学习算法工程师面试知识点总结(三) 基于t ...

  2. 深度学习算法工程师面试知识点总结(一)

    深度学习算法工程师岗位需要具有的技术栈初步总结如下: 这个总结并不是很完整,这个方向所需要的知识体系非常的庞大,例如高等数学基础.线性代数.概率论的基础,这对很多的同学来说是一个比较大的挑战.还需要针 ...

  3. 岗位推荐 | 百度招聘计算机视觉、深度学习算法工程师(可实习)

    PaperWeekly 致力于推荐最棒的工作机会,精准地为其找到最佳求职者,做连接优质企业和优质人才的桥梁. 如果你需要我们来帮助你推广实习机会或全职岗位,请添加微信号「pwbot02」. 工作地点: ...

  4. 刚发布!开发者调查报告:机器学习/深度学习算法工程师急缺

    近日,CSDN发布了<2019-2020中国开发者调查报告>,本报告从2004年开始针对一年一度的CSDN开发者大调查数据分析结果形成,是迄今为止覆盖国内各类开发者人群数量最多.辐射地域. ...

  5. 推荐 | 一个统计硕士的深度学习算法工程师的成长之路

    公众号推荐 推荐人/文文 俗话说,一个人走得快,但一群人可以走的远.在数据科学和机器学习的道路上,相信每个人都不是闭门造车的人.技术学习除了在个人努力外,交流和分享也是很重要的一部分. 今天给大家推荐 ...

  6. 【杂谈】什么是我心目中深度学习算法工程师的标准

    有三AI平台只专心做原创输出很少扯淡也不蹭热点,不过最近询问的朋友多了,不得不统一写篇文章来回答一下这个大家都很关心的问题,当然,这仅仅是个人观点. 作者&编辑 | 言有三 目前利用深度学习这 ...

  7. 「杂谈」什么是我心目中深度学习算法工程师的标准

    http://blog.sina.com.cn/s/blog_cfa68e330102zoco.html 有三AI平台只专心做原创输出很少扯淡也不蹭热点,不过最近询问的朋友多了,不得不统一写篇文章来回 ...

  8. 网易北京研发中心-网易传媒部门深度学习算法实习生面试总结

    "微信公众号" 2018年6月13日,网易北京研发中心-网易传媒部门深度学习算法实习生面试总结 1. 问了一下能实习多久,以及实习开始的时间. 2. 问了一下目前去除水印的一些工作 ...

  9. python算法工程师招聘_经验 | 我心目中招聘深度学习算法工程师的标准

    原标题:经验 | 我心目中招聘深度学习算法工程师的标准 本文转载自有三AI 目前利用深度学习这个工具可以做很多事情,各大领域(图像,语音,NLP等),各大行业(娱乐,金融,医疗等)这几年都被玩的风生水 ...

最新文章

  1. Windows 服务入门指南
  2. 分布式系统的架构思路
  3. springboot2 虚拟路径设置_转载—springboot配置虚拟路径以外部访问
  4. 深入细枝末节,Python的字体反爬虫到底怎么一回事
  5. 【LeetCode】剑指 Offer 38. 字符串的排列
  6. 三 转码需求(智源GM813X多国语言OSD开发)
  7. IntelliJ IDEA使用教程(新手入门--持续更新)
  8. 驻马店计算机招聘信息网,2017河南职称计算机考试报名:驻马店职称计算机报名入口...
  9. 相机标定matlab版本,相机标定 matlab
  10. 【codeforces】【比赛题解】#960 CF Round #474 (Div. 1 + Div. 2, combined)
  11. 文件夹快速隐藏,文件夹选项中勾选隐藏目录依旧不能使其显示
  12. 手把手教你如何通过大厂面试
  13. Keras中Conv1D和Conv2D的区别
  14. 剖析 Android ART Runtime (2) – dex2oat
  15. lcs算法c语言代码,LCS算法
  16. python爬取微博数据存入数据库_Python爬取微博数据并存入mysql,excel中
  17. QT常用布局layout快速入门
  18. IP地址与int整数的转换
  19. 解决ubuntu下crossover中qq中文字体乱码问题
  20. 三态输出门实验报告注意事项_冬季行车注意事项 请广大驾驶员注意出行安全...

热门文章

  1. BZOJ3916 [Baltic2014]friends
  2. 解释一下python中的//,%和**运算符
  3. 洛谷 [SDOI2015]约数个数和 解题报告
  4. hdu 527 Necklace
  5. spring3创建RESTFul Web Service
  6. IOS atomic与nonatomic,assign,copy与retain的定义和区别
  7. Android 开机自动启动服务
  8. C# 多线程,解决处理大数据时窗体(不能拖动等)假死现象
  9. sql server 关联left join条件on和where条件的区别
  10. 如何测试机房的速度和带宽?