keras 中 reuse 问题
- 结论:在搭建GAN判别器时,慎用keras。
- 原因:tf.variable_scope() 的 reuse=True 对keras变量不起作用,导致判别real_img的判别器与判别fake_img的判别器是两个独立的判别器,使对抗训练失效。
- 导致现象:real_img 的判别结果快速收敛到1, fake_img的判别结果快速收敛到0, 且不再变化, 快速收敛是指在100~1000次迭代内完全收敛,不再变化。当判别器损失函数中用到log函数时, 损失函数会出现 nan 现象。
- 其他: 无论是用keras.layers 还是tensorflow.keras.layers都存在这个问题。具体在keras中如何使用变量重用,我还没有探究,希望有看到的可以指教。
-------------------------------------------------------------------------------------------------------------------------------------------------
个人经历:
我之前在搭建神经网络时,用的卷积层是定义在tf.layers中的:
from tensorflow.compat.v1.layers import Conv2D
最近,发现keras 很好用,尤其对LSTM的封装很好用,所以开始转用keras:
from keras.layers import Conv2D
一直用的挺顺利的,直到这两天在训练GAN时出现问题:判别器对fake_img的判别分值很快就趋近于0,对real_img的判别分值很快就趋近于1。
我百思不得其解,调了学习率,多次检查了GAN的损失函数,都没有问题,折腾了一天时间。直到刚刚突然想到变量重用的问题,检查了变量空间,发现以下问题:用keras的Conv2D构建的判别器,在计算real_map与fake_map时,重复创建了判别器的计算节点,并没有执行变量重用,示例代码如下:
import numpy as np
import tensorflow.compat.v1 as tf
from keras.layers import Conv2D
from tensorflow.keras.layers import Conv2Dinputs = tf.convert_to_tensor(np.random.random((1,256,256,3)).astype(np.float32))with tf.variable_scope("dis"):conv1 = Conv2D(32,(3,3),strides=(1,1),padding="same",activation='relu',name="conv1")(inputs)with tf.variable_scope("dis",reuse=True):conv2 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv1")(inputs)var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for v in var_list:print(v.name)"""
返回:
dis/conv1/kernel:0
dis/conv1/bias:0
dis_1/conv1/kernel:0
dis_1/conv1/bias:0
"""
由代码可见,我在命名空间中设置了 reuse=True, 但是代码仍旧重复创建了相同的计算节点。当在训练GAN时遇到这个问题时,就会导致 计算real_map的判别器与计算fake_map的判别器是两个独立的判别器,他们两个完全不相关,所以就达不到对抗训练的目的了。
很明显,这个问题是不应该出现的,tensorflow设置了reuse参数,应该是要起作用的,于是我又测试了tf.layer.Conv2D的情况,发现用tf.layer.Conv2D 不存在这个问题,示例代码如下(仅仅改变了Conv2D的来源):
import numpy as np
import tensorflow.compat.v1 as tf
# from keras.layers import Conv2D
from tensorflow.compat.v1.layers import Conv2Dinputs = tf.convert_to_tensor(np.random.random((1,256,256,3)).astype(np.float32))with tf.variable_scope("dis"):conv1 = Conv2D(32,(3,3),strides=(1,1),padding="same",activation='relu',name="conv1")(inputs)with tf.variable_scope("dis",reuse=True):conv2 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv1")(inputs)var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for v in var_list:print(v.name)
"""
返回:
dis/conv1/kernel:0
dis/conv1/bias:0
"""
可以发现,当改用 tf.compat.v1.layers 的源时不存在这个问题。
另外, 在reuse = True的命名空间中,只能重复使用已经在该空间中定义的变量,不能创建新的变量,否则会报错。
如果,既希望能够变量重用,又不耽误创建新的变量,请用 reuse=tf.AUTO_REUSE,示例代码如下(与上例相比增加了一个新变量):
import numpy as np
import tensorflow.compat.v1 as tf
# from keras.layers import Conv2D
from tensorflow.compat.v1.layers import Conv2Dinputs = tf.convert_to_tensor(np.random.random((1,256,256,3)).astype(np.float32))with tf.variable_scope("dis"):conv1 = Conv2D(32,(3,3),strides=(1,1),padding="same",activation='relu',name="conv1")(inputs)with tf.variable_scope("dis",reuse=True):conv2 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv1")(inputs)conv3 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv2")(inputs)var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for v in var_list:print(v.name)
"""
报错:
ValueError: Variable dis/conv2/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
"""
当把 reuse=True 改为 reuse=tf.AUTO_REUSE后,错误消失:
import numpy as np
import tensorflow.compat.v1 as tf
# from keras.layers import Conv2D
from tensorflow.compat.v1.layers import Conv2Dinputs = tf.convert_to_tensor(np.random.random((1,256,256,3)).astype(np.float32))with tf.variable_scope("dis"):conv1 = Conv2D(32,(3,3),strides=(1,1),padding="same",activation='relu',name="conv1")(inputs)with tf.variable_scope("dis",reuse=tf.AUTO_REUSE):conv2 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv1")(inputs)conv3 = Conv2D(32, (3, 3), strides=(1, 1), padding="same", activation='relu', name="conv2")(inputs)var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for v in var_list:print(v.name)
"""
返回:
dis/conv1/kernel:0
dis/conv1/bias:0
dis/conv2/kernel:0
dis/conv2/bias:0
"""
可见,在已有的变量命名空间中新定义的变量成功了。
keras 中 reuse 问题相关推荐
- 深度学习布料交换:在Keras中实现条件类比GAN
2017年10月26日SHAOANLU 条件类比GAN:交换人物形象的时尚文章(链接) 给定三个输入图像:人穿着布A,独立布A和独立布B,条件类比GAN(CAGAN)生成穿着布B的人类图像.参见下图. ...
- CNN在Keras中的实践|机器学习你会遇到的“坑”
2018-12-16 23:43:37 本文作为上一节<卷积之上的新操作>的补充篇,将会关注一些读者关心的问题,和一些已经提到但并未解决的问题: 到底该如何理解padding中的valid ...
- 神经网络在Keras中不work!博士小哥证明何恺明的初始化方法堪比“CNN还魂丹”...
铜灵 发自 凹非寺 量子位 出品 | 公众号 QbitAI 南巴黎电信学院(Télécom SudParis)的在读博士生Nathan Hubens在训练CNN时遇到点难题. 使用在CIFAR10数据 ...
- Keras中神经网络可视化模块keras.utils.visualize_util安装配置方法
Keras中提供了一个神经网络可视化的函数plot,并可以将可视化结果保存在本地.plot使用方法如下: from keras.utils.visualize_util import plot plo ...
- Keras之ML~P:基于Keras中建立的回归预测的神经网络模型(根据200个数据样本预测新的5+1个样本)——回归预测
Keras之ML~P:基于Keras中建立的回归预测的神经网络模型(根据200个数据样本预测新的5+1个样本)--回归预测 目录 输出结果 核心代码 输出结果 核心代码 # -*- coding: u ...
- Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5个样本)——概率预测
Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5个样本)--概率预测 目录 输出结果 核心代码 输出结果 核心代码 # -*- coding: ...
- Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5+1个样本)——类别预测
Keras之ML~P:基于Keras中建立的简单的二分类问题的神经网络模型(根据200个数据样本预测新的5+1个样本)--类别预测 目录 输出结果 核心代码 输出结果 核心代码 # -*- codin ...
- 【小白学习keras教程】十一、Keras中文本处理Text preprocessing
@Author:Runsen 文章目录 Text preprocessing Tokenization of a sentence One-hot encoding Padding sequences ...
- keras构建卷积神经网络_在Keras中构建,加载和保存卷积神经网络
keras构建卷积神经网络 This article is aimed at people who want to learn or review how to build a basic Convo ...
最新文章
- PPTPD×××服务器架设
- CollegeStudent
- 第三章:3.5 傅里叶变换
- 我们需要一个时期,把我们之前的愿景用实际行动实现
- mvc3中正确处理ajax访问需要登录的页面
- Git 笔记——如何处理分支合并冲突
- Win32汇编学习笔记(罗云彬)(二)
- 魔兽国服修改服务器地址,魔兽国服退役服务器上架暴雪官方商店
- ajax下载表格文件
- 虚拟机VMware的安装及使用
- 1013_MISRA C规范学习笔记9
- 一周信创舆情观察(2.1~2.7)
- 测试开发 - 十年磨一剑(序)
- installshield 如何实现Oracle数据库脚本的执行功能
- 电驴维持友情链接地址、更新服务器列表
- 制作轮播图经验分享——element ui走马灯的使用(附源码,效果截图)
- c++小游戏——忍者必须死
- Flowable工作流之核心流程操作的本质
- 去除html双击后选中有蓝色背景
- python 拼多多_python 拼多多_拼多多2018校招编程题汇总 Python实现
热门文章
- 三次样条插值-轨迹规划
- 将视频抽取成图片,并对图片进行批量命名opencv代码
- 卡尔曼滤波器的一种形象表达
- 华为鸿蒙HarmonyOS,华为鸿蒙HarmonyOS-系统概述
- opengl光线追踪的程序_【PathTracing】实时光线追踪和BSSRDF的那些事
- java语言生日蛋糕代码_AcWing 168. 【Java】生日蛋糕
- 在Python中查找字符串长度
- ios8升级ios12教程_iOS SpriteKit教程
- testng_TestNG @工厂注释
- jsf表单验证_JSF验证示例教程–验证器标签,定制验证器