使用Dropout解决过拟合的情况发生

修改代码

import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
#训练集
train_dataset=datasets.MNIST(root='./', #存放到项目目录下train=True,  #是训练数据transform=transforms.ToTensor(),  #转换成基本类型tensor数据download=True) #需要下载
#测试集
test_dataset=datasets.MNIST(root='./',train=False,transform=transforms.ToTensor(),download=True)#每次训练图片的数量
batch_size=64
#装在训练集数据
train_loader=DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
#加载训练集
test_loader=DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):inputs,labels=dataprint(inputs.shape)print(labels.shape)break#定义网络结构(使用Dropout)
class Net(nn.Module):def __init__(self):super(Net,self).__init__()#调用父类方法#Dropout一部分神经元工作一部分不工作self.layer1=nn.Sequential(nn.Linear(784,500),nn.Dropout(p=0.5),nn.Tanh())#p=0.5表示50%的神经元不工作self.layer2 = nn.Sequential(nn.Linear(500, 300), nn.Dropout(p=0.5), nn.Tanh())self.layer3 = nn.Sequential(nn.Linear(300, 10), nn.Softmax(dim=1))def forward(self,x):#[64,1,28,28]----(64,784)四维数据编程2维数据x=x.view(x.size()[0],-1)#-1表示自动匹配x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)return xLR=0.5
#定义模型
model=Net()
#定义代价函数(均方差)
mse_loss=nn.MSELoss()
#定义优化器
optimizer=optim.SGD(model.parameters(),LR)def train():# 表示训练状态, #Dropout一部分神经元工作一部分不工作model.train()for i,data in enumerate(train_loader):#获得一个皮次数据和标签inputs,labels=data#获得模型预测结果(64,10)out=model(inputs)#to onehot,把数据编码变成独热编码#(64)编程(64,-1)labels=labels.reshape(-1,1)#tensor.scatter(dim,index,src)#dim对那个维度进行独热编码#index:要将src中对应的值放到tensor的哪个位置#src:插入index的数值one_hot=torch.zeros(inputs.shape[0],10).scatter(1,labels,1)#计算loss,mse_loss的俩个数据的shape要一致loss=mse_loss(out,one_hot)#梯度清零optimizer.zero_grad()#计算梯度loss.backward()#修改权值optimizer.step()#测试
def test():# 表示模型测试状态,#Dropout所有神经元都要工作model.eval()correct=0for i,data in enumerate(test_loader):#获取一个批次的数据和标签inputs,labels=data#获得模型的预测结果(64,10)out=model(inputs)#获取最大值,以及最大值所在的位置_,predicted=torch.max(out,1)# 预测正确的数量correct += (predicted == labels).sum()print("Test acc:{0}".format(correct.item()/len(test_dataset)))correct = 0for i, data in enumerate(train_loader):# 获取一个批次的数据和标签inputs, labels = data# 获得模型的预测结果(64,10)out = model(inputs)# 获取最大值,以及最大值所在的位置_, predicted = torch.max(out, 1)# 预测正确的数量correct += (predicted == labels).sum()print("Train acc:{0}".format(correct.item() / len(train_dataset)))for epoch in range(10):print("epoch:",epoch)train()test()

主要代码:

#定义网络结构(使用Dropout)
class Net(nn.Module):def __init__(self):super(Net,self).__init__()#调用父类方法#Dropout一部分神经元工作一部分不工作self.layer1=nn.Sequential(nn.Linear(784,500),nn.Dropout(p=0.5),nn.Tanh())#p=0.5表示50%的神经元不工作self.layer2 = nn.Sequential(nn.Linear(500, 300), nn.Dropout(p=0.5), nn.Tanh())self.layer3 = nn.Sequential(nn.Linear(300, 10), nn.Softmax(dim=1))def forward(self,x):#[64,1,28,28]----(64,784)四维数据编程2维数据x=x.view(x.size()[0],-1)#-1表示自动匹配x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)return x

测试结果对比:

F:\开发工具\pythonProject\tools\venv\Scripts\python.exe F:/开发工具/pythonProject/tools/pytools/pytools032.py
torch.Size([64, 1, 28, 28])
torch.Size([64])
epoch: 0
Test acc:0.8875
Train acc:0.8811
epoch: 1
Test acc:0.9111
Train acc:0.9065833333333333
epoch: 2
Test acc:0.9178
Train acc:0.91535
epoch: 3
Test acc:0.9234
Train acc:0.92
epoch: 4
Test acc:0.9259
Train acc:0.9243166666666667
epoch: 5
Test acc:0.9285
Train acc:0.9270833333333334
epoch: 6
Test acc:0.9296
Train acc:0.9291
epoch: 7
Test acc:0.9331
Train acc:0.9323
epoch: 8
Test acc:0.9357
Train acc:0.9348
epoch: 9
Test acc:0.9366
Train acc:0.93645Process finished with exit code 0

另外一种正则化解决过拟合方案:建议网络结构复杂的情况下使用

LR=0.5
#定义模型
model=Net()
#定义代价函数(均方差)
mse_loss=nn.MSELoss()
#定义优化器,weight_decay设置L2正则化
optimizer=optim.SGD(model.parameters(),LR,weight_decay=0.001)

测试结果:

F:\开发工具\pythonProject\tools\venv\Scripts\python.exe F:/开发工具/pythonProject/tools/pytools/pytools033.py
torch.Size([64, 1, 28, 28])
torch.Size([64])
epoch: 0
Test acc:0.883
Train acc:0.8743166666666666
epoch: 1
Test acc:0.8962
Train acc:0.8936833333333334
epoch: 2
Test acc:0.9044
Train acc:0.9002833333333333
epoch: 3
Test acc:0.9043
Train acc:0.90135
epoch: 4
Test acc:0.9051
Train acc:0.9009
epoch: 5
Test acc:0.9035
Train acc:0.9002166666666667
epoch: 6
Test acc:0.9067
Train acc:0.9016833333333333
epoch: 7
Test acc:0.9052
Train acc:0.9004166666666666
epoch: 8
Test acc:0.9036
Train acc:0.9023333333333333
epoch: 9

pytorch的优化器介绍:优化器原理都是随机梯度下降法

Adadelta、Adagrad、Adam、Adamx、AdamW、ASGD、LBFGS、RMSprop、Rprop、SGD、SparseAdam

总结:

Dropout和正则化看情况适用,建议网络复杂的情况下使用,用了不一定效果就好。

pytorch之过拟合的处理(Dropout)(笔记五)相关推荐

  1. (pytorch-深度学习系列)pytorch避免过拟合-dropout丢弃法的实现-学习笔记

    pytorch避免过拟合-dropout丢弃法的实现 对于一个单隐藏层的多层感知机,其中输入个数为4,隐藏单元个数为5,且隐藏单元hih_ihi​(i=1,-,5i=1, \ldots, 5i=1,- ...

  2. 【pytorch】过拟合的应对办法 —— 丢弃法(dropout)

    文章目录 一.什么是丢弃法,为什么丢弃法可以缓解过拟合? 二.丢弃法的手动实现 三.丢弃法的pytorch实现 参考 关于过拟合.欠拟合的解释可以参考我的博文:[pytorch]过拟合和欠拟合详解,并 ...

  3. (pytorch-深度学习系列)pytorch避免过拟合-权重衰减的实现-学习笔记

    pytorch避免过拟合-权重衰减的实现 首先学习基本的概念背景 L0范数是指向量中非0的元素的个数:(L0范数难优化求解) L1范数是指向量中各个元素绝对值之和: L2范数是指向量各元素的平方和然后 ...

  4. 【Keras】减少过拟合的秘诀——Dropout正则化

    摘要: Dropout正则化是最简单的神经网络正则化方法.阅读完本文,你就学会了在Keras框架中,如何将深度学习神经网络Dropout正则化添加到深度学习神经网络模型里. Dropout正则化是最简 ...

  5. MATLAB,Python,Pytorch实现数据拟合

    目录 1.MATLAB实现数据拟合 2.纯python实现数据拟合 3.pytorch实现数据拟合 1.MATLAB实现数据拟合 %MATLAB 数据拟合 x=linspace(-1,1,100); ...

  6. 吴恩达《机器学习》学习笔记五——逻辑回归

    吴恩达<机器学习>学习笔记五--逻辑回归 一. 分类(classification) 1.定义 2.阈值 二. 逻辑(logistic)回归假设函数 1.假设的表达式 2.假设表达式的意义 ...

  7. 论文阅读笔记(五)——狐猴识别系统:一种便于狐猴个体识别的面部识别系统

    论文阅读笔记(五)--狐猴识别系统:一种便于狐猴个体识别的面部识别系统 论文简介 论文中文翻译:狐猴识别系统:一种便于狐猴个体识别的面部识别系统 论文名称:<LemurFaceID: a fac ...

  8. python函数是一段具有特定功能的语句组_Python学习笔记(五)函数和代码复用

    本文将为您描述Python学习笔记(五)函数和代码复用,具体完成步骤: 函数能提高应用的模块性,和代码的重复利用率.在很多高级语言中,都可以使用函数实现多种功能.在之前的学习中,相信你已经知道Pyth ...

  9. Ethernet/IP 学习笔记五

    Ethernet/IP 学习笔记五 Accessing data within a device using a non-time critical message (an explicit mess ...

  10. StackExchange.Redis学习笔记(五) 发布和订阅

    StackExchange.Redis学习笔记(五) 发布和订阅 原文:StackExchange.Redis学习笔记(五) 发布和订阅 Redis命令中的Pub/Sub Redis在 2.0之后的版 ...

最新文章

  1. 云服务蓬勃发展,平均年增长率高达28%
  2. 开发环境和运行环境的区别_生产环境 VS 开发环境,关于Kubernetes的四大认识误区...
  3. spring融合activitymq-all启动报错的解决办法
  4. Python开发:初识Python
  5. 游戏引擎cocos2d-android使用大全
  6. 怎样编译libdb_比特币编译(Ubuntu 16.04)
  7. 用计算机求正有理数算术平方根的步骤,用计算器求算数平方根、用有理数估计算数平方根的大小.ppt...
  8. 爆红的变老神器 FaceApp,夹杂着安全隐患?
  9. CCCC-GPLT L1-035. 情人节 团体程序设计天梯赛
  10. Windows10 关闭自动更新
  11. 合并出错:svn Working copy and merge source not ready for reintegration
  12. 19n20c的参数_供应IC芯片 745653-3 329056 品牌、价格、PDF参数 - 电子产品资料
  13. 热血传奇C++版官网
  14. 英语单词记忆(词缀 / 前缀)
  15. select下拉框如何显示提示语,不要出现下拉选项中
  16. translate()方法
  17. Netty高性能之道1-传统RPC调用性能差的三宗罪
  18. Android设置沉浸式
  19. 从图形界面到会话界面
  20. 『杭电1726』God’s cutter

热门文章

  1. 2018.3.13 浮动 定位
  2. 分享到:空间等各大网站 代码
  3. android webview远程调试
  4. noip2006提高组-金明的预算方案解题报告
  5. 如何将shapefile进行拆分
  6. c语言链表版百度云,链表详解(C语言版)
  7. java中的等待_Java中更好的等待语法
  8. 文件设置索引_Linux文件系统是怎么工作的?
  9. python 给类添加属性_python – 如何动态添加属性到类中?
  10. activity中获取fragment布局_安卓开发入门教程Fragment