学习目标:

  1. 理解条件生成对抗网络的基本原理。
  2. 掌握利用条件生成对抗网络生成新样本的方法。

学习内容:

fashion_mnist数据库(from keras.datasets import fashion_minist)数据集包含了10个类别的图像,分别是:t-shirt(T恤),trouser(牛仔裤),pullover(套衫),dress(裙子),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包),ankle boot(短靴),如下图。利用fashion_mnist数据库的训练数据构造条件生成对抗网络,并分别生成10个类别的新的图片显示出来。

  


学习过程:

网络结构:

设置训练间隔和批量大小为500/5000:

运行结果如下图:

设置训练间隔和批量大小为500/10000:

运行结果如下图:

把图片保存在与源码相同目录下的文件夹中:


源码:

# 条件GANfrom keras.layers import Dense,BatchNormalization,concatenate
from keras.layers import Conv2D, Flatten,LeakyReLU
from keras.layers import Reshape, Conv2DTranspose, Activation
from keras import Model,Sequential,Input,utils
from keras.datasets import fashion_mnist
from keras.optimizers import RMSpropimport os
import numpy as np
import matplotlib.pyplot as plt
import math# In[1]: 构造生成网络
# 生成网络将一维向量(100,)及其类别向量(10,),反向构造成图片所对应的矩阵(28,28,1)
def  build_generator(latent_shape, label_shape, image_shape):# latent_dim = 100 # label_shape=(10,)# 由于有2个输入,所以使用函数式模型构造网络比较方便 # different!input_latent = Input(latent_shape, name='generator_input')input_label = Input(label_shape, name='input_label')#将100维的输入向量与10维的One-hot-vector结合在一起,成为110维的xx = concatenate([input_latent, input_label], axis = 1) # different!begin_shape = (image_shape[0] // 4, image_shape[1] // 4)model = Sequential( [Dense(begin_shape[0] * begin_shape[1] * 128),Reshape((begin_shape[0], begin_shape[1], 128)),BatchNormalization(),Activation('relu'),# (7,7,128) -> (14,14,128)Conv2DTranspose(filters=128, kernel_size=5,strides=2,padding='same'),BatchNormalization(),Activation('relu'),# (14,14,128) -> (28,28,64)Conv2DTranspose(filters=64, kernel_size=5,strides=2,padding='same'),BatchNormalization(),Activation('relu'),# (28,28,64) -> (28,28,32)Conv2DTranspose(filters=32, kernel_size=5,strides=1,padding='same'),# (28,28,32) -> (28,28,1)BatchNormalization(),Activation('relu'),Conv2DTranspose(filters=1, kernel_size=5,strides=1,padding='same'),Activation('sigmoid') # 输出一个 (28,28,1) 的矩阵,每个像素值为0到1])image_rec = model(x) generator = Model([input_latent,input_label],image_rec,name='generator')return generator# In[2]: 构造判别网络
# 判别网络输入一个 (28,28,1) 的图片,输出一个0到1的数,0:假样本,1:真样本
def  build_discriminator(image_shape,label_shape):# image_shape=(28,28,1)# label_shape=(10,)# 由于有2个输入,所以使用函数式模型构造网络比较方便 # different!input_image = Input(image_shape, name='discriminator_input')input_label = Input(shape=label_shape, name='input_label')# 对10个分量的one-hot标签向量,经过全连接和reshape层得到图像大小的矩阵y = Dense(image_shape[0] * image_shape[1])(input_label) y = Reshape((image_shape[0], image_shape[1], 1))(y)#把图片数据与one-hot-vector拼接起来,这里是唯一与前面代码不同之处x = concatenate([input_image, y]) # shape=(28, 28, 2) # different!model = Sequential( [# (28,28,1) -> (14,14,32)LeakyReLU(alpha=0.2),Conv2D(32, kernel_size=5, strides=2, padding="same"), # (14,14,32) -> (7,7,64)LeakyReLU(alpha=0.2),Conv2D(64, kernel_size=5, strides=2, padding="same"), # (7,7,64) -> (4,4,128) LeakyReLU(alpha=0.2),Conv2D(128, kernel_size=5, strides=2, padding="same"), # (4,4,128) -> (4,4,256)LeakyReLU(alpha=0.2),Conv2D(256, kernel_size=5, strides=1, padding="same"), Flatten(),Dense(1),Activation('sigmoid') # 输出一个0到1的数,0:假样本,1:真样本])score =  model(x)discriminator = Model([input_image,input_label],score,name='discriminator') return discriminator# In[3]: 显示和保存生成器构造的一批图片(5*5=25张)
def plot_images(generator, noise_input, noise_class, show=False, step=0, model_name = ''):os.makedirs(model_name, exist_ok=True)filename = os.path.join(model_name, "%05d.png" % step)images = generator.predict([noise_input,noise_class])plt.figure(figsize = (5, 5))num_images = images.shape[0]rows = int(math.sqrt(noise_input.shape[0]))for i in range(num_images):plt.subplot(rows, rows, i + 1)image = np.reshape(images[i], [images.shape[1], images.shape[2]])plt.imshow(image, cmap= 'gray')plt.axis('off')plt.savefig(filename)if show:plt.show()else:plt.close('all')# In[4]: 构建判别网络 和 对抗网络(生成网络+判别网络),并设置训练参数
# 设置训练相关的参数
model_name = 'DCGAN_mnist_condition'
latent_dim = 100
batch_size = 64
train_steps = 10000 # 训练train_steps个batch,这里可更改为10000或5000
lr = 2e-4
decay = 6e-8latent_shape=(latent_dim,)# 读取数据,获取图片大小。分类别的GAN,需要标签。只是为了生成新样本,不需要测试样本进行对比
(x_train, y_train), (_, _) = fashion_mnist.load_data()
image_shape = (x_train.shape[1],x_train.shape[2],1)# 数据预处理,二维卷积操作的输入数据要求:[样本数,宽度,高度,通道数]
x_train = np.reshape(x_train, [-1, image_shape[0], image_shape[1], 1])
x_train = x_train.astype('float32') / 255  # 生成网络的输出的像素值是0到1之间的
y_train = utils.to_categorical(y_train)
label_shape = (y_train.shape[-1],) # different!# 编译判别网络
discriminator = build_discriminator(image_shape,label_shape) # different!
discriminator.compile(loss = 'binary_crossentropy', optimizer = RMSprop(lr=lr, decay=decay),metrics = ['accuracy'])
discriminator.summary()# 构建并编译对抗网络(生成网络+判别网络)
generator = build_generator(latent_shape,label_shape,image_shape) # different!
generator.summary()
discriminator.trainable = False # 训练生成者时识别者网络要保持不变input_latent = Input(latent_shape, name='adversarial_input')
input_label = Input(label_shape, name='input_label')
outputs = discriminator([generator([input_latent, input_label]), input_label])
adversarial = Model([input_latent, input_label], outputs, name='adversarial')
adversarial.compile(loss = 'binary_crossentropy',optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5),metrics = ['accuracy'])
adversarial.summary()# In[5]: 训练网络
'''
1) 先冻结生成网络,采样 真实图片 和 生成网络输出的假样本,训练判别网络,区分两类样本
2) 然后冻结判别网络,让生成网络构造图片输入给判别网络,训练生成网络,使得判别网络输出越接近1越好
'''save_interval = 500 # 训练每间隔500个batch把生成网络输出的图片保存下来# 构造给生成网络的一维随机向量,每隔500个batch训练后,都生成同样的这100个伪造样本,方便对比
noise_input = np.random.uniform(-1.0, 1.0, size = [10*10, latent_dim])
noise_class = np.eye(label_shape[0])[np.arange(0, 10*10) % label_shape[0]] # different!
train_size = x_train.shape[0]for i in range(train_steps):# 1. 先训练判别网络,将真实图片和伪造图片同时输入判别网络,让判别网络学会区分真假图片# 随机选取真实图片rand_indexes = np.random.randint(0, train_size, size = batch_size)real_images = x_train[rand_indexes]real_labels = y_train[rand_indexes]  # different!#让生成网络构造伪造图片noise = np.random.uniform(-1.0, 1.0, size = [batch_size, latent_dim])# 随机指定每个伪造样本的类别,并转化为one-hot向量, different!fake_labels = np.eye(label_shape[0])[np.random.choice(label_shape[0], batch_size)]fake_images = generator.predict([noise, fake_labels])# 合并真实图片和伪造图片x = np.concatenate((real_images, fake_images))#将真实图片对应的one-hot-vecotr和虚假图片对应的One-hot-vector连接起来, different!y_labels = np.concatenate((real_labels, fake_labels))y = np.ones([2 * batch_size, 1])#上半部分图片为真,下半部分图片为假y[batch_size:, :] = 0.0# 训练判别网络,用一个batch的真实图片和一个batch的伪造图片# 注意这里需要将图片及对应的one-hot-vector输入loss, acc = discriminator.train_on_batch([x, y_labels], y) # different!log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)# 2. 然后再训练生成网络:冻结判别网络,让生成网络构造图片输入给判别网络,使得输出越接近1越好noise = np.random.uniform(-1.0, 1.0, size = [batch_size, latent_dim])fake_labels = np.eye(label_shape[0])[np.random.choice(label_shape[0], batch_size)]y = np.ones([batch_size, 1]) # 注意此时假样本的标签为1,即要使得输出越接近1越好# 训练生成网络时需要使用到判别网络返回的结果,因此从两者连接后的对抗网络进行训练loss, acc = adversarial.train_on_batch([noise, fake_labels], y)log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)# 每隔save_interval次保存训练结果if (i+1) % save_interval == 0:print(log)if (i + 1) == train_steps:show = Trueelse:show = False#将生成者构造的图片绘制出来plot_images(generator, noise_input = noise_input,noise_class = noise_class, # different!show = show, step = i+1,model_name = model_name)# 保存生成网络的权重generator.save_weights(model_name + "_generator.h5")# In[6]: 读取训练好得权重,显示结果
noise_input = np.random.uniform(-1.0, 1.0, size = [10*10, latent_dim])
noise_class = np.eye(label_shape[0])[np.arange(0, 10*10) % label_shape[0]] # different!
generator.load_weights(model_name + "_generator.h5")
plot_images(generator, noise_input = noise_input,noise_class = noise_class, # different!show = True, step = 5000,model_name = model_name)

源码下载


学习产出:

  1. 把批量大小更改为5000和10000后,每500个间隔就把图片保存下来,训练需要的时间比较长,但效果比较好,能辨别出是fashion_mnist数据库中的10类图像;

人工智能--条件生成对抗网络相关推荐

  1. [深度学习-实践]条件生成对抗网络cGAN的例子-Tensorflow2.x Keras

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子 深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子 深度学习GAN(四)之c ...

  2. CGAN之条件生成对抗网络(Matlab)

    代码来源 代码全文: clear all; close all; clc; %% Conditional Generative Adversarial Network %% Load Data loa ...

  3. CGAN(条件生成-对抗网络)简述教程

    一直以来提醒自己说,你要坚持,要坚持,要坚持更新,坚持读paper,要是现在都坚持不下去,还有将近五年该怎么办呐. 所幸,我还是有底线的,在今天抽了一个下午,看了看论文,跑了跑代码,过程谈不上舒服,也 ...

  4. 深度学习故障诊断之-使用条件生成对抗网络CGAN生成泵流量信号

    开始填坑 MATLAB统计机器学习,深度学习,计算机视觉 - 哥廷根数学学派的文章 - 知乎 MATLAB统计机器学习,深度学习,计算机视觉 - 知乎 之前写过在使用深度学习对机械系统或电气系统进行故 ...

  5. GAN(生成对抗网络) and CGAN(条件生成对抗网络)

    前言 GAN(生成对抗网络)是2014年由Goodfellow大佬提出的一种深度生成模型,适用于无监督学习.监督学习.但是GAN进行生成时是不可控的,所以后来又有人提出可控的CGAN(条件生成对抗网络 ...

  6. 条件生成对抗网络(CGAN)

    记录一下 上资源:(github), 基于Pytorch的条件对抗生成网络 条件对抗生成网络和生成对抗网络的区别在于,条件对抗网络生成器和鉴别器额外输入了条件信息(以minist为例,就是额外输入了标 ...

  7. GAN生成对抗网络-CGAN原理与基本实现-条件生成对抗网络04

    CGAN - 条件GAN 原始GAN的缺点 代码实现 import tensorflow as tf from tensorflow import keras from tensorflow.kera ...

  8. 实战生成对抗网络[1]:简介

    引言 2016年3月,AlphaGO横空出世,击败人类顶尖职业棋手,引爆了人工智能热潮.之后AlphaGO Master和AlphaGO Zero更是无情的碾压人类棋手,人们终于认识到,人类迎来了可怕 ...

  9. 必读论文 | 生成对抗网络经典论文推荐10篇

    生成式对抗网络(Generative adversarial networks, GAN)是当前人工智能学界最为重要的研究热点之一.其突出的生成能力不仅可用于生成各类图像和自然语言数据,还启发和推动了 ...

最新文章

  1. Tomcat 源码阅读记录(1)
  2. 迁移学习:如何为您的机器学习问题选择正确的预训练模型
  3. Linux配置环境变量source时报错:export `=‘ not a valid identifier的一般原因
  4. linux脚本判断流程控制,Shell 脚本-6- 流程控制之判断分支
  5. Reporting Services 4: Web Service
  6. Asp.net 2.0生命周期
  7. linux tar命令压缩_Linux tar命令来压缩和提取文件
  8. 从用户洞察到数据应用 诸葛io让“增长”深入场景
  9. 招聘 | 清华大学计算机系知识工程实验室博士后
  10. PDF文件中失效链接修改
  11. php商品评价,商品评价,评价,商品详情,商品评价api,api,评价api,商品详情
  12. 玻尔兹曼机、深度信念网络、编码器等生成模型
  13. Soul App打造社交元宇宙,打破次元壁
  14. 浅析linux下的回收站以及U盘中的.Trash文件夹
  15. 一键即可实现图片翻译成中文,多国语言任意选
  16. 开关电源空载吱吱声_导致开关电源啸叫的六种情况及解决方法
  17. python的matplotlib画饼状图
  18. Hello CTP(五)——CTP仓位计算
  19. 网速慢、WIFI信号差?这样操作路由器就可以
  20. PowerShell 学习笔记:压缩、解压缩文件

热门文章

  1. 最详细的【微信小程序+阿里云Web服务】开发部署指引(十二):开发小程序用户反馈功能
  2. 按钮点击事件的实现方式---原生js
  3. pytorch中AdaGrad优化器源码解读
  4. 单元测试-JMockit
  5. 可重入锁(ReentrantLock为例)
  6. 作业管理系统系统流程图
  7. 合格率计算(只作新手参考)
  8. LeetCode 每日一题——535. TinyURL 的加密与解密
  9. 17、 数组和字符串的应用 输入一行字符,统计其中有多少个单词,单词之间用空格分隔开
  10. 互联网行业女孩子做什么比较好?