一.引言

使用 mean-std 归一化数值型 Tensor 时,出现 Nan 值,导致训练时出现 Nan Loss:

CSDN-BITDDD

通过下面几种方法简单处理下 Nan 值。

二.情景再现

出现 Nan 值是因为归一化时原始 Tensor 为全0导致 variance 为 0,从而 x - mean / std 得到 Nan

    # 初始化全0 Tensortensor = tf.constant(np.zeros(shape=(5, 3)), dtype='float32')# 获取方差均值mean, variance = tf.nn.moments(x=a, axes=[1])# -Meantensor -= tf.expand_dims(mean, axis=1)# /Stdtensor /= tf.expand_dims(variance, axis=1)print(tensor)
tf.Tensor(
[[nan nan nan][nan nan nan][nan nan nan][nan nan nan][nan nan nan]], shape=(5, 3), dtype=float32)

三.解决方案

通过情景再现我们定位了问题所在,方差为0导致除法得到 Nan,所以解决全 0 方差即可解决问题

1.tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,name=None):

tf.clip_by_value 函数中有两个参数 clip_value_min,clip_value_max,这两个值对 tensor t 中的值进行了限制,如果值小于等于 clip_value_min,则数值转换为 clip_value_min 对应的最小值,同样如果超过了 clip_value_max 的值,则会被替换为 clip_value_max。本例中最小值为0导致除法得到 Nan,所以可以限制最小值区间,例如 demo 中给到的 1e-8,这样全0的值都会转换为 1e-8,被除数不为0,归一化时就不会出现 Nan 了。

clip = tf.clip_by_value(variance, 1e-8, 1.0)
# clip 后的 variance
tf.Tensor([1.e-08 1.e-08 1.e-08 1.e-08 1.e-08], shape=(5,), dtype=float32)

2.tf.where

通过掩码 + where 的模式处理原始输入数据在数据预处理时经常用到,主要分两步:

A.计算mask

通过 tf.not_equal 判断 variance 中是否包含异常值 0

    mask_value = 0mask = tf.not_equal(variance, tf.constant(mask_value, dtype=variance.dtype))# Masktf.Tensor([False False False False False], shape=(5,), dtype=bool)

B.填充掩码

def where_v2(condition, x=None, y=None, name=None):

通过 tf.where 函数进行条件判断,condition 为 True 时选择 x 的值,为 False 时选择为 y 的值,默认值为 None,填充值 Padding 的选择一般有两个选择,填充后的 tensor 如果用于 softmax 函数,可以选择 -IntMax + 1,这样 exp 后会得到一个趋于0但不为0的值,如果使用 log 函数,可以使用一个极小值比如 1e-8 作为填充。

    # softmax paddings = tf.ones_like(variance) * (-2 ** 32 + 1)# logpaddings = tf.ones_like(variance) * 1e-8out = tf.where(mask, variance, paddings)# 掩码填充后的 variance tf.Tensor([1.e-08 1.e-08 1.e-08 1.e-08 1.e-08], shape=(5,), dtype=float32)

3.BatchNormalization

第三种方案参考了 BN 层的实现,BN 层通过滑动均值与滑动方差归一化时,在分母处添加了一个极小值 epsilon,这里也可以取 1e-8,在极小值的加持下保证了分母不为0从而避免了零除得到 Nan 的情况,简单实现的话也可以采用该方法。

# `(batch - self.moving_mean) / (self.moving_var + epsilon) * gamma + beta`.
    variance += 1e-8# + epsilon 后的 variancetf.Tensor([1.e-08 1.e-08 1.e-08 1.e-08 1.e-08], shape=(5,), dtype=float32)

四.总结

通过三种方案,tensor 归一化时 Nan 都会调整为 0,从而避免了报错。除了归一化可能遇到 Nan 值时,反向传播过程中也可能出现零除和 Nan 的情况,上述几种方法同样适用于其他步骤的数据处理。

​tf.Tensor(
[[nan nan nan][nan nan nan][nan nan nan][nan nan nan][nan nan nan]], shape=(5, 3), dtype=float32)↓↓↓↓↓↓↓↓↓↓tf.Tensor(
[[0. 0. 0.][0. 0. 0.][0. 0. 0.][0. 0. 0.][0. 0. 0.]], shape=(5, 3), dtype=float32)

Tensorflow - 训练中出现 Nan 值相关推荐

  1. TensorFlow中的Nan值的陷阱

    北京站 | NVIDIA DLI深度学习培训 2018年1月26日 NVIDIA 深度学习学院 带你快速进入火热的DL领域 阅读全文                           正文共1583 ...

  2. 深度学习网络训练中出现nan的原因分析

    报错: nan:Not a Number 该错误导致的后果:造成训练准确率的断崖式下跌 错误原因分析: 1)在loss函数中出现nan 出现原因:一般是因为tf中的log函数输入了'负数'或'0'值( ...

  3. mysql nan_在MySQL数据库中插入NaN值

    我有一些包含空值.浮点数和偶尔的Nan的数据.我正试图使用python和MySqldb将这些数据插入MySQL数据库. 这是插入语句:for row in zip(currents, voltages ...

  4. mysql中的nan_使用python-cod将MySql列中的“NAN”值转换为NULL

    我通过python在MySql中编写/存储数据.如果MySql中的列数据包含"NAN",那么我如何处理它们.目前我知道如何处理空白或无值,但在这里我被卡住了.数据看起来像这样# f ...

  5. python把nan值去掉_python – Keras Neural Nets,如何删除输出中的NaN值?

    我一直使用Keras从我的神经网络中获得一些NaN输出.我每10,000个结果只得到一个NaN.最初我有一个relu激活层进入最终的softmax层.这产生了更多的NaN结果.我将构成网络中最后两个密 ...

  6. 数值的加减会改变python中id,在python中调用Nan值并更改为数字

    ix已弃用,请不要使用它.在 选项1 我会用np.where-df = df.assign(pro=np.where(df.pro.isnull(), df.property_type, df.pro ...

  7. nc文件在ncl中取代nan值为缺省值

    begin f = addfile("tw.nc","w") tw = f->tw if (any(isnan_ieee(tw))) then       ...

  8. MATLAB去除矩阵中的NAN值

    //I1为待去除的矩阵 for i=1:m for j=1:n if isnan(I1(i,j)) I1(i,j)=0; end end end

  9. python找出值为nan_Python Numpy:找到list中的np.nan值方法

    这个问题源于在训练机器学习的一个模型时,使用训练数据时提示prepare的数据中存在np.nan 报错信息如下: ValueError: np.nan is an invalid document, ...

  10. python 时间序列异常值_python中缺少时间序列值

    插值和滤波: 由于是时间序列问题,我将在答案中使用o/p图图像进行解释: 假设我们有如下时间序列的数据:(在x轴上=天数,y=数量)pdDataFrame.set_index('Dates')['QU ...

最新文章

  1. php localcompare,GetDriveName 方法
  2. idea软件,如何不每次弹出“欢迎界面!”
  3. SpringMVC 框架系列之初识与入门实例
  4. 【MySQL】玩转定时器
  5. 阿里云云主机添加swap分区与swap性能优化
  6. 想学习Android开发
  7. 代码演示:先来后到的特例、优劣、源码分析
  8. 不相交集合求并的路径压缩
  9. 光端机的原理和使用范围
  10. 如果删除github上项目的文件
  11. pytorch 训练过程acc_Pytorch之Softmax多分类任务
  12. Linux 2 unit7 挂载网络共享
  13. RGB vs YCbCr(YUV)
  14. php访问属性两种方式,使用PHP访问对象的属性
  15. 阿里 java ide_纯JAVA版JAVA IDE环境(源码)
  16. PostgreSQL shapefile 导入导出
  17. html转pdf手机,html转pdf
  18. 夜神模拟器安装frida-server图文详解
  19. 【点云配准算法】【NDT】
  20. 基于OAuth2的认证

热门文章

  1. 第一讲 数系发展史纲
  2. unity 物体移动方式的一些笔记
  3. Linux 课程设计 每日小结
  4. 【元宇宙系列】元宇宙的创世居民——M 世代(Mateverse)
  5. 利用Hexo GitHub Page和 travis CI搭建播客
  6. 大牛深入讲解!6年老Android面经总结,系列教学
  7. python有哪些学习内容_python学习内容包括哪些
  8. 很抱歉,三维地图当前不能在你的国家/地区使用 Excel绘制三维地图问题解决
  9. Excel·VBA单元格重复值标记颜色
  10. 大压缩文件解压错误,台服wow common-2.MPQ 文件损坏