0. 写作目的

通过实验分析keras中Dropout在训练阶段和测试阶段的使用情况。

结论: Keras使用的 Inverted Dropout,因此测试时不需要修改 Dropout中的参数(rate)。

1.  Dropout 的实现方式

Dropout的实现方式有两种。

Dropout:(使用较少, AlexNet使用的是这种Dropout)

训练阶段:

keepProb: 保留该神经元的概率。

d3 = np.random.rand( a3.shape[0], a3.shape[1] ) < keepProb

测试阶段: 计算的结果需要乘以keepProb:

Inverted Dropout:(目前常用的方法)

训练阶段:

d3 = np.random.rand( a3.shape[0], a3.shape[1] ) < keepProb

a3 = a3 / keepProb

测试阶段:  

2. 实验验证Dropout的实现(也可以通过源码查看)

实验思路:通过训练带有Dropout的网络,然后加载训练的模型,并修改其中的Dropout的参数。观察在相同的数据集上的预测结果是否相同,为避免实验的随机性,对于测试实验运行10次观察结果。

实验过程:首先运行代码一,然后运行代码二,然后对比代码一与代码二的结果。

实验结果猜测:如果代码一的结果是代码二结果的 1/2,说明Keras中Dropout是采用AlexNet中的Dropout,如果代码一二的结果近似相等,说明Keras中Dropout使用的是Inverted Dropout。

代码一:

# _*_ coding:utf-8 _*_import keras
from keras.layers import Dense, Dropout, Input
from keras.optimizers import SGD
import numpy as np
from keras.models import Model, load_model
import tensorflow as tf## y = 2 * x1 + x2
def generateData():X = np.array([[3, 2], [2, 4], [1, 6]])y = np.array([[8], [8], [8]])return X, ydef Net(rate=0):tf.reset_default_graph()input_x = Input( shape=(2, ) )x = Dense(units=100, activation='linear')(input_x)x = Dropout(rate=rate)(x)x = Dense(units=100, activation='linear')(x)x = Dense(units=1, activation='linear')(x)model = Model(inputs=input_x, outputs=x)model.summary()return modeldef main():model_with = Net(rate=0.5)model_with.compile(optimizer=SGD(0.001), loss='mse')X, y = generateData()model_with.fit(X, y, nb_epoch=1000, verbose=0)model_with.save('model.h5')for ii in range(10):y_with = model_with.predict( X )print( 'model with dropout:{}'.format(y_with) )if __name__ == "__main__":main()

代码二:

#!/usr/bin/env python
# _*_ coding:utf-8 _*_import keras
from keras.layers import Dense, Dropout, Input
from keras.optimizers import SGD
import numpy as np
from keras.models import Model, load_model
import tensorflow as tf## y = 2 * x1 + x2
def generateData():X = np.array([[3, 2], [2, 4], [1, 6]])y = np.array([[8], [8], [8]])return X, ydef Net(rate=0):tf.reset_default_graph()input_x = Input( shape=(2, ) )x = Dense(units=100, activation='linear')(input_x)x = Dropout(rate=rate)(x)x = Dense(units=100, activation='linear')(x)x = Dense(units=1, activation='linear')(x)model = Model(inputs=input_x, outputs=x)model.summary()return modeldef main():X, y = generateData()model_without = Net(rate=0)model_without.load_weights('model.h5', by_name=True)# model_without = load_model( 'model.h5' )for ii in range(10):y_without = model_without.predict(X)print('model without dropout: {}'.format(y_without))if __name__ == "__main__":main()

3. 实验结果

代码一结果:

model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]
model with dropout:[[8.249627][8.171895][8.094164]]

代码二结果:

model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]
model without dropout: [[8.249627][8.171895][8.094164]]

4. 实验结论

通过实验结果可以看出,Keras中Dropout使用的是Inverted Dropout。

[Reference]

https://github.com/keras-team/keras/issues/5357

Keras 实现细节——dropout在训练阶段与测试阶段的使用分析相关推荐

  1. keras系列︱图像多分类训练与利用bottleneck features进行微调(三)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72861152 中文文档:http://keras-cn.readthedocs.io/ ...

  2. BN和Dropout在训练和测试时有哪些差别?

    作者丨海晨威@知乎 来源丨https://zhuanlan.zhihu.com/p/61725100 编辑丨极市平台 Batch Normalization BN,Batch Normalizatio ...

  3. 为什么极度随机树比随机森林更随机?这个极度随机的特性有什么好处?在训练阶段、极度随机数比随机森林快还是慢?

    为什么极度随机树比随机森林更随机?这个极度随机的特性有什么好处?在训练阶段.极度随机数比随机森林快还是慢? ExtRa Trees是Extremely Randomized Trees的缩写,意思就是 ...

  4. keras和tensorflow使用 keras.callbacks.EarlyStopping 提前结束训练

    此文首发于我的个人博客:keras和tensorflow使用 keras.callbacks.EarlyStopping 提前结束训练 - zhang0peter的个人博客 一般来说机器学习的训练次数 ...

  5. Fast-RCNN解析:训练阶段代码导读

    转载自:http://blog.csdn.net/linj_m/article/details/48930179#0-tsina-1-35514-397232819ff9a47a7b7e80a4061 ...

  6. HRBU 2021年暑期训练阶段二Day3

    目录 A - Shuffle'm Up 题目链接: 题意: 做法: B - Prime Path 题目链接: 题意: 做法: C - Function Run Fun 题目链接: 题意: 做法: D ...

  7. BN和Dropout在训练和测试时的差别

    Batch Normalization BN,Batch Normalization,就是在深度神经网络训练过程中使得每一层神经网络的输入保持相近的分布. BN训练和测试时的参数是一样的嘛? 对于BN ...

  8. keras指定gpu_Keras多GPU训练指南

    摘要:随着Keras(v2.0.8)最新版本的发布,使用多GPU 训练深度神经网络将变得非常容易,就跟调用函数一样简单!利用多GPU,能够获得准线性的提速. Keras是我最喜欢的Python深度学习 ...

  9. HRBU 2021年暑期训练阶段三Day1

    目录 A - Similar Strings 题目链接: 题意: 做法: B - card card card 题目链接: 题意: 做法: C - String 题目链接: 题意: 做法: D - C ...

最新文章

  1. MMD_5a_Clustering
  2. Oracle针对SCOTT下EMP表的练习题
  3. 请键入 net helpmsg 3534 以获得更多的帮助。_相遇不易,请珍惜
  4. 【Linux 内核 内存管理】RCU 机制 ④ ( RCU 模式下更新链表项 list_replace_rcu 函数 | 链表操作时使用 smp_wmb() 函数保证代码执行顺序 )
  5. OpenGL帧缓存对象(FBO:Frame Buffer Object)(转载)
  6. [转]JS脚本抢腾讯云学生1元代金券
  7. 关于 /dev/null 与 /dev/zero
  8. 客户端与服务器之间的文件传输,客户端与服务器的文件传输
  9. Redis 分布式集群搭建2022版本+密码(linux环境)
  10. 文件上传控件 css,CSS3 自定义文件上传输入控件界面
  11. python求解二次规划_Python二次规划和线性规划使用实例
  12. 【华为云技术分享】前端快速建⽴Mock App
  13. Spring Cloud 之 Ribbon,Spring RestTemplate 调用服务
  14. 原型和原型链 及 instanceof函数
  15. 使用python制作趣味小游戏—投骰子
  16. 服务器调用税务数字系统失败,终于等到你!网上报税常见问题解决方案大集锦!!!...
  17. ps 透明底和改变颜色
  18. 孙子兵法始计篇读后感&心得(下)
  19. wine 微信输入框不能正常显示(不显示)输入的文字
  20. web渗透--vnc密码破解

热门文章

  1. 利用计算机求该货车,吉林大学汽车理论第二次作业[7页].doc
  2. linux同步bios时间指令,Shell实现系统时间和BIOS时间同步校准脚本分享
  3. Jasper创建柱状图(6.17.0)
  4. 4.1-文本分类+超参搜索
  5. 【腾讯Bugly干货分享】从0到1打造直播 App 1
  6. NLP-python3 translate()报错问题-TypeError: translate() takes exactly one argument (2 given)
  7. 解决系统提示msvcr71.dll文件丢失的错误
  8. ABP中的依赖注入思想
  9. Delphi 7启动后提示Unable to rename delphi32.dro的解决办法
  10. 基于MATLAB GUI的匀速目标回波模拟器设计