• 结论:在搭建GAN判别器时,慎用keras。
  • 原因:tf.variable_scope()  的 reuse=True 对keras变量不起作用,导致判别real_img的判别器与判别fake_img的判别器是两个独立的判别器,使对抗训练失效。
  • 导致现象:real_img 的判别结果快速收敛到1, fake_img的判别结果快速收敛到0, 且不再变化, 快速收敛是指在100~1000次迭代内完全收敛,不再变化。当判别器损失函数中用到log函数时, 损失函数会出现 nan 现象。
  • 其他: 无论是用keras.layers 还是tensorflow.keras.layers都存在这个问题。具体在keras中如何使用变量重用,我还没有探究,希望有看到的可以指教。

-------------------------------------------------------------------------------------------------------------------------------------------------

个人经历:

我之前在搭建神经网络时,用的卷积层是定义在tf.layers中的:

from tensorflow.compat.v1.layers import Conv2D

最近,发现keras 很好用,尤其对LSTM的封装很好用,所以开始转用keras:

from keras.layers import Conv2D

一直用的挺顺利的,直到这两天在训练GAN时出现问题:判别器对fake_img的判别分值很快就趋近于0,对real_img的判别分值很快就趋近于1。

我百思不得其解,调了学习率,多次检查了GAN的损失函数,都没有问题,折腾了一天时间。直到刚刚突然想到变量重用的问题,检查了变量空间,发现以下问题:用keras的Conv2D构建的判别器,在计算real_map与fake_map时,重复创建了判别器的计算节点,并没有执行变量重用,示例代码如下:

import numpy as np
import tensorflow.compat.v1 as tf
from keras.layers import Conv2D
from tensorflow.keras.layers import Conv2Dinputs = tf.convert_to_tensor(np.random.random((1,256,256,3)).astype(np.float32))with tf.variable_scope("dis"):conv1 = Conv2D(32,(3,3),strides=(1,1),padding="same",activation='relu',name="conv1")(inputs)with tf.variable_scope("dis",reuse=True):conv2 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv1")(inputs)var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for v in var_list:print(v.name)"""
返回:
dis/conv1/kernel:0
dis/conv1/bias:0
dis_1/conv1/kernel:0
dis_1/conv1/bias:0
"""

由代码可见,我在命名空间中设置了 reuse=True, 但是代码仍旧重复创建了相同的计算节点。当在训练GAN时遇到这个问题时,就会导致 计算real_map的判别器与计算fake_map的判别器是两个独立的判别器,他们两个完全不相关,所以就达不到对抗训练的目的了。

很明显,这个问题是不应该出现的,tensorflow设置了reuse参数,应该是要起作用的,于是我又测试了tf.layer.Conv2D的情况,发现用tf.layer.Conv2D 不存在这个问题,示例代码如下(仅仅改变了Conv2D的来源):

import numpy as np
import tensorflow.compat.v1 as tf
# from keras.layers import Conv2D
from tensorflow.compat.v1.layers import Conv2Dinputs = tf.convert_to_tensor(np.random.random((1,256,256,3)).astype(np.float32))with tf.variable_scope("dis"):conv1 = Conv2D(32,(3,3),strides=(1,1),padding="same",activation='relu',name="conv1")(inputs)with tf.variable_scope("dis",reuse=True):conv2 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv1")(inputs)var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for v in var_list:print(v.name)
"""
返回:
dis/conv1/kernel:0
dis/conv1/bias:0
"""

可以发现,当改用 tf.compat.v1.layers 的源时不存在这个问题。

另外, 在reuse = True的命名空间中,只能重复使用已经在该空间中定义的变量,不能创建新的变量,否则会报错。

如果,既希望能够变量重用,又不耽误创建新的变量,请用 reuse=tf.AUTO_REUSE,示例代码如下(与上例相比增加了一个新变量):

import numpy as np
import tensorflow.compat.v1 as tf
# from keras.layers import Conv2D
from tensorflow.compat.v1.layers import Conv2Dinputs = tf.convert_to_tensor(np.random.random((1,256,256,3)).astype(np.float32))with tf.variable_scope("dis"):conv1 = Conv2D(32,(3,3),strides=(1,1),padding="same",activation='relu',name="conv1")(inputs)with tf.variable_scope("dis",reuse=True):conv2 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv1")(inputs)conv3 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv2")(inputs)var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for v in var_list:print(v.name)
"""
报错:
ValueError: Variable dis/conv2/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
"""

当把 reuse=True 改为 reuse=tf.AUTO_REUSE后,错误消失:

import numpy as np
import tensorflow.compat.v1 as tf
# from keras.layers import Conv2D
from tensorflow.compat.v1.layers import Conv2Dinputs = tf.convert_to_tensor(np.random.random((1,256,256,3)).astype(np.float32))with tf.variable_scope("dis"):conv1 = Conv2D(32,(3,3),strides=(1,1),padding="same",activation='relu',name="conv1")(inputs)with tf.variable_scope("dis",reuse=tf.AUTO_REUSE):conv2 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv1")(inputs)conv3 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv2")(inputs)var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for v in var_list:print(v.name)
"""
返回:
dis/conv1/kernel:0
dis/conv1/bias:0
dis/conv2/kernel:0
dis/conv2/bias:0
"""

可见,在已有的变量命名空间中新定义的变量成功了。

keras 中 reuse 问题相关推荐

  1. 深度学习布料交换:在Keras中实现条件类比GAN

    2017年10月26日SHAOANLU 条件类比GAN:交换人物形象的时尚文章(链接) 给定三个输入图像:人穿着布A,独立布A和独立布B,条件类比GAN(CAGAN)生成穿着布B的人类图像.参见下图. ...

  2. CNN在Keras中的实践|机器学习你会遇到的“坑”

    2018-12-16 23:43:37 本文作为上一节<卷积之上的新操作>的补充篇,将会关注一些读者关心的问题,和一些已经提到但并未解决的问题: 到底该如何理解padding中的valid ...

  3. 神经网络在Keras中不work!博士小哥证明何恺明的初始化方法堪比“CNN还魂丹”...

    铜灵 发自 凹非寺 量子位 出品 | 公众号 QbitAI 南巴黎电信学院(Télécom SudParis)的在读博士生Nathan Hubens在训练CNN时遇到点难题. 使用在CIFAR10数据 ...

  4. Keras中神经网络可视化模块keras.utils.visualize_util安装配置方法

    Keras中提供了一个神经网络可视化的函数plot,并可以将可视化结果保存在本地.plot使用方法如下: from keras.utils.visualize_util import plot plo ...

  5. Keras之ML~P:基于Keras中建立的回归预测的神经网络模型(根据200个数据样本预测新的5+1个样本)——回归预测

    Keras之ML~P:基于Keras中建立的回归预测的神经网络模型(根据200个数据样本预测新的5+1个样本)--回归预测 目录 输出结果 核心代码 输出结果 核心代码 # -*- coding: u ...

  6. Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5个样本)——概率预测

    Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5个样本)--概率预测 目录 输出结果 核心代码 输出结果 核心代码 # -*- coding: ...

  7. Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5+1个样本)——类别预测

    Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5+1个样本)--类别预测 目录 输出结果 核心代码 输出结果 核心代码 # -*- codin ...

  8. 【小白学习keras教程】十一、Keras中文本处理Text preprocessing

    @Author:Runsen 文章目录 Text preprocessing Tokenization of a sentence One-hot encoding Padding sequences ...

  9. keras构建卷积神经网络_在Keras中构建,加载和保存卷积神经网络

    keras构建卷积神经网络 This article is aimed at people who want to learn or review how to build a basic Convo ...

最新文章

  1. PPTPD×××服务器架设
  2. CollegeStudent
  3. 第三章:3.5 傅里叶变换
  4. 我们需要一个时期,把我们之前的愿景用实际行动实现
  5. mvc3中正确处理ajax访问需要登录的页面
  6. Git 笔记——如何处理分支合并冲突
  7. Win32汇编学习笔记(罗云彬)(二)
  8. 魔兽国服修改服务器地址,魔兽国服退役服务器上架暴雪官方商店
  9. ajax下载表格文件
  10. 虚拟机VMware的安装及使用
  11. 1013_MISRA C规范学习笔记9
  12. 一周信创舆情观察(2.1~2.7)
  13. 测试开发 - 十年磨一剑(序)
  14. installshield 如何实现Oracle数据库脚本的执行功能
  15. 电驴维持友情链接地址、更新服务器列表
  16. 制作轮播图经验分享——element ui走马灯的使用(附源码,效果截图)
  17. c++小游戏——忍者必须死
  18. Flowable工作流之核心流程操作的本质
  19. 去除html双击后选中有蓝色背景
  20. python 拼多多_python 拼多多_拼多多2018校招编程题汇总 Python实现

热门文章

  1. 三次样条插值-轨迹规划
  2. 将视频抽取成图片,并对图片进行批量命名opencv代码
  3. 卡尔曼滤波器的一种形象表达
  4. 华为鸿蒙HarmonyOS,华为鸿蒙HarmonyOS-系统概述
  5. opengl光线追踪的程序_【PathTracing】实时光线追踪和BSSRDF的那些事
  6. java语言生日蛋糕代码_AcWing 168. 【Java】生日蛋糕
  7. 在Python中查找字符串长度
  8. ios8升级ios12教程_iOS SpriteKit教程
  9. testng_TestNG @工厂注释
  10. jsf表单验证_JSF验证示例教程–验证器标签,定制验证器