一个简单的GAN搭建

  • 1.Generator
  • 2.Discriminator
  • 3.完整代码
  • 4.项目实践
    • 4.1 下载项目
    • 4.2 环境配置
    • 4.3 运行代码
    • 4.4 查看运行结果

此博客为学习 他人博客, bilibili视频解析所作的笔记,在看这个作者的项目之前,可以看看 这位up主的深度学习科普
项目源码: https://github.com/bubbliiiing/GAN-keras
此博文使用代码是 gan.py

1.Generator

def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))#把100维全连接到256个节点上model.add(LeakyReLU(alpha=0.2))#激活函数model.add(BatchNormalization(momentum=0.8))#标准化model.add(Dense(512))#把256的神经元映射到512的神经元上model.add(LeakyReLU(alpha=0.2))#激活函数model.add(BatchNormalization(momentum=0.8))#标准化model.add(Dense(1024))#把512的神经元映射到1024的神经元上model.add(LeakyReLU(alpha=0.2))#激活函数model.add(BatchNormalization(momentum=0.8))#标准化model.add(Dense(np.prod(self.img_shape), activation='tanh'))#np.prod==28*28*1,把1024映射带784的神经元上model.add(Reshape(self.img_shape))#再reshape成28*28*1noise = Input(shape=(self.latent_dim,))#输入n维变量,例如100维img = model(noise)#模型就能生成一张图片了return Model(noise, img)

2.Discriminator

def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))#28*28*1 Flatten:把28*28*1平铺成向量model.add(Dense(512))#将784映射到512神经元上model.add(LeakyReLU(alpha=0.2))#激活函数,alpha是学习率model.add(Dense(256))#将512映射到256神经元上model.add(LeakyReLU(alpha=0.2))#激活函数# 判断真伪model.add(Dense(1, activation='sigmoid'))#全连接到1维向量img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)

3.完整代码

from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adamimport matplotlib.pyplot as pltimport sys
import os
import numpy as npclass GAN():def __init__(self):# --------------------------------- ##   行28,列28,也就是mnist的shape# --------------------------------- #self.img_rows = 28self.img_cols = 28self.channels = 1# 28,28,1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 100# adam优化器,学习率0.0002optimizer = Adam(0.0002, 0.5)self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])self.generator = self.build_generator()#生成生成网络模型gan_input = Input(shape=(self.latent_dim,))#产生噪声输入img = self.generator(gan_input)#生成一张图片# 在训练generate的时候不训练discriminatorself.discriminator.trainable = False# 对生成的假图片进行预测validity = self.discriminator(img)#将生成图片传入判别模型中得到预测结果self.combined = Model(gan_input, validity)#结合判别模型对生成模型进行训练self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, sample_interval=50):# 获得数据(X_train, _), (_, _) = mnist.load_data()#加载数据集# 进行标准化,标准化X_train = X_train / 127.5 - 1.#28,28->28,28,1X_train = np.expand_dims(X_train, axis=3)# 创建标签valid = np.ones((batch_size, 1))#把真图片标签为1fake = np.zeros((batch_size, 1))#把假图片标记为0for epoch in range(epochs):# --------------------------- ##   随机选取batch_size个图片#   对discriminator进行训练# --------------------------- #idx = np.random.randint(0, X_train.shape[0], batch_size)#随机选取几张真是图片imgs = X_train[idx]#放入这个数组noise = np.random.normal(0, 1, (batch_size, self.latent_dim))#生成一堆noisegen_imgs = self.generator.predict(noise)#传入生成模型中,生成生成模型d_loss_real = self.discriminator.train_on_batch(imgs, valid)#传入真实图片,结果和1对比,train_on_batch一个batch一个batch的训练d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)#传入假图片,结果和0对d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# --------------------------- ##  训练generator# --------------------------- #noise = np.random.normal(0, 1, (batch_size, self.latent_dim))#生成一组noiseg_loss = self.combined.train_on_batch(noise, valid)#生成模型和1对比print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))if epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()if __name__ == '__main__':if not os.path.exists("./images"):os.makedirs("./images")gan = GAN()gan.train(epochs=30000, batch_size=256, sample_interval=200)

4.项目实践

4.1 下载项目

项目地址:https://github.com/bubbliiiing/GAN-keras

4.2 环境配置

在annaconda里搭建一个虚拟环境,并配置

tensorflow==1.13.1
keras==2.1.5

作者说需要tensorflow-gpu,但是其实这个项目挺小的,完全可以使用CPU

4.3 运行代码

跳转到gan目录下,直接运行python文件就ok

python gan.py

4.4 查看运行结果



第三万次结果:

GAN学习:一个简单的GAN搭建相关推荐

  1. 说一下dubbo项目简单的搭建过程_dubbo学习(1)--简单的入门搭建实例

    1 简介 dubbo是一个分布式服务框架,由阿里巴巴的工程师开发,致力于提供高性能和透明化的RPC远程服务调用.可惜的是该项目在2012年之后就没有再更新了,之后由当当基于dubbo开发了dubbox ...

  2. GAN学习笔记-李宏毅:GAN Lecture 7 (2018): Info GAN, VAE-GAN, BiGAN

    李宏毅老师讲解的GAN Lecture 7 (2018): Info GAN, VAE-GAN, BiGAN 问题:input feature 对output影响不明确这件事:本来假设不同的特性在la ...

  3. c4d学习笔记-简单自行车的搭建--自行车轴的制作、改变轴心旋转、立体三角制作

    自行车的搭建 具体的搭建过程也不是太难,只用到了基础的模型,只记录一下搭建时候的关键知识点吧 1.自行车车轮中间杠的克隆 首先绘制一个圆环面 旋转到正确的位置,再绘制一个圆柱体–放到合适的位置 新建一 ...

  4. springboot+mybatis+thymeleaf学习一个简单的管理系统

    在淘宝上买的课程的一个例子,看了视频,抄了一遍代码,那时候刚开始学springboot,所以感觉没什么用,然后就又学习了一段时间.最近回想起来有这样的一个系统符合我现阶段的学习程度,然后就又写了一遍. ...

  5. 帆软(FineReport)报表学习——一个简单的报表

    客户要用帆软做东西,就下载了一个,弄了一些报表出来. 废话不说,走起! 先建立一个数据源连接. 这个Mysql的连接没什么可讲的,就是后面的连接参数需要注意一下,useUnicode=true& ...

  6. GAN学习:Keras入门

    Keras入门 1.环境配置 2.搭建一个简单的网络 3.多元线性回归 3.1 准备数据集 3.2 环境配置 3.2代码 3.4 结果 4.全链接模型之手写数字识别模型 5.优化手写数字识别模型 6. ...

  7. [GAN学习系列3]采用深度学习和 TensorFlow 实现图片修复(上)

    在之前的两篇 GAN 系列文章–[GAN学习系列1]初识GAN以及[GAN学习系列2] GAN的起源中简单介绍了 GAN 的基本思想和原理,这次就介绍利用 GAN 来做一个图片修复的应用,主要采用的也 ...

  8. float在python中的书写形式错误的是_在Python3.7.1中,编写简单的GAN时,“TypeError:”float“对象不能解释为整数”错误...

    我对Python和编程是全新的.我试图编写一个简单的GAN来使用Keras数据集(参见下面的教程超链接). 我收到两个警告,然后是一个错误:TypeError: 'float' object cann ...

  9. PyTorch学习笔记(10)--搭建简单的神经网络以及Sequential的使用

    PyTorch学习笔记(10)–搭建简单的神经网络以及Sequential的使用     本博文是PyTorch的学习笔记,第10次内容记录,主要搭建一个简单的神经网络,并介绍Sequential的使 ...

最新文章

  1. Angular应用一个创建场景的问题分析
  2. P2522 [HAOI2011]Problem b
  3. php pathseparator,在PHP拥有与命名空间和通过set_include_path()的一个问题
  4. mysql数据库1129错误
  5. 汉罗塔python_基于Python的汉诺塔算法
  6. Unity 入门笔记 - 02 - 各种动画
  7. php图片点击查看大图,jQuery点击小图看大图,大图查看内容详情所有图片
  8. 双均线策略代码【利用聚宽平台】
  9. 从比特保存和信息保存看数字资源长期保存
  10. 推荐官方开源 PInvoke 库 包含大量 win32 封装
  11. DO、DTO、BO、AO、VO、POJO
  12. Windows下验证https证书
  13. 使用 Indy WEB Server 支持 https
  14. 二极管包络检波器电路仿真实验
  15. 【阅读笔记】《TDN: Temporal Difference Networks for Efficient Action Recognition》阅读笔记
  16. 调查显示,多数受众最常用微博搜索明星、Vlog、新品发布和活动信息
  17. 如何用python做表_如何使用Python中的Tkinter制作钟表?
  18. 坦克世界登录服务器未响应,笔者详解win7系统坦克世界登录连接不上服务器的解决方案...
  19. 谭铁牛院士荣获国际模式识别最高奖
  20. 用python写期货量化策略,期货单品种MACD择时加ATR止损

热门文章

  1. 某电子计算机有400个终端,(概率四习题.doc
  2. K-Means聚类算法 — 算法原理、质心计算、距离度量、聚类效果评价及优缺点
  3. PbootCMS采集-PbootCMS自动采集-PbootCMS免登录发布插件
  4. IBM MQ常用命令
  5. asp mysql 留言本_适用于ASP.NET的留言本(翻译)
  6. 0ra-12170 tns 连接超时
  7. 4.1二维曲线绘制(plot与fplot)
  8. DEJA_VU3D - Cesium功能集 之 083-Cesium热力图实现完整版
  9. java计算机毕业设计网上宠物商店源程序+mysql+系统+lw文档+远程调试
  10. 标题隐藏_经典街机游戏《三国志》,二十多年后你告诉我还有隐藏必杀