1、Dropout概念

Dropout:随机失活,随机是dropout probability,失活是指weight=0。

通过下面的示例图理解随机失活:

左边的图是正常的全连接网络,右边的图是使用dropout的神经网络,dropout是以一定的概率让一部分的神经元失活,这可以让神经元学习到更鲁棒的特征,减轻过度的依赖性,从而缓解过拟合,降低方差达到正则化效果,这种操作可以使模型更多样化,因为每一次前向传播神经元都会随机失活,每次训练得到的模型都是不一样的。

为什么dropout能够达到很好的正则化效果呢?

  1. 从特征依赖性角度
    假设一个神经元会接收上一层的五个神经元的输出值,可以理解为上一层的特征,如果当前神经元特别依赖于某一个特征。如果加了dropout之后,当前神经元就不知道上一层所有神经元中哪些神经元会出现,这样当前神经元就不会过度依赖上一层神经元中的某些神经元。

数据尺度变化
测试时,所有权重乘以1-drop_prob,例如drop_prob=0.3,1-drop_prob=0.7;

1.2 nn.Dropout

功能:Dropout层;
参数

  • P:被舍弃概率,失活概率;
    注意:dropout层通常放在需要dropout的网络层的前一层;
torch.nn.Dropout(p=0.5,inplace=False)

下面通过代码分析Dropout层的作用:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from toolss.common_tools import set_seed
from torch.utils.tensorboard import SummaryWriterset_seed(1)  # 设置随机种子
n_hidden = 200
max_iter = 2000
disp_interval = 400
lr_init = 0.01# ============================ step 1/5 数据 ============================
def gen_data(num_data=10, x_range=(-1, 1)):w = 1.5train_x = torch.linspace(*x_range, num_data).unsqueeze_(1)train_y = w*train_x + torch.normal(0, 0.5, size=train_x.size())test_x = torch.linspace(*x_range, num_data).unsqueeze_(1)test_y = w*test_x + torch.normal(0, 0.3, size=test_x.size())return train_x, train_y, test_x, test_ytrain_x, train_y, test_x, test_y = gen_data(x_range=(-1, 1))# ============================ step 2/5 模型 ============================
class MLP(nn.Module):def __init__(self, neural_num, d_prob=0.5):super(MLP, self).__init__()self.linears = nn.Sequential(nn.Linear(1, neural_num),nn.ReLU(inplace=True),nn.Dropout(d_prob),nn.Linear(neural_num, neural_num),nn.ReLU(inplace=True),nn.Dropout(d_prob),nn.Linear(neural_num, neural_num),nn.ReLU(inplace=True),nn.Dropout(d_prob),nn.Linear(neural_num, 1),)def forward(self, x):return self.linears(x)net_prob_0 = MLP(neural_num=n_hidden, d_prob=0.)
net_prob_05 = MLP(neural_num=n_hidden, d_prob=0.5)# ============================ step 3/5 优化器 ============================
optim_normal = torch.optim.SGD(net_prob_0.parameters(), lr=lr_init, momentum=0.9)
optim_reglar = torch.optim.SGD(net_prob_05.parameters(), lr=lr_init, momentum=0.9)# ============================ step 4/5 损失函数 ============================
loss_func = torch.nn.MSELoss()# ============================ step 5/5 迭代训练 ============================writer = SummaryWriter(comment='_test_tensorboard', filename_suffix="12345678")
for epoch in range(max_iter):pred_normal, pred_wdecay = net_prob_0(train_x), net_prob_05(train_x)loss_normal, loss_wdecay = loss_func(pred_normal, train_y), loss_func(pred_wdecay, train_y)optim_normal.zero_grad()optim_reglar.zero_grad()loss_normal.backward()loss_wdecay.backward()optim_normal.step()optim_reglar.step()if (epoch+1) % disp_interval == 0:net_prob_0.eval()net_prob_05.eval()# 可视化for name, layer in net_prob_0.named_parameters():writer.add_histogram(name + '_grad_normal', layer.grad, epoch)writer.add_histogram(name + '_data_normal', layer, epoch)for name, layer in net_prob_05.named_parameters():writer.add_histogram(name + '_grad_regularization', layer.grad, epoch)writer.add_histogram(name + '_data_regularization', layer, epoch)test_pred_prob_0, test_pred_prob_05 = net_prob_0(test_x), net_prob_05(test_x)# 绘图plt.scatter(train_x.data.numpy(), train_y.data.numpy(), c='blue', s=50, alpha=0.3, label='train')plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='red', s=50, alpha=0.3, label='test')plt.plot(test_x.data.numpy(), test_pred_prob_0.data.numpy(), 'r-', lw=3, label='d_prob_0')plt.plot(test_x.data.numpy(), test_pred_prob_05.data.numpy(), 'b--', lw=3, label='d_prob_05')plt.text(-0.25, -1.5, 'd_prob_0 loss={:.8f}'.format(loss_normal.item()), fontdict={'size': 15, 'color': 'red'})plt.text(-0.25, -2, 'd_prob_05 loss={:.6f}'.format(loss_wdecay.item()), fontdict={'size': 15, 'color': 'red'})plt.ylim((-2.5, 2.5))plt.legend(loc='upper left')plt.title("Epoch: {}".format(epoch+1))plt.show()plt.close()net_prob_0.train()net_prob_05.train()

代码的图输出如下所示:

pytorch —— 正则化之Dropout相关推荐

  1. PyTorch框架学习十六——正则化与Dropout

    PyTorch框架学习十六--正则化与Dropout 一.泛化误差 二.L2正则化与权值衰减 三.正则化之Dropout 补充: 这次笔记主要关注防止模型过拟合的两种方法:正则化与Dropout. 一 ...

  2. 深度学习笔记5:正则化与dropout

    出处:数据科学家养成记 深度学习笔记5:正则化与dropout 在笔记 4 中,笔者详细阐述了机器学习中利用正则化防止过拟合的基本方法,对 L1 和 L2 范数进行了通俗的解释.为了防止深度神经网络出 ...

  3. 吴恩达作业5:正则化和dropout

    构建了三层神经网络来验证正则化和dropout对防止过拟合的作用. 首先看数据集,reg_utils.py包含产生数据集函数,前向传播,计算损失值等,代码如下: import numpy as np ...

  4. keras添加L1正则化,L2正则化和Dropout正则化及其原理

    一.什么是正则化,用来干嘛的? 正则化(regularization),是指在线性代数理论中,不适定问题通常是由一组线性代数方程定义的,而且这组方程组通常来源于有着很大的条件数的不适定反问题.大条件数 ...

  5. PyTorch 中的 dropout Dropout2d Dropout3d

    文章目录 PyTorch 中的 dropout 1. [Pytoch 说明文档官网 PyTorch documentation 链接](https://pytorch.org/docs/stable/ ...

  6. pytorch中的dropout在drop什么?

    最近遇到了一个很基础的问题,就是pytorch中的dropout在面对一个n维的矩阵时,是会随机drop某一行.或者某一维上的一个向量,还是某一个元素呢?用试验稍微验证了下 import torch ...

  7. pytorch中nn.Dropout的使用技巧

    dropout是Hinton老爷子提出来的一个用于训练的trick.在pytorch中,除了原始的用法以外,还有数据增强的用法(后文提到). 首先要知道,dropout是专门用于训练的.在推理阶段,则 ...

  8. Pytorch 正则化方法(权重衰减和Dropout)

    正则化方法(权重衰退和Dropout) 正则化方法和以前学过的正则表达式没有任何关系! 花书 p141 说到: 能显式地减少测试误差(可能会以增大训练误差为代价)的方法都被称为正则化. 0. 环境介绍 ...

  9. 偏差与方差、L1正则化、L2正则化、dropout正则化、神经网络调优、批标准化Batch Normalization(BN层)、Early Stopping、数据增强

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 3.2 深度学习正则化 3.2.1 偏差与方差 3.2.1.1 ...

最新文章

  1. 要不要读博,以及读博后如何顺利毕业并找到理想工作?五个最接地气的忠告...
  2. Edraw Max——亿图图示设计软件基本使用教程
  3. 大数据WEB阶段 后台和页面之间传递日期格式数据的400问题
  4. mysql create routine 权限的一些说明
  5. Power Automate生产现场实例分享回顾
  6. 关于主函数main(int argc,char *argv[])
  7. 赢利定位是网站建设前提
  8. Jar包冲突解决方法
  9. PM3GUI 和 RDV4GUI 专业版软件试用手札
  10. 云片网短信模版自定义变量的替换内容
  11. java js css 压缩工具_JS/CSS压缩工具(YUI Compressor)使用方法
  12. 再见,2017,你好,2018
  13. linux/android系统开发,高级adb 命令汇总
  14. qt项目在Linux平台上面发布成可执行程序.run
  15. Android BLE HIDS Data ,从问询DB 到写入Android 节点的flow 之五
  16. 如何识别媒体偏见_超越偏见:为什么我们不能仅仅“修正”面部识别
  17. iOS 如何连接打印机
  18. 博客系统程序(页面设计)
  19. Multisim基础 有极性的电容 添加元件的位置
  20. 格网DEM生成不规则三角网TIN

热门文章

  1. Exchange2003-2010迁移系列之十一,Exchange2010 OWA配置
  2. 通过Repository Manager 1.3来管理戴尔驱动程序更新
  3. 宁静——一种心灵的奢望
  4. Hystrix面试 - 基于本地缓存的 fallback 降级机制
  5. Hystrix面试 - 深入 Hystrix 线程池隔离与接口限流
  6. 数据可视化组件Grafana详细解读--MacOSX上的安装
  7. CCNA初认识——OSPF(开放式最短路径优先协议)配置命令
  8. linux find命令mtime/atime/ctime +n -n n 全网最正确的总结
  9. 苹果cms10 官方QQ微信防红防封代码
  10. C#LeetCode刷题之#665-非递减数列( Non-decreasing Array)