参考

3.13 丢弃法

过拟合问题的另一种解决办法是丢弃法。当对隐藏层使用丢弃法时,隐藏单元有一定概率被丢弃。

3.12.1 方法

3.13.2 从零开始实现

import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2ldef dropout(X, drop_prob):X = X.float()assert 0 <= drop_prob <= 1keep_prob = 1 - drop_prob# 这种情况下把全部元素都丢弃if keep_prob == 0:return torch.zeros_like(X)mask = (torch.rand(X.shape) < keep_prob).float()return mask * X / keep_prob
X = torch.arange(16).view(2, 8)
X

dropout(X, 0.5)

dropout(X, 1)

3.13.2.1 定义模型参数

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256W1 = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_hiddens1)), dtype=torch.float, requires_grad=True)
b1 = torch.zeros(num_hiddens1, requires_grad=True)
W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1, num_hiddens2)), dtype=torch.float, requires_grad=True)
b2 = torch.zeros(num_hiddens2, requires_grad=True)
W3 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens2, num_outputs)), dtype=torch.float, requires_grad=True)
b3 = torch.zeros(num_outputs, requires_grad=True)params = [W1, b1, W2, b2, W3, b3]

3.13.2.2 定义模型

drop_prob1, drop_prob2 = 0.2, 0.5def net(X, is_training=True):X = X.view(-1, num_inputs)H1 = (torch.matmul(X, W1) + b1).relu()if is_training:  # 只在训练模型时使用丢弃法H1 = dropout(H1, drop_prob1)  # 在第一层全连接后添加丢弃层H2 = (torch.matmul(H1, W2) + b2).relu()if is_training:H2 = dropout(H2, drop_prob2)  # 在第二层全连接后添加丢弃层return torch.matmul(H2, W3) + b3# 本函数已保存在d2lzh_pytorch
def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutacc_sum += (net(X).argmax(dim=1) == y).float().sum().item()net.train() # 改回训练模式else: # 自定义的模型if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0]return acc_sum / n

3.13.2.3 训练和测试模型

num_epochs, lr, batch_size = 5, 100.0, 256
loss = torch.nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

3.13.3 简洁实现

net = nn.Sequential(d2l.FlattenLayer(),nn.Linear(num_inputs, num_hiddens1),nn.ReLU(),nn.Dropout(drop_prob1),nn.Linear(num_hiddens1, num_hiddens2),nn.ReLU(),nn.Dropout(drop_prob2),nn.Linear(num_hiddens2, 10)
)for param in net.parameters():nn.init.normal_(param, mean=0, std= 0.01)optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

[pytorch、学习] - 3.13 丢弃法相关推荐

  1. 动手学深度学习V2.0(Pytorch)——13.丢弃法

    文章目录 1. 课件讲解 插一句(正则的分类) 2. Q&A 2.1 dropout是初次生效,还是每次都重新选取概率 2.2 dropout的感性评价 2.3 dropout随机置0对求梯度 ...

  2. PyTorch学习笔记(13)——强力的可视化工具visdom

    今天,让我们来放松一下大脑,学习点轻松的东西----可视化工具Visdom,它可以让我们在使用PyTorch训练模型的时候,可视化中间的训练情况,无论是loss变化还是中间结果比较.相比干呆呆的瞪着命 ...

  3. 【pytorch】过拟合的应对办法 —— 丢弃法(dropout)

    文章目录 一.什么是丢弃法,为什么丢弃法可以缓解过拟合? 二.丢弃法的手动实现 三.丢弃法的pytorch实现 参考 关于过拟合.欠拟合的解释可以参考我的博文:[pytorch]过拟合和欠拟合详解,并 ...

  4. 《动手学深度学习》丢弃法(dropout)

    丢弃法(dropout) 丢弃法 方法 从零开始实现 定义模型参数 定义模型 训练和测试模型 简洁实现 小结 参考文献 丢弃法 除了前一节介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout ...

  5. Pytorch与drop_out(丢弃法)

    简述 深度学习模型常常使用丢弃法(dropout)[1] 来应对过拟合问题.丢弃法有一些不同的变体.文中提到的丢弃法特指倒置丢弃法(inverted dropout). 对于激活函数而言有: hi=ϕ ...

  6. PyTorch——Dropout(丢弃法)

    参考链接 https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.13_dropout dropout 深度学习模型常常使 ...

  7. PyTorch学习笔记(13)--现有网络模型的使用及修改

    PyTorch学习笔记(13)–现有网络模型的使用及修改     本博文是PyTorch的学习笔记,第13次内容记录,主要介绍如何使用现有的神经网络模型,如何修改现有的网络模型. 目录 PyTorch ...

  8. 【深度学习】丢弃法(dropout)

    丢弃法 在小虾的这篇文章中介绍了权重衰减来应对过拟合问题(https://blog.csdn.net/qq_33432841/article/details/107879937),下面在介绍一种应对过 ...

  9. pytorch学习笔记(十三):Dropout

    文章目录 1. 方法 2. 从零开始实现 2.1 定义模型参数 2.2 定义模型 2.3 训练和测试模型 3. 简洁实现 小结 除了前一节介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout ...

最新文章

  1. python list列表与array区别
  2. react-native 集成极光推送jpush-react-native时的小问题
  3. 计算机指令格式哪几部分组成,计算机的指令格式,通常是由()两部分组成。 - 百科题库网...
  4. RocketMQ中的死信队列
  5. 使用 idea 创建第一个 springboot 项目
  6. 详解6G系统数据治理方案的设计要点和原则
  7. 详细描述一下 Elasticsearch 搜索的过程?
  8. [原]无法删除openstack nova的image instance
  9. fatal error C1010: 是否忘记了向源中添加“#include stdafx.h”?
  10. 存储服务器内的温度检测信号线 用线,常用的3线和4线电阻温度检测器介绍
  11. 基于遗传算法(deap)的配词问题与deap框架
  12. python3解析纯真ip数据库
  13. 记录:2018年CCF优秀博士学位论文奖信息
  14. 不同浏览器JS获取浏览器高度和宽度
  15. 171023—各进制数输出:二进制转换用格式控制符输出八,十,十六进制数
  16. android设置加密步长,非稳态计算时Fluent 时间步长如何设置(转载)
  17. 华北屋脊:大秦岭的前世今生
  18. 全程免费 - 挖矿转录组学大数据,中科院胡松年、方向东等开讲
  19. python拼写_用 Python 27 行实现拼写纠正
  20. content(contents)

热门文章

  1. 计算机机房用户不规则行为,网络及网管机房管理理论练习
  2. ubuntu20.04中安装划词翻译_教你轻松玩转免安装的网页翻译插件“有道网页翻译2.0”...
  3. 大学计算机用的笔记本,推荐一款大学生用笔记本电脑
  4. python自动化办公知识点整理汇总_python自动化办公小结
  5. java s1_转!!Java 基础面试题的剖析: short s1=1;s1 = s1 +1 报错? s1+=1 呢
  6. c语言使用未初始化的内存怎么解决_C语言快速入门——数组与调试进阶
  7. Java生鲜电商平台-深入订单拆单架构与实战
  8. 各视频、各音频之间格式任意玩弄(图文详解)
  9. asp.net web api集成微信服务(使用Senparc微信SDK)
  10. WSDL文件生成WEB service server端C#程序