dropout是Hinton老爷子提出来的一个用于训练的trick。在pytorch中,除了原始的用法以外,还有数据增强的用法(后文提到)。
首先要知道,dropout是专门用于训练的。在推理阶段,则需要把dropout关掉,而model.eval()就会做这个事情。
原文链接: https://arxiv.org/abs/1207.0580
通常意义的dropout解释为:在训练过程的前向传播中,让每个神经元以一定概率p处于不激活的状态。以达到减少过拟合的效果
然而,在pytorch中,dropout有另一个用法。如果把dropout加在输入张量上:

x = torch.randn(20, 16)
dropout = nn.Dropout(p=0.2)
x_drop = dropout(x)

那么,这个操作表示使x每个位置的元素都有一定概率归0,以此来模拟现实生活中的某些频道的数据缺失,以达到数据增强的目的。
在官方doc描述如下:

During training, randomly zeroes some of the elements of the input tensor with probability :attr:p using samples from a Bernoulli distribution. Each channel will be zeroed out independently on every forward call.

每个频道的数据缺失相互独立,以服从伯努利分布的概率值p来进行随机变为0。

pytorch中nn.Dropout的使用技巧相关推荐

  1. Pytorch中 nn.Transformer的使用详解与Transformer的黑盒讲解

    文章目录 本文内容 将Transformer看成黑盒 Transformer的推理过程 Transformer的训练过程 Pytorch中的nn.Transformer nn.Transformer简 ...

  2. PyTorch 中的 dropout Dropout2d Dropout3d

    文章目录 PyTorch 中的 dropout 1. [Pytoch 说明文档官网 PyTorch documentation 链接](https://pytorch.org/docs/stable/ ...

  3. pytorch中的dropout在drop什么?

    最近遇到了一个很基础的问题,就是pytorch中的dropout在面对一个n维的矩阵时,是会随机drop某一行.或者某一维上的一个向量,还是某一个元素呢?用试验稍微验证了下 import torch ...

  4. 什么是embedding(把物体编码为一个低维稠密向量),pytorch中nn.Embedding原理及使用

    文章目录 使embedding空前流行的word2vec 句子的表达 训练样本 损失函数 输入向量表达和输出向量表达vwv_{w}vw​ 从word2vec到item2vec 讨论环节 pytorch ...

  5. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

  6. 对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用

    在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题. BCELoss是Binary CrossEntropyLoss的缩写,nn. ...

  7. Pytorch中nn.Conv2d数据计算模拟

    Pytorch中nn.Conv2d数据计算模拟 最近在研究dgcnn网络的源码,其网络架构部分使用的是nn.Conv2d模块.在Pytorch的官方文档中,nn.Conv2d的输入数据为(B, Cin ...

  8. Pytorch中nn.Module和nn.Sequencial的简单学习

    文章目录 前言 1.Python 类 2.nn.Module 和 nn.Sequential 2.1 nn.Module 2.1.1 torch.nn.Module类 2.1.2 nn.Sequent ...

  9. 总结PYTORCH中nn.lstm(自官方文档整理 包括参数、实例)

    参考pytorch官方文档 https://pytorch.org/docs/master/nn.html#torch.nn.LSTM 先上原图 | 这里是关键参数介绍 input_size:输入特征 ...

最新文章

  1. 如何在Python中获取字符串的子字符串?
  2. 【Socket网络编程】12. send()、recv()、sendto() 和 recvfrom() 函数解析
  3. html 属于mvvm框架,mvvm模式和mvc的区别是什么?
  4. libjpeg(1)
  5. Golang sort 排序
  6. Java获取网络IP
  7. python:数组和列表相互转化
  8. 令牌环网Token Ring协议
  9. iconfont 图标不生效
  10. python爬取淘宝数据魔方_淘宝数据魔方看人群情况
  11. 360 自动 html 极速模式,用Meta标签代码让360双核浏览器默认极速模式打开网站不是兼容模式(顺带解决很多兼容性问题)...
  12. JVM 为什么使用元空间替换了永久代?
  13. linux禅道在线迁移,禅道从windows迁移到linux
  14. Java web接入google身份验证器二次验证
  15. 小米node2红外_简单易懂,联动好用:小米 米家蓝牙温湿度计2 晒单
  16. 近期研究方向 (内部参考)
  17. 基于深度学习的花卉检测与识别系统(YOLOv5清新界面版,Python代码)
  18. 使用OpenSSL实现CA证书的搭建过程
  19. git WorkFlow规范
  20. STM32+IAP方案的实现,IAP实现原理(详细解决说明)。

热门文章

  1. 安装 suds 出现 问题
  2. 用tecplot导出圆柱绕流中的表面平均压力系数
  3. 朴素贝叶斯(naive bayes)分类
  4. python邮件定时发送短信_python实现自动定时给女朋友发手机短信,每天一个笑话!...
  5. 无监督对话数据清洗利器:Data Purification Framework
  6. EditPlus打开.tpl文件高亮显示代码
  7. 解决IDEA项目运行Tomcat时报错Cannot build artifact
  8. 计算机技术比武活动方案,计算机操作技能比赛方案
  9. 四川省凉山彝族自治州谷歌高清卫星地图下载
  10. 5折限时抢购移动开发者大会门票!