最简单的示例,基本框架,方便改写。

大型案例,博客根本写不下。

# 导入依赖项
import os
import glob
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
print("tensorflow版本号:", tf.__version__)
tensorflow版本号: 2.7.0
# 加载数据
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images.shape  # 60000张28*28像素的图片
(60000, 28, 28)
train_images.dtype  # 图片类型
dtype('uint8')
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')  # 数据类型改成float32
train_images.shape
(60000, 28, 28, 1)
train_images = (train_images-127.5)/127.5  # 归一化数据集
BATCH_SIZE = 256  # 批训练大小
BUFFRE_SIZE = 60000  # 数据取出
# 创建datasets数据集,只要图片。
datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFRE_SIZE).batch(BATCH_SIZE)  # 数据打乱、数据批次
datasets
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>
# 生成网络(最简单的写法,复杂的自行修改。)
def generator_model():model = keras.Sequential()model.add(layers.Dense(256,  # 输出形状input_shape=(100,),  # 输入形状use_bias=False  # 不要偏置参数,只要权重参数))model.add(layers.BatchNormalization())  # 归一化model.add(layers.LeakyReLU())  # 激活层model.add(layers.Dense(512,  # 输出形状use_bias=False  # 不要偏置参数,只要权重参数))model.add(layers.BatchNormalization())  # 归一化model.add(layers.LeakyReLU())  # 激活层model.add(layers.Dense(28*28*1,  # 输出形状use_bias=False,  # 不要偏置参数,只要权重参数activation='tanh'  # 激活函数))model.add(layers.BatchNormalization())  # 归一化model.add(layers.Reshape((28, 28, 1)))  # 生成图片形状return model
# 鉴别网络
def discriminator_model():model = keras.Sequential()model.add(layers.Flatten())  # 输入的图片降成一维model.add(layers.Dense(512,  # 输出形状use_bias=False  # 不要偏置参数,只要权重参数))model.add(layers.BatchNormalization())  # 归一化model.add(layers.LeakyReLU())  # 激活层model.add(layers.Dense(256,  # 输出形状use_bias=False  # 不要偏置参数,只要权重参数))model.add(layers.BatchNormalization())  # 归一化model.add(layers.LeakyReLU())  # 激活层model.add(layers.Dense(1))  # 二分类交叉熵损失函数需要。return model
# 损失函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# 鉴别网络损失函数
def discriminator_loss(real_out,  # image -> 鉴别网络 -> real_outfake_out  # noise -> 生成网络 -> 鉴别网络 -> fake_out
):# 真实图片损失函数real_loss = cross_entropy(tf.ones_like(real_out),  # 期望值是1real_out  # image -> 鉴别网络 -> real_out)# 生成图片损失函数fake_loss = cross_entropy(tf.zeros_like(fake_out),  # 期望值是0fake_out  # noise -> 生成网络 -> 鉴别网络 -> fake_out)return real_loss + fake_loss  # 真实图片和生成图片的loss的和(目标就是使这个loss值最小)
# 生成网络损失函数
def generator_loss(fake_out  # noise -> 生成网络 -> 鉴别网络 -> fake_out
):return cross_entropy(tf.ones_like(fake_out),  # 期望值是1fake_out  # 真实图片)
# 生成网络优化器
generator_opt = tf.keras.optimizers.Adam(1e-4  # 学习率
)
discriminator_opt = tf.keras.optimizers.Adam(1e-4)  # 鉴别网络优化器
EPOCHS = 100  # 训练次数
nosie_dim = 100  # 输入数据的维度
num_exp_to_generator = 16  # 生成图片数量
seed = tf.random.normal([num_exp_to_generator, nosie_dim])  # 输入的随机向量
# 创建生成网络和鉴别网络
generator = generator_model()
discriminator = discriminator_model()

# 训练步骤
def train_step(images):noise = tf.random.normal([BATCH_SIZE, nosie_dim])  # 输入的随机向量# 创建生成网络梯度和鉴别网络梯度with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:real_out = discriminator(  # 鉴别网络鉴别真实图片images,  # 真实图片training=True  # 可训练的)gen_image = generator(  # 生成网络生成图片noise,  # 随机生成的向量training=True  # 可训练的)fake_out = discriminator(  # 鉴别网络鉴别生成图片gen_image,  # 生成图片training=True  # 可训练的)gen_loss = generator_loss(fake_out)  # 生成网络损失值disc_loss = discriminator_loss(real_out, fake_out)  # 鉴别网络损失值# 生成网络梯度值gradient_gen = gen_tape.gradient(gen_loss,  # 生成网络损失值generator.trainable_variables  # 生成网络权重参数)# 鉴别网络梯度值gradient_disc = disc_tape.gradient(disc_loss,  # 鉴别网络损失值discriminator.trainable_variables  # 鉴别网络权重参数)# 生成网络反向传播,优化权重参数generator_opt.apply_gradients(zip(gradient_gen,  # 生成网络梯度值generator.trainable_variables  # 生成网络权重参数))# 鉴别网络反向传播,优化权重参数discriminator_opt.apply_gradients(zip(gradient_disc,  # 鉴别网络梯度值discriminator.trainable_variables  # 鉴别网络权重参数))
# 展示生成的图片
def generate_plot_image(gen_model,  # 生成器训练好的模型test_noise  # 输入的噪音向量
):pre_images = gen_model(test_noise,  # 输入的噪音向量training=False  # 不可训练,不更新模型参数)# 绘图fig = plt.figure(figsize=(4, 4))for i in range(pre_images.shape[0]):  # 遍历图片。plt.subplot(4, 4, i+1)  # 子图 4 * 4 从第1张开始# 显示图片设置plt.imshow(# 取第i张图片,全部的宽,全部的高,第一张图片。(pre_images[i, :, :, 0] + 1)/2,  # /2:取0和1之间cmap='gray'  # 设置颜色为灰度图)# 一起显示图片plt.show()
# 训练函数
def train(dataset,  # 数据集epochs  # 训练次数
):for epoch in range(epochs):  # 循环训练for image_batch in dataset:  # 遍历数据集train_step(image_batch)  # 训练步骤print('.', end='')  # 训练1次,出现一个点,打印不要换行# 每一个epoch都绘图generate_plot_image(generator,  # 生成器网络seed  # 随机输入向量)
# 开始训练
train(datasets, EPOCHS)

tensorflow gan 网络 示例相关推荐

  1. tensorflow gan网络流程图

    tensorflow gan 网络流程图

  2. 54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例

    1.54.GAN(生成对抗网络) 1.54.1.什么是GAN 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative ...

  3. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  4. #教计算机学画卡通人物#生成式对抗神经网络GAN原理、Tensorflow搭建网络生成卡通人脸

    生成式对抗神经网络GAN原理.Tensorflow搭建网络生成卡通人脸 下面这张图是我教计算机学画画,计算机学会之后画出来的,具体实现在下面. ▲以下是对GAN形象化地表述 ●赵某不务正业.游手好闲, ...

  5. 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)

    图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow) 文章目录 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网 ...

  6. GAN背后的理论依据,以及为什么只使用GAN网络容易产生

    花了一下午研究的文章,解答了我关于GAN网络的很多疑问,内容的理论水平很高,只能尽量理解,但真的是一篇非常好的文章转自http://www.dataguru.cn/article-10570-1.ht ...

  7. GAN网络图像翻译机:图像复原、模糊变清晰、素描变彩图

    贴个文章,记录学习历程 http://www.sohu.com/a/169212360_473283 本文介绍深度学习方法在图像翻译领域的应用,通过实现一个编码解码"图像翻译机"进 ...

  8. Nat. Commun. | 条件GAN网络和基因表达特征用于类苗头化合物的发现

    今天给大家介绍的是拜耳作物科学公司.拜耳公司机器学习研发部和遗传毒理学部于2020年1月联合发表在Nature Communications上的一篇论文,这篇文章通过一种生成模型进行分子的从头设计以及 ...

  9. GAN网络生成手写体数字图片

    Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的. 目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接 ...

最新文章

  1. attention retain_Attention-Aware Compositional Network
  2. 【剪枝算法】通过网络瘦身学习高效的卷积网络Learning Efficient Convolutional Networks through Network Slimming论文翻译
  3. 直播预告|中台基石腾讯云TStack的正确使用姿势
  4. Eclipse Java注释模板设置详解
  5. C++:08---成员变量初始化方式
  6. 高质量JAVA代码编写规范
  7. 计组之概述:计算机系统
  8. 一个软件系统哪些可独立实现
  9. Tensorflow2.0数据和部署(二)——基于设备的模型与TensorFlow Lite
  10. eventlistener java_EventListener原理
  11. 25.MySQL sql_mode
  12. Android List的混排 随机排序
  13. 195-Redigo基本操作mget-mset
  14. Python一些常用的网站
  15. python arp断网攻击_arp断网攻击,小编教你arp断网攻击怎么解决
  16. win7系统提示“此windows副本不是正版” 解决方案
  17. 哲学家是如何思考问题的-2.0版
  18. 什么是Google Voice
  19. 电路中VCC、VDD、VEE和VSS的区别
  20. [DeploymentService:290066]Error occurred while downloading files from admin server for deployment re

热门文章

  1. Nature Methods | TooManyCells:单细胞聚类和可视化方法
  2. 标准氨基酸和质子化氨基酸 三字母 单字母 对应表
  3. C库函数-perror()
  4. 计算机二级word保存要不要加.docx,计算机二级word实操题.docx
  5. mysql int和bigdecimal,mysql的 int 类型,刨析返回类型为BigDicemal 类型的奇怪现象
  6. Nature、Science的绘图新宠,博导人论文覆盖率高达78%...
  7. 微生物组数据揭示中国稻谷产毒真菌分布及仓储动态变化
  8. 35张图,看懂肠道和大脑的魔性关系,绝对涨知识!
  9. 差点被人类消灭的疾病,科学家说是苏联让它重新肆虐全球?
  10. php7 swoole 扩展,PHP7.2加入swoole扩展