http://www.sohu.com/a/150686946_116235

作者:崔静闯

神经网络算法利用了随机性,比如初始化随机权重,因此用同样的数据训练同一个网络会得到不同的结果。

初学者可能会有些懵圈,因为算法表现得不太稳定。但实际上它们就是这么设计的。随机初始化可以让网络通过学习,得到一个所学函数的很好的近似。

然而, 有时候用同样的数据训练同一个网络,你需要每次都得到完全相同的结果。例如在教学和产品上。

在这个教程中,你会学到怎样设置随机数生成器,才能每次用同样的数据训练同一网络时,都能得到同样的结果。

我们开始。

教程概览

这个教程分为六部分:

为啥我每次得到的结果都不一样?不同结果的演示解决方法用Theano 后端设置随机数种子用TensorFlow 后端设置随机数种子得到的结果还是不同,咋办?

运行环境

该教程需要你安装了Python SciPy。你能用Python2或3来演示这个例子需要你安装Keras (v2.0.3+),后台为TensorFlow (v1.1.0+)或Theano (v0.9+)还需要你安装了scikit-learn,Pandas,NumPy以及Matplotlib

如果在Python环境的设置方面需要帮助,请看下面这个帖子:

How to Setup a Python Environment for Machine Learning and Deep Learning with Anaconda

为啥我每次得到的结果都不一样?

我发现这对神经网络和深度学习的初学者而言是个常见问题。

这种误解可能出于以下问题:

我如何得到稳定的结果?我如何得到可重复的结果我应该如何设置种子点

神经网络特意用随机性来保证,能通过有效学习得到问题的近似函数。采用随机性的原因是:用它的机器学习算法,要比不用它的效果更好。

在神经网络中,最常见的利用随机性的方式是网络权值的随机初始化,尽管在其他地方也能利用随机性,这有一个简短的清单:

初始化的随机性,比如权值正则化的随机性,比如dropout层的随机性,比如词嵌入最优化的随机性,比如随机优化

这些甚至更多的随机性来源意味着,当你对同一数据运行同一个神经网络算法时,注定得到不同的结果。

想了解更多关于随机算法的原委,参考下面的帖子

Embrace Randomness in Machine Learning

不同结果的演示

我们可以用一个小例子来演示神经网络的随机性.

在这一节中,我们会建立一个多层感知器模型来学习一个以0.1为间隔的从0.0到0.9的短序列。给出0.0,模型必须预测出0.1;给出0.1,模型必须预测出0.2;以此类推。

下面是准备数据的代码

# create sequence length = 10 sequence = [i/float(length) for i in range(length)] # create X/y pairs df = DataFrame(sequence) df = concat([df.shift(1), df], axis=1) df.dropna(inplace=True) # convert to MLPfriendly format values = df.values X, y = values[:,0], values[:,1]

我们要用的网络,有1个输入,10个隐层节点和1个输出。这个网络将采用均方差作为损失函数,用高效的ADAM算法来训练数据

这个网络需要约1000轮才能有效的解决这个问题,但我们只对它训练100轮。这样是为了确保我们在预测时能得到一个有误差的模型。

网络训练完之后,我们要对数据集进行预测并且输出均方差

建立网络的代码如下

# design network model = Sequential() model.add(Dense(10, input_dim=1)) model.add(Dense(1)) model.compile(loss='mean_squared_error', optimizer='adam') # fit network model.fit(X, y, epochs=100, batch_size=len(X), verbose=0) # forecast yhat = model.predict(X, verbose=0) print(mean_squared_error(y, yhat[:,0]))

在这个例子中,我们要建立10次网络并且输出10个不同的网络得分

完整的代码如下

from pandas import DataFrame from pandas import concat from keras.models import Sequential from keras.layers import Dense from sklearn.metrics import mean_squared_error # fit MLP to dataset and print error def fit_model(X, y): # design network model = Sequential() model.add(Dense(10, input_dim=1)) model.add(Dense(1)) model.compile(loss='mean_squared_error', optimizer='adam') # fit network model.fit(X, y, epochs=100, batch_size=len(X), verbose=0) # forecast yhat = model.predict(X, verbose=0) print(mean_squared_error(y, yhat[:,0])) # create sequence length = 10 sequence = [i/float(length) for i in range(length)] # create X/y pairs df = DataFrame(sequence) df = concat([df.shift(1), df], axis=1) df.dropna(inplace=True) # convert to MLP friendly format values = df.values X, y = values[:,0], values[:,1] # repeat experiment repeats = 10 for _ in range(repeats): fit_model(X, y)

运行这个例子会在每一行输出一个不同的精确值,具体结果也都不同。

下面是一个输出的示例

0.0282584265697 0.0457025913022 0.145698137198 0.0873461454407 0.0309397604521 0.046649185173 0.0958450337178 0.0130660263779 0.00625176026631 0.00296055161492

解决方案

下面是两个主要的解决方案。

解决方案#1:重复实验

解决这个问题传统且切实可行的方法是多次运行网络(30+),然后运用统计学方法概括模型的性能,并与其他模型作比较。

我强烈推荐这种方法,但是由于有些模型的训练时间太长,这种方法并不总是可行的。

解决方案#2:设置随机数字生成器的种子

另一种解决方案是为随机数字生成器使用固定的种子。

随机数由伪随机数生成器生成。一个随机生成器就是一个数学函数,该函数将生成一长串数字,这些数字对于一般目的的应用足够随机。

随机生成器需要一个种子点开启该进程,在大多数实现中,通常默认使用以毫秒为单位的当前时间。这是为了确保,默认情况下每次运行代码都会生成不同的随机数字序列。该种子点可以是指定数字,比如“1”,来保证每次代码运行时生成相同的随机数序列。只要运行代码时指定的种子的值不变,它是什么并不重要。

设置随机数生成器的具体方法取决于后端,我们将探究下在Theano和TensorFlow后端下怎样做到这点。

用Theano后端设置随机数种子

通常,Keras从NumPy随机数生成器中获得随机源。

大部分情况下,Theano后端也是这样。

我们可以通过从random模块中调用seed()函数的方式,设置NumPy随机数生成器的种子,如下面所示:

from numpy.random import seed seed(1)

最好在代码文件的顶部导入和调用seed函数。

这是最佳的实现方式(best practice),这是因为当各种各样的Keras或者Theano(或者其他的)库作为初始化的一部分被导入时,甚至在直接使用他们之前,可能会用到一些随机性。

我们可以在上面示例的顶端再加两行,并运行两次。

每次运行代码时,可以看到相同的均方差值的列表(在不同的机器上可能会有一些微小变化,这取决于机器的精度),如下面的示例所示:

0.169326527063 2.75750621228e-05 0.0183287291562 1.93553737255e-07 0.0549871087449 0.0906326807824 0.00337575114075 0.00414857518259 8.14587362008e-08 0.0522927019639

你的结果应该跟我的差不多(忽略微小的精度差异)。

用TensorFlow后端设置随机数种子

Keras从NumPy随机生成器中获得随机源,所以不管使用Theano或者TensorFlow后端的哪一个,都必须设置种子点。

必须在其他模块的导入或者其他代码之前,文件的顶端部分通过调用seed()函数设置种子点。

from numpy.random import seed seed(1)

另外,TensorFlow有自己的随机数生成器,该生成器也必须在NumPy随机数生成器之后通过立马调用 set_random_seed() 函数设置种子点。

from tensorflow import set_random_seed set_random_seed(2)

要明确的是,在代码文件的顶端,在其他之前,一定要有以下4行:

from numpy.random import seed seed(1) from tensorflow import set_random_seed set_random_seed(2)

你可以使用两个相同或者不同的种子。我认为这不会造成多大差别,因为随机源进入了不同的进程。

在以上示例中增加这4行,可以使代码每次运行时都产生相同的结果。你应该看到与下面列出的相同的均方差值(也许有一些微小差别,这取决于不同机器的精度):

0.224045112999 0.00154879478823 0.00387589994044 0.0292376881968 0.00945528404353 0.013305765525 0.0206255228201 0.0359538356108 0.00441943512128 0.298706569397

你的结果应该与我的差不多(忽略精度的微小差异)。

如果我仍然得到不同的结果,怎么办?

为了重复迭代,报告结果和比较模型鲁棒性最好的做法是多次(30+)重复实验,并使用汇总统计。如果这是不可行的,你可以通过为代码使用的随机数发生器设置种子来获得100%可重复的结果。

如果你已经按照上面的说明去做,仍然用相同的数据从相同的算法中获得了不同的结果,怎么办?

这可能是有其他的随机源你还没有考虑到。

来自第三方库的随机性

也许你的代码使用了另外的库,该库使用不同的也必须设置种子的随机数生成器。

试着将你的代码简化到最低要求(例如,一个数据样本,一轮训练等等),并仔细阅读API文档,尽力减少可能引入随机性的第三方库。

使用GPU产生的随机性

以上所有示例都假设代码是在一个CPU上运行的。

这种情况也是有可能的,就是当使用GPU训练模型时,可能后端设置的是使用一套复杂的GPU库,这些库中有些可能会引入他们自己的随机源,你可能会或者不会考虑到这个。

例如,有证据显示如果你在堆栈中使用了 Nvidia cuDNN,这可能引入额外的随机源( introduce additional sources of randomness),并且使结果不能准确再现。

来自复杂模型的随机性

由于模型的复杂性和训练的并行性,你可能会得到不可复现的结果。

这很可能是由后端库的效率造成的,或者是不能在内核中使用随机数序列。我自己没有遇到过这个,但是在一些GitHub问题和StackOverflowde问题中看到了一些案例。

如果只是缩小成因的范围的话,你可以尝试降低模型的复杂度,看这样是否影响结果的再现。

我建议您阅读一下你的后端是怎么使用随机性的,并看一下是否有任何选项向你开放。在Theano中,参考:

Random NumbersFriendly random numbersUsing Random Numbers

在TensorFlow中,参考:

Constants, Sequences, and Random Values

tf.set_random_seed

另外,为了更深入地了解,考虑一下寻找拥有同样问题的其他人。一些很好的搜寻平台包括GitHub、StackOverflow 和 CrossValidated。

总结

在本教程中,你了解了如何在Keras上得到神经网络模型的可重复结果。特别是,你学习到了:

神经网络是有意设计成随机的,固定随机源可以使结果可复现。你可以为NumPy和TensorFlow的随机数生成器设置种子点,这将使大多数的Keras代码100%的可重复使用。在有些情况下存在另外的随机源,并且你知道如何找出他们,或许也是固定它们。

End.

用深度学习每次得到的结果都不一样,怎么办?相关推荐

  1. 深度学习每次结果不一样

    文章目录 为什么不一样 如何使得他们一样 深度学习算法在开始训练的时候,都会对神经网络进行初始化,这个初始化是由随机数来确定的.我们如果使用同一个数据,同一个网络,同样的参数设置,由于随机初始化的不同 ...

  2. 深度学习的光环背后,都有哪些机器学习的新进展被忽视了?

    2020-01-27 19:22 导语:机器学习领域的下一场革命开始萌芽了吗? 雷锋网 AI 科技评论按:从神经网络被学术界排挤,到计算机科学界三句话不离人工智能.各种建模和预测任务被深度学习大包大揽 ...

  3. 【深度学习】每个数据科学家都必须了解的 6 种神经网络类型

    神经网络是强大的深度学习模型,能够在几秒钟内合成大量数据.有许多不同类型的神经网络,它们帮助我们完成各种日常任务,从推荐电影或音乐到帮助我们在线购物. 与飞机受到鸟类启发的方式类似,神经网络(NNs) ...

  4. 人工智能之深度学习常见应用方向你都了解吗?(文末包邮送书5本)

    文章目录 本文导读 1. 数字识别 2. 图像识别 3. 图像分类 4. 目标检测 5. 人脸识别 6. 文本分类 7. 聊天机器人 8. 书籍推荐(包邮送书5本) 本文导读 从零带你了解深度学习常见 ...

  5. 一文读懂AI圣经,凡研究《深度学习》都知道的一本书!

    由深度学习领域三位前沿.权威的专家Ian Goodfellow.Yoshua Bengio和Aaron Courville合著的人工智能领域的圣经.长期位居美国亚马逊人工智能类图书榜首的<深度学 ...

  6. 深度学习——数据预处理篇

    深度学习--数据预处理篇 文章目录 深度学习--数据预处理篇 一.前言 二.常用的数据预处理方法 零均值化(中心化) 数据归一化(normalization) 主成分分析(PCA.Principal ...

  7. 深度学习的算法实践和演进

    1. 前言 如果说高德纳的著作奠定了第一代计算机算法,那么传统机器学习则扩展出第二代,而近十年崛起的深度学习则是传统机器学习上进一步发展的第三代算法.深度学习算法的魅力在于它核心逻辑的简单且通用. 在 ...

  8. 从0开始,基于Python探究深度学习神经网络

    来源 |  Data Science from Scratch, Second Edition 作者 | Joel Grus 全文共6778字,预计阅读时间50分钟. 深度学习 1.  张量 2.  ...

  9. 医生再添新助手!深度学习诊断传染病 | 完整代码+实操

    作者 | Dipanjan (DJ) Sarkar 译者 | Monanfei 编辑 | Rachel.Jane 出品 | AI科技大本营(id:rgznai100) [导读]文本基于深度学习和迁移学 ...

最新文章

  1. 给JFinal添加 Sqlite 数据库支持
  2. mysql int 默认值 为ull_mysql的 约束 数据库设计 数据库 存储 触发器 mysql 权限问题...
  3. 查看Linux服务器网卡流量小脚本shell和Python各一例
  4. lsnrctl 与 tnsnames.ora 的联系
  5. 【数据挖掘】高斯混合模型 ( 模型简介 | 软聚类 | 概率作用 | 高斯分布 | 概率密度函数 | 高斯混合模型参数 | 概率密度函数 )
  6. internship weekly task update
  7. C语音的预处理,编译,汇编,链接过程分析
  8. Java知多少(96)绘图之设置字型和颜色
  9. SAP License:销售流程
  10. Qt4_实现Edit菜单
  11. 服务器设置客户端网页安装,在Windows 7环境下安装并配置web、SSH、E-mail、FTP等服务器...
  12. 程序员文档写作能力(三)-如何处理好微信、邮件、开会时的话术
  13. 屏幕录像专家 - 视频压缩教程
  14. bat 等待输入_继续提速——双拼的进阶,音形输入
  15. 参考文献的序号怎么对齐_word序号对齐方式 word中如何让编号自动对齐
  16. java-setBounds方法
  17. KECRS: Towards Knowledge-Enriched Conversational Recommendation System
  18. hbase snappy 安装_hbase自带snappy压缩测试出错
  19. 奔梦向前-代码实现表白男生女生-2020-06-15
  20. 线性回归算法梳理——Test1

热门文章

  1. iphone常用代码锦集(二)
  2. 基于SSH的校园二手物品交易系统
  3. bc在计算机领域是什么意思,“BC”是“Before Computers”的缩写,意思是“在计算机之前”...
  4. 设计全局ER模型 数据库系统原理(2007版) 课程代码4735 笔记
  5. ps将logo变透明
  6. Php ui 3dmax,Unity3d和3dMax美工功能简介
  7. HDU 4389 - X mod f(x)
  8. Web 服务系列标准和规范
  9. java 操作 word 表格和样式_java 处理word文档 (含图片,表格内容)
  10. python判断是不是字母_python判断字符是否为字母和数字