keras padding_GAN整体思路以及使用Keras搭建DCGAN
整体思路:
1:使用噪音,通过一系列的转秩卷积(逆卷积)操作,生成一张图片;
2:使用正常的卷积神经网络判断图片的真假。
训练细节:
1:在训练判决器网络时,对真实图像,加上正标签,对假图像加上假标签,使得判别器能够完美的发现假图像;
2:训练生成器时,给假图像打上正标签,使得判别器能够返回假图像在变成真图像所需要拟合的变化。
模型细节:
1:整个网络是两个模型的组合。首先是生成器,然后是判别器;
2:在训练判别器时,使用的网络仅仅是判别器网络,训练的参数也仅仅是判别器网络的参数;
3:训练生成器时,使用的是组合的模型,即先用生成器网络生成图像,再用判别器网络判断优化方向。在优化生成器网络参数时,需要关闭判别器网络。也就是说,尽管反向传播回来的梯度是从判别器开始的,但是判别器的参数不参与优化工作,这也是为什么训练生成器时需要先将判别器网络的梯度更新关闭的原因。
代码重点讲解:
生成器网络搭建:
def generator_model():model = tf.keras.Sequential()#将100维的噪音升维到7*7*256,方便后面转换为图像矩阵model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU())#转换为图像矩阵model.add(layers.Reshape((7, 7, 256)))assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size#使用逆卷积操作,将图片进行尺寸的放大(逆卷积讲解:https://blog.csdn.net/nima1994/article/details/83959495;逆卷积和向上采样的区别:https://www.zhihu.com/question/290376931)model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))assert model.output_shape == (None, 7, 7, 128) model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))assert model.output_shape == (None, 14, 14, 64) model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())#一维卷积,通道合并,生成28*28*1的图像model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))assert model.output_shape == (None, 28, 28, 1)return model
判别器网络搭建:
def discriminator_model():model = tf.keras.Sequential()model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))#将特征图像展平,然后做一个汇总输出,输出维度为1,即判别图片为真假图片的概率model.add(layers.Flatten())model.add(layers.Dense(1))return model
损失函数:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
判别器损失函数:
def discriminator_loss(real_output, fake_output):
#对真图像,定义标签为1,ones_lisk为输出real_output维度的1,然后使用交叉熵损失,表示真图像距离被判断为真图像有多大的差距
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
#对假图片,定义标签为0,然后使用交叉熵损失,表示假图像距离被判断为假图像有多大的差距
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
#总的损失为在真假图像上的损失之和,即图片的预测与图片的标签之间的差距
total_loss = real_loss + fake_loss
return total_loss
生成器损失函数:
def generator_loss(fake_output):
#将假图像的标签定为1,衡量假图像距离真图像还有多大的差距,这也是生成器网络需要优化的方向
return cross_entropy(tf.ones_like(fake_output), fake_output)
生成对抗网络(自此开始引自github,与上文不是一个作者,可能代码有出入,主要是看整体的思路):
将生成器网络和判别器网络加入,生成器网络在前,判别器网络在后,同时将判别器网络的参数更新关闭。因为这个网络的主要作用就是优化生成器,因此判别器不参与优化。
进行训练:
首先,在83,84行初始化两个模型,然后构建生成对抗网络d_on_g。
上图中的88,89,91行是设置损失函数的,作用等同于上面写的损失函数内容。(截图和代码并不是来自同一个作者,因此这里只表达的了整体逻辑,而非真正没有错误的)
然后设置判别器网络可优化参数,对每一批次,首先先生成噪音数据,然后用噪音数据生成假图像。然后将假图像与真图像组合,并合并标签,假图像为0,真图像为1,,在106行,开始训练判别器网络,注意,此时训练的仅仅是判别器网络d。训练使用的是Keras中的train_on_batch,解释:https://blog.csdn.net/weixin_42886817/article/details/99855287
判别器网络训练完毕后,在使用噪音生成假图像,此时将其标签设置为1(在110行,[1]*BATCH_SIZE即设置y),然后关闭判别器网络的优化,开始训练d_on_g网络,此网络由生成器和判别器组成,由于前面将判别器网络的参数优化给关闭了,因此在这里只优化生成器网络。
代码来自:https://github.com/jacobgil/keras-dcgan/blob/master/dcgan.py
keras padding_GAN整体思路以及使用Keras搭建DCGAN相关推荐
- 好像还挺好玩的GAN2——Keras搭建DCGAN利用深度卷积神经网络实现图片生成
好像还挺好玩的GAN2--Keras搭建DCGAN利用深度卷积神经网络实现图片生成 注意事项 学习前言 什么是DCGAN 神经网络构建 1.Generator 2.Discriminator 训练思路 ...
- java 搭建个人博客_Spring boot 搭建个人博客系统(一)——整体思路
Spring boot 搭建个人博客系统(一)--整体思路 一直想用Spring boot 搭建一个属于自己的博客系统,刚好前段时间学习了叶神的牛客项目课受益匪浅,乘热打铁也主要是学习,好让自己熟悉这 ...
- 基于keras的CNN图片分类模型的搭建以及参数调试
基于keras的CNN图片分类模型的搭建与调参 更新一下这篇博客,因为最近在CNN调参方面取得了一些进展,顺便做一下总结. 我的项目目标是搭建一个可以分五类的卷积神经网络,然后我找了一些资料看了一些博 ...
- 【Keras】Win10系统 + Anaconda+TensorFlow+Keras 环境搭建教程
1. 安装 Anaconda 打开 Anaconda 的官方下载地址:https://www.anaconda.com/download/ 选择 Python 对应的version 下载.下载完成后直 ...
- keras神经网络回归预测_如何使用Keras建立您的第一个神经网络来预测房价
keras神经网络回归预测 by Joseph Lee Wei En 通过李维恩 一步一步的完整的初学者指南,可使用像Deep Learning专业版这样的几行代码来构建您的第一个神经网络! (A s ...
- keras构建卷积神经网络_在Keras中构建,加载和保存卷积神经网络
keras构建卷积神经网络 This article is aimed at people who want to learn or review how to build a basic Convo ...
- 菜鸟做设计必看!有关如何做设计的整体思路,以及能否综合的笔记
对Verilog 初学者比较有用的整理(转自它处) 作者: Ian11122840 时间: 2010-9-27 09:04 标题: 菜鸟做设计必看!有关如何做设计的整体思路,以及能否综合的笔记 所谓综 ...
- Keras官方中文文档:Keras安装和配置指南(Windows)
这里需要说明一下,笔者不建议在Windows环境下进行深度学习的研究,一方面是因为Windows所对应的框架搭建的依赖过多,社区设定不完全:另一方面,Linux系统下对显卡支持.内存释放以及存储空间调 ...
- 做旅游网站建设的整体思路
做旅游网站建设的整体思路 现如今互联网的应用已经延伸到了个个行业领域,那么接踵而来的就是互联网线上交易的时代,之前这块大蛋糕难分,因为没有合适的刀子,尤其是旅游行业,不过现在拥有一个网站来宣传你的旅游 ...
最新文章
- 用thttpd做Web Server
- 分别用BFS和DFS求给定的矩阵中“块”的个数
- [LintCode] Wildcard Matching
- 在vs2005中安装boost库
- CentOS6 安装 MySQL 并配置
- Xshell 连接ubuntu16.04 32位
- matlab实现盖尔圆,[理学]数值分析习题解答.doc
- 概率编程编程_概率编程语言的温和介绍
- 心情随笔(三):注入新的血液
- Sublime Text : 创建工程
- winform ctrl键单击多选_鼠标各键在CAD中的运用,左右键常用,但滚轮这个功能不一定用过...
- 你不知道的Retrofit缓存库RxCache
- CAD - 多段线、矩形、修订云线、样条曲线
- python绘图画猫咪_Turtle库画小猫咪
- 大前研一/聰明人必做的十件事
- 如何在最短的时间内完成立春主题的公众号图文排版?
- 什么叫一层交换机,二层交换机,三层交换机?
- (本人亲测有效)华为magicbook 16SE笔记本电脑重装系统过程
- 关于 NLP 中的 tokenize 总结
- dlt645协议电表数据采集接入PLC或scada等组态软件系统(转modbus)实现内网监控技术方案
热门文章
- node本地连接服务器的数据库_Linux本地连接阿里云服务器,以及下载node.js配置环境...
- 高德,百度,Google地图定位偏移以及坐标系转换
- php include不可用,无法设置PHP include_path
- 邮件发送类_SpringBoot优雅地发送邮件
- java在捕获异常并弹窗_Java捕获异常的问题
- php怎么在html上得到input值,怎么把一個php頁面的值傳到另一個html表單中的input里面去...
- abb机器人伺服电机报闸是什么_ABB机器人电池更换时回零程序Reference
- 电脑 win10 android,新版win10 20185来袭!微软:让你可以直接从PC访问手机App
- mysql数据生成词云图_CVPR2018关键字分析生成词云图与查找
- 计算机英语摘要,英语翻译摘要地理信息系统 (GIS,Geographic Information System) 是一种基于计算机的工具...