学习目标:

  1. 理解生成对抗网络的基本原理。
  2. 掌握利用生成对抗网络生成新样本的方法。

学习内容:

fashion_mnist数据库(from keras.datasets import fashion_minist)数据集包含了10个类别的图像,分别是:t-shirt(T恤),trouser(牛仔裤),pullover(套衫),dress(裙子),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包),ankle boot(短靴),如下图。利用fashion_mnist数据库的训练数据构造生成对抗网络,并生成新的图片显示出来。

  


学习过程:

网络结构:

设置训练间隔和批量大小设置为500/10000:

运行结果如下图:

设置训练间隔和批量大小设置为500/5000:

运行结果如下图:

把图片保存在与源码相同目录下的文件夹中:


源码:

from keras.layers import Dense,BatchNormalization
from keras.layers import Conv2D, Flatten,LeakyReLU
from keras.layers import Reshape, Conv2DTranspose, Activation
from keras import Model,Sequential,Input
from keras.datasets import fashion_mnist
from keras.optimizers import RMSpropimport os,math
import numpy as np
import matplotlib.pyplot as plt# In[1]: 构造生成网络
# 生成网络将一维向量(100,)反向构造成图片所对应的矩阵(28,28,1)
def  build_generator(latent_shape, image_shape):# latent_shape = (100,)# image_shape = (28,28,1)begin_shape = (image_shape[0] // 4, image_shape[1] // 4)model = Sequential( [#Input(latent_shape),   # (100,) -> (7*7*128,) -> (7,7,128)Dense(begin_shape[0] * begin_shape[1] * 128,input_shape=latent_shape),Reshape((begin_shape[0], begin_shape[1], 128)),BatchNormalization(),Activation('relu'),# (7,7,128) -> (14,14,128)Conv2DTranspose(filters=128, kernel_size=5,strides=2,padding='same'),BatchNormalization(),Activation('relu'),# (14,14,128) -> (28,28,64)Conv2DTranspose(filters=64, kernel_size=5,strides=2,padding='same'),BatchNormalization(),Activation('relu'),# (28,28,64) -> (28,28,32)Conv2DTranspose(filters=32, kernel_size=5,strides=1,padding='same'),# (28,28,32) -> (28,28,1)BatchNormalization(),Activation('relu'),Conv2DTranspose(filters=1, kernel_size=5,strides=1,padding='same'),Activation('sigmoid') # 输出一个 (28,28,1) 的矩阵,每个像素值为0到1],name='generator')# 需要和判别器一起构造 对抗网络,用对抗网络训练生成器的参数return model# In[2]: 构造判别网络
# 判别网络输入一个 (28,28,1) 的图片,输出一个0到1的数,0:假样本,1:真样本
def  build_discriminator(image_shape):# image_shape=(28,28,1)discriminator = Sequential( [# (28,28,1) -> (14,14,32)LeakyReLU(alpha=0.2,input_shape=image_shape),Conv2D(32, kernel_size=5, strides=2, padding="same"), # (14,14,32) -> (7,7,64)LeakyReLU(alpha=0.2),Conv2D(64, kernel_size=5, strides=2, padding="same"), # (7,7,64) -> (4,4,128) LeakyReLU(alpha=0.2),Conv2D(128, kernel_size=5, strides=2, padding="same"), # (4,4,128) -> (4,4,256)LeakyReLU(alpha=0.2),Conv2D(256, kernel_size=5, strides=1, padding="same"), Flatten(),Dense(1),Activation('sigmoid') # 输出一个0到1的数,0:假样本,1:真样本],name='discriminator')return discriminator# In[3]: 显示和保存生成器构造的一批图片(5*5=25张)
def plot_images(generator, noise_input, show=False, step=0, model_name = ''):os.makedirs(model_name, exist_ok=True)filename = os.path.join(model_name, "%05d.png" % step)images = generator.predict(noise_input)plt.figure(figsize = (2.2, 2.2))num_images = images.shape[0]rows = int(math.sqrt(noise_input.shape[0]))for i in range(num_images):plt.subplot(rows, rows, i + 1)image = np.reshape(images[i], [images.shape[1], images.shape[2]])plt.imshow(image, cmap= 'gray')plt.axis('off')plt.savefig(filename)if show:plt.show()else:plt.close('all')# In[4]: 构建判别网络 和 对抗网络(生成网络+判别网络),并设置训练参数
# 设置训练相关的参数
model_name = 'DCGAN_mnist'
latent_dim = 100
batch_size = 64
train_steps = 10000 # 训练train_steps个batch,这里可更改为10000或5000
lr = 2e-4
decay = 6e-8
latent_shape = (latent_dim,)# 读取数据,获取图片大小。无监督训练,不需要标签。只是为了生成新样本,不需要测试样本进行对比
(x_train, _), (_, _) = fashion_mnist.load_data()
image_shape = (x_train.shape[1],x_train.shape[2],1)# 数据预处理,二维卷积操作的输入数据要求:[样本数,宽度,高度,通道数]
x_train = np.reshape(x_train, [-1, image_shape[0], image_shape[1], 1])
x_train = x_train.astype('float32') / 255  # 生成网络的输出的像素值是0到1之间的# 编译判别网络
discriminator = build_discriminator(image_shape)
discriminator.compile(loss = 'binary_crossentropy', optimizer = RMSprop(lr=lr, decay=decay),metrics = ['accuracy'])
discriminator.summary()# 构建并编译对抗网络(生成网络+判别网络)
generator = build_generator(latent_shape,image_shape)
generator.summary()discriminator.trainable = False # 训练生成者时识别者网络要保持不变
input_latent = Input(latent_shape, name='adversarial_input')
outputs = discriminator(generator([input_latent]))
adversarial = Model([input_latent], outputs, name='adversarial')
adversarial.compile(loss = 'binary_crossentropy',optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5),metrics = ['accuracy'])
adversarial.summary()# In[5]: 训练网络
'''
1) 先冻结生成网络,采样 真实图片 和 生成网络输出的假样本,训练判别网络,区分两类样本
2) 然后冻结判别网络,让生成网络构造图片输入给判别网络,训练生成网络,使得判别网络输出越接近1越好
'''save_interval = 500 # 训练每间隔500个batch把生成网络输出的图片保存下来# 构造给生成网络的一维随机向量,每隔500个batch训练后,都生成同样的这25个伪造样本,方便对比
noise_input = np.random.uniform(-1.0, 1.0, size = [5*5, latent_dim])
train_size = x_train.shape[0]for i in range(train_steps):# 1. 先训练判别网络,将真实图片和伪造图片同时输入判别网络,让判别网络学会区分真假图片# 随机选取真实图片rand_indexes = np.random.randint(0, train_size, size = batch_size)real_images = x_train[rand_indexes]#让生成网络构造伪造图片noise = np.random.uniform(-1.0, 1.0, size = [batch_size, latent_dim])fake_images = generator.predict(noise)# 合并真实图片和伪造图片,设置真实图片对应标签1,虚假图片对应标签0   x = np.concatenate((real_images, fake_images))y = np.ones([2 * batch_size, 1])y[batch_size:, :] = 0.0# 训练判别网络,用一个batch的真实图片和一个batch的伪造图片loss, acc = discriminator.train_on_batch(x, y)log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)# 2. 然后再训练生成网络:冻结判别网络,让生成网络构造图片输入给判别网络,使得输出越接近1越好noise = np.random.uniform(-1.0, 1.0, size = [batch_size, latent_dim])y = np.ones([batch_size, 1]) # 注意此时假样本的标签为1,即要使得输出越接近1越好# 训练生成网络时需要使用到判别网络返回的结果,因此从两者连接后的对抗网络进行训练loss, acc = adversarial.train_on_batch(noise, y)log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)# 每隔save_interval次保存训练结果if (i+1) % save_interval == 0:print(log)if (i + 1) == train_steps:show = Trueelse:show = False#将生成者构造的图片绘制出来plot_images(generator, noise_input = noise_input,show = show, step = i+1,model_name = model_name)# 保存生成网络的权重generator.save_weights(model_name + "_generator.h5")# In[6]: 直接读取以前训练的权重(可以不用重复执行步骤[5]训练网络),生成伪造图片
#构造一批随机初始化的一维向量让生成者网络创造图片
generator.load_weights(model_name + "_generator.h5")
noise = np.random.randint(-1.0, 1.0, size=[5*5, 100])
plot_images(generator, noise_input = noise, show=True, model_name=model_name)

源码下载


学习产出:

  1. 把批量大小更改为5000和10000后,每500个间隔就把图片保存下来,训练需要的时间比较长,但效果比较好,能辨别出是fashion_mnist数据库的图像;

人工智能--生成对抗网络相关推荐

  1. 从生成对抗网络到更自动化的人工智能

    来源:中国计算机协会 作者:黄鹤   王长虎 概要:"What I cannot create, I do not understand." 这是著名物理学家费曼的一句名言.把这句 ...

  2. [Python人工智能] 二十九.什么是生成对抗网络GAN?基础原理和代码普及(1)

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CN ...

  3. 人工智能知识全面讲解:生成对抗网络的应用

    13.2.1 GAN的特点 GAN从2015年提出至今,短短4年的时间已经发展成为人工智能学界一个热 门的研究方向,吸引了大批研究人员来研究 GAN.除了学术界的理论研究以 外,许多科技公司已经付诸行 ...

  4. 人工智能知识全面讲解:初识生成对抗网络

    13.1.1 猫和老鼠的游戏 在2016年7月,一款国外的照片处理软件火遍了全世界,同时也引爆了国 人的朋友圈.这款产品就是Prisma.Prisma可以按照你提供的图片内容和指定 的风格,生成一副指 ...

  5. [人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  6. 人工智能 - paddlepaddle飞桨 - 深度学习基础教程 - 生成对抗网络

    生成对抗网络 本教程源代码目录在book/09.gan,初次使用请您参考Book文档使用说明. 说明:¶ 硬件环境要求: 本文可支持在CPU.GPU下运行 Docker镜像支持的CUDA/cuDNN版 ...

  7. [人工智能-深度学习-63]:生成对抗网络GAN - 图片创作:普通GAN, pix2pix, CycleGAN和pix2pixHD的演变过程

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  8. 人工智能--条件生成对抗网络

    学习目标: 理解条件生成对抗网络的基本原理. 掌握利用条件生成对抗网络生成新样本的方法. 学习内容: fashion_mnist数据库(from keras.datasets import fashi ...

  9. 2021-01-24过去十年十大AI研究热点,分别为深度神经网络、特征抽取、图像分类、目标检测、语义分割、表示学习、生成对抗网络、语义网络、协同过滤和机器翻译。

    专利申请量全球第一!清华人工智能发展报告:国内215所高校成立相关本科专业 发布时间:01-2415:20万象大会年度获奖创作者,东方财富网官方帐号 1月20日,清华大学人工智能研究院.清华-中国工程 ...

最新文章

  1. js控制input框输入数字时,累计求和
  2. Python_Statsmodels包_时间序列分析_ARIMA模型
  3. Py之torchvision:torchvision库的简介、安装、使用方法之详细攻略
  4. html多语言国际化,gMIS吉密斯i18n多语言国际化更新
  5. 最近做了一个安装包的安装流程图
  6. 16进制数用空格分开 tcp_面试时,你是否被问到过TCP/IP协议?
  7. 用c#控制台模拟双色球随机选
  8. github使用_简单使用Git与github
  9. 传说中的80后的17条潜规则,你占了几条...
  10. B/S架构 Web打印程序(Argox)
  11. Nginx 常用配置,避坑指南!
  12. 计算机系第一学期电脑,公共课第一学期《计算机基础》.doc
  13. 魔鬼训练Day2作业
  14. SqlMapTamper使用指南
  15. 技术领导者空降后,如何管理全新的团队
  16. 普渡大学统计与计算机科学,普渡大学西拉法叶分校
  17. 日语随记_(文本编辑*)
  18. uniapp的uniapp navigateTo 点击无法跳转的问题
  19. 人人视频显示服务器睡着了,人人视频显示连接超时
  20. 在职研究生读计算机专业,读计算机专业在职研究生让我择业自如高升有望

热门文章

  1. linux系统挂载逻辑卷和扩展逻辑卷组
  2. 【vue】输入文字生成二维码
  3. 三立期货:掌财社怎么用筹码分布图看主力成本?
  4. JAVA入门算法题(十五)
  5. JAVA入门算法题(十二)
  6. 记得绑定邮箱 接收CSDN停用通知
  7. AI在招聘领域的这些应用,你会是被第一轮淘汰的吗
  8. 【四】软考—计算机网络
  9. 关于JavaScript的继承
  10. leetcode系列-242.有效的字母异位词