深度学习Deep Learning: dropout策略防止过拟合
本文参考 hogo在youtube上的视频 :https://www.youtube.com/watch?v=UcKPdAM8cnI
一、理论基础
dropout的提出是为了防止在训练过程中的过拟合现象,那就有人想了,能不能对每一个输入样本训练一个模型,然后在test阶段将每个模型取均值,这样通过所有模型共同作用,可以将样本最有用的信息提取出来,而把一些噪声过滤掉。
那如何来实现这种想法呢?在每一轮训练过程中,我们对隐含层的每个神经元以一定的概率p舍弃掉,这样相当于每一个样本都训练出一个模型。假设有H个神经元,那么就有2H种可能性,对应2H模型,训练起来时间复杂度太高。我们通过权重
共享(weights sharing)的方法来简化训练过程,每个样本所对应模型是部分权重共享的,只有被舍弃掉那部分权重不同。
使用dropout可以使用使一个隐含结点不能与其它隐含结点完全协同合作,因此其它的隐含结点可能被舍弃,这样就不能通过所有的隐含结点共同作用训练出复杂的模型(只针对某一个训练样本),我们不能确定其它隐含结点此时是否被激活,这样
就有效的防止了过拟合现象。
如下图所示,在训练过程中神经元以概率p出现,而在测试阶段它一直都存在。
注:如果有多个隐含层,那么对每一个隐含层分别使用dropout策略
1.1 训练阶段
forward propagation
在前向传播过程中,使用掩模m(k)uq将部分隐含层结点舍弃。
backpropagation
反向传播阶段,即权重调整阶段,通过掩模只调整那些未被舍弃的结点的权重。
1.2 测试阶段
在前面介绍过,我们可能训练出很多种模型,在测试阶段对其取平均,有两种取平均的方法:
假设有两个模型m1,m2,输出分别为O1,O2,最终输出为O
1、mixture
O = (O1+O2) / 2
2、product (geometric mean)
O = sqrt( O1*O2 )
以上这两种方法都是非常耗时的,我们使用一种挖的方法,即对模型的输出乘以0.5(假设dropout的概率是0.5),如果仅包含一个隐含结点,那么这种方法与geometric mean结果相同,反之,也可以很好的近似。
如果dropout的概率是0.5,那么就对所有输出乘以0.5,
dropout是在每一轮的权重调整时(backpropagation时)在隐含层以一定的概率舍弃某些神经元(一般取0.5),因此每个神经元只以上一层的一部分神经元相关,即隐含层每个神经元相当于单独训练,即每个神经元模型独立。
1.3 对输入层的dropout
以上讲的是如何在隐含层做dropout,其实也可以在输入层做dropout,这就是前面提到的denoising策略,只是我们以比较大的概率将输入保留下来 。
1.4 denoising and dropout
- denoising用于输入层,dropout用于隐含层
- denoising是用于无监督训练,dropout用于有监督训练,denoising可用于有监督训练的预训练过程。
- 两者都用来防止过拟合
二、实验部分
本实验使用deepLearnToolbox 工具包,将autoencoder模型使用dropout前后的结果进行比较。
dropout并没有明显的降低误差率,可能需要调参吧。。作者在论文中的效果非常之明显。
误差率:
without dropout 0.18300
with dropout : 0.144000
实验主要代码:
nn.dropoutFraction = 0.5; 用来设置dropout的百分比,一般0.5的效果最好。
load mnist_uint8; train_x = double(train_x(1:2000,:)) / 255; test_x = double(test_x(1:1000,:)) / 255; train_y = double(train_y(1:2000,:)); test_y = double(test_y(1:1000,:));%% //实验一without dropout rand('state',0) sae = saesetup([784 100]); sae.ae{1}.activation_function = 'sigm'; sae.ae{1}.learningRate = 1 opts.numepochs = 10; opts.batchsize = 100; sae = saetrain(sae , train_x , opts ); visualize(sae.ae{1}.W{1}(:,2:end)');nn = nnsetup([784 100 10]);% //初步构造了一个输入-隐含-输出层网络,其中包括了% //权值的初始化,学习率,momentum,激发函数类型,% //惩罚系数,dropout等nn.W{1} = sae.ae{1}.W{1};opts.numepochs = 10; % //Number of full sweeps through data opts.batchsize = 100; % //Take a mean gradient step over this many samples [nn, L] = nntrain(nn, train_x, train_y, opts); [er, bad] = nntest(nn, test_x, test_y); str = sprintf('testing error rate is: %f',er); disp(str)%% //实验二:with dropout rand('state',0) sae = saesetup([784 100]); sae.ae{1}.activation_function = 'sigm'; sae.ae{1}.learningRate = 1;opts.numepochs = 10; opts.bachsize = 100; sae = saetrain(sae , train_x , opts ); figure; visualize(sae.ae{1}.W{1}(:,2:end)');nn = nnsetup([784 100 10]);% //初步构造了一个输入-隐含-输出层网络,其中包括了% //权值的初始化,学习率,momentum,激发函数类型,% //惩罚系数,dropout等 nn.dropoutFraction = 0.5; nn.W{1} = sae.ae{1}.W{1}; opts.numepochs = 10; % //Number of full sweeps through data opts.batchsize = 100; % //Take a mean gradient step over this many samples [nn, L] = nntrain(nn, train_x, train_y, opts); [er, bad] = nntest(nn, test_x, test_y); str = sprintf('testing error rate is: %f',er); disp(str)
参考文献:
hintin dropout youtube视频:https://www.youtube.com/watch?v=5t-mVtrFVyY
hogo dropout youtube视频:https://www.youtube.com/watch?v=UcKPdAM8cnI
deepLearn Toolbox使用: http://www.cnblogs.com/dupuleng/articles/4340293.html
hinton原文 Dropout: A simple Way to prevent Neural Networks from Overfitting
from: http://www.cnblogs.com/dupuleng/articles/4341265.html
深度学习Deep Learning: dropout策略防止过拟合相关推荐
- 机器学习(Machine Learning)深度学习(Deep Learning)资料(Chapter 2)
机器学习(Machine Learning)&深度学习(Deep Learning)资料(Chapter 2) - tony的专栏 - 博客频道 - CSDN.NET 注:机器学习资料篇目一共 ...
- 【深度学习Deep Learning】资料大全
感谢关注天善智能,走好数据之路↑↑↑ 欢迎关注天善智能,我们是专注于商业智能BI,人工智能AI,大数据分析与挖掘领域的垂直社区,学习,问答.求职一站式搞定! 对商业智能BI.大数据分析挖掘.机器学习, ...
- 机器学习(Machine Learning)深度学习(Deep Learning)资料汇总
本文来源:https://github.com/ty4z2008/Qix/blob/master/dl.md 机器学习(Machine Learning)&深度学习(Deep Learning ...
- 深度学习Deep Learning 资料大全
转自:http://www.cnblogs.com/charlotte77/ [深度学习Deep Learning]资料大全 最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: ...
- 机器学习(Machine Learning)深度学习(Deep Learning)资料【转】
转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...
- 机器学习(Machine Learning)深度学习(Deep Learning)资料集合
机器学习(Machine Learning)&深度学习(Deep Learning)资料 原文链接:https://github.com/ty4z2008/Qix/blob/master/dl ...
- 大量机器学习(Machine Learning)深度学习(Deep Learning)资料
机器学习目前比较热,网上也散落着很多相关的公开课和学习资源,这里基于课程图谱的机器学习公开课标签做一个汇总整理,便于大家参考对比. 1.Coursera上斯坦福大学Andrew Ng教授的" ...
- (转)机器学习(Machine Learning)深度学习(Deep Learning)资料
原文链接:https://github.com/ty4z2008/Qix/blob/master/dl.md 机器学习(Machine Learning)&深度学习(Deep Learning ...
- 深度学习Deep learning小白入门笔记——PanGu模型训练分析
书接上回 深度学习Deep learning小白入门笔记--在AI平台上训练LLM--PanGu 对训练模型重新认知与评估. 模型评估 在训练过程中或训练完成后,通常使用验证集或测试集来评估模型的性能 ...
最新文章
- 当PrintForm遇到RPC服务不可用的错误”
- 在解决方案中所使用 NuGet 管理软件包依赖
- char* 大小_SQL Server中char, nchar, varchar和nvarchar数据类型有何区别
- CVPR 2016 SINT:《Siamese Instance Search for Tracking》论文笔记
- 关于Flex-Mvc的几个框架的简单介绍
- Exception in thread http-bio-8081-exec-3 java.lang.OutOfMemoryError: PermGen space
- 阿里发布天猫精灵X1 探索人机交互新大陆
- 当年叱咤风云的框架Struts2,你可知Struts2内功如何修炼之体系结构
- 当点击ListView的列头时,对ListView排序
- python逻辑型数据也叫什么_python基础(三)python数据类型
- Linux环境下编译运行大型C语言项目
- JDBC与数据库连接池
- 十次方社交系统开发项目 源码 视频 文档 工具 合集百度云下载地址
- 操作系统实验一:进程管理(含成功运行C语言源代码)
- 如何将视频网站的视频下载为mp4格式
- 入门级动态规划:2018年第九届蓝桥杯省赛B组第四题—测试次数( 摔手机 )
- Chisel 手册(中文part1)
- 异步传输模式 Asynchronous Transfer Mode
- 二维数组与指向指针的指针
- 二次规划(1):Lagrange法
热门文章
- 图数据库应用:金融反欺诈实践
- vue语法 `${ }` (模版字符串)
- 人工智能靠人工:标注员1天要听1000条录音
- 机器学习入门系列四(关键词:BP神经网络)
- c语言开发游戏趋势,都9012年了,为何我还坚持用C语言开发游戏
- linux添加硬盘分区设置柱面,linux 下添加新硬盘设备和硬盘分区格式化挂载使用磁盘配额限制...
- MyBatis-24MyBatis缓存配置【集成EhCache】
- java compare 返回值_关于Java你不知道的那些事之Java8新特性[Lambda表达式和函数式接口]...
- hashmap 判断key是否存在
- spring之java配置(springboot推荐的配置方式)