什么是GAN














实现代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
# 显存自适应分配
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:tf.config.experimental.set_memory_growth(gpu,True)
gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU", gpu_ok) # 判断是否使用gpu进行训练
# 手写数据集
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()


train_images = train_images.reshape(train_images.shape[0],28,28,1).astype("float32")

# 归一化
train_images = (train_images-127.5)/127.5
BATCH_SIZE = 256
BUFFER_SIZE = 60000
# 创建数据集
datasets = tf.data.Dataset.from_tensor_slices(train_images)
# 乱序
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)


编写模型

# 生成器模型
def generator_model():model = keras.Sequential() # 顺序模型model.add(layers.Dense(256,input_shape=(100,),use_bias=False)) # 输出256个单元,随机数输入数据形状长度100的向量model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU())# LeakyReLU()激活model.add(layers.Dense(512,use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())# 激活model.add(layers.Dense(28*28*1,use_bias=False,activation="tanh")) # 输出28*28*1形状 使用tanh激活得到-1 到1 的值model.add(layers.BatchNormalization())model.add(layers.Reshape((28,28,1))) # reshape成28*28*1的形状return model
# 判别模型
def discriminator_model():model = keras.Sequential()model.add(layers.Flatten())model.add(layers.Dense(512,use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())# 激活model.add(layers.Dense(256,use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())# 激活model.add(layers.Dense(1))return model
# 编写loss    binary_crossentropy(对数损失函数)即 log loss,与 sigmoid 相对应的损失函数,针对于二分类问题。
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# 辨别器loss
def discriminator_loss(real_out,fake_out):read_loss = cross_entropy(tf.ones_like(real_out),real_out) # 使用binary_crossentropy 对真实图片判别为1fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out) # 生成的图片 判别为0return read_loss + fake_loss
# 生成器loss
def generator_loss(fake_out):return cross_entropy(tf.ones_like(fake_out),fake_out) # 希望对生成的图片返回为1
# 优化器
generator_opt = tf.keras.optimizers.Adam(1e-4) # 学习速率1e-4   0.0001
discriminator_opt = tf.keras.optimizers.Adam(1e-4)
EPOCHS = 100 # 训练步数
noise_dim = 100 num_exp_to_generate = 16seed = tf.random.normal([num_exp_to_generate,noise_dim]) # 16,100 # 生成16个样本,长度为100的随机数
generator = generator_model()
discriminator = discriminator_model()
# 训练一个epoch
def train_step(images):noise = tf.random.normal([BATCH_SIZE,noise_dim])with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape: # 梯度real_out = discriminator(images,training=True)gen_image = generator(noise,training=True)fake_out = discriminator(gen_image,training=True)gen_loss = generator_loss(fake_out)disc_loss = discriminator_loss(real_out,fake_out)gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables)generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))
def genrate_plot_image(gen_model,test_noise):pre_images = gen_model(test_noise,training=False)fig = plt.figure(figsize=(4,4))for i in range(pre_images.shape[0]):plt.subplot(4,4,i+1)plt.imshow((pre_images[i,:,:,0]+1)/2,cmap="gray")plt.axis("off")plt.show()
def train(dataset,epochs):for epoch in range(epochs):for image_batch in dataset:train_step(image_batch)print(".",end="")genrate_plot_image(generator,seed)
# 训练模型
train(datasets,EPOCHS)



GAN生成对抗网络-GAN原理与基本实现-入门实例02相关推荐

  1. GAN生成对抗网络的原理及CycleGAN、Pixel2Pixel、starGAN的的原理即实现

    生成对抗网络 1.生成对抗网络的定义 生成式对抗网络是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成模型和判别模型的互相博弈学习产生相当好的输出 ...

  2. GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10

    CycleGAN的原理可以概述为: 将一类图片转换成另一类图片 .也就是说,现在有两个样 本空间,X和Y,我们希望把X空间中的样本转换成Y空间中 的样本.(获取一个数据集的特征,并转化成另一个数据 集 ...

  3. GAN生成对抗网络-PIX2PIXGAN原理与基本实现-图像翻译09

    什么是pix2pix Gan 普通的GAN接收的G部分的输入是随机向量,输出是图像 :D部分接收的输入是图像(生成的或是真实的),输出是对或 者错.这样G和D联手就能输出真实的图像. 对于图像翻译任务 ...

  4. GAN生成对抗网络-CGAN原理与基本实现-条件生成对抗网络04

    CGAN - 条件GAN 原始GAN的缺点 代码实现 import tensorflow as tf from tensorflow import keras from tensorflow.kera ...

  5. GAN生成对抗网络-DCGAN原理与基本实现-深度卷积生成对抗网络03

    什么是DCGAN 实现代码 import tensorflow as tf from tensorflow import keras from tensorflow.keras import laye ...

  6. GAN生成对抗网络-SSGAN原理与基本实现-半监督学习GAN-08

  7. GAN生成对抗网络-INFOGAN原理与基本实现-可解释的生成对抗网络-06

    代码 import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import m ...

  8. GAN生成对抗网络-GAN原理与基本实现-去噪与卷积自编码器01

    基本去噪自编码器 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np # 显存自适应分配 gpus = ...

  9. 生成对抗网络(GAN)简单梳理

    作者:xg123321123 - 时光杂货店 出处:http://blog.csdn.net/xg123321123/article/details/78034859 声明:版权所有,转载请联系作者并 ...

最新文章

  1. github/python/ show me the code 25题(一)
  2. TensorFlow RNN tutorial解读
  3. github文件上传全流程-新手入门系列
  4. java商城_java开源商城系统的优势是什么?
  5. 企业微信加密消息体_用企业微信小程序发送消息
  6. Python并发编程Futures
  7. 1021个位数字统计
  8. python中debug有什么用途_python中调试或排错的五种方法示例
  9. ubuntu 常用命令锦集
  10. KB4484127 更新导致ACCESS数据库查询报 Query '' is corrupt 异常解决方案
  11. 怎么卸载apowerrec_怎么卸载win10自带应用 工具
  12. 计算机科学与技术有几大类,计算机科学与技术类包括哪些专业
  13. Excel 点击单元格打钩,再点击取消
  14. ltp测试操作步详解(压力测试网站最详、下载、使用)
  15. 高级程序员的自我修养:如何才能成长为牛逼的高级程序员?
  16. 安装青龙面板(不用购买服务器即可薅羊毛)Ubuntu
  17. 如何查看MySql的安装位置?
  18. 树莓派桌面多出个计算机,树莓派|计算机实验室之树莓派:课程 9 屏幕04
  19. python学习:向Firebird数据库表中插入数据
  20. 草料二维码生成器怎么连接打通其他应用?

热门文章

  1. 怎么判断间隙过渡过盈配合_间隙配合过盈配合过渡配合之间的区别
  2. thinkphp5将时间戳直接转换成时间格式
  3. PHP执行一个http请求
  4. PHP对抗web扫描器的脚本技巧
  5. uniapp连接php,thinkphp5 对接手机uni-app的unipush推送(个推)
  6. python登录代码思路_终于找到一个思路比较清晰的可以模拟登录百度的代码!
  7. java wate_Trapping Rain Water leetcode java
  8. 计算机二级vf笔试,计算机二级(VF)笔试120.doc
  9. linux如何导出加密卡私钥,linux – 如何使用gpg中的私钥加密文件
  10. 【springboot】模板路径、静态资源路径、WebRoot的本地路径