GAN入门实例【个人理解】
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
@GAN入门实例【个人理解】
前言
本文记录了B站[GAN生成对抗网络精讲 tensorflow2.0代码实战],相关理解(https://b23.tv/jlDcosT)
代码流程
1.引入库
import tensorflow as tf
from tensorflow import keras
from tensorflow .keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
TensorFlow包
- 是一个深度学习库,由 Google 开源,可以对定义在 Tensor(张量)上的函数自动求导
- TensorFlow会给我们一张空白的数据流图,我们往这张数据流图填充(创建节点),从而实现想要效果。
- 可以参考该链接](https://www.jianshu.com/p/6766fbcd43b9)
keras - Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行
- 允许简单而快速的原型设计(由于用户友好,高度模块化,可扩展性
- 同时支持卷积神经网络和循环神经网络,以及两者的组合。
- 在 CPU 和 GPU 上无缝运行
matplotlib - 画图,在这里主要是使acc/loss可视化
numpy - NumPy是使用Python进行科学计算的基础软件包
- 功能强大的N维数组对象
- 精密广播功能函数
- 强大的线性代数、傅立叶变换和随机数功能
glob - 查找文件路径
os - os 模块提供了非常丰富的方法用来处理文件和目录
2.查看版本
tf.__version__
查看tensorflow版本
3.查看版本
(train_images,train_labels), (_, _)=tf.keras.datasets.mnist.load_data()
等号左边第一括号是训练数据,第二个括号是测试数据,测试数据暂时不需要,就先用占位符占着,等号右边是加载mnist数据集
4.查看数据集shape
train_images.shape
- 60000:这个数据集有60000张图片
- 2828:每张图片像素值是2828
5.查看数据类型
train_images.dtype
6.查看版本
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
reshape成一个四维张量
7.归一化,定义两个常量
train_images = (train_images-127.5)/127.5
BATCH_SIZE = 256
BUFFER_SIZE = 6000
归一化有两种方式:
- img/255.0:图像x取值范围[0,1]
- List item
( img/127.5)-1:图像x取值范围[-1,1]
8.from_tensor_slices
datasets = tf.data.Dataset.from_tensor_slices(train_images)
tf.data.Dataset.from_tensor_slices
-该函数是dataset核心函数之一,它的作用是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。如果有多个特征进行组合,那么一次切片是把每个组合的最外维度的数据切开,分成一组一组的
- 假设我们现在有两组数据,分别是特征和标签,为了简化说明问题,我们假设每两个特征对应一个标签。之后把特征和标签组合成一个tuple,那么我们的想法是让每个标签都恰好对应2个特征,而且像直接切片,比如:[f11, f12] [t1]。f11表示第一个数据的第一个特征,f12表示第1个数据的第二个特征,t1表示第一个数据标签。那么tf.data.Dataset.from_tensor_slices就是做了这件事情:
9.shuffle()
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
shuffle:
- shuffle() 方法将序列的所有元素随机排序。
10.查看数据集类型
datasets
11.定义生成器模型
def generator_model():model = keras.Sequential()model.add(layers.Dense(256,input_shape=(100,), use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(512, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(28*28*1, use_bias=False,activation = 'tanh'))model.add(layers.BatchNormalization())model.add(layers.Reshape((28,28,1)))return model
Sequential
- Sequential模型的核心操作是添加layers(图层)
- 首先,输入层,中间层,输出层
- 然后,进入最重要的部分:编译,包括优化器和损失函数
- 接着,调用fit()函数
- 最后可以用evaluate()方法进行评估
参考 https://blog.csdn.net/mogoweb/article/details/82152174
Dense
卷积取的是局部特征,全连接就是把以前的局部特征重新通过权值矩阵组装成完整的图。因为用到了所有的局部特征,所以叫全连接
BatchNormalization
- 结果:均值约为0,标准差约为1
- 作用:
加快收敛速度;
控制过拟合,可以少用或不用dropout和正则
降低网络对初始权重的不敏感
允许使用较大的学习率
LeakyReLU - 激活函数,当x>0,f(x)=x;当x<=0,f(x)=alpha*x
12.定义判别器模型
def discriminator_model():model = keras.Sequential()model.add(layers.Flatten())model.add(layers.Dense(512, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(256, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(1))return model
Flatten
- 生成器生成的是28281是三维的,现在把它变成一维的
13.交叉熵
为什么可以用交叉熵做损失函数?
因为随着预测值越来越准确,交叉熵的值越来越小。
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
binaryCrossentropy()
tf.keras.losses.BinaryCrossentropy(yTrue,yPred)
- yTrue:它是事实的二进制张量输入,可以是tf.Tensor类型。
- yPred:它是预测的指定二进制张量输入,可以是tf.Tensor类型。
- 返回值:它返回tf.Tensor对象。
from_logits
- False:(即输出层是带softmax激活函数的),那么其结果会被clip在 [epsilon_, 1 - epsilon_] 的范围之内。这是由于简单的softmax函数会有数值溢出的问题 (参见 softmax溢出问题)。因此如果是先计算softmax再计算cross-entropy,那么要通过clip防止数据溢出问题。
- Ture:模型会将 softmax和cross-entropy 结合在一起计算
转载https://blog.csdn.net/muyuu/article/details/122762442
14.判别器损失函数
def discriminator_loss(real_out,fake_out):read_loss = cross_entropy(tf.ones_like(real_out),real_out)fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)return read_loss+fake_loss
我们希望real_out被判定为1;fake_out被判定为0.
- ones_like():全1矩阵
- zeros_like():全0矩阵
15.生成器损失函数
def generator_loss(fake_out):return cross_entropy(tf.ones_like(fake_out),fake_out)
我们希望生成的图片尽可能多的1;
16.生成器和判别器的优化器
generator_opt = tf.keras.optimizers.Adam(1e-4)
discriminator_opt = tf.keras.optimizers.Adam(1e-4)
几种常用优化器:
随机梯度下降法(SGD):
keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)lr:大或等于0的浮点数,学习率
momentum:大或等于0的浮点数,动量参数
decay:大或等于0的浮点数,每次更新后的学习率衰减值
nesterov:布尔值,确定是否使用Nesterov动量Adagrad:
keras.optimizers.Adagrad(lr=0.01, epsilon=1e-06)epsilon:大或等于0的小浮点数,防止除0错误
Adam
keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)beta_1/beta_2:浮点数, 0<beta<1,通常很接近1
17.tf.random_normal()
EPOCHS = 100
noise_dim = 100
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate,noise_dim])
noise_dim = 100:
用长度为100的随机向量来生成手写数据集num_exp_to_generate = 16
16个tf.random_normal()
tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)服从指定正态分布的序列”中随机取出指定个数的值。
shape: 输出张量的形状,必选
转载:https://blog.csdn.net/dcrmg/article/details/79028043
18.调用生成器模型
generator = generator_model()
19.调用判别器模型
discriminator = discriminator_model()
20.定义批次训练的函数
def train_step(images):noise = tf.random.normal([BATCH_SIZE,noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:real_out = discriminator(images,training=True)gen_image = generator(noise,training=True)fake_out = discriminator(gen_image,training=True)gen_loss = generator_loss(fake_out)disc_loss = discriminator_loss(real_out,fake_out)gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables)generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:追踪两个model生成的梯度。
with enter() as variable:执行过程
- step1:variable = enter()
- step2:exit()
关于梯度函数gradients,可参考
https://blog.csdn.net/QKK612501/article/details/115335437?spm=1001.2101.3001.6650.1&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1.pc_relevant_default&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1.pc_relevant_default&utm_relevant_index=2
apply_gradients(
grads_and_vars, name=None, experimental_aggregate_gradients=True
)
- grads_and_vars (梯度,变量)对的列表
- ame 返回操作的可选名称。默认为传递给Optimizer 构造函数的名称。
- experimental_aggregate_gradients 是否在存在 tf.distribute.Strategy 的情况下对来自不同副本的梯度求和。如果为 False,则聚合梯度是用户的责任。默认为真。
21.绘制生成图片
def generator_plot_image(gen_model,test_noise):pre_images = gen_model(test_noise,training=False)fig = plt.figure(figsize=(4,4))for i in range(pre_images.shape[0]):plt.subplot(4,4,i+1)plt.imshow((pre_images[i, :, :, 0]+1)/2,cmap='gray')plt.axis('off')plt.show()
figure(figsize=(4,4)):初始4*4的画布,因为噪音有16个
subplot(4,4,i+1):四行四列第i+1个
imshow((pre_images[i, :, :, 0]+1)/2,cmap=‘gray’):显示第i张图片,取所有的长和宽,tanh的范围是[-1,1]将其转化为[0,1]上,画布是灰色的
plt.axis(‘off’):不显示坐标
22.训练
def train(dataset,epochs):for epoch in range(epochs):for image_batch in dataset:train_step(image_batch)print('.',end='')generator_plot_image(generator,seed)
23.运行
train(datasets,EPOCHS)
总结
懵懵懂懂的过了一遍,对全连接层还是很不理解,继续加油吧!争取早日毕业
GAN入门实例【个人理解】相关推荐
- linux怎么运行datastage,ETL工具Datastage入门+实例(易理解)
引言 传统的数据整合方式需要大量的手工编码,而采用 IBM WebSphere DataStage 进行数据 整合可以大大的减少手工编码的数量,而且更加容易维护.数据整合的核心内容是从数据源中抽取 数 ...
- Datastage入门+实例(易理解)
转自http://www.ibm.com/developerworks/cn/data/library/techarticles/dm-0602zhoudp/ 传统的数据整合方式需要大量的手工编码,而 ...
- 《HFSS电磁仿真设计从入门到精通》一第2章 入门实例——T形波导的内场分析和优化设计...
本节书摘来自异步社区<HFSS电磁仿真设计从入门到精通>一书中的第2章,作者 易迪拓培训 , 李明洋 , 刘敏,更多章节内容可以访问云栖社区"异步社区"公众号查看 第2 ...
- Java Socket入门实例
基于测试驱动的Socket入门实例(代码的具体功能可以看我的程序中的注释,不理解的可以短信我) 先看Server的代码: package socketStudy; import java.io.Buf ...
- linux Shell(脚本)编程入门实例讲解详解
linux Shell(脚本)编程入门实例讲解详解 为什么要进行shell编程 在Linux系统中,虽然有各种各样的图形化接口工具,但是sell仍然是一个非常灵活的工具.Shell不仅仅是命令的收集, ...
- Windows 外壳扩展编程入门实例
Windows 外壳扩展编程入门实例 -- Delphi 篇 作者的话 关于Windows 外壳扩展方面的文章私心以为最好的应当算是Michael Dunn 的TheComplete Idiot's ...
- JUnit学习摘要+入门实例 (junit4)
http://www.cnblogs.com/xwdreamer/archive/2012/03/29/2423136.html 1.学习摘要 看<重构-改善既有代码的设计>这本书的时候, ...
- php页面get方法实现ajax,入门实例教程
ajax,入门实例教程 本例针对php页面,做了一个小的demo加深对ajax的理解 1.文档结构: 共有ajax.php 和action.php 2个页面. 2.源码如下: /*ajax.php页面 ...
- wxpython使用实例_wxPython中文教程入门实例
wxPython中文教程入门实例 wx.Window 是一个基类,许多构件从它继承.包括 wx.Frame 构件. 可以在所有的子类中使用 wx.Window 的方法. wxPython的几种方法: ...
- [深度学习-实践]GAN入门例子-利用Tensorflow Keras与数据集CIFAR10生成新图片
系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之基于CIFAR10数据集的例子; 深度学习GAN(三)之基于手写体Mnist数据集的例子; 深度学习GAN(四)之PIX2PIX G ...
最新文章
- Spring事务管理--嵌套事务详解
- hive安装测试及Hive 元数据的三种存储方式
- 项目中的加减法--《最后期限》读书笔记(1)
- 为什么那么好的女孩子还单身?
- Anaconda下安装tensorflow-gpu踩坑日记
- Spring攻略学习笔记(13)------继承Bean配置
- android 项目将csv文件写入sqlite数据库 代码,如何将csv文件大容量插入sqlite c#
- 转] 两种自定义表单设计方案
- tomcat绿色版及安装版修改内存大小的方法
- [LeetCode] Rotate Array
- C语言循环语句的用法——while循环
- 记一次Comparator.comparing(XXX::getStartTime).reversed()失效
- 《东周列国志》第五十一回 责赵盾董狐直笔 诛斗椒绝缨大会
- 安装配置apache
- java的入口函数_java中有几种入口函数
- 学习代码要先学会“学习”
- 使用EJS脚本实现花生壳动态域名更新服务(二)
- oracle 检查索引失效,oracle 索引失效原因_汇总
- 2005-04-26 星期二
- 容器微服务的前世今生