Tensorflow 2.x代码中如何控制随机性

  • 引言
  • 随机性控制
    • 基本常用包中的随机性
    • Tensorflow中的随机性
  • 总结
  • References

引言

控制实验的随机性非常有必要:(1) 保证结果的可复现/重复性一直都是研究中的一个基本问题; (2)在验证所提方法/系统中往往需要做分离/消融实验来对结果进行拆分,以验证各个模块是否有效以及对总体结果的贡献,控制随机性可以消除因随机性的引入产生的影响。基于此,在此总结tensorflow2.x实验环境的随机性控制方法。

随机性控制

基本常用包中的随机性

numpyrandom, os中的随机性控制:

import numpy as np
import randomseed_value = 0
random.seed(seed_value )
np.random.seed(seed_value)
os.environ['PYTHONHASHSEED'] = str(seed_value)

在重复实验时,往往需要利用随机性来产生数据集的不同划分,例如:
(基于sklearn通过控制随机性来产生不同的划分)

from sklearn.model_selection import train_test_split, StratifiedKFold
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=seed_value)
cv = StratifiedKFold(n_splits=10, random_state=seed_value) # 10折

Tensorflow中的随机性

tensorflow2.x中的随机性控制

import tensorflow as tf
tf.random.set_seed(0)

tensorflow2.x 中各层的随机性控制
(1)各层初始化过程中的随机性
1D conv.为例,其API格式为:

tf.keras.layers.Conv1D(kernel_initializer=None, ....)

可以看到kernel_initializer默认为Null, 如何控制初始化中的随机性?

tn_initializer = tf.initializers.TruncatedNormal(seed=seed_value)
conv_1d_f_pointwise = tf.keras.layers.Conv1D(filters=1, kernel_size=1, activation=None, kernel_initializer=tn_initializer)

注:Conv1D还有其他参数初始化会引入随机性,此处仅以kernel_initializer为例.
(2)Dropout
Dropout层的构造函数如下:

__init__(self, rate=0.5, noise_shape=None, seed=None, name=None, **kwargs)

在使用过程中需要指定随机种子来控制dropout操作本身带来的随机性:

tf.keras.layers.Dropout(rate=self.drop_rate, seed=self.random_seed_alg_part)

(3) 训练过程中的随机性
如果调用model.fit()进行训练,API接口如下:

fit(self,x=None, y=None, batch_size=None,  epochs=1, verbose=1, shuffle=True, **kwargs)

其中参数shuffle默认True会带来随机性,需要将参数shuffle显式置为False.

(4) GPU多线程训练中的随机性
tensorflow-determinism包(https://github.com/NVIDIA/framework-determinism):该包解决了GPU上训练时不确定性的问题。注意:安装时通过命令:pip install tensorflow-determinism 安装, 安装后调用如下:

from tfdeterminism import patch
patch()

可惜的是:tfdeterminism 暂时只适用于小于2.1的tensorflow版本,因为目前没有适用于TensorFlow 2.1版或更高版本的修补程序

对于Tensorflow2.3版本(通过命令:pip package tensorflow=2.3.0安装),因为其本身实现了大多数当前可用的GPU确定性操作解决方案, 因此可以通过如下代码控制GPU中的不确定性:

import tensorflow as tf
import os
os.environ['TF_DETERMINISTIC_OPS'] = '1'

另外, 还看到如下代码: os.environ['TF_CUDNN_DETERMINISTIC'] = str(seed_value), 暂时还不清楚其适用场景。

总结

未完待续。

References

1.https://github.com/NVIDIA/framework-determinism
2.https://zhuanlan.zhihu.com/p/95416326
3.https://stackoverflow.com/questions/50744565/how-to-handle-non-determinism-when-training-on-a-gpu

Tensorflow 2.x代码中如何控制随机性以保证结果可重复性相关推荐

  1. 从Tensorflow代码中理解LSTM网络

    目录 RNN LSTM 参考文档与引子 缩略词  RNN (Recurrent neural network) 循环神经网络  LSTM (Long short-term memory) 长短期记忆人 ...

  2. tensorflow代码中的tf.app.run()

    一般 if __name__ == '__main__':之后紧接着的是主函数的运行入口,但在tensorflow的代码里头经常可以看到其后面的是tf.app.run(),这个究竟是什么意思呢??? ...

  3. swagger openapi开放平台 pyhton3.7实现http发送请求,pyhon中代码中发送http请求控制4g物联网开关

    swagger openapi开放平台 pyhton3.7实现http发送请求 pyhon中代码中控制 4g物联网开关,此代码与python2.7不兼容,具体体现在加密解: get_authoriza ...

  4. Camstar开发思考:如何在C#代码中控制事务

    目录 开发现状 开发问题 解决方案 1)自定义UserFunction 2)预调用服务 预调用服务方案设计与实现 代码设计 实现结果 开发现状 Camstar开发过程中,业务代码通常写在以下位置: 1 ...

  5. html中圆形单元格,HTML代码中关于table的边框控制以及单元格艰巨

    HTML代码中关于table的边框控制以及单元格艰巨 (2011-06-22 14:37:36) 标签: 南国 鹿灵子鹿胎膏 杂谈 1.若有一个1行1列的表格,设放table属性,使患上他的右.上两条 ...

  6. FOC 无感 代码 算法 电机控制 PMSM 基于中颍SH32F2601的洗衣机量产无感bldc控制方案,电机控制算法完全手写

    FOC 无感 代码 算法 电机控制 PMSM 基于中颍SH32F2601的洗衣机量产无感bldc控制方案,电机控制算法完全手写,MCU寄存器配置完全手写,未用到任何库文件 ID:34500065518 ...

  7. c# mysql代码中写事务_代码中添加事务控制 VS(数据库存储过程+事务) 保证数据的完整性与一致性...

    [c#]代码库代码中使用事务前提:务必保证一个功能(或用例)在同一个打开的数据连接上,放到同一个事务里面操作. 首先是在D层添加一个类为了保存当前操作的这一个连接放到一个事务中执行,并事务执行打开同一 ...

  8. tensorflow代码中tf.app.run()什么意思

    # 前面的代码省略了... ... ... ... def main(argv=None):mnist = input_data.read_data_sets("F:\mydata\Tens ...

  9. 独家 | 手把手教TensorFlow(附代码)

    上一期我们发布了"一文读懂TensorFlow(附代码.学习资料)",带领大家对TensorFlow进行了全面了解,并分享了入门所需的网站.图书.视频等资料,本期文章就来带你一步步 ...

  10. 【CVPR Oral】TensorFlow实现StarGAN代码全部开源,1天训练完

    [CVPR Oral]TensorFlow实现StarGAN代码全部开源,1天训练完 原文:https://mp.weixin.qq.com/s?__biz=MzI3MTA0MTk1MA==& ...

最新文章

  1. 十万浙企上云 阿里云崛起的最大征候?
  2. LeetCode上稀缺的四道shell编程题解析
  3. pytorch拼接函数:torch.stack()和torch.cat()--详解及例子
  4. URAL 1233 Amusing Numbers 好题
  5. lnmp之PDO_mysql.so
  6. python_selenium简单的滑动验证码
  7. 优控触摸屏使用手册_中达优控PLC触摸屏一体机说明书.pdf
  8. Linux忘记密码的找回方法
  9. EasyAr聚焦模式
  10. Word另存为PDF后无导航栏解决办法
  11. java三角形角度_利用java解决三角形角度问题
  12. Kalman Fuzzy Actor-Critic Learning Automaton Algorithm for the Pursuit-Evasion Differential Game
  13. win10系统电脑点击桌面图标没反应怎么处理
  14. ubuntu18.04更新内核导致显卡驱动失效
  15. POSE estimation,肢体估计HPE
  16. 快递查询方法一键查询物流信息
  17. Android 高德地图so包太大,高德地图sdk配置心得(jar文件与so文件导入)
  18. SAP-财务-统驭科目
  19. 手写springboot自动装配 autoConfiguration
  20. 1.调查问卷-接口文档

热门文章

  1. 初学Laravel框架与ThinkPHP框架的不同
  2. 创建类模式(零):简单/静态工厂(Static Factory)
  3. Cocos2d-x 3.0 开发(四)使用CocoStudio创建UI并载入到程序中
  4. 64位sql server 如何使用链接服务器连接Access
  5. 借助Sigar API获取网络信息
  6. Eclipse—在Eclipse中如何创建JavaWeb工程
  7. 关于@NotNull 和 @Nullable
  8. Java设计模式之单例(Singleton)模式解析
  9. jdbc数据库连接池连接
  10. 二级公共基础知识_二级公共基础知识 01