GANs是Generative Adversarial Networks的简写,中文翻译为生成对抗网络,它最早出现在2014年Goodfellow发表的论文中:Generative Adversarial Networks。GANs是目前深度学习领域最火的网络模型,苹果最近发布的第一篇论文就是关于GANs的:SimGAN。

简单来说,GANs会学着生成和训练数据相似的数据,一个最典型的应用是生成图像。假设你有一堆猫的图片,你使用这些图片训练GANs,之后它会生成和训练数据相类似的猫的图片(它习的了猫的特征)。

GANs用到机器学习的两种模型:Generative生成模型和Discriminative判别模型。

GANs类比:假设G是大伪艺术家,以制作古董赝品为生,G的终极目标是以假乱真。但是呢,又有一些人以鉴宝为生(D)。开始你给D展示了一些古董真品,告诉D这是正品。然后G开始制作赝品,想骗过D,让他分辨不出真假。随着D看到越来越多的真品,G要骗过D就越来越难,当然,G也不是吃闲饭的,它会加倍努力的试图骗过D。随着这种对抗的持续,不仅D鉴宝的本领提高了,G也会越来越擅长制作赝品。这就是名字中生成-对抗的意思。

判别模型可以判断数据属于哪一类,例如<TensorFlow练习23: 恶作剧>训练的CNN模型可以判断一张脸是不是我的脸。相反,生成模型不用预先知道分类,它可生成最符合训练样本分布的新样本。例如高斯混合模型,经过训练,它生成的随机数据符合训练样本的分布。

GANs简单图示:

GAN相关代码实现:

  • DCGAN TensorFlow实现
  • 根据文本描述生成图像(反过来的: 看图说话Show and Tell)
  • 图像补全,叫你在打码
  • TF-VAE-GAN-DRAW
  • Auxiliary Classifier GAN
  • InfoGAN时间序列数据分类
  • 生成视频
  • Generative Models (OpenAI)

一个TensorFlow代码示例(生成明星脸-EBGAN)

使用的数据集:Large-scale CelebFaces Attributes (CelebA) Dataset,这个数据集包含20万明星脸,可用来做人脸检测、人脸特征识别等等任务。

下载地址:Google Drive或Baidu云。

  • Energy Based Generative Adversarial Networks (EBGAN)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

# -*- coding: utf-8 -*-
"""
Energy Based Generative Adversarial Networks (EBGAN): https://arxiv.org/pdf/1609.03126v2.pdf
<blog.topspeedsnail.com>
由于我把Python升级到了3.6破坏了开发环境, 暂时先使用Python 2.7
"""
import os
import random
import numpy as np
import tensorflow as tf
from PIL import Image
#import cv2
import scipy.misc as misc
CELEBA_DATE_DIR= 'img_align_celeba'
train_images = []
for image_filename in os.listdir(CELEBA_DATE_DIR):
if image_filename.endswith('.jpg'):
train_images.append(os.path.join(CELEBA_DATE_DIR, image_filename))
random.shuffle(train_images)
batch_size = 64
num_batch = len(train_images) // batch_size
# 图像大小和channel
IMAGE_SIZE = 64
IMAGE_CHANNEL = 3
def get_next_batch(pointer):
image_batch = []
images = train_images[pointer*batch_size:(pointer+1)*batch_size]
for img in images:
arr = Image.open(img)
arr = arr.resize((IMAGE_SIZE, IMAGE_SIZE))
arr = np.array(arr)
arr = arr.astype('float32') / 127.5 - 1
#image = cv2.imread(img)
#image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
#image = image.astype('float32') / 127.5 - 1
image_batch.append(arr)
return image_batch
# noise
z_dim = 100
noise = tf.placeholder(tf.float32, [None, z_dim], name='noise')
X = tf.placeholder(tf.float32, [batch_size, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNEL], name='X')
# 是否在训练阶段
train_phase = tf.placeholder(tf.bool)
# http://stackoverflow.com/a/34634291/2267819
def batch_norm(x, beta, gamma, phase_train, scope='bn', decay=0.9, eps=1e-5):
with tf.variable_scope(scope):
#beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
#gamma = tf.get_variable(name='gamma', shape=[n_out], initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed
# 重用变量出了点问题, 先用dict
generator_variables_dict = {
"W_1": tf.Variable(tf.truncated_normal([z_dim, 2 * IMAGE_SIZE * IMAGE_SIZE], stddev=0.02), name='Generator/W_1'),
"b_1": tf.Variable(tf.constant(0.0, shape=[2 * IMAGE_SIZE * IMAGE_SIZE]), name='Generator/b_1'),
'beta_1': tf.Variable(tf.constant(0.0, shape=[512]), name='Generator/beta_1'),
'gamma_1': tf.Variable(tf.random_normal(shape=[512], mean=1.0, stddev=0.02), name='Generator/gamma_1'),
"W_2": tf.Variable(tf.truncated_normal([5, 5, 256, 512], stddev=0.02), name='Generator/W_2'),
"b_2": tf.Variable(tf.constant(0.0, shape=[256]), name='Generator/b_2'),
'beta_2': tf.Variable(tf.constant(0.0, shape=[256]), name='Generator/beta_2'),
'gamma_2': tf.Variable(tf.random_normal(shape=[256], mean=1.0, stddev=0.02), name='Generator/gamma_2'),
"W_3": tf.Variable(tf.truncated_normal([5, 5, 128, 256], stddev=0.02), name='Generator/W_3'),
"b_3": tf.Variable(tf.constant(0.0, shape=[128]), name='Generator/b_3'),
'beta_3': tf.Variable(tf.constant(0.0, shape=[128]), name='Generator/beta_3'),
'gamma_3': tf.Variable(tf.random_normal(shape=[128], mean=1.0, stddev=0.02), name='Generator/gamma_3'),
"W_4": tf.Variable(tf.truncated_normal([5, 5, 64, 128], stddev=0.02), name='Generator/W_4'),
"b_4": tf.Variable(tf.constant(0.0, shape=[64]), name='Generator/b_4'),
'beta_4': tf.Variable(tf.constant(0.0, shape=[64]), name='Generator/beta_4'),
'gamma_4': tf.Variable(tf.random_normal(shape=[64], mean=1.0, stddev=0.02), name='Generator/gamma_4'),
"W_5": tf.Variable(tf.truncated_normal([5, 5, IMAGE_CHANNEL, 64], stddev=0.02), name='Generator/W_5'),
"b_5": tf.Variable(tf.constant(0.0, shape=[IMAGE_CHANNEL]), name='Generator/b_5')
}
# Generator
def generator(noise):
with tf.variable_scope("Generator"):
out_1 = tf.matmul(noise, generator_variables_dict["W_1"]) + generator_variables_dict['b_1']
out_1 = tf.reshape(out_1, [-1, IMAGE_SIZE//16, IMAGE_SIZE//16, 512])
out_1 = batch_norm(out_1, generator_variables_dict["beta_1"], generator_variables_dict["gamma_1"], train_phase, scope='bn_1')
out_1 = tf.nn.relu(out_1, name='relu_1')
out_2 = tf.nn.conv2d_transpose(out_1, generator_variables_dict['W_2'],  output_shape=tf.pack([tf.shape(out_1)[0], IMAGE_SIZE//8, IMAGE_SIZE//8, 256]), strides=[1, 2, 2, 1], padding='SAME')
out_2 = tf.nn.bias_add(out_2, generator_variables_dict['b_2'])
out_2 = batch_norm(out_2, generator_variables_dict["beta_2"], generator_variables_dict["gamma_2"], train_phase, scope='bn_2')
out_2 = tf.nn.relu(out_2, name='relu_2')
out_3 = tf.nn.conv2d_transpose(out_2, generator_variables_dict['W_3'],  output_shape=tf.pack([tf.shape(out_2)[0], IMAGE_SIZE//4, IMAGE_SIZE//4, 128]), strides=[1, 2, 2, 1], padding='SAME')
out_3 = tf.nn.bias_add(out_3, generator_variables_dict['b_3'])
out_3 = batch_norm(out_3, generator_variables_dict["beta_3"], generator_variables_dict["gamma_3"], train_phase, scope='bn_3')
out_3 = tf.nn.relu(out_3, name='relu_3')
out_4 = tf.nn.conv2d_transpose(out_3, generator_variables_dict['W_4'],  output_shape=tf.pack([tf.shape(out_3)[0], IMAGE_SIZE//2, IMAGE_SIZE//2, 64]), strides=[1, 2, 2, 1], padding='SAME')
out_4 = tf.nn.bias_add(out_4, generator_variables_dict['b_4'])
out_4 = batch_norm(out_4, generator_variables_dict["beta_4"], generator_variables_dict["gamma_4"], train_phase, scope='bn_4')
out_4 = tf.nn.relu(out_4, name='relu_4')
out_5 = tf.nn.conv2d_transpose(out_4, generator_variables_dict['W_5'],  output_shape=tf.pack([tf.shape(out_4)[0], IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNEL]), strides=[1, 2, 2, 1], padding='SAME')
out_5 = tf.nn.bias_add(out_5, generator_variables_dict['b_5'])
out_5 = tf.nn.tanh(out_5, name='tanh_5')
return out_5
discriminator_variables_dict = {
"W_1": tf.Variable(tf.truncated_normal([4, 4, IMAGE_CHANNEL, 32], stddev=0.002), name='Discriminator/W_1'),
"b_1": tf.Variable(tf.constant(0.0, shape=[32]), name='Discriminator/b_1'),
'beta_1': tf.Variable(tf.constant(0.0, shape=[32]), name='Discriminator/beta_1'),
'gamma_1': tf.Variable(tf.random_normal(shape=[32], mean=1.0, stddev=0.02), name='Discriminator/gamma_1'),
"W_2": tf.Variable(tf.truncated_normal([4, 4, 32, 64], stddev=0.002), name='Discriminator/W_2'),
"b_2": tf.Variable(tf.constant(0.0, shape=[64]), name='Discriminator/b_2'),
'beta_2': tf.Variable(tf.constant(0.0, shape=[64]), name='Discriminator/beta_2'),
'gamma_2': tf.Variable(tf.random_normal(shape=[64], mean=1.0, stddev=0.02), name='Discriminator/gamma_2'),
"W_3": tf.Variable(tf.truncated_normal([4, 4, 64, 128], stddev=0.002), name='Discriminator/W_3'),
"b_3": tf.Variable(tf.constant(0.0, shape=[128]), name='Discriminator/b_3'),
'beta_3': tf.Variable(tf.constant(0.0, shape=[128]), name='Discriminator/beta_3'),
'gamma_3': tf.Variable(tf.random_normal(shape=[128], mean=1.0, stddev=0.02), name='Discriminator/gamma_3'),
"W_4": tf.Variable(tf.truncated_normal([4, 4, 64, 128], stddev=0.002), name='Discriminator/W_4'),
"b_4": tf.Variable(tf.constant(0.0, shape=[64]), name='Discriminator/b_4'),
'beta_4': tf.Variable(tf.constant(0.0, shape=[64]), name='Discriminator/beta_4'),
'gamma_4': tf.Variable(tf.random_normal(shape=[64], mean=1.0, stddev=0.02), name='Discriminator/gamma_4'),
"W_5": tf.Variable(tf.truncated_normal([4, 4, 32, 64], stddev=0.002), name='Discriminator/W_5'),
"b_5": tf.Variable(tf.constant(0.0, shape=[32]), name='Discriminator/b_5'),
'beta_5': tf.Variable(tf.constant(0.0, shape=[32]), name='Discriminator/beta_5'),
'gamma_5': tf.Variable(tf.random_normal(shape=[32], mean=1.0, stddev=0.02), name='Discriminator/gamma_5'),
"W_6": tf.Variable(tf.truncated_normal([4, 4, 3, 32], stddev=0.002), name='Discriminator/W_6'),
"b_6": tf.Variable(tf.constant(0.0, shape=[3]), name='Discriminator/b_6')
}
# Discriminator
def discriminator(input_images):
with tf.variable_scope("Discriminator"):
# Encoder
out_1 = tf.nn.conv2d(input_images, discriminator_variables_dict['W_1'], strides=[1, 2, 2, 1], padding='SAME')
out_1 = tf.nn.bias_add(out_1, discriminator_variables_dict['b_1'])
out_1 = batch_norm(out_1, discriminator_variables_dict['beta_1'], discriminator_variables_dict['gamma_1'], train_phase, scope='bn_1')
out_1 = tf.maximum(0.2 * out_1, out_1, 'leaky_relu_1')
out_2 = tf.nn.conv2d(out_1, discriminator_variables_dict['W_2'], strides=[1, 2, 2, 1], padding='SAME')
out_2 = tf.nn.bias_add(out_2, discriminator_variables_dict['b_2'])
out_2 = batch_norm(out_2, discriminator_variables_dict['beta_2'], discriminator_variables_dict['gamma_2'], train_phase, scope='bn_2')
out_2 = tf.maximum(0.2 * out_2, out_2, 'leaky_relu_2')
out_3 = tf.nn.conv2d(out_2, discriminator_variables_dict['W_3'], strides=[1, 2, 2, 1], padding='SAME')
out_3 = tf.nn.bias_add(out_3, discriminator_variables_dict['b_3'])
out_3 = batch_norm(out_3, discriminator_variables_dict['beta_3'], discriminator_variables_dict['gamma_3'], train_phase, scope='bn_3')
out_3 = tf.maximum(0.2 * out_3, out_3, 'leaky_relu_3')
encode = tf.reshape(out_3, [-1, 2*IMAGE_SIZE*IMAGE_SIZE])
# Decoder
out_3 = tf.reshape(encode, [-1, IMAGE_SIZE//8, IMAGE_SIZE//8, 128])
out_4 = tf.nn.conv2d_transpose(out_3, discriminator_variables_dict['W_4'],  output_shape=tf.pack([tf.shape(out_3)[0], IMAGE_SIZE//4, IMAGE_SIZE//4, 64]), strides=[1, 2, 2, 1], padding='SAME')
out_4 = tf.nn.bias_add(out_4, discriminator_variables_dict['b_4'])
out_4 = batch_norm(out_4, discriminator_variables_dict['beta_4'], discriminator_variables_dict['gamma_4'], train_phase, scope='bn_4')
out_4 = tf.maximum(0.2 * out_4, out_4, 'leaky_relu_4')
out_5 = tf.nn.conv2d_transpose(out_4, discriminator_variables_dict['W_5'],  output_shape=tf.pack([tf.shape(out_4)[0], IMAGE_SIZE//2, IMAGE_SIZE//2, 32]), strides=[1, 2, 2, 1], padding='SAME')
out_5 = tf.nn.bias_add(out_5, discriminator_variables_dict['b_5'])
out_5 = batch_norm(out_5, discriminator_variables_dict['beta_5'], discriminator_variables_dict['gamma_5'], train_phase, scope='bn_5')
out_5 = tf.maximum(0.2 * out_5, out_5, 'leaky_relu_5')
out_6 = tf.nn.conv2d_transpose(out_5, discriminator_variables_dict['W_6'],  output_shape=tf.pack([tf.shape(out_5)[0], IMAGE_SIZE, IMAGE_SIZE, 3]), strides=[1, 2, 2, 1], padding='SAME')
out_6 = tf.nn.bias_add(out_6, discriminator_variables_dict['b_6'])
decoded = tf.nn.tanh(out_6, name="tanh_6")
return encode, decoded
# mean squared errors
_, real_decoded = discriminator(X)
real_loss = tf.sqrt(2 * tf.nn.l2_loss(real_decoded - X)) / batch_size
fake_image = generator(noise)
_, fake_decoded = discriminator(fake_image)
fake_loss = tf.sqrt(2 * tf.nn.l2_loss(fake_decoded - fake_image)) / batch_size
# loss
# D_loss = real_loss + tf.maximum(1 - fake_loss, 0)
margin = 20
D_loss = margin - fake_loss + real_loss
G_loss = fake_loss # no pt
def optimizer(loss, d_or_g):
optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.5)
#print([v.name for v in tf.trainable_variables() if v.name.startswith(d_or_g)])
var_list = [v for v in tf.trainable_variables() if v.name.startswith(d_or_g)]
gradient = optim.compute_gradients(loss, var_list=var_list)
return optim.apply_gradients(gradient)
train_op_G = optimizer(G_loss, 'Generator')
train_op_D = optimizer(D_loss, 'Discriminator')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), feed_dict={train_phase: True})
saver = tf.train.Saver()
# 恢复前一次训练
ckpt = tf.train.get_checkpoint_state('.')
if ckpt != None:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("没找到模型")
step = 0
for i in range(40):
for j in range(num_batch):
batch_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, z_dim]).astype(np.float32)
d_loss, _ = sess.run([D_loss, train_op_D], feed_dict={noise: batch_noise, X: get_next_batch(j), train_phase: True})
g_loss, _ = sess.run([G_loss, train_op_G], feed_dict={noise: batch_noise, X: get_next_batch(j), train_phase: True})
g_loss, _ = sess.run([G_loss, train_op_G], feed_dict={noise: batch_noise, X: get_next_batch(j), train_phase: True})
print(step, d_loss, g_loss)
# 保存模型并生成图像
if step % 100 == 0:
saver.save(sess, "celeba.model", global_step=step)
test_noise = np.random.uniform(-1.0, 1.0, size=(5, z_dim)).astype(np.float32)
images = sess.run(fake_image, feed_dict={noise: test_noise, train_phase: False})
for k in range(5):
image = images[k, :, :, :]
image += 1
image *= 127.5
image = np.clip(image, 0, 255).astype(np.uint8)
image = np.reshape(image, (IMAGE_SIZE, IMAGE_SIZE, -1))
misc.imsave('fake_image' + str(step) + str(k) + '.jpg', image)
step += 1

 

2000step就出现了人脸的雏型,接着练吧,再改改参数。

ps.昨天做梦,梦见自己变成矩阵了,太诡异了。

如要转载,请保持本文完整,并注明作者@斗大的熊猫和本文原始地址: http://blog.topspeedsnail.com/archives/10977

TensorFlow练习24: GANs-生成对抗网络 (生成明星脸)相关推荐

  1. 利用生成对抗网络生成海洋塑料合成图像

    问题陈述 过去十年来,海洋塑料污染一直是气候问题的首要问题.海洋中的塑料不仅能够通过勒死或饥饿杀死海洋生物,而且也是通过捕获二氧化碳使海洋变暖的一个主要因素. 近年来,非营利组织海洋清洁组织(Ocea ...

  2. 生成对抗网络生成多维数据集_生成没有数据集的新颖内容

    生成对抗网络生成多维数据集 介绍(Introduction) GAN architecture has been the standard for generating content through ...

  3. 【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

    [万物皆可 GAN]生成对抗网络生成手写数字 Part 1 概述 GAN 网络结构 GAN 训练流程 模型详解 生成器 判别器 概述 GAN (Generative Adversarial Netwo ...

  4. pytorch生成对抗网络生成动漫图像

    代码地址:pytorch实战,使用生成对抗网络生成动漫图像 dataset from torchvision import transforms from torch.utils.data impor ...

  5. 基于改进型生成对抗网络生成异构故障样本的方法

    文章地址:A Modified Generative Adversarial Network for Fault Diagnosis in High-Speed Train Components wi ...

  6. 掌握生成对抗网络(GANs),召唤专属二次元老婆(老公)不是梦

    全文共6706字,预计学习时长12分钟或更长 近日,<狮子王>热映,其逼真的外形,几乎可以以假乱真,让观众不禁大呼:awsl,这也太真实了吧! 实体模型.CGI动画.实景拍摄.VR等技术娴 ...

  7. 新手必看:生成对抗网络的初学者入门指导

    新手必看:生成对抗网络的初学者入门指导 https://www.cnblogs.com/DicksonJYL/p/9698877.html 本文为 AI 研习社编译的技术博客,原标题 A Beginn ...

  8. 一文看懂「生成对抗网络 - GAN」基本原理+10种典型算法+13种应用

    生成对抗网络 – Generative Adversarial Networks | GAN 文章目录 GAN的设计初衷 生成对抗网络 GAN 的基本原理 GAN的优缺点 10大典型的GAN算法 GA ...

  9. 【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题

    参考资料: <PyTorch深度学习>(人民邮电出版社)第7章 生成网络 PyTorch官方文档 廖星宇著<深度学习入门之Pytorch>第6章 生成对抗网络 其他参考的网络资 ...

最新文章

  1. 安卓使用Socket发送中文,C语言服务端接收乱码问题解决方式
  2. 移动端geolocation插件+百度地图js获取地址
  3. Android中利用隐式意图发送短信
  4. linux 查找进程 删除进程 命令
  5. SAP CRM Fiori应用Simulation pipeline里dualSlider的实现
  6. js怎么在一个div中嵌入另一网站_好程序员web前端学习路线分享HTML5常见面试题集锦一...
  7. MYSQL 的静态表和动态表的区别, MYISAM 和 INNODB 的区别
  8. hadoop 自定义OutputFormat
  9. Ubuntu 下安装JDK
  10. 配置Hibernate二级缓存步骤
  11. Vm下安装centos7.0时电脑进入黑屏的解决方法(选择Install Centos 7或者是Test this media install Centos 7以后,虚拟机屏幕立马就进入黑屏状态)
  12. java构造方法不允许调用重载方法
  13. 五分钟带你了解什么是PID模糊算法
  14. 【技术美术图形部分】纹理基础2.0-凹凸映射
  15. 正则表达式匹配经纬度
  16. 唯物论、辩证法和认识论
  17. 通过时间序列分析预测未来广州的空气质量指数变化
  18. JetpackCompose Modifier常用属性介绍(1)
  19. [UE4]接入steam sdk的plugin,可以获取到用户id和name,steam userid playerid
  20. 怎么使用股票委托下单接口?

热门文章

  1. 【c语言】蓝桥杯算法训练 sign函数
  2. matlab设置数组输出到文件中,Matlab将元胞数组输出到txt或者dat文件中
  3. Django:永别了pycrypto库~
  4. Mysql系列(三)—— Mysql主从复制配置
  5. vfork 挂掉的一个问题
  6. html5常用模板下载网站
  7. LANMP框架搭建——源码编译
  8. 2015 AlBaath Collegiate Programming Contest(2月14日训练赛)
  9. android doGet和doPost
  10. js递归函数使用介绍