提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

@GAN入门实例【个人理解】


前言

本文记录了B站[GAN生成对抗网络精讲 tensorflow2.0代码实战],相关理解(https://b23.tv/jlDcosT)

代码流程

1.引入库

import tensorflow as tf
from tensorflow import keras
from tensorflow .keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os

TensorFlow包

  • 是一个深度学习库,由 Google 开源,可以对定义在 Tensor(张量)上的函数自动求导
  • TensorFlow会给我们一张空白的数据流图,我们往这张数据流图填充(创建节点),从而实现想要效果。
  • 可以参考该链接](https://www.jianshu.com/p/6766fbcd43b9)
    keras
  • Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行
  • 允许简单而快速的原型设计(由于用户友好,高度模块化,可扩展性
  • 同时支持卷积神经网络和循环神经网络,以及两者的组合。
  • 在 CPU 和 GPU 上无缝运行
    matplotlib
  • 画图,在这里主要是使acc/loss可视化
    numpy
  • NumPy是使用Python进行科学计算的基础软件包
  • 功能强大的N维数组对象
  • 精密广播功能函数
  • 强大的线性代数、傅立叶变换和随机数功能
    glob
  • 查找文件路径
    os
  • os 模块提供了非常丰富的方法用来处理文件和目录

2.查看版本

tf.__version__

查看tensorflow版本

3.查看版本

(train_images,train_labels), (_, _)=tf.keras.datasets.mnist.load_data()

等号左边第一括号是训练数据,第二个括号是测试数据,测试数据暂时不需要,就先用占位符占着,等号右边是加载mnist数据集

4.查看数据集shape

train_images.shape

  • 60000:这个数据集有60000张图片
  • 2828:每张图片像素值是2828

5.查看数据类型

train_images.dtype

6.查看版本

train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')

reshape成一个四维张量

7.归一化,定义两个常量

train_images = (train_images-127.5)/127.5
BATCH_SIZE = 256
BUFFER_SIZE = 6000

归一化有两种方式:

  • img/255.0:图像x取值范围[0,1]
  • List item

( img/127.5)-1:图像x取值范围[-1,1]

8.from_tensor_slices

datasets = tf.data.Dataset.from_tensor_slices(train_images)

tf.data.Dataset.from_tensor_slices

-该函数是dataset核心函数之一,它的作用是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。如果有多个特征进行组合,那么一次切片是把每个组合的最外维度的数据切开,分成一组一组的

  • 假设我们现在有两组数据,分别是特征和标签,为了简化说明问题,我们假设每两个特征对应一个标签。之后把特征和标签组合成一个tuple,那么我们的想法是让每个标签都恰好对应2个特征,而且像直接切片,比如:[f11, f12] [t1]。f11表示第一个数据的第一个特征,f12表示第1个数据的第二个特征,t1表示第一个数据标签。那么tf.data.Dataset.from_tensor_slices就是做了这件事情:

9.shuffle()

datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

shuffle:

  • shuffle() 方法将序列的所有元素随机排序。

10.查看数据集类型

datasets

11.定义生成器模型

def generator_model():model = keras.Sequential()model.add(layers.Dense(256,input_shape=(100,), use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(512, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(28*28*1, use_bias=False,activation = 'tanh'))model.add(layers.BatchNormalization())model.add(layers.Reshape((28,28,1)))return model

Sequential

  • Sequential模型的核心操作是添加layers(图层)
  • 首先,输入层,中间层,输出层
  • 然后,进入最重要的部分:编译,包括优化器和损失函数
  • 接着,调用fit()函数
  • 最后可以用evaluate()方法进行评估

参考 https://blog.csdn.net/mogoweb/article/details/82152174

Dense
卷积取的是局部特征,全连接就是把以前的局部特征重新通过权值矩阵组装成完整的图。因为用到了所有的局部特征,所以叫全连接

BatchNormalization

  • 结果:均值约为0,标准差约为1
  • 作用:
    加快收敛速度;
    控制过拟合,可以少用或不用dropout和正则
    降低网络对初始权重的不敏感
    允许使用较大的学习率
    LeakyReLU
  • 激活函数,当x>0,f(x)=x;当x<=0,f(x)=alpha*x

12.定义判别器模型

def discriminator_model():model = keras.Sequential()model.add(layers.Flatten())model.add(layers.Dense(512, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(256, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(1))return model

Flatten

  • 生成器生成的是28281是三维的,现在把它变成一维的

13.交叉熵

为什么可以用交叉熵做损失函数?

因为随着预测值越来越准确,交叉熵的值越来越小。

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

binaryCrossentropy()
tf.keras.losses.BinaryCrossentropy(yTrue,yPred)

  • yTrue:它是事实的二进制张量输入,可以是tf.Tensor类型。
  • yPred:它是预测的指定二进制张量输入,可以是tf.Tensor类型。
  • 返回值:它返回tf.Tensor对象。

from_logits

  • False:(即输出层是带softmax激活函数的),那么其结果会被clip在 [epsilon_, 1 - epsilon_] 的范围之内。这是由于简单的softmax函数会有数值溢出的问题 (参见 softmax溢出问题)。因此如果是先计算softmax再计算cross-entropy,那么要通过clip防止数据溢出问题。
  • Ture:模型会将 softmax和cross-entropy 结合在一起计算

转载https://blog.csdn.net/muyuu/article/details/122762442

14.判别器损失函数

def discriminator_loss(real_out,fake_out):read_loss = cross_entropy(tf.ones_like(real_out),real_out)fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)return read_loss+fake_loss

我们希望real_out被判定为1;fake_out被判定为0.

  • ones_like():全1矩阵
  • zeros_like():全0矩阵

15.生成器损失函数

def generator_loss(fake_out):return cross_entropy(tf.ones_like(fake_out),fake_out)

我们希望生成的图片尽可能多的1;

16.生成器和判别器的优化器

generator_opt = tf.keras.optimizers.Adam(1e-4)
discriminator_opt = tf.keras.optimizers.Adam(1e-4)

几种常用优化器:

  1. 随机梯度下降法(SGD):
    keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)

    lr:大或等于0的浮点数,学习率
    momentum:大或等于0的浮点数,动量参数
    decay:大或等于0的浮点数,每次更新后的学习率衰减值
    nesterov:布尔值,确定是否使用Nesterov动量

  2. Adagrad:
    keras.optimizers.Adagrad(lr=0.01, epsilon=1e-06)

    epsilon:大或等于0的小浮点数,防止除0错误

  3. Adam
    keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    beta_1/beta_2:浮点数, 0<beta<1,通常很接近1

17.tf.random_normal()

EPOCHS = 100
noise_dim = 100
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate,noise_dim])
  • noise_dim = 100:
    用长度为100的随机向量来生成手写数据集

  • num_exp_to_generate = 16
    16个

  • tf.random_normal()
    tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)

    服从指定正态分布的序列”中随机取出指定个数的值。
    shape: 输出张量的形状,必选

转载:https://blog.csdn.net/dcrmg/article/details/79028043

18.调用生成器模型

generator = generator_model()

19.调用判别器模型

discriminator = discriminator_model()

20.定义批次训练的函数

def train_step(images):noise = tf.random.normal([BATCH_SIZE,noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:real_out = discriminator(images,training=True)gen_image = generator(noise,training=True)fake_out = discriminator(gen_image,training=True)gen_loss = generator_loss(fake_out)disc_loss = discriminator_loss(real_out,fake_out)gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables)generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:追踪两个model生成的梯度。
with enter() as variable:执行过程

  • step1:variable = enter()
  • step2:exit()

关于梯度函数gradients,可参考
https://blog.csdn.net/QKK612501/article/details/115335437?spm=1001.2101.3001.6650.1&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1.pc_relevant_default&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1.pc_relevant_default&utm_relevant_index=2

apply_gradients(
grads_and_vars, name=None, experimental_aggregate_gradients=True
)

  • grads_and_vars (梯度,变量)对的列表
  • ame 返回操作的可选名称。默认为传递给Optimizer 构造函数的名称。
  • experimental_aggregate_gradients 是否在存在 tf.distribute.Strategy 的情况下对来自不同副本的梯度求和。如果为 False,则聚合梯度是用户的责任。默认为真。

21.绘制生成图片

def generator_plot_image(gen_model,test_noise):pre_images = gen_model(test_noise,training=False)fig = plt.figure(figsize=(4,4))for i in range(pre_images.shape[0]):plt.subplot(4,4,i+1)plt.imshow((pre_images[i, :, :, 0]+1)/2,cmap='gray')plt.axis('off')plt.show()

figure(figsize=(4,4)):初始4*4的画布,因为噪音有16个
subplot(4,4,i+1):四行四列第i+1个
imshow((pre_images[i, :, :, 0]+1)/2,cmap=‘gray’):显示第i张图片,取所有的长和宽,tanh的范围是[-1,1]将其转化为[0,1]上,画布是灰色的
plt.axis(‘off’):不显示坐标

22.训练

def train(dataset,epochs):for epoch in range(epochs):for image_batch in dataset:train_step(image_batch)print('.',end='')generator_plot_image(generator,seed)

23.运行

train(datasets,EPOCHS)

总结

懵懵懂懂的过了一遍,对全连接层还是很不理解,继续加油吧!争取早日毕业

GAN入门实例【个人理解】相关推荐

  1. linux怎么运行datastage,ETL工具Datastage入门+实例(易理解)

    引言 传统的数据整合方式需要大量的手工编码,而采用 IBM WebSphere DataStage 进行数据 整合可以大大的减少手工编码的数量,而且更加容易维护.数据整合的核心内容是从数据源中抽取 数 ...

  2. Datastage入门+实例(易理解)

    转自http://www.ibm.com/developerworks/cn/data/library/techarticles/dm-0602zhoudp/ 传统的数据整合方式需要大量的手工编码,而 ...

  3. 《HFSS电磁仿真设计从入门到精通》一第2章 入门实例——T形波导的内场分析和优化设计...

    本节书摘来自异步社区<HFSS电磁仿真设计从入门到精通>一书中的第2章,作者 易迪拓培训 , 李明洋 , 刘敏,更多章节内容可以访问云栖社区"异步社区"公众号查看 第2 ...

  4. Java Socket入门实例

    基于测试驱动的Socket入门实例(代码的具体功能可以看我的程序中的注释,不理解的可以短信我) 先看Server的代码: package socketStudy; import java.io.Buf ...

  5. linux Shell(脚本)编程入门实例讲解详解

    linux Shell(脚本)编程入门实例讲解详解 为什么要进行shell编程 在Linux系统中,虽然有各种各样的图形化接口工具,但是sell仍然是一个非常灵活的工具.Shell不仅仅是命令的收集, ...

  6. Windows 外壳扩展编程入门实例

    Windows 外壳扩展编程入门实例 -- Delphi 篇 作者的话 关于Windows 外壳扩展方面的文章私心以为最好的应当算是Michael Dunn 的TheComplete Idiot's ...

  7. JUnit学习摘要+入门实例 (junit4)

    http://www.cnblogs.com/xwdreamer/archive/2012/03/29/2423136.html 1.学习摘要 看<重构-改善既有代码的设计>这本书的时候, ...

  8. php页面get方法实现ajax,入门实例教程

    ajax,入门实例教程 本例针对php页面,做了一个小的demo加深对ajax的理解 1.文档结构: 共有ajax.php 和action.php 2个页面. 2.源码如下: /*ajax.php页面 ...

  9. wxpython使用实例_wxPython中文教程入门实例

    wxPython中文教程入门实例 wx.Window 是一个基类,许多构件从它继承.包括 wx.Frame 构件. 可以在所有的子类中使用 wx.Window 的方法. wxPython的几种方法: ...

  10. [深度学习-实践]GAN入门例子-利用Tensorflow Keras与数据集CIFAR10生成新图片

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之基于CIFAR10数据集的例子; 深度学习GAN(三)之基于手写体Mnist数据集的例子; 深度学习GAN(四)之PIX2PIX G ...

最新文章

  1. Spring事务管理--嵌套事务详解
  2. hive安装测试及Hive 元数据的三种存储方式
  3. 项目中的加减法--《最后期限》读书笔记(1)
  4. 为什么那么好的女孩子还单身?
  5. Anaconda下安装tensorflow-gpu踩坑日记
  6. Spring攻略学习笔记(13)------继承Bean配置
  7. android 项目将csv文件写入sqlite数据库 代码,如何将csv文件大容量插入sqlite c#
  8. 转] 两种自定义表单设计方案
  9. tomcat绿色版及安装版修改内存大小的方法
  10. [LeetCode] Rotate Array
  11. C语言循环语句的用法——while循环
  12. 记一次Comparator.comparing(XXX::getStartTime).reversed()失效
  13. 《东周列国志》第五十一回 责赵盾董狐直笔 诛斗椒绝缨大会
  14. 安装配置apache
  15. java的入口函数_java中有几种入口函数
  16. 学习代码要先学会“学习”
  17. 使用EJS脚本实现花生壳动态域名更新服务(二)
  18. oracle 检查索引失效,oracle 索引失效原因_汇总
  19. 2005-04-26 星期二
  20. 容器微服务的前世今生

热门文章

  1. python实现——处理Excel表格(超详细)
  2. 直流屏电源模块GF22007-2高频充电模块R22007
  3. 电脑能正常上网上网,某些软件不能上网
  4. 遇到服务器网络偶尔断线如何检查
  5. POJ 1088 滑雪 题解
  6. android qq音乐歌词怎么实现,Android自定义View,高仿QQ音乐歌词滚动控件!
  7. 自动加减工单结存算法实现
  8. 职场必备:十句外企 office 常用英语
  9. 我的appstore新游戏--LeBallon 拿码了
  10. 敏感词过滤的算法原理之 Aho-Corasick 算法