在抑制过拟合的方法中,我们前面有讲到一个方法:抑制过拟合的方法之权值衰减 ,在某种程度上能够很好的抑制过拟合,如果神经网络的模型很复杂,只用权值衰减就难以应对了,这样的情况下,我们一般选择Dropout方法,也就是在训练的过程中,随机选出隐藏层的神经元,然后将其删除,被删除的神经元不再进行信号的传递。代码在权值衰减中有出现,layers.py里面,摘录出来

class Dropout:'''随机删除神经元self.mask:保存的是False和True的数组,False的值为0是删除的数据'''def __init__(self,dropout_ratio=0.5):self.dropout_ratio=dropout_ratioself.mask=Nonedef forward(self,x,train_flg=True):if train_flg:self.mask=np.random.rand(*x.shape)>self.dropout_ratioreturn x*self.maskelse:return x*(1.0-self.dropout_ratio)def backward(self,dout):return dout*self.mask

随机删除的意思是指每次正向传播时,self.mask中都会以False的形式保存要删除的神经元。

np.random.rand(2,3) 随机生成[0,1)形状为(2,3)的数组
np.random.rand(2,3)>0.5 把大于0.5的值设为True,其余为False(而不是删除一半的意思,因为数据是随机的)
x * self.mask 结果就是False为0,True还是x原来的值

正向传播时传递了信号的神经元,反向传播时按照原样传递信号,正向传播时没有传递信号的神经元,反向传播时信号将停在那里。
        现在我们来比较使用Dropout和不使用Dropout的情况,还是基于MNIST数据集来测试
训练类(common.trainer.py)

import numpy as np
from common.optimizer import *class Trainer:'''把前面用来训练的代码做一个类'''def __init__(self,network,x_train,t_train,x_test,t_test,epochs=20,mini_batch_size=100,optimizer='SGD',optimizer_param={'lr':0.01},evaluate_sample_num_per_epoch=None,verbose=True):self.network=networkself.verbose=verbose#是否打印数据(调试或查看)self.x_train=x_trainself.t_train=t_trainself.x_test=x_testself.t_test=t_testself.epochs=epochsself.batch_size=mini_batch_sizeself.evaluate_sample_num_per_epoch=evaluate_sample_num_per_epochoptimizer_dict={'sgd':SGD,'momentum':Momentum,'nesterov':Nesterov,'adagrad':AdaGrad,'rmsprop':RMSprop,'adam':Adam}self.optimizer=optimizer_dict[optimizer.lower()](**optimizer_param)self.train_size=x_train.shape[0]self.iter_per_epoch=max(self.train_size/mini_batch_size,1)self.max_iter=int(epochs*self.iter_per_epoch)self.current_iter=0self.current_epoch=0self.train_loss_list=[]self.train_acc_list=[]self.test_acc_list=[]def train_step(self):batch_mask=np.random.choice(self.train_size,self.batch_size)x_batch=self.x_train[batch_mask]t_batch=self.t_train[batch_mask]grads=self.network.gradient(x_batch,t_batch)self.optimizer.update(self.network.params,grads)loss=self.network.loss(x_batch,t_batch)self.train_loss_list.append(loss)if self.verbose:print('训练损失值:'+str(loss))if self.current_iter%self.iter_per_epoch==0:self.current_epoch+=1x_train_sample,t_train_sample=self.x_train,self.t_trainx_test_sample,t_test_sample=self.x_test,self.t_testif not self.evaluate_sample_num_per_epoch is None:t=self.evaluate_sample_num_per_epochx_train_sample,t_train_sample=self.x_test[:t],self.t_test[:t]      train_acc=self.network.accuracy(x_train_sample,t_train_sample)test_acc=self.network.accuracy(x_test_sample,t_test_sample) self.train_acc_list.append(train_acc)self.test_acc_list.append(test_acc)if self.verbose:print('epoch:'+str(self.current_epoch)+',train acc:'+str(train_acc)+' | test acc:'+str(test_acc))self.current_iter+=1def train(self):for i in range(self.max_iter):self.train_step()test_acc=self.network.accuracy(self.x_test,self.t_test)if self.verbose:print('最终测试的正确率:'+str(format(test_acc,'.2%')))        
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net_extend import MultiLayerNetExtend
from common.trainer import Trainer(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True)
#截取少量数据,让它再现过拟合
x_train=x_train[:300]#(300,784)
t_train=t_train[:300]#构建7层神经网络(6个隐藏层)
epochsNum=300
network=MultiLayerNetExtend(inputSize=784,hiddenSizeList=[100,100,100,100,100,100],outputSize=10,use_dropout=True,dropout_ration=0.2)
trainer=Trainer(network,x_train,t_train,x_test,t_test,epochs=epochsNum,mini_batch_size=100,optimizer='sgd',optimizer_param={'lr':0.01},verbose=True)
trainer.train()#画图
train_acc_list,test_acc_list=trainer.train_acc_list,trainer.test_acc_list
x=np.arange(len(train_acc_list))
plt.plot(x,train_acc_list,marker='s',label='train',markevery=10)
plt.plot(x,test_acc_list,marker='d',label='test',markevery=10)
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.ylim(0,1.0)
plt.legend(loc='lower right')
plt.show()

不使用Dropout的情况,修改成use_dropout=False,我们会发现下图中,train数据过拟合了,所以很多时候我们都会优选Dropout来抑制过拟合。

其中需要用到的多层神经网络扩展版本(支持Dropout)multi_layer_net_extend.py

基于MNIST数据集的Batch Normalization(批标准化层)https://blog.csdn.net/weixin_41896770/article/details/121557928

抑制过拟合的方法之Dropout(随机删除神经元)相关推荐

  1. 抑制过拟合的方法之权值衰减

    机器学习很常见的一个需要解决的问题就是过度拟合(overift),过拟合的意思是它能够很好的拟合训练数据,但是对于训练数据之外的数据可能就显得差强人意了,也就是常说的泛化能力比较差,所以抑制过拟合就显 ...

  2. 抑制过拟合之正则化与Dropout

    避免过拟合: 1.增大数据集合 – 使用更多的数据,噪声点比减少(减少数据扰动所造成的影响) 2.减少数据特征 – 减少数据维度,高维空间密度小(减少模型复杂度) 3.正则化 / dropout / ...

  3. 深度学习-Tensorflow2.2-深度学习基础和tf.keras{1}-优化函数,学习速率,反向传播,网络优化与超参数选择,Dropout 抑制过拟合概述-07

    多层感知器: 优化使用梯度下降算法 学习速率 学习速率选取原则 反向传播 SGD RMSprop Adam learning_rate=0.01 # -*- coding: utf-8 -*- # - ...

  4. DL之DNN优化技术:DNN中抑制过拟合/欠拟合、提高泛化能力技术的简介、使用方法、案例应用详细攻略

    DL之DNN优化技术:DNN中抑制过拟合.提高泛化能力技术的简介.使用方法.案例应用详细攻略 目录 抑制过拟合.提高泛化能力技术的简介 1.过拟合现象的表述

  5. DL之DNN:利用MultiLayerNetExtend模型【6*100+ReLU+SGD,dropout】对Mnist数据集训练来抑制过拟合

    DL之DNN:利用MultiLayerNetExtend模型[6*100+ReLU+SGD,dropout]对Mnist数据集训练来抑制过拟合 目录 输出结果 设计思路 核心代码 更多输出 输出结果 ...

  6. Dropout抑制过拟合

    dropout 可以看出,网络中的的一层中的某些神经元被丢弃,网络变得简单了一些. Dropout解决过拟合的原因 (1)取平均的作用 (2)减少神经元之间复杂的共适应关系: 因为dropout程序导 ...

  7. 深度学习中防止过拟合的方法

    在深度学习中,当数据量不够大时候,常常采用下面4中方法: 1. 人工增加训练集的大小. 通过平移, 翻转, 加噪声等方法从已有数据中创造出一批"新"的数据.也就是Data Augm ...

  8. 神经网络防止过拟合的方法

    知乎上的回答:https://www.zhihu.com/question/59201590 深度学习防止过拟合的方法 过拟合即在训练误差很小,而泛化误差很大,因为模型可能过于的复杂,使其" ...

  9. 机器学习中用来防止过拟合的方法有哪些?

     机器学习中用来防止过拟合的方法有哪些? 雷锋网(公众号:雷锋网)按:本文作者 qqfly,上海交通大学机器人所博士生,本科毕业于清华大学机械工程系,主要研究方向机器视觉与运动规划,会写一些好玩的 ...

最新文章

  1. Java必刷100题
  2. PHP中的多行字符串传递给JavaScript方法两则
  3. python flask 如何修改默认端口号
  4. Spring ActiveMQ教程
  5. Kali Linux 网络扫描秘籍 第二章 探索扫描(二)
  6. jQuery 倒计时插件
  7. redis 雪崩、击穿、穿透
  8. SAP ERP和ORACLE ERP的区别是哪些?
  9. 移动ESP分区到磁盘最前端
  10. 网吧游戏二层更新linux,图文细说网吧游戏更新软件【处理手段】
  11. 惊喜! UE4 + ftrack开源了!
  12. devops summary
  13. win10应用闪退解决方法
  14. GUTI,Globally Unique Temporary UE Identity,全球唯一临时UE标识。
  15. python教程68--cufflinks库绘图功能
  16. python学习笔记——小插曲
  17. 7-16 然后是几点(15 分)
  18. 人生十鉴:大喜易失言,大哀易失值
  19. 2022年10月16日 记
  20. 并发-MESI缓存一直协议详解

热门文章

  1. jmap与jstat工具实战分析
  2. JS模拟实现数组的map方法
  3. Docker环境搭建,K8s
  4. 数据库表在join时的三种方式
  5. 201521123060 《Java程序设计》第12周学习总结
  6. redhat下升级gcc编译器
  7. 如何使用HttpContext对象
  8. Visual Assist X Options 常用宏
  9. C++_类和对象_对象特性_成员变量占用对象内存_成员函数_静态成员函数_静态变量_都不占用对象内存_他们是分开存储的---C++语言工作笔记048
  10. 软考信息系统项目管理师_体系介绍_证书作用价值_报考条件_考生分析---软考高级之信息系统项目管理师001