需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、生成对抗网络的概念

生成对抗网络(GANs,Generative Adversarial Nets),由Ian Goodfellow在2014年提出的,是当今计算机科学中最有趣的概念之一。GAN最早提出是为了弥补真实数据的不足,生成高质量的人工数据。GAN的主要思想是通过两个模型的对抗性训练。随着训练过程的推进,生成网络(Generator,G)逐渐变得擅长创建看起来真实的图像,而判别网络(Discriminator,D)则变得更擅长区分真实图像和生成器生成的图像。GAN网络不局限于提高单一网络的性能,而是希望实现生成器和鉴别器之间的纳什均衡。

假设在低维空间Z存在一个简单容易采样的分布p(z),例如正态分布 ,生成网络构成一个映射函数G:Z→X,判别网络需要判别输入是来自真实数据X_real还是生成网络生成的数据X_fake,结构示意图如图8-1所示

下面给出DCGAN利用LSUN数据库生成卧室样本的例子和生成人脸样本的例子,虽然DCGAN还难以生成高精度的图像样本,但这样的结果已经足够让世人感到惊艳

二、DCGAN在MNIST手写数据集上实战

通过本程序可以完成两个模型的训练。一个是生成模型,一个是判别模型

1:项目结构如下

代码大致可以分为以下几部分

1:构建生成网络

2:构建判别网络

3:DCGAN网络训练

开始下载模型

2:效果展示

生成图片如下 可以说效果十分逼真

这是第一张生成图片 可以看出里面有些字体还是略微不够真实,容易被判别器鉴别出来

这一张是图片生成的十分逼真,几乎没有什么缺点

三、代码

部分代码如下 全部代码和数据集请点赞关注收藏后评论区留言私信~~~

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from tensorflow.python.keras.layers import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from tensorflow.keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import mathdef generator_model():model = Sequential()model.add(Dense(input_dim=100, units=1024))model.add(Activation('tanh'))model.add(Dense(128*7*7))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))model.add(UpSampling2D(size=(2, 2)))model.add(Conv2D(64, (5, 5), padding='same'))model.add(Activation('tanh'))model.add(UpSampling2D(size=(2, 2)))model.add(Conv2D(1, (5, 5), padding='same'))model.add(Activation('tanh'))return modeldef discriminator_model():model = Sequential()model.add(Conv2D(64, (5, 5),padding='same',input_shape=(28, 28, 1)))model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(128, (5, 5)))model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Flatten())model.add(Dense(1024))model.add(Activation('tanh'))model.add(Dense(1))model.add(Activation('sigmoid'))return modeldef generator_containing_discriminator(g, d):model = Sequential()model.add(g)d.trainable = Falsemodel.add(d)return modeldef combine_images(generated_images):num = generated_images.shape[0]width = int(math.sqrt(num))height = int(math.ceil(float(num)/width))shape = generated_images.shape[1:3]image = np.zeros((height*shape[0], width*shape[1]),dtype=generated_images.dtype)for index, img in enumerate(generated_images):i = int(index/width)j = index % widthimage[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \img[:, :, 0]return imagedef train(BATCH_SIZE,path):(X_train, y_train), (X_test, y_test) = mnist.load_data()X_train = (X_train.astype(np.float32) - 127.5)/127.5X_train = X_train[:, :, :, None]X_test = X_test[:, :, :, None]# X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])d = discriminator_model()g = generator_model()d_on_g = generator_containing_discriminator(g, d)d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)g.compile(loss='binary_crossentropy', optimizer="SGD")d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)d.trainable = Trued.compile(loss='binary_crossentropy', optimizer=d_optim)for epoch in range(100):print("Epoch is", epoch)print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))for index in range(int(X_train.shape[0]/BATCH_SIZE)):noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]generated_images = g.predict(noise, verbose=0)if index % 20 == 0:image = combine_images(generated_images)image = image*127.5+127.5Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")X = np.concatenate((image_batch, generated_images))y = [1] * BATCH_SIZE + [0] * BATCH_SIZEd_loss = d.train_on_batch(X, y)print("batch %d d_loss : %f" % (index, d_loss))noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))d.trainable = Falseg_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)d.trainable = Trueprint("batch %d g_loss : %f" % (index, g_loss))if index % 10 == 9:g.save_weights('generator', True)d.save_weights('discriminator', True)def generate(BATCH_SIZE, nice=False):g = generator_model()g.compile(loss='binary_crossentropy', optimizer="SGD")g.load_weights('generator')if nice:s = g.predict(noise, verbose=1)d_pret = d.predict(generated_images, verbose=1)index = np.arange(0, BATCH_SIZE*20)index.resize((BATCH_SIZE*20, 1))pre_with_index = list(np.append(d_pret, index, axis=1))pre_with_index.sort(key=lambda x: x[0], reverse=True)nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)nice_images = nice_images[:, :, :, None]for i in range(BATCH_SIZE):idx = int(pre_with_index[i][1])nice_images[i, :, :, 0] = generated_images[idx, :, :, 0].predict(noise, verbose=1)image = combine_images(generated_images)image = image*127.5+127.5Image.fromarray(image.astype(np.uint8)).save("generated_image.png")def get_args():parser = argparse.ArgumentParser()parser.add_argument("--mode", type=str,default = 'train',)# parser.add_argument("--mode", type=str,default = 'generate',)parser.add_argument("--batch_size", type=int, default=8)parse
if __name__ == "__main__":args = get_args()if args.mode == "train":train(BATCH_SIZE=args.batch_size,path =args.path )elif args.mode == "generate":generate(BATCH_SIZE=args.batch_size, nice=args.nice)

创作不易 觉得有帮助请点赞关注收藏~~~

【Keras+计算机视觉+Tensorflow】DCGAN对抗生成网络在MNIST手写数据集上实战(附源码和数据集 超详细)相关推荐

  1. 对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析

  2. 【Android App】二维码的讲解及生成属于自己的二维码实战(附源码和演示 超详细必看)

    需要全部代码请点赞关注收藏后评论区留言~~~ 一.二维码基本内容介绍 条形码只能表达十几位数字编码,无法表示更复杂的数据. 二维码在二维方格上描出一个个黑点,从而表达更丰富的信息. 二维码早已在手机A ...

  3. 利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结。

    利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结. 摘要 一.神经网络与卷积网络的对比 1.数据处理 2.对获取到的数据进行归一化和独热编码 二.开始我们的tensorflow神经 ...

  4. 【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)

    需要源码和数据集请点赞关注收藏后评论区留言私信~~~ 一.OCR文字识别简介 利用计算机自动识别字符的技术,是模式识别应用的一个重要领域.人们在生产和生活中,要处理大量的文字.报表和文本.为了减轻人们 ...

  5. Tensorflow之 CNN卷积神经网络的MNIST手写数字识别

    点击"阅读原文"直接打开[北京站 | GPU CUDA 进阶课程]报名链接 作者,周乘,华中科技大学电子与信息工程系在读. 前言 tensorflow中文社区对官方文档进行了完整翻 ...

  6. PyTorch使用快速梯度符号攻击(FGSM)实现对抗性样本生成(附源码和数据集MNIST手写数字)

    需要源码和数据集请点赞关注收藏后评论区留言或者私信~~~ 一.威胁模型 对抗性机器学习,意思是在训练的模型中添加细微的扰动最后会导致模型性能的巨大差异,接下来我们通过一个图像分类器上的示例来进行讲解, ...

  7. Tensorflow GAN对抗生成网络实战

    这一节的回顾主要针对使用JS散度得DCGAN和基于GP理论和Wasserstein Distance理论的WGAN首先是DCGAN 我们的训练数据集是一堆这种二次元的动漫头像的图片,那么我们就是要训练 ...

  8. 深度学习核心技术精讲100篇(十二)-DCGAN(对抗生成网络)算法应用及代码实现

    前言 一次偶然看到一个换脸的视频,觉得实在是很神奇,于是饶有兴致的去了解一下换脸算法.原来背后有一个极为有意思的算法思想--对抗生成. 随后各种各样的GAN算法以指数级增长的方式涌现出来,比如WGAN ...

  9. Pytorch:GAN生成对抗网络实现MNIST手写数字的生成

    github:https://github.com/SPECTRELWF/pytorch-GAN-study 个人主页:liuweifeng.top:8090 网络结构 最近在疯狂补深度学习一些基本架 ...

最新文章

  1. Wiki为什么会流行
  2. UA MATH564 概率论IV 次序统计量例题1
  3. 摸清全国农村集体家底-农业大健康:产权改革谋定清产核资
  4. leetcode 310. Minimum Height Trees | 310. 最小高度树(图的邻接矩阵DFS / 拓扑排序)
  5. Aixs2发布webservice服务
  6. 收银机服务器操作系统,第二章 超市收银机操作系统最终版.doc
  7. 关于maven pom
  8. C语言日字,【C语言日日练(二)】static关键字
  9. ceb怎么转换成word_pdf怎么转换成word?这个方法值得一试
  10. 装机软件备忘、分类介绍 评点
  11. java开发入职注意
  12. 《算法导论》第三版第4章 分治策略 练习思考题 个人答案
  13. 网易云音乐直链提取及下载
  14. 计算机网络管理员期末,计算机网络管理员期中考试统一试题(A)
  15. docker部署案例
  16. div css 会员登录表单,html5 css3谷歌会员登录表单界面特效
  17. Java版Spring Cloud B2B2C o2o鸿鹄云商平台--概述
  18. C++写入并追加内容到txt中
  19. MySQL基础知识点集合
  20. 基于RESTful的FastAPI服务模板

热门文章

  1. DNS地址解析的设置
  2. js调用pc摄像头实现拍照、录视频等,新版Chrome无访问http页面无法打开麦克风、摄像头
  3. 班主任如何展开期中表彰班会
  4. 【虹科】“天问一号”着落的火星,你也想亲眼见证吗?——天体物理观测、短波红外技术与SIRIS相机
  5. 【虹科案例】虹科数字化仪在氢燃料电池测试中的应用
  6. HTML基础 - HTML列表
  7. python计算平方根保留两位小数_python 使用二分法计算平方根
  8. openLDAP安装经验分享
  9. 【学习之路】spring boot 整合mybatis报错 “serverTimezone=UTC“
  10. 同济大学高等数学上册第七章微分方程以及每日一题