keras搭建wgan-gp和wgan-div,可生成图像
文章目录
- 一、前言
- 二、wgan-Gp原理
- 三、wgan-Div原理
- 四、代码结构设计过程
- 4.1.生成tfrecord
- 4.2.设计残差网络结构
- 搭建resBlock模块
- 4.3 搭建generate网络:
- 4.4.搭建discriminator网络:
- 4.5.定义网络的损失函数:
- a.首先得到判别网络和生成网络:
- b.定义训练判别网络gan_train_d:
- c.定义训练判别网络gan_train_g:
- 4.6.定义训练网络循环体:
- 4.7.其他函数:
- plot()
- 五、代码及训练结果
- 5.1.运行代码
- 5.2.使用wgan-Gp生成训练过程(名字表示训练次数):
- 5.3.使用wgan-div生成训练过程(使用leakyReLU函数):
- 六、总结
- 参考代码和文献:
一、前言
1.最早的DCGAN网络损失函数是采用交叉熵的形式:
但其判别网络的目标函数存在着较大的缺陷。当判别网络能力过强,即能将生成器和真实数据分辨出时,这时候生成图像和真实图像之间没有交叉,两个分布之间的JS散度恒为log2:
此时对于生成网络来说目标函数关于参数的梯度为0,即出现梯度消失,这时候判别器无法指导生成器向固定方向更新,生成器的生成图像几乎一致,判别器的loss值收敛到0.
出现梯度消失大概率是由于判别网络太强,强过生成网络,导致真实数据分布与生成数据分布没有重叠。
同时也可能产生另一种情况,就是模型坍塌(model collapse),表现为生成器生成的图像没有任何显示意义,仅仅只是为了拟合判别器,或者生成单一模式的图像。
2.所以为了解决DCGAN出现的问题,出现了wgan:Wasserstein GANs,wgan取消了log函数,采用新的w距离来描述真实和生成数据:
另外采用截断的方式将网络参数截断到[-0.01, 0.01],但这种方法使得神经网络变成了二值网络,如图:
这就降低了整个网络的拟合能力,另外在强行截断的时候,很可能会导致出现梯度爆炸或者梯度消失。
于是乎在后来出现了以wgan为基础的多种算法,大部分都是更改损失函数来缓解出现的梯度消失和模型坍塌的问题,例如wgan-gp,wgan-div等等:
tip:图来自知乎作者 ‘桑龙’
下面将介绍gp以及实现的代码:
二、wgan-Gp原理
原论文:Improved Training of Wasserstein GANs
https://arxiv.org/pdf/1704.00028.pdf
目标函数:
其中
算法过程:
这里,我没有像算法中那样,先将判别器训练5次后再训练生成器,因为这样在我的网络中会让判别器太强,所以一开始设置的就是1:1的训练。
三、wgan-Div原理
原论文:Wasserstein Divergence for GANs
https://arxiv.org/pdf/1712.01026.pdf
目标函数:
对于判别器和生成器而言:
其中k=2,p=6
算法过程:
其网络结构使用了resBlock
参考论文:Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
残差网络结构主要是针对于梯度消失问题而设计,在图像分类问题上表现非常好。
这里只利用了卷积块,当然也可加上恒等块来加深加宽网络。
四、代码结构设计过程
4.1.生成tfrecord
这种格式的数据对内存友好,读取速度快,同时利于转移、保存;
def create_tfrecords():if os.path.exists(tfrecords_path):return 0if(FLAGS.data == None):print('the data is none,use: python gan.py --data []')os._exit(0)writer_train= tf.python_io.TFRecordWriter(tfrecords_path)object_path = FLAGS.datatotal = os.listdir(object_path)num = len(total)num_i = 1value = 0print('-----------------------------making dataset tfrecord,waiting--------------------------')for index in total:img_path=os.path.join(object_path,index)img=Image.open(img_path)img=img.resize((dim,dim))img_raw=img.tobytes()'''it is on my datasets, please change these codes! '''example = tf.train.Example(features=tf.train.Features(feature={'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))writer_train.write(example.SerializeToString()) #序列化为字符串sys.stdout.write('--------%.4f%%-----'%(num_i/float(num)*100))sys.stdout.write('\r')sys.stdout.flush()num_i = num_i +1print('-------------------------------datasets has completed-----------------------------------')global data_numdata_num = num_iwriter_train.close()
4.2.设计残差网络结构
参考博客:Keras入门与残差网络的搭建
搭建resBlock模块
在这里我使用的激活函数时LeakyReLU(),经过验证发现LeakyReLU()的效果相比于relu来说要好一点点。
如下图所示:我们在主通道中设计三个卷积、BN层,对shortcut进行卷积和归一化处理,主通道和shortcut都需要进行维度的改变:对于判别网络来说是下采样,对于生成网络是上采样,利用反卷积(Conv2DTranspose)或者UpSampling2D+Conv2D实现:
定义上采样和下采样的函数:
def convolutional2D(x,num_filters,kernel_size,resampling,strides=2):if resampling is 'up':x = keras.layers.UpSampling2D()(x)x = keras.layers.Conv2D(num_filters, kernel_size=kernel_size, strides=1, padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#x = keras.layers.Conv2DTranspose(num_filters,kernel_size=kernel_size, strides=2, padding='same',# kernel_initializer=keras.initializers.RandomNormal())(x)elif resampling is 'down':x = keras.layers.Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)return x
定义resBlock:
def ResBlock(x, num_filters, resampling,strides=2):#F1,F2,F3 = num_filtersX_shortcut = x#//up or downx = convolutional2D(x,num_filters,kernel_size=(3,3),resampling=resampling,strides=strides)#//BN_relux = keras.layers.BatchNormalization()(x)#x = keras.layers.Activation('relu')(x)x = keras.layers.LeakyReLU()(x)#//cov2dx = keras.layers.Conv2D(num_filters, kernel_size=(3,3), strides=1,padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#//BN_relux = keras.layers.BatchNormalization()(x)#x = keras.layers.Activation('relu')(x)x = keras.layers.LeakyReLU()(x)#//cov2dx = keras.layers.Conv2D(num_filters, kernel_size=(3,3), strides=1,padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#//BNx = keras.layers.BatchNormalization()(x)#//add_shortcutX_shortcut = convolutional2D(X_shortcut,num_filters,kernel_size=(1,1),resampling=resampling,strides=strides)X_shortcut = keras.layers.BatchNormalization()(X_shortcut)X_add = keras.layers.Add()([x,X_shortcut])#X_add = keras.layers.Activation('relu')(X_add)X_add = keras.layers.LeakyReLU()(X_add)return X_add
可以注意到的是,在主通道中的卷积核用了3×3尺寸的,而非源论文中1×1:
这样做是因为如果使用1×1的卷积核,在只使用4个ResBlock时的判别器和生成器的参数量仅仅一百多万个,这个数量级是很难让判别器具有很好的拟合能力的,当使用3×3卷积核时,参数量可以提高到一千多万。当然也可以使用IdentifyBlock来加深加宽网络。
原论文wgan-div的结构:
4.3 搭建generate网络:
def generate(resampling='up'):nosie = keras.layers.Input(shape=(noise_dim,))g = keras.layers.Dense(512*4*4)(nosie)g = keras.layers.Reshape((4,4,512))(g)#//BN_relug = keras.layers.BatchNormalization()(g)#g = keras.layers.Activation('relu')(g)g = keras.layers.LeakyReLU()(g)#4*4*512g = ResBlock(g,num_filters=512,resampling=resampling)#8*8*512g = ResBlock(g,num_filters=256,resampling=resampling)#16*16*256g = ResBlock(g,num_filters=128,resampling=resampling)#32*32*128g = ResBlock(g,num_filters=64,resampling=resampling)#64*64*64g = keras.layers.Conv2D(3, kernel_size=(3,3), strides=1, padding='same',kernel_initializer=keras.initializers.RandomNormal())(g)#64*64*3g_out = keras.layers.Activation('tanh')(g)g_model = keras.Model(nosie,g_out)return g_model
4.4.搭建discriminator网络:
def discriminator(resampling='down'):real_in = keras.layers.Input(shape=(dim, dim, 3))d = keras.layers.Conv2D(64, kernel_size=(3,3), padding='same',strides=1,kernel_initializer=keras.initializers.RandomNormal())(real_in)#//BN_relud = keras.layers.BatchNormalization()(d)#d = keras.layers.Activation('relu')(d)d = keras.layers.LeakyReLU()(d)#64*64*64d = ResBlock(d,num_filters=128,resampling=resampling) #32*32*128d = ResBlock(d,num_filters=256,resampling=resampling)#16*16*256d = ResBlock(d,num_filters=512,resampling=resampling)#8*8*512d = ResBlock(d,num_filters=512,resampling=resampling)#4*4*512'''GlobalAveragePooling :it can replace the full connection layeryou can use the Dense to test the network'''d = keras.layers.GlobalAveragePooling2D()(d)d_out = keras.layers.Dense(1,use_bias = False)(d)d_model = keras.Model(real_in,d_out)return d_model
4.5.定义网络的损失函数:
a.首先得到判别网络和生成网络:
#------------------------------#define the generate model *#------------------------------generate_model = generate()#--------------------------------#define the discriminator model *#--------------------------------discriminator_model = discriminator()
b.定义训练判别网络gan_train_d:
方法:
1.定义三个输入(Input):
- 真实图像数据
- 生成数据需要的噪音
- 混合真假数据需要的分布数
2.设置生成模型不可训练:
generate_model.trainable = False
3.利用Input得到:
- D_fake_img
- D_fake_score
- D_real_score
4.混合真实和生成数据:
x_ = (1.-u)Dx_real_img+uD_fake_img
5.根据论文算法设计损失函数:
wgan-div:
wgan-gp:
看到这里可能有个疑惑,感觉两篇论文的损失函数更新方向是反的。实际上,无论是哪种方法,wgan-div还是wgan-gp,对于判别器和生成器来说,必须要是对抗更新的,生成器要向着生成数据和真实数据之间的distance为0的方向更新,判别器向着distance变大的方向更新,即区分两个数据堆。
#//#-------------------------------------------------------------------# train the Discriminator |#-------------------------------------------------------------------#//'''you need to redefined the Input rather than use the Input previous'''#Input paraDx_real_img = keras.layers.Input(shape=(dim, dim, 3))Dz_noise = keras.layers.Input(shape=(noise_dim,))D_uniform = keras.layers.Input(shape=(1,1,1))#set the trainable generate_model.trainable = False#get the scoreD_fake_img = generate_model(Dz_noise)D_fake_score = discriminator_model(D_fake_img)D_real_score = discriminator_model(Dx_real_img)#train netgan_train_d = keras.Model([Dx_real_img, Dz_noise, D_uniform],[D_real_score,D_fake_score])#set the loss function according to the algorithmk = 2p = 6u = D_uniform#then, get a new input consist from fake and realx_ = (1.-u)*Dx_real_img+u*D_fake_img#//#-------------------------------------------------------------------# wgan div loss function |# arxiv.org/pdf/1712.01026.pdf |#-------------------------------------------------------------------#//if FLAGS.type == 'div':gradients = K.gradients(discriminator_model(x_), [x_])[0]grad_norm = K.sqrt(K.sum(gradients ** 2, axis=[1, 2, 3])) ** pgrad_penalty = k * K.mean(grad_norm)discriminator_loss = K.mean(D_real_score - D_fake_score)#//#-------------------------------------------------------------------# wgan gp loss function |# arxiv.org/pdf/1704.00028.pdf |#-------------------------------------------------------------------#//if FLAGS.type == 'gp':gradients = K.gradients(discriminator_model(x_), [x_])[0]grad_norm = K.sqrt(K.sum(gradients ** 2, axis=[1, 2, 3]))grad_norm = K.square(1-grad_norm)grad_penalty = 10*K.mean(grad_norm)discriminator_loss = K.mean(D_fake_score-D_real_score)#loss functiondiscriminator_loss_all = grad_penalty+ discriminator_loss #compile the modelgan_train_d.add_loss(discriminator_loss_all) #mingan_train_d.compile(optimizer=keras.optimizers.Adam(learning_rate, 0.5))gan_train_d.metrics_names.append('DistanceFromRealAndFake')gan_train_d.metrics_tensors.append(-discriminator_loss) #max
c.定义训练判别网络gan_train_g:
方法:
1.定义一个输入(Input):
- 生成数据需要的噪音
2.设置生成模型不可训练:
discriminator_model.trainable = False
generate_model.trainable = True
3.利用Input得到:
- G_fake_img
- G_fake_score
4.损失函数:
if FLAGS.type == ‘div’:
generate_loss = K.mean(G_fake_score)
if FLAGS.type == ‘gp’:
generate_loss = -K.mean(G_fake_score)#min this value
#//#-------------------------------------------------------------------# train the Generator |#-------------------------------------------------------------------#//#Input paraGz_nosie = keras.layers.Input(shape=(noise_dim,))#set the trainable discriminator_model.trainable = Falsegenerate_model.trainable = True#get the scoreG_fake_img = generate_model(Gz_nosie)G_fake_score = discriminator_model(G_fake_img)#train netgan_train_g = keras.Model(Gz_nosie,G_fake_score)#loss functionif FLAGS.type == 'div':generate_loss = K.mean(G_fake_score)if FLAGS.type == 'gp':generate_loss = -K.mean(G_fake_score)#min this value#compile the modelgan_train_g.add_loss(generate_loss) #mingan_train_g.compile(optimizer=keras.optimizers.Adam(learning_rate, 0.5))
4.6.定义训练网络循环体:
循环体内的主体:
首先输入数据,噪音,随机数;然后分别训练discriminator和generator;
#datasetstrain_datas_ = sess.run(train_datas)'''if the datasets' shape is not batch_size'''if train_datas_[0].shape[0] != batch_size:sess.run(iter.initializer)train_datas_ = sess.run(train_datas)z_noise = np.random.normal(size=batch_size*noise_dim)\.reshape([batch_size,noise_dim])u_niform = np.random.uniform(low=0.0,high=1.0,size=(batch_size,1,1,1))#-----------------------------------------# phase 1 - training the discriminator |#-----------------------------------------#\\for step_critic in range(n_critic):d_loss,distance = gan_train_d.train_on_batch([train_datas_[0],z_noise,u_niform],None)#-----------------------------------------# phase 2 - training the generator |#-----------------------------------------#\\for step_generate in range(n_generate):g_loss = gan_train_g.train_on_batch(z_noise,None)
4.7.其他函数:
plot()
主要输出损失函数的值变化过程并保存
def plot(history):history = np.array(history)plt.ion()plt.figure(figsize=(12,4))plt.title('Train History')plt.plot(history[:,0],history[:,1])plt.ylabel('loss')plt.plot(history[:,0],history[:,2])plt.plot(history[:,0],history[:,3])plt.xlabel('step')plt.legend(['d_loss','distance','g_loss'],loc='upper left')plt.savefig(os.path.join(model_path,'history.png'))plt.pause(1)plt.close()
五、代码及训练结果
5.1.运行代码
运行方法:
利用自己准备的数据集可直接运行
不要在乎我的蹩脚英语注释,哈哈
python gan.py --data [image path] --type ['gp' or 'div']
gan.py
#! -*- coding: utf-8 -*-
'''Designer: zyluse :python gan.py --data [image path] --type ['gp' or 'div']
'''
import time
import numpy as np
import tensorflow as tf
import keras
from keras import backend as K
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2
import sysnoise_dim = 128
dim = 64
epochs = 1000
batch_size = 64
data_num = 12500
learning_rate = 2e-4
save_step = 300
n_critic = 1
n_generate = 1
tfrecords_path = 'data/train.tfrecords'
save_path = 'image/'
model_path = 'model/'
#log_path = 'log/'tf.app.flags.DEFINE_string('data', 'None', 'where the datas?.')
tf.app.flags.DEFINE_string('type', 'gp', 'what is the type?.')
FLAGS = tf.app.flags.FLAGSif not os.path.exists('data'):os.mkdir('data')
if not os.path.exists('image'):os.mkdir('image')
if not os.path.exists('data'):os.mkdir('data')
if not os.path.exists('model'):os.mkdir('model')
#if not os.path.exists('log'):
# os.mkdir('log')#-------------------------------------------------------------------
# create the tfrecords |
#------------------------------------------------------------------- def create_tfrecords():if os.path.exists(tfrecords_path):return 0if(FLAGS.data == None):print('the data is none,use: python gan.py --data []')os._exit(0)writer_train= tf.python_io.TFRecordWriter(tfrecords_path)object_path = FLAGS.datatotal = os.listdir(object_path)num = len(total)num_i = 1value = 0print('-----------------------------making dataset tfrecord,waiting--------------------------')for index in total:img_path=os.path.join(object_path,index)img=Image.open(img_path)img=img.resize((dim,dim))img_raw=img.tobytes()'''it is on my datasets, please change these codes! '''example = tf.train.Example(features=tf.train.Features(feature={'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))writer_train.write(example.SerializeToString()) #序列化为字符串sys.stdout.write('--------%.4f%%-----'%(num_i/float(num)*100))sys.stdout.write('\r')sys.stdout.flush()num_i = num_i +1print('-------------------------------datasets has completed-----------------------------------')global data_numdata_num = num_iwriter_train.close()#-------------------------------------------------------------------
# datatfrecords |
#-------------------------------------------------------------------
def load_image(serialized_example): features={'label': tf.io.FixedLenFeature([], tf.int64),'img_raw' : tf.io.FixedLenFeature([], tf.string)}parsed_example = tf.io.parse_example(serialized_example,features)image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)image = tf.reshape(image,[-1,dim,dim,3])image = tf.cast(image,tf.float32)*(1./255)label = tf.cast(parsed_example['label'], tf.int32)label = tf.reshape(label,[-1,1])return image,labeldef dataset_tfrecords(tfrecords_path,use_keras_fit=True): #是否使用tf.kerasif use_keras_fit:epochs_data = 1else:epochs_data = epochsdataset = tf.data.TFRecordDataset([tfrecords_path])'''这个可以有多个组成[tfrecords_name1,tfrecords_name2,...],可以用os.listdir(tfrecords_path):'''dataset = dataset\.repeat(epochs_data)\.shuffle(1000)\.batch(batch_size)\.map(load_image,num_parallel_calls = 8)#注意一定要将shuffle放在batch前 iter = dataset.make_initializable_iterator()#make_one_shot_iterator()train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值return train_datas,iter#-------------------------------------------------------------------
# define resBlock |
#------------------------------------------------------------------- def convolutional2D(x,num_filters,kernel_size,resampling,strides=2):if resampling is 'up':x = keras.layers.UpSampling2D()(x)x = keras.layers.Conv2D(num_filters, kernel_size=kernel_size, strides=1, padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#x = keras.layers.Conv2DTranspose(num_filters,kernel_size=kernel_size, strides=2, padding='same',# kernel_initializer=keras.initializers.RandomNormal())(x)elif resampling is 'down':x = keras.layers.Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)return xdef ResBlock(x, num_filters, resampling,strides=2):'''1.如果训练的数据量较少,则需要将BN的参数momentum减少,减少到0.9甚至是0.8(默认0.99)即 : BatchNormalization(momentum=0.8)训练数据大时可使用默认值0.992.另外也可以使用keras.layers.LeakyReLU()函数来代替relu函数,使得负值段有一定梯度可以通过设置alpha参数来改变负值段的斜率,alpha=0.2relu的思想更接近于生物的神经元,卷积后relu处理会将数据映射到正值,负值段梯度为零'''#F1,F2,F3 = num_filtersX_shortcut = x#//up or downx = convolutional2D(x,num_filters,kernel_size=(3,3),resampling=resampling,strides=strides)#//BN_relux = keras.layers.BatchNormalization()(x)#x = keras.layers.Activation('relu')(x)x = keras.layers.LeakyReLU()(x)#//cov2dx = keras.layers.Conv2D(num_filters, kernel_size=(3,3), strides=1,padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#//BN_relux = keras.layers.BatchNormalization()(x)#x = keras.layers.Activation('relu')(x)x = keras.layers.LeakyReLU()(x)#//cov2dx = keras.layers.Conv2D(num_filters, kernel_size=(3,3), strides=1,padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#//BN_relux = keras.layers.BatchNormalization()(x)#//add_shortcutX_shortcut = convolutional2D(X_shortcut,num_filters,kernel_size=(1,1),resampling=resampling,strides=strides)X_shortcut = keras.layers.BatchNormalization()(X_shortcut)X_add = keras.layers.Add()([x,X_shortcut])#X_add = keras.layers.Activation('relu')(X_add)X_add = keras.layers.LeakyReLU()(X_add)return X_adddef IdentifyBlock(x, num_filters):#F1,F2,F3 = num_filtersX_shortcut = x#//cov2dx = keras.layers.Conv2D(num_filters//4, kernel_size=(1,1), strides=1,padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#//BN_relux = keras.layers.BatchNormalization()(x)x = keras.layers.Activation('relu')(x)#//cov2dx = keras.layers.Conv2D(num_filters//4, kernel_size=(1,1), strides=1,padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#//BN_relux = keras.layers.BatchNormalization()(x)x = keras.layers.Activation('relu')(x)#//cov2dx = keras.layers.Conv2D(num_filters, kernel_size=(1,1), strides=1,padding='same',kernel_initializer=keras.initializers.RandomNormal())(x)#//BN_relux = keras.layers.BatchNormalization()(x)#//add_shortcutX_add = keras.layers.Add()([x,X_shortcut])X_add = keras.layers.Activation('relu')(X_add)return X_add#-------------------------------------------------------------------
# define generator |
#------------------------------------------------------------------- def generate(resampling='up'):nosie = keras.layers.Input(shape=(noise_dim,))g = keras.layers.Dense(512*4*4)(nosie)g = keras.layers.Reshape((4,4,512))(g)#//BN_relug = keras.layers.BatchNormalization()(g)#g = keras.layers.Activation('relu')(g)g = keras.layers.LeakyReLU()(g)#4*4*512g = ResBlock(g,num_filters=512,resampling=resampling)#8*8*512g = ResBlock(g,num_filters=256,resampling=resampling)#16*16*256g = ResBlock(g,num_filters=128,resampling=resampling)#32*32*128g = ResBlock(g,num_filters=64,resampling=resampling)#64*64*64g = keras.layers.Conv2D(3, kernel_size=(3,3), strides=1, padding='same',kernel_initializer=keras.initializers.RandomNormal())(g)#64*64*3g_out = keras.layers.Activation('tanh')(g)g_model = keras.Model(nosie,g_out)return g_model#-------------------------------------------------------------------
# define discriminator |
#------------------------------------------------------------------- def discriminator(resampling='down'):real_in = keras.layers.Input(shape=(dim, dim, 3))d = keras.layers.Conv2D(64, kernel_size=(3,3), padding='same',strides=1,kernel_initializer=keras.initializers.RandomNormal())(real_in)#//BN_relud = keras.layers.BatchNormalization()(d)#d = keras.layers.Activation('relu')(d)d = keras.layers.LeakyReLU()(d)#64*64*64d = ResBlock(d,num_filters=128,resampling=resampling) #32*32*128d = ResBlock(d,num_filters=256,resampling=resampling)#16*16*256d = ResBlock(d,num_filters=512,resampling=resampling)#8*8*512d = ResBlock(d,num_filters=512,resampling=resampling)#4*4*512'''GlobalAveragePooling :it can replace the full connection layeryou can use the Dense to test the network'''d = keras.layers.GlobalAveragePooling2D()(d)d_out = keras.layers.Dense(1)(d)d_model = keras.Model(real_in,d_out)return d_model#-------------------------------------------------------------------
# show process of trian |
#-------------------------------------------------------------------
def plot(history):history = np.array(history)plt.ion()plt.figure(figsize=(12,4))plt.title('Train History')plt.plot(history[:,0],history[:,1])plt.ylabel('loss')plt.plot(history[:,0],history[:,2])plt.plot(history[:,0],history[:,3])plt.xlabel('step')plt.legend(['d_loss','distance','g_loss'],loc='upper left')plt.savefig(os.path.join(model_path,'history.png'))plt.pause(1)plt.close()def main():#------------------------------#define the generate model *#------------------------------generate_model = generate()#--------------------------------#define the discriminator model *#--------------------------------discriminator_model = discriminator()#cat the networkdiscriminator_model.summary()generate_model.summary()#//#-------------------------------------------------------------------# train the Discriminator |#-------------------------------------------------------------------#//'''you need to redefined the Input rather than use the Input previous'''#Input paraDx_real_img = keras.layers.Input(shape=(dim, dim, 3))Dz_noise = keras.layers.Input(shape=(noise_dim,))D_uniform = keras.layers.Input(shape=(1,1,1))#set the trainable generate_model.trainable = False#get the scoreD_fake_img = generate_model(Dz_noise)D_fake_score = discriminator_model(D_fake_img)D_real_score = discriminator_model(Dx_real_img)#train netgan_train_d = keras.Model([Dx_real_img, Dz_noise, D_uniform],[D_real_score,D_fake_score])#set the loss function according to the algorithmk = 2p = 6u = D_uniform#then, get a new input consist from fake and realx_ = (1.-u)*Dx_real_img+u*D_fake_img#//#-------------------------------------------------------------------# wgan div loss function |# n_critic = 1 |# arxiv.org/pdf/1712.01026.pdf |#-------------------------------------------------------------------#//if FLAGS.type == 'div':gradients = K.gradients(discriminator_model(x_), [x_])[0]grad_norm = K.sqrt(K.sum(gradients ** 2, axis=[1, 2, 3])) ** pgrad_penalty = k * K.mean(grad_norm)discriminator_loss = K.mean(D_real_score - D_fake_score)#//#-------------------------------------------------------------------# wgan gp loss function |# n_critic = 5 |# arxiv.org/pdf/1704.00028.pdf |#-------------------------------------------------------------------#//if FLAGS.type == 'gp':gradients = K.gradients(discriminator_model(x_), [x_])[0]grad_norm = K.sqrt(K.sum(gradients ** 2, axis=[1, 2, 3]))grad_norm = K.square(1-grad_norm)grad_penalty = 10*K.mean(grad_norm)discriminator_loss = K.mean(D_fake_score-D_real_score)#loss functiondiscriminator_loss_all = grad_penalty+ discriminator_loss #compile the modelgan_train_d.add_loss(discriminator_loss_all) #mingan_train_d.compile(optimizer=keras.optimizers.Adam(learning_rate, 0.5))gan_train_d.metrics_names.append('DistanceFromRealAndFake')gan_train_d.metrics_tensors.append(-discriminator_loss) #max#//#-------------------------------------------------------------------# train the Generator |#-------------------------------------------------------------------#//#Input paraGz_nosie = keras.layers.Input(shape=(noise_dim,))#set the trainable discriminator_model.trainable = Falsegenerate_model.trainable = True#get the scoreG_fake_img = generate_model(Gz_nosie)G_fake_score = discriminator_model(G_fake_img)#train netgan_train_g = keras.Model(Gz_nosie,G_fake_score)#loss functionif FLAGS.type == 'div':generate_loss = K.mean(G_fake_score)if FLAGS.type == 'gp':generate_loss = -K.mean(G_fake_score)#min this value#compile the modelgan_train_g.add_loss(generate_loss) #mingan_train_g.compile(optimizer=keras.optimizers.Adam(learning_rate, 0.5))#\\#---------------------------------------------------------------------#\\#cat the networkgan_train_d.summary()gan_train_g.summary()#creat the session, get the dataset from tfrecordssess = tf.Session()train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)sess.run(iter.initializer)print("-----------------------------------------start---------------------------------------")#continueif os.path.exists(os.path.join(model_path,'gan.weights')):gan_train_g.load_weights(os.path.join(model_path,'gan.weights'))if os.path.exists(os.path.join(model_path,'history.npy')):history = np.load(os.path.join(model_path,'./history.npy'), allow_pickle=True).tolist()#read the last data use -1 index,and use 0 to read the first data#\\last_iter = int(history[-1][0])print('Find the npy file, the last save iter:%d' % (last_iter))else:history = []last_iter = -1else:print('There is no .npy file, creating a new file---------')history = []last_iter = -1#state the global vars#you can change them in this function body, so that it makes the training stable#\\global n_criticglobal n_generate#the loop body#\\for step in range(last_iter+1,int(epochs*data_num/batch_size+1)):try:#get the timestart_time = time.time()#datasetstrain_datas_ = sess.run(train_datas)'''if the datasets' shape is not batch_size'''if train_datas_[0].shape[0] != batch_size:sess.run(iter.initializer)train_datas_ = sess.run(train_datas)z_noise = np.random.normal(size=batch_size*noise_dim)\.reshape([batch_size,noise_dim])u_niform = np.random.uniform(low=0.0,high=1.0,size=(batch_size,1,1,1))#-----------------------------------------# phase 1 - training the discriminator |#-----------------------------------------#\\for step_critic in range(n_critic):d_loss,distance = gan_train_d.train_on_batch([train_datas_[0],z_noise,u_niform],None)#-----------------------------------------# phase 2 - training the generator |#-----------------------------------------#\\for step_generate in range(n_generate):g_loss = gan_train_g.train_on_batch(z_noise,None)#get the time duration = time.time()-start_time#-----------------------------------------# print the loss |#-----------------------------------------if step % 5 == 0:print("The step is %s,d_loss:%s,distance:%s,g_loss:%s, "%(step,d_loss,distance,g_loss),end=' ')print('%.2f s/step'%(duration))#-----------------------------------------# plot the train history |#-----------------------------------------#\\if step % 5 == 0 :history.append([step, d_loss,distance, g_loss])#-----------------------------------------# save the model_weights |#-----------------------------------------#\\if step % save_step == 0 and step != 0:# save the train stepsnp.save(os.path.join(model_path,'./history.npy'), history)gan_train_g.save_weights(os.path.join(model_path,'gan.weights'))plot(history)#-----------------------------------------# save the image of generate |#-----------------------------------------#\\if step % 50 == 0 and step != 0:noise_test = np.random.normal(size=[1,noise_dim])noise_test = np.cast[np.float32](noise_test)fake_image = generate_model.predict(noise_test,steps=1)'''复原图像1.乘以255后需要映射成uint8的类型2.也可以保持[0,1]的float32类型,依然可以直接输出'''arr_img = np.array([fake_image],np.float32).reshape([dim,dim,3])*255arr_img = np.cast[np.uint8](arr_img)#保存为tfrecords用的是PIL.Image,即打开为RGB,所以在用cv显示时需要转换为BGRarr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)cv2.imwrite(save_path+str(step)+'.jpg',arr_img)#cv2.imshow('fake image',arr_img)#cv2.waitKey(1500)#show the fake image 1.5s#cv2.destroyAllWindows()except tf.errors.OutOfRangeError: sess.run(iter.initializer)plot(history) #summary_writer.close()create_tfrecords()
main()
5.2.使用wgan-Gp生成训练过程(名字表示训练次数):
5.3.使用wgan-div生成训练过程(使用leakyReLU函数):
只训练了三万多次,年轻人我们点到为止。。。
训练过程:(保证distance在0附近)
六、总结
要想训练好GAN网络是一件很困难的事,因为要保证GAN网络的稳定性,不能让生成器太强也不能让判别器太强,现阶段出现的各种方法都主要在解决训练稳定、梯度消失和模型崩溃的问题;实际上,类似于WGAN-GP,WGAN-DIV等修改loss损失函数的方法在我看来相比于调参来说(即超参数,网络结构),其实效果不太大。但对于大部分gan网络来说,我认为需要注意几个问题:
1.gan网络中需要让判别器占据主导地位,稍强于生成网络;
2.训练的次数和batch_size同样影响着生成器最终的效果;
3.两者的学习率不一定要相等,相等的学习率不一定就能让生成网络和判别网络同步稳定进行更新,必要时可以尝试设置不同的学习率;
4.优化目标函数最终的目的是让生成数据和真实数据之间的’距离’(这个距离是广义笼统的距离)无限逼近0,我在代码中也设置了这个观测值,可以实时观察动向,如果distance越来越远离0,则表示判别网络discriminator太强,或者生成器网络能力不够等,这个时候就需要调节参数重新训练;
5.一般来说,判别网络更容易训练,生成网络则比较难调整;所以有的时候比如DCGAN网络就容易出现判别网络的损失函数的值先到达0并且一直为0的情况,这时候可以减小判别网络的学习率,减少判别网络结构等方法来调整;
6.梯度消失问题,可以利用基于wgan的参考算法来实现,利用ResNet网络以及使用LeakyReLU激活函数等;
7.利用keras搭建的网络模型可以观测到搭建网络的参数量,比如一般来说64×64图像的需要百万级别以上的参数量;一般而言,参数量越大,网络结构越深,拟合能力越强,所以对于具有相似结构的判别和生成网络,判别网络的参数量应该稍多于生成网络;
8.数据集也会影响训练效果;数据集之间也会存在着不同的特征差异,如果数据集内间‘特征距离’较小,整个数据集间的特征重合度高(例如人脸数据),那么训练的生成器的效果将会更好;相反,如果数据集内间‘特征距离’较大,那么对于判别网络来说都是巨大的挑战,因为数据集相对分散,生成网络生成的图像有时候不尽人意…可以在设计判别网络的时候考虑数据集的差异性分布,适当加深加宽网络,增强判别网络的能力;
9.多看论文!GAN网络的研究到现在已经出现了很大的进步;我们研究生深度学习的授课老师来自智能与计算学部,主要方向是GAN网络和计算机视觉,老师也讲了很多他本人在GAN网络领域上的重大进展(确实是大牛,毕竟他和他的学生是和杨立昆(Yann LeCun)合过影的,哈哈!!);GAN网络的研究目前还是非常火热的,经过几年的发展,GAN网络出现了很多种结构和算法,取得了不错的成就,这个方向我认为可以深入发展,我自己的方向就是医疗手术机器人,所以比如可以应用在医疗图像等领域,生成融合分割等等…
学渣一枚,个人总结仅供参考。
迁移学习与GAN结合的医学图像融合模型
参考代码和文献:
https://github.com/ABaoccy/wgan-div/blob/master/wgan_div.py
https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py
https://github.com/bojone/gan/blob/master/keras/wgan_div_celeba.py
https://github.com/eriklindernoren/Keras-GAN/blob/master/wgan/wgan.py
1.Deep Residual Learning for Image Recognition
2.Wasserstein Divergence for GANs
3.Wasserstein GAN
4.Improved Training of Wasserstein GANs
5.深度残差收缩网络 Deep Residual Shrinkage Networks for Fault Diagnosis
Keras入门与残差网络的搭建
keras搭建wgan-gp和wgan-div,可生成图像相关推荐
- (大佬)睿智的目标检测13——Keras搭建mtcnn人脸检测平台
原文链接:https://blog.csdn.net/weixin_44791964/article/details/103530206 睿智的目标检测13--Keras搭建mtcnn人脸检测平台 学 ...
- 掌声送给TensorFlow 2.0!用Keras搭建一个CNN | 入门教程
作者 | Himanshu Rawlani 译者 | Monanfei,责编 | 琥珀 出品 | AI科技大本营(id:rgznai100) 2019 年 3 月 6 日,谷歌在 TensorFlow ...
- 基于Keras搭建cifar10数据集训练预测Pipeline
基于Keras搭建cifar10数据集训练预测Pipeline 钢笔先生关注 0.5412019.01.17 22:52:05字数 227阅读 500 Pipeline 本次训练模型的数据直接使用Ke ...
- 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)
不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN) 生成对抗网络(Generative Adversarial Networks,GAN)最早由 Ian Goodfello ...
- 使用tf.keras搭建mnist手写数字识别网络
使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...
- TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络
TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...
- 神经网络densecnn_对比学习用 Keras 搭建 CNN RNN 等常用神经网络
参考: 各模型完整代码 周莫烦的教学网站 这个网站上有很多机器学习相关的教学视频,推荐上去学习学习. Keras 是一个兼容 Theano 和 Tensorflow 的神经网络高级包, 用他来组件一个 ...
- 教你如何用Keras搭建分类神经网络
摘要:本文主要通过Keras实现了一个分类学习的案例,并详细介绍了MNIST手写体识别数据集. 本文分享自华为云社区<[Python人工智能] 十七.Keras搭建分类神经网络及MNIST数字图 ...
- cnn神经网络可以用于数据拟合吗_使用Keras搭建卷积神经网络进行手写识别的入门(包含代码解读)...
本文是发在Medium上的一篇博客:<Handwritten Equation Solver using Convolutional Neural Network>.本文是原文的翻译.这篇 ...
- keras padding_GAN整体思路以及使用Keras搭建DCGAN
整体思路: 1:使用噪音,通过一系列的转秩卷积(逆卷积)操作,生成一张图片: 2:使用正常的卷积神经网络判断图片的真假. 训练细节: 1:在训练判决器网络时,对真实图像,加上正标签,对假图像加上假标签 ...
最新文章
- prd模板案例_第三课:产品需求文档——案例分析
- leetcode C++ 28. 实现 strStr() 实现 strStr() 函数。 给定一个 haystack 字符串和一个 needle 字符串,在 haystack 字符串中找出 need
- HTML输入学生成绩并排序java_JS实现冒泡排序,插入排序和快速排序并排序输出...
- python ui自动化测试框架_基于python语言下的UI自动化测试框架搭建(一)
- python功能代码_整理几个常用的Python功能代码片段【收藏】
- 从html到pug模板,将变量从html-webpack-plugin传递到pug模板
- MyBatisPlus_AR篇_入门试炼_06
- python程序出现了异常会执行哪个语句,python中的异常是什么?应该怎么处理异常?...
- 2018-7-10杂记
- CSS3属性 box-shadow
- xp计算机如何查看内存大小,xp如何查看内存大小
- 利用SPI编写类似sockscap的代理工具
- Java移位运算符:<<,>>,>>>
- the JDBC Driver has been forcibly unregistered问题解决
- 中国各省市及省会城市名称的由来
- JavaScript应用程序开发(游娱平台)
- 【手写源码-设计模式9】-装饰器模式-基于王者荣耀英雄-甄姬皮肤场景
- JAVA基础-java继承类实现
- 织梦dede:memberlist增加会员级别名称
- 独立分包中包含app.wxss的问题
热门文章
- springboot生成包含特定数字_关于Spring Boot 这可能是全网最好的知识点总结
- mysql 合并_MySQL——合并查询结果
- JavaWeb9大内置对象的作用与作用域
- html设置按钮不同状态的背景图片,CSS实例:创建一个鼠标感应换图片的按钮
- 事件mousseenter和mouseover的区别
- vue学习笔记(五)— 组件通信
- ××× L2TP over IPSec 配置
- 对一个前端使用AngularJS后端使用ASP.NET Web API项目的理解(3)
- 马什么梅?I什么N?浅谈 web 前端开发中的国际化
- js中数组的一些常见操作 - 1