来自课程GAN生成对抗网络精讲 tensorflow2.0代码实战 全网最简洁易懂的GAN课程
的代码

tensorflow版本为2.0
升级版本看https://blog.csdn.net/qq_43620967/article/details/108835207

jupyter 文件
链接:https://pan.baidu.com/s/1PR4phiKAEoK4sAAmTWR18A
提取码:z9od

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
tf.__version__

Out[104]: ‘2.3.1’

输入

(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
train_images.shape

Out[106]: (60000, 28, 28)

60000张图,都是28*28 像素

train_images.dtype

Out[107]: dtype(‘uint8’)

数据预处理

train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images.shape

Out[109]: (60000, 28, 28, 1)

train_images=(train_images-127.5)/127.5#归一化 到【-1,1】
BATCH_SIZE=256
BUFFER_SIZE=600000
datasets=tf.data.Dataset.from_tensor_slices(train_images)
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
datasets

Out[114]: <BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

第一维度表示个数

生成器模型

def generator_model():model=tf.keras.Sequential()model.add(layers.Dense(256,input_shape=(100,),use_bias=False))#Dense全连接层,input_shape=(100,)长度100的随机向量,use_bias=False,因为后面有BN层model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())#激活#第二层model.add(layers.Dense(512,use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())#激活#输出层model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))model.add(layers.BatchNormalization())model.add(layers.Reshape((28,28,1)))#变成图片 要以元组形式传入return model

辨别器模型

def discriminator_model():#输入图片model=keras.Sequential()model.add(layers.Flatten())#把输入的三维图片扁平化model.add(layers.Dense(512,use_bias=False))    model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())#激活    model.add(layers.Dense(256,use_bias=False))    model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())#激活      model.add(layers.Dense(1))#输出数字,>0.5真实图片return model

loss函数

cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)#from_logits=True因为最后的输出没有激活

判别器损失函数

def discriminator_loss(real_out,fake_out):#辨别器的输出 真实图片判1,假的图片判0real_loss=cross_entropy(tf.ones_like(real_out),real_out)fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)return real_loss+fake_loss

生成器损失函数

def generator_loss(fake_out):#希望fakeimage的判别输出fake_out判别为真return cross_entropy(tf.ones_like(fake_out),fake_out)

优化器

generator_opt=tf.keras.optimizers.Adam(1e-4)#学习速率
discriminator_opt=tf.keras.optimizers.Adam(1e-4)
EPOCHS=100
noise_dim=100 #长度为100的随机向量生成手写数据集
​
num_exp_to_generate=16 #每步生成16个样本
​
seed=tf.random.normal([num_exp_to_generate,noise_dim]) #生成随机向量观察变化情况

训练

generator=generator_model()
discriminator=discriminator_model()

定义批次训练函数

def train_step(images):noise=tf.random.normal([BATCH_SIZE,noise_dim]) #生成随机数with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: #记录训练过程real_out = discriminator(images,training=True)gen_image = generator(noise,training=True)fake_out=discriminator(gen_image,training=True)gen_loss = generator_loss(fake_out)disc_loss = discriminator_loss(real_out,fake_out)#训练过程gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables) #生成器与生成器可训练参数的梯度gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables) generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

可视化

def generator_plot_image(gen_model,test_noise): #观察训练出的图象pre_images = gen_model(test_noise,training=False)fig = plt.figure(figsize=(4,4))for i in range(pre_images.shape[0]):plt.subplot(4,4,i+1) #从1开始排plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray') #归一化,灰色度plt.axis('off') #不显示坐标轴plt.show()

训练

def train(dataset,epochs):for epoch in range(epochs):for image_batch in dataset:train_step(image_batch)#print('.',end='')print('第'+str(epoch+1)+'次训练结果')generator_plot_image(generator,seed)
train(datasets,EPOCHS)

结果

第1次训练结果

第100次训练结果

手写数字识别--日月光华的gan小例子相关推荐

  1. 机器学习:手写数字识别(Hand-written digits recognition)小项目

    该项目的所有代码在我的github上,欢迎有兴趣的同学与我探讨研究~ 地址:Machine-Learning/machine-learning-ex3/ 1. Introduction 手写数字识别( ...

  2. 深度学习数字仪表盘识别_【深度学习系列】手写数字识别实战

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  3. 手写数字识别的小优化

    在用KNN实现手写数字识别的时候突发奇想是否可以根据数字的特点对其进行一个分类,以此来提高判断的准确率.经过许多天的改进与完善终于实现了此算法,就称其为洞拐法吧.(第一次写,有许多不足之处,还望多多包 ...

  4. GAN变种ACGAN利用手写数字识别mnist生成手写数字

    1.摘要 本文主要讲解:GAN变种ACGAN利用手写数字识别mnist数据集进行训练,最终生成手写数字图片 主要思路: Initialize generator and discriminator I ...

  5. pytorch手写数字识别【源码实现-小清新版】

    引言 手写数字识别,也就是让机器能够习得图片中的手写数字,并能正确归类. 本文使用 pytorch 搭建一个简单的神经网络,实现手写数字的识别, 从本文,你可了解到: 1.搭建神经网络的流程 2.完成 ...

  6. PYQT5+CNN(TensorFlow-keras)做一个简单的手写数字识别PC端图形化小程序

    目录 前言 一.功能介绍 1.画板识别 2.图片识别 二.UI设计 1.整体设计思想 2.颜色设计 3.Logo 设计 4.按钮设计 三.算法介绍 1.图片预处理 2.数字分割和显示 3.识别算法 4 ...

  7. 【学习日记】手写数字识别及神经网络基本模型

    2021.10.7 [学习日记]手写数字识别及神经网络基本模型 1 概述 张量(tensor)是数字的容器,是矩阵向任意维度的推广,其维度称为轴(axis).深度学习的本质是对张量做各种运算处理,其分 ...

  8. 深度学习--TensorFlow(项目)Keras手写数字识别

    目录 效果展示 基础理论 1.softmax激活函数 2.神经网络 3.隐藏层及神经元最佳数量 一.数据准备 1.载入数据集 2.数据处理 2-1.归一化 2-2.独热编码 二.神经网络拟合 1.搭建 ...

  9. 教程 | 基于LSTM实现手写数字识别

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 基于tensorflow,如何实现一个简单的循环神经网络,完成手写 ...

最新文章

  1. K近邻算法KNN的简述
  2. Py之gensim:gensim的简介、安装、使用方法之详细攻略
  3. 产品必备:注册登录完整解决方案 | 含原型下载
  4. The directory '*' or its parent directory is not owned by the current user
  5. [传奇单机架设]DBC2000数据库使用教程
  6. EdgeRouter X设置外网远程访问和HTTPS连接指定出口网关
  7. Flume实战采集文件内容存入HDFS
  8. BI与大数据之间的差距有哪些
  9. LTE下行物理层传输机制(5)-DCI格式的选择和DCI1A
  10. 最简单的单片机c语言程序,单片机的C语言编程基础知识(初学注意)
  11. (实例解析)Python 函数调用的几种方式(类里面,类之间,类外面)
  12. Hourglass网络的理解和代码分析
  13. PHP自学教程之PHP语法基础
  14. Redis I/O 多路复用
  15. MySQL大厂优化方案轻松应对高并发!真牛!
  16. vue-router防跳墙控制
  17. 面试侃集合 | ArrayBlockingQueue篇
  18. 比上清华更难的,是加入这支中国顶级黑客战队
  19. Flutter如何集成第三方插件
  20. 分享7个实用的电脑软件,满满的干货,大家低调收藏

热门文章

  1. 在JSP中连接数据库及使用
  2. 端到端机器学习_使用automl进行端到端的自动化机器学习过程
  3. 空间大地测量与GPS导航定位时间系统相互转换,格里高利时通用时儒略日,GPS时,年积日相互转换
  4. json mysql 字段 默认值_MySQL新增JSON类型字段的使用总结
  5. linux怎么撤销关机命令,Linux的shutdown命令
  6. 以太网PLC无线WIFI跨网段通讯和Modbus仪表数据采集
  7. 权重 缩写 英文_常用英语术语缩写--采购
  8. WebService(腾讯QQ在线状态 WEB 服务)
  9. JavaScript 常见的设计模式
  10. 【大数据分析软件应用在足球预测实例】足球滚球走地大小球分析方法和技巧