1 生成器判别器实现

import  tensorflow as tf
from    tensorflow import keras
from    tensorflow.keras import layersclass Generator(keras.Model):def __init__(self):super(Generator, self).__init__()# z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]self.fc = layers.Dense(3*3*512)self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')self.bn1 = layers.BatchNormalization()self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')self.bn2 = layers.BatchNormalization()self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')def call(self, inputs, training=None):# [z, 100] => [z, 3*3*512]x = self.fc(inputs)x = tf.reshape(x, [-1, 3, 3, 512])x = tf.nn.leaky_relu(x)#x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))x = self.conv3(x)x = tf.tanh(x)return xclass Discriminator(keras.Model):def __init__(self):super(Discriminator, self).__init__()# [b, 64, 64, 3] => [b, 1]self.conv1 = layers.Conv2D(64, 5, 3, 'valid')self.conv2 = layers.Conv2D(128, 5, 3, 'valid')self.bn2 = layers.BatchNormalization()self.conv3 = layers.Conv2D(256, 5, 3, 'valid')self.bn3 = layers.BatchNormalization()# [b, h, w ,c] => [b, -1]self.flatten = layers.Flatten()self.fc = layers.Dense(1)def call(self, inputs, training=None):x = tf.nn.leaky_relu(self.conv1(inputs))x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))# [b, h, w, c] => [b, -1]x = self.flatten(x)# [b, -1] => [b, 1]logits = self.fc(x)return logitsdef main():d = Discriminator()g = Generator()x = tf.random.normal([2, 64, 64, 3])z = tf.random.normal([2, 100])prob = d(x)print(prob)x_hat = g(z)print(x_hat.shape)if __name__ == '__main__':main()

loss function

重点改变

''' WGAN gradient_penalty '''
def gradient_penalty(discriminator, batch_x, fake_image):batchsz = batch_x.shape[0]# [b, h, w, c]t = tf.random.uniform([batchsz, 1, 1, 1])# [b, 1, 1, 1] => [b, h, w, c]t = tf.broadcast_to(t, batch_x.shape)interplate = t * batch_x + (1 - t) * fake_imagewith tf.GradientTape() as tape:tape.watch([interplate])d_interplote_logits = discriminator(interplate, training=True)grads = tape.gradient(d_interplote_logits, interplate)# grads:[b, h, w, c] => [b, -1]grads = tf.reshape(grads, [grads.shape[0], -1])gp = tf.norm(grads, axis=1) #[b]gp = tf.reduce_mean( (gp-1)**2 )return gp
def celoss_ones(logits):# [b, 1]# [b] = [1, 1, 1, 1,]# loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,#                                                y_true=tf.ones_like(logits))return - tf.reduce_mean(logits)def celoss_zeros(logits):# [b, 1]# [b] = [1, 1, 1, 1,]# loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,#                                                y_true=tf.zeros_like(logits))return tf.reduce_mean(logits)def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):# 1. treat real image as real# 2. treat generated image as fakefake_image = generator(batch_z, is_training)d_fake_logits = discriminator(fake_image, is_training)d_real_logits = discriminator(batch_x, is_training)d_loss_real = celoss_ones(d_real_logits)d_loss_fake = celoss_zeros(d_fake_logits)gp = gradient_penalty(discriminator, batch_x, fake_image)loss = d_loss_real + d_loss_fake + 10. * gpreturn loss, gpdef g_loss_fn(generator, discriminator, batch_z, is_training):fake_image = generator(batch_z, is_training)d_fake_logits = discriminator(fake_image, is_training)loss = celoss_ones(d_fake_logits)return loss

3 WGAN 完整代码

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import  numpy as np
import  tensorflow as tf
from    tensorflow import kerasfrom    PIL import Image
import  glob
from    .gan import Generator, Discriminatorfrom    .dataset import make_anime_datasetdef save_result(val_out, val_block_size, image_path, color_mode):def preprocess(img):img = ((img + 1.0) * 127.5).astype(np.uint8)# img = img.astype(np.uint8)return imgpreprocesed = preprocess(val_out)final_image = np.array([])single_row = np.array([])for b in range(val_out.shape[0]):# concat image into a rowif single_row.size == 0:single_row = preprocesed[b, :, :, :]else:single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)# concat image row to final_imageif (b+1) % val_block_size == 0:if final_image.size == 0:final_image = single_rowelse:final_image = np.concatenate((final_image, single_row), axis=0)# reset single rowsingle_row = np.array([])if final_image.shape[2] == 1:final_image = np.squeeze(final_image, axis=2) Image.fromarray(final_image).save(image_path)def celoss_ones(logits):# [b, 1]# [b] = [1, 1, 1, 1,]# loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,#                                                y_true=tf.ones_like(logits))return - tf.reduce_mean(logits)def celoss_zeros(logits):# [b, 1]# [b] = [1, 1, 1, 1,]# loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,#                                                y_true=tf.zeros_like(logits))return tf.reduce_mean(logits)def gradient_penalty(discriminator, batch_x, fake_image):batchsz = batch_x.shape[0]# [b, h, w, c]t = tf.random.uniform([batchsz, 1, 1, 1])# [b, 1, 1, 1] => [b, h, w, c]t = tf.broadcast_to(t, batch_x.shape)interplate = t * batch_x + (1 - t) * fake_imagewith tf.GradientTape() as tape:tape.watch([interplate])d_interplote_logits = discriminator(interplate, training=True)grads = tape.gradient(d_interplote_logits, interplate)# grads:[b, h, w, c] => [b, -1]grads = tf.reshape(grads, [grads.shape[0], -1])gp = tf.norm(grads, axis=1) #[b]gp = tf.reduce_mean( (gp-1)**2 )return gpdef d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):# 1. treat real image as real# 2. treat generated image as fakefake_image = generator(batch_z, is_training)d_fake_logits = discriminator(fake_image, is_training)d_real_logits = discriminator(batch_x, is_training)d_loss_real = celoss_ones(d_real_logits)d_loss_fake = celoss_zeros(d_fake_logits)gp = gradient_penalty(discriminator, batch_x, fake_image)loss = d_loss_real + d_loss_fake + 10. * gpreturn loss, gpdef g_loss_fn(generator, discriminator, batch_z, is_training):fake_image = generator(batch_z, is_training)d_fake_logits = discriminator(fake_image, is_training)loss = celoss_ones(d_fake_logits)return lossdef main():tf.random.set_seed(233)np.random.seed(233)assert tf.__version__.startswith('2.')# hyper parametersz_dim = 100epochs = 3000000batch_size = 512learning_rate = 0.0005is_training = Trueimg_path = glob.glob(r'C:\Users\Jackie\Downloads\faces\*.jpg')assert len(img_path) > 0dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)print(dataset, img_shape)sample = next(iter(dataset))print(sample.shape, tf.reduce_max(sample).numpy(),tf.reduce_min(sample).numpy())dataset = dataset.repeat()db_iter = iter(dataset)generator = Generator() generator.build(input_shape = (None, z_dim))discriminator = Discriminator()discriminator.build(input_shape=(None, 64, 64, 3))z_sample = tf.random.normal([100, z_dim])g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)for epoch in range(epochs):for _ in range(5):batch_z = tf.random.normal([batch_size, z_dim])batch_x = next(db_iter)# train Dwith tf.GradientTape() as tape:d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)grads = tape.gradient(d_loss, discriminator.trainable_variables)d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))batch_z = tf.random.normal([batch_size, z_dim])with tf.GradientTape() as tape:g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)grads = tape.gradient(g_loss, generator.trainable_variables)g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))if epoch % 100 == 0:print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss),'gp:', float(gp))z = tf.random.normal([100, z_dim])fake_image = generator(z, training=False)img_path = os.path.join('images', 'wgan-%d.png'%epoch)save_result(fake_image.numpy(), 10, img_path, color_mode='P')if __name__ == '__main__':main()
  • 数据集加载与转换 dataset.py
import multiprocessingimport tensorflow as tfdef make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):# @tf.functiondef _map_fn(img):img = tf.image.resize(img, [resize, resize])# img = tf.image.random_crop(img,[resize, resize])# img = tf.image.random_flip_left_right(img)# img = tf.image.random_flip_up_down(img)img = tf.clip_by_value(img, 0, 255)img = img / 127.5 - 1 #-1~1return imgdataset = disk_image_batch_dataset(img_paths,batch_size,drop_remainder=drop_remainder,map_fn=_map_fn,shuffle=shuffle,repeat=repeat)img_shape = (resize, resize, 3)len_dataset = len(img_paths) // batch_sizereturn dataset, img_shape, len_datasetdef batch_dataset(dataset,batch_size,drop_remainder=True,n_prefetch_batch=1,filter_fn=None,map_fn=None,n_map_threads=None,filter_after_map=False,shuffle=True,shuffle_buffer_size=None,repeat=None):# set defaultsif n_map_threads is None:n_map_threads = multiprocessing.cpu_count()if shuffle and shuffle_buffer_size is None:shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048# [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costlyif shuffle:dataset = dataset.shuffle(shuffle_buffer_size)if not filter_after_map:if filter_fn:dataset = dataset.filter(filter_fn)if map_fn:dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)else:  # [*] this is slowerif map_fn:dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)if filter_fn:dataset = dataset.filter(filter_fn)dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)return datasetdef memory_data_batch_dataset(memory_data,batch_size,drop_remainder=True,n_prefetch_batch=1,filter_fn=None,map_fn=None,n_map_threads=None,filter_after_map=False,shuffle=True,shuffle_buffer_size=None,repeat=None):"""Batch dataset of memory data.Parameters----------memory_data : nested structure of tensors/ndarrays/lists"""dataset = tf.data.Dataset.from_tensor_slices(memory_data)dataset = batch_dataset(dataset,batch_size,drop_remainder=drop_remainder,n_prefetch_batch=n_prefetch_batch,filter_fn=filter_fn,map_fn=map_fn,n_map_threads=n_map_threads,filter_after_map=filter_after_map,shuffle=shuffle,shuffle_buffer_size=shuffle_buffer_size,repeat=repeat)return datasetdef disk_image_batch_dataset(img_paths,batch_size,labels=None,drop_remainder=True,n_prefetch_batch=1,filter_fn=None,map_fn=None,n_map_threads=None,filter_after_map=False,shuffle=True,shuffle_buffer_size=None,repeat=None):"""Batch dataset of disk image for PNG and JPEG.Parameters----------img_paths : 1d-tensor/ndarray/list of strlabels : nested structure of tensors/ndarrays/lists"""if labels is None:memory_data = img_pathselse:memory_data = (img_paths, labels)def parse_fn(path, *label):img = tf.io.read_file(path)img = tf.image.decode_jpeg(img, channels=3)  # fix channels to 3return (img,) + labelif map_fn:  # fuse `map_fn` and `parse_fn`def map_fn_(*args):return map_fn(*parse_fn(*args))else:map_fn_ = parse_fndataset = memory_data_batch_dataset(memory_data,batch_size,drop_remainder=drop_remainder,n_prefetch_batch=n_prefetch_batch,filter_fn=filter_fn,map_fn=map_fn_,n_map_threads=n_map_threads,filter_after_map=filter_after_map,shuffle=shuffle,shuffle_buffer_size=shuffle_buffer_size,repeat=repeat)return dataset

tensorflow WGAN 实现相关推荐

  1. 生成对抗式网络 GAN及其衍生CGAN、DCGAN、WGAN、LSGAN、BEGAN等原理介绍、应用介绍及简单Tensorflow实现

    生成式对抗网络(GAN,Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.学界大牛Yann Lecun 曾说,令他最激 ...

  2. 对抗生成网络学习(四)——WGAN+爬虫生成皮卡丘图像(tensorflow实现)

    一.背景 WGAN的全称为Wasserstein GAN, 是Martin Arjovsky等人于17年1月份提出的一个模型,该文章可以参考[1].WGAN针对GAN存在的问题进行了有针对性的改进,但 ...

  3. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  4. 手把手教你在Tensorflow实现BEGAN 达到惊人的人脸图像生成效果

    全球人工智能 文章来源:GitHub 作者:Heumi 翻译:马卓奇 文章投稿:news@top25.cn 相关文章: 导读:本文是基于谷歌大脑(Google Brain)发表在 arXiv 的最新论 ...

  5. 收敛速度更快更稳定的Wasserstein GAN(WGAN)

    生成对抗网络(GANs)是一种很有力的生成模型,它解决生成建模问题的方式就像在两个对抗式网络中进行比赛:给出一些噪声源,生成器网络能够产生合成的数据,鉴别器网络在真实数据和生成器的输出中进行鉴别.GA ...

  6. WGAN-GP与GAN及WGAN的比较

    WGAN-GP 转载自:https://www.e-learn.cn/content/qita/814071 from datetime import datetime import os impor ...

  7. PaperNotes(6)-GAN/DCGAN/WGAN/WGAN-GP/WGAN-SN-网络结构/实验效果

    GAN模型网络结构+实验效果演化 1.GAN 1.1网络结构 1.2实验结果 2.DCGAN 2.1网络结构 2.2实验结果 3.WGAN 3.1网络结构 3.2实验结果 4.WGAN-GP 4.1网 ...

  8. SAGAN生成更为精细的人脸图像(tensorflow实现)

    一.背景 SAGAN全称为Self-Attention Generative Adversarial Networks,是由Han Zhang等人[1]于18年5月提出的一种模型.文章中作者解释到,传 ...

  9. Tensorflow GAN对抗生成网络实战

    这一节的回顾主要针对使用JS散度得DCGAN和基于GP理论和Wasserstein Distance理论的WGAN首先是DCGAN 我们的训练数据集是一堆这种二次元的动漫头像的图片,那么我们就是要训练 ...

最新文章

  1. 《大规模web服务开发技术》阅读笔记
  2. Node.js Web开发框架
  3. iOS - Analyze 静态分析
  4. c语言调用串口扫码枪,C#利用控件mscomm32.ocx读取串口datalogic扫描枪数据
  5. 盘点那些年用过的机械键盘,为什么我最爱Keychron键盘呢
  6. 每天一种设计模式之抽象工厂模式(Java实现)
  7. 局域网php服务器搭建,php局域网服务器搭建
  8. html页面设置过期时间,meta标签http-equiv=Expires属性写法及用法
  9. 一步步教你使用云端服务器yeelink远程监控
  10. Calcite-学习笔记(入门篇)
  11. 【Python课程作业】食物数据的爬取及分析(详细介绍及分析)
  12. mysql 5.7.24-winx64_mysql-5.7.24-winx64下载与安装
  13. 英语语音中的调核例子_英语调核研究.pdf
  14. CUDA+VS2017+win环境下 cuda工程环境搭建(解决标识符未定义或申明)
  15. 【可见光室内定位】(一)概览
  16. APP性能测试--内存测试
  17. 及时总结工作中的经验是个人成长的关键
  18. 创业的捷径!打造黄金人脉!
  19. martin fowler_Martin Kleppmann的大型访谈:“弄清楚分布式数据系统的未来”
  20. 迅为恩智浦i.MX8MM开发平台虚拟机安装Ubuntu16.04系统

热门文章

  1. 自动更新开奖数据的excel文件,供大家下载
  2. 常见接口状态码状态码
  3. 快速 Building ONL 网络操作系统 X86 平台image
  4. 用数据说话,BCH众多指标已经碾压LTC
  5. Lucene(全文检索)入门
  6. 为了懒,我痛心学习Ajax
  7. 扫码盒获取微信支付宝付款码等信息的前端处理
  8. 【Spring学习34】Spring事务(4):事务属性之7种传播行为
  9. C# 访问修饰符含义与注意事项
  10. 拆分又遇变数,传赛门铁克或将出售VERITAS,这又是挖的什么坑?