一、论文亮点

论文地址:https://arxiv.org/abs/1511.06434

论文第三章讲了改进点,如下:

  1. 将pooling层替换成带strides的卷积层。判别器中就是带strides的卷积,生成器中,论文中说是fractional-strided
  2. TF中用的conv2d_transpose,总之是上采样。
  3. 消除顶层卷积特征中的全连接层,为了实现更深的网络。顶层特征指的是生成器的输入,以及判别器的输出。
  4. 使用BatchNorm。直接对所有层使用batchnorm会导致震荡和不稳定。所以在生成器的输出层和辨别其的输入层不用。
  5. 在生成器中使用ReLU激活,除了输出层,用的是tanh激活。辨别器使用的是leakyReLU激活,尤其对于高分辨率建模。

第四章讲了训练超参数细节:

  1. 对于输入的训练图,预处理缩放到tanh的范围[-1,1];
  2. 模型用SGD训练,batchsize为128;
  3. 权重用零中心正态分布、标准偏差0.02初始化;
  4. LeakyReLU的leak rate设为0.2;
  5. 如果使用Adam加速训练,推荐的0.001学习率太高,改用0.0002;
  6. 另外,momentum beta1用推荐的0.9导致振荡不稳定,降到0.5会稳定很多。

在LSUN(Large-scale Scene Understanding (LSUN))场景下,使用的生成器网络结构如下:

PS:论文中总共在三个数据集上做了测试:Large-scale Scene Understanding (LSUN),Imagenet-1k,a newly assembled Faces dataset。实验结果什么的我就不贴了。

二、简化代码解读

git上有份很好的demo,可以试试

https://github.com/carpedm20/DCGAN-tensorflow

参考这个demo的代码,我把DCGAN的关键的生成器、判别器、loss、train用最简单的代码写出来,先整理下思路,然后就放飞自我随便修改做各种尝试了,毕竟后面还有很多很多的CNN GAN网络,WGAN,WGAN-GP,LSGAN,等等等

判别器和生成器使用的函数:

其中的很多函数,新版的tf都是有的,看看就好

# 输出shape= [batch_size, output_size],可能也会输出w和b
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):shape = input_.get_shape().as_list()  # [batch, dim]with tf.variable_scope(scope or "Linear"):try:matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,tf.random_normal_initializer(stddev=stddev))# matrix shape = [dim, output_size]except ValueError as err:msg = "NOTE: Usually, this is due to an issue with the image dimensions.  Did you correctly set '--crop' or '--input_height' or '--output_height'?"err.args = err.args + (msg,)raisebias = tf.get_variable("bias", [output_size],initializer=tf.constant_initializer(bias_start))# bias shape = [output_size]# tf.matmul shape = [batch, output_size]if with_w:return tf.matmul(input_, matrix) + bias, matrix, biaselse:return tf.matmul(input_, matrix) + bias# kernel_size=5,strides = 2, padding='same',带bias的卷积
def conv2d(input_, output_dim,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="conv2d"):with tf.variable_scope(name):w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],initializer=tf.truncated_normal_initializer(stddev=stddev))conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')biases = tf.get_variable('biases', [output_dim],  initializer=tf.constant_initializer(0.0))conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())return conv# kernel_size=5,strides = 2, padding='same',带bias的反卷积(转置卷积)
def deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="deconv2d", with_w=False):with tf.variable_scope(name):# filter : [height, width, output_channels, in_channels]w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],initializer=tf.random_normal_initializer(stddev=stddev))try:deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,strides=[1, d_h, d_w, 1])# Support for verisons of TensorFlow before 0.7.0except AttributeError:deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,strides=[1, d_h, d_w, 1])biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())if with_w:return deconv, w, biaseselse:return deconvdef conv_out_size_same(size, stride):return int(math.ceil(float(size) / float(stride)))def bn(x,epsilon=1e-5, momentum=0.9, name="batch_norm"):tf.layers.batch_normalization(inputs=x,momentum=momentum,epsilon=epsilon,name=name)def lrelu(x, leak=0.2, name="lrelu"):return tf.maximum(x, leak * x,name)

辨别器

# 假设输入的batch_size = 16,第一层输出通道df_dim = 64
# 输入的image的shape=[batch_size,96,96,3]
def discriminator(image, reuse=False):with tf.variable_scope("discriminator") as scope:# 辨别器可能会被使用多次,所以可能需要resue=Trueif reuse:scope.reuse_variables()# lrelu、bn、conv2d、linear的代码后面会有h0 = lrelu(conv2d(image, df_dim, name='d_h0_conv'))h1 = lrelu(bn(conv2d(h0, df_dim * 2, name='d_h1_conv')))h2 = lrelu(bn(conv2d(h1, df_dim * 4, name='d_h2_conv')))h3 = lrelu(bn(conv2d(h2, df_dim * 8, name='d_h3_conv')))# h3的shape=[batchsize,6,6,512]h4 = linear(tf.reshape(h3, [batch_size, -1]), 1, 'd_h4_lin')  # similar as full connect# h4 的shape=[batchsize,1]return tf.nn.sigmoid(h4), h4

生成器

# output_height, output_width都是96, gf_dim = 64
# 输入的z的shape=[batch_size,100]
def generator(z):with tf.variable_scope("generator") as scope:s_h, s_w = output_height, output_width  # [96,96]s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)  # [48,48]s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)  # [24,24]s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)  # [12,12]s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)  # [6,6]z_, h0_w, h0_b = linear(z, gf_dim * 8 * s_h16 * s_w16, 'g_h0_lin', with_w=True)# z_的shape=[batch_size,512*6*6]h0 = tf.reshape(z_, [-1, s_h16, s_w16, gf_dim * 8])h0 = tf.nn.relu(bn(h0))h1 = deconv2d(h0, [batch_size, s_h8, s_w8, gf_dim * 4], name='g_h1', with_w=True)h1 = tf.nn.relu(bn(h1))h2 = deconv2d(h1, [batch_size, s_h4, s_w4, gf_dim * 2], name='g_h2', with_w=True)h2 = tf.nn.relu(bn(h2))h3 = deconv2d(h2, [batch_size, s_h2, s_w2, gf_dim * 1], name='g_h3', with_w=True)h3 = tf.nn.relu(bn(h3))h4 = deconv2d(h3, [batch_size, s_h, s_w, c_dim], name='g_h4', with_w=True)# h4的shape=[batch_size, 96, 96, 3]return tf.nn.tanh(h4)

loss

# 生成的假图
G = self.generator(z)
# 真图判别结果
D, D_logits = discriminator(inputs, reuse=False)
# 假图判别结果
D_, D_logits_ = discriminator(G, reuse=True)# 判别器loss: 真图-真
d_loss_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D)))
# 判别器loss: 假图-假
d_loss_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(D_logits_, tf.zeros_like(D_)))
# 判别器总loss
d_loss =d_loss_real + d_loss_fake# 生成器loss: 假图-真
g_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_)))

optimizer

这里,生成器的optimizer只优化生成器里面的参数;同理,判别器只优化判别器参数

t_vars = tf.trainable_variables()
# 判别器参数变量
d_vars = [var for var in t_vars if 'd_' in var.name]
# 生成器参数变量
g_vars = [var for var in t_vars if 'g_' in var.name]
# 只优化判别器
d_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
#只优化生成器
g_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)

train

with tf.Session() as sess:# 生成batch_size个zbatch_z = np.random.uniform(-1, 1, [batch_size, z_dim]).astype(np.float32)# 当然还有读入batch_size个image,这里就不写了# 然后feed给sesssess.run(d_optim,feed_dict={inputs: batch_images,z: batch_z})  # 更新判别器sess.run(g_optim,feed_dict={z: batch_z})  # 更新生成器

DCGAN论文改进之处+简化代码相关推荐

  1. aop 代码_项目学生:使用AOP简化代码

    aop 代码 这是Project Student的一部分. 许多人坚信方法应适合您的编辑器窗口(例如20行),而有些人认为方法应小于此范围. 这个想法是一种方法应该做一件事,而只能做一件事. 如果它做 ...

  2. 项目学生:使用AOP简化代码

    这是Project Student的一部分. 许多人坚信方法应适合您的编辑器窗口(例如20行),而有些人则认为方法应小于此范围. 这个想法是一种方法应该做一件事,而只能做一件事. 如果它做的还不止于此 ...

  3. [2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析

    [2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析 论文:https:/ ...

  4. [论文阅读] (23)恶意代码作者溯源(去匿名化)经典论文阅读:二进制和源代码对比

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  5. DCGAN 论文简单解读

    DCGAN的全称是Deep Convolution Generative Adversarial Networks(深度卷积生成对抗网络).是2014年Ian J.Goodfellow 的那篇开创性的 ...

  6. 利用位与运算简化代码

    利用位与运算简化代码 题目及代码来源:<数据结构习题解析(第三版)>,邓俊辉编著,ISBN: 978-7-302-33065-3 题目要求:改进教材中的countOnes()算法,使得时间 ...

  7. C#泛型简化代码量示例

    泛型简化代码量 下是我在项目中通过泛型来简化工作的一个Demo,记录一下: using System; using System.Collections.Generic; namespace MyCo ...

  8. FE之DR之线性降维:PCA/白化、LDA算法的数学知识(协方差矩阵)、相关论文、算法骤、代码实现、案例应用等相关配图之详细攻略

    FE之DR之线性降维:PCA/白化.LDA算法的数学知识(协方差矩阵).相关论文.算法骤.代码实现.案例应用等相关配图之详细攻略 目录 PCA 1.PCA的数学知识 1.协方差矩阵计算 2.PCA算法 ...

  9. ML之DR之SVD:SVD算法相关论文、算法过程、代码实现、案例应用之详细攻略

    ML之DR之SVD:SVD算法相关论文.算法过程.代码实现.案例应用之详细攻略 目录 SVD算法相关论文 SVD算法过程 1.公式的推导 2.SVD算法两步过程 SVD代码实现 SVD的案例应用 1. ...

最新文章

  1. mac最好用的markdown_Markdown 语法简明教程 amp; Markdown 编辑器推荐
  2. RASPBERRY PI PICO 树莓派PICO开发板双核高性能低功耗RP2040芯片
  3. SolrQuery的使用
  4. Ubuntu ADSL 拨号上网时断时续问题
  5. XIII Open Grodno SU Championship
  6. 成功解决WARNING: You do not appear to have an NVIDIA GPU supported by the 430.34 NVIDIA Linux graph
  7. 如何给字体添加底色indesign_“美哉汉字”2020字体设计专家工作坊预告+报名
  8. BZOJ2844 albus就是要第一个出场
  9. 用java发送邮件(黄海已测试通过)
  10. html展示pdf文件流,使用pdfjs提供的viewer.html展示pdf文件流
  11. Python 判断字符串是否为IP(字符串中是否包含IP)
  12. C#操作XML方法详解
  13. UVA10427 Naughty Sleepy Boys【数学】
  14. 【深度学习】你不了解的细节问题(四)
  15. 【数学建模】数据包络分析法
  16. ospf路由 华3_华三路由器命令信息
  17. java服务器如何群发消息,java TCP编程简单实现一个消息群发功能
  18. 电子计算机工作最主要特征,电子计算机最重要的工作特征是( )
  19. C++单个配置与多个配置
  20. 央视网商城app_央视网商城“中国好产品联合招商大会”召开

热门文章

  1. LSA(潜在语义分析)
  2. 2022焊工(初级)试题及在线模拟考试
  3. Unity 摄像机切换镜头
  4. unity摄像机跟随鼠标旋转
  5. linux下的lib文件知识
  6. 【项目管理/PMP/PMBOK第六版/新考纲】纯干货!Sprint冲刺/冲刺计划会/每日站立会/冲刺评审会/冲刺回顾会,系列文章建议收藏
  7. 51单片机+DS18B20+数码管显示+串口通讯+Proteus仿真
  8. iOS开发:兼容适配iPhone X
  9. Serializable的含义
  10. word无法验证服务器,Win8系统打开office文件提示“无法验证此产品的许可证”如何解决...