章的思想是,利用网络层的权重共享约束,训练GAN网络.模型包括两个生成网络,两个判别网络,

训练数据为不成对的两个域Domain1,Domain2的图片,我们希望的是训练的两个生成网络g1,g2能够在输入向量z相同的情况下,生成的图片高频信息相同,低频信息不同.因此在觉得高频特征的生成网络的前几层,将两个生成网络的权重共享,并且,将两个判别网络f1,f2的最后几层网络权重共享,如上图所示.

github代码为:https://github.com/andrewliao11/CoGAN-tensorflow

两个生成网络,判别网络的结构相同,通过输入参数share_params控制权重是否共享.

生成网络代码,

def generator(self, z, y=None, share_params=False, reuse=False, name='G'):

if '1' in name:
            branch = '1'
        elif '2' in name:
            branch = '2'

# layers that share the variables 
        s = self.output_size
        s2, s4 = int(s/2), int(s/4) 
        h0 = prelu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin', reuse=share_params), reuse=share_params), 
                        name='g_h0_prelu', reuse=share_params)

h1 = prelu(self.g_bn1(linear(z, self.gf_dim*2*s4*s4,'g_h1_lin',reuse=share_params),reuse=share_params),
                        name='g_h1_prelu', reuse=share_params)
        h1 = tf.reshape(h1, [self.batch_size, s4, s4, self.gf_dim * 2])

h2 = prelu(self.g_bn2(deconv2d(h1, [self.batch_size,s2,s2,self.gf_dim * 2], 
            name='g_h2', reuse=share_params), reuse=share_params), name='g_h2_prelu', reuse=share_params)

# layers that don't share the variable
    with tf.variable_scope(name):
        if reuse:
        tf.get_variable_scope().reuse_variables()
        output = tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s, s, self.c_dim], name='g'+branch+'_h3', reuse=False))

return output
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
判别网络代码,

def discriminator(self, image, y=None, share_params=False, reuse=False, name='D'):

# select the corresponding batchnorm1(not shared)
        if '1' in name:
            d_bn1 = self.d1_bn1
        branch = '1'
        elif '2' in name:
            d_bn1 = self.d2_bn1
        branch = '2'

# layers that don't share variable
    with tf.variable_scope(name):
        if reuse:
        tf.get_variable_scope().reuse_variables()

h0 = prelu(conv2d(image, self.c_dim, name='d'+branch+'_h0_conv', reuse=False), 
                    name='d'+branch+'_h0_prelu', reuse=False)

h1 = prelu(d_bn1(conv2d(h0, self.df_dim, name='d'+branch+'_h1_conv', reuse=False), reuse=reuse), 
                    name='d'+branch+'_h1_prelu', reuse=False)
            h1 = tf.reshape(h1, [self.batch_size, -1])

# layers that share variables
        h2 = prelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin', reuse=share_params),reuse=share_params), 
                    name='d_h2_prelu', reuse=share_params)

h3 = linear(h2, 1, 'd_h3_lin', reuse=share_params)

return tf.nn.sigmoid(h3), h3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
输入向量z(噪声向量),以及attribute向量y(头发颜色,年龄等特征向量),生成图片,

# input of the generator is the concat of z, y
        self.G1 = self.generator(self.z, self.y, share_params=False, reuse=False, name='G1')
    self.G2 = self.generator(self.z, self.y, share_params=True, reuse=False, name='G2')
1
2
3
两个域的输入图像(real),以及生成图像(fake)分别输入判别网络,

#input the real images
        self.D1_logits, self.D1 = self.discriminator(self.images1, self.y, share_params=False, reuse=False, name='D1')
    self.D2_logits, self.D2 = self.discriminator(self.images2, self.y, share_params=True, reuse=False, name='D2')
# input the fake images
        self.D1_logits_, self.D1_ = self.discriminator(self.G1, self.y, share_params=True, reuse=True, name='D1')
    self.D2_logits_, self.D2_ = self.discriminator(self.G2, self.y, share_params=True, reuse=True, name='D2')
1
2
3
4
5
6
在损失函数中加入了权重参数,

GAN1损失函数,

self.d1_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits, tf.ones_like(self.D1)*0.9))
        self.d1_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits_,tf.ones_like(self.D1_)*0.1))
        self.g1_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits_, tf.ones_like(self.D1_)*0.9))

self.d1_loss = self.d1_loss_real + self.d1_loss_fake
1
2
3
4
5
6
GAN2损失函数,

self.d2_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits, tf.ones_like(self.D2)*0.9))
        self.d2_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits_,tf.ones_like(self.D2_)*0.1))
        self.g2_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits_, tf.ones_like(self.D2_)*0.9))
        self.d2_loss = self.d2_loss_real + self.d2_loss_fake
1
2
3
4
5
试验效果
手写字体,上下两行分别为生成网络G1,G2的生成效果,

--------------------- 
作者:imperfect00 
来源:CSDN 
原文:https://blog.csdn.net/u011961856/article/details/79122036 
版权声明:本文为博主原创文章,转载请附上博文链接!

CoGAN pytorch相关推荐

  1. jittor和pytorch生成网络对比之cogan

    pytorch代码 import argparse import os import numpy as np import math import scipy import itertoolsimpo ...

  2. 通过anaconda2安装python2.7和安装pytorch

    ①由于官网下载anaconda2太慢,最好去byrbt下载,然后安装就行 ②安装完anaconda2会自动安装了python2.7(如终端输入python即进入python模式) 但是可能没有设置环境 ...

  3. 记录一次简单、高效、无错误的linux上安装pytorch的过程

    1 准备miniconda Miniconda Miniconda 可以理解成Anaconda的免费.浓缩版.它非常小,只包含了conda.python以及它们依赖的一些包.我们可以根据我们的需要再安 ...

  4. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  5. PyTorch代码调试利器_TorchSnooper

    GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...

  6. pytorch常用代码

    20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...

  7. API pytorch tensorflow

    pytorch与tensorflow API速查表 方法名称 pytroch tensorflow numpy 裁剪 torch.clamp(x, min, max) tf.clip_by_value ...

  8. tensor转换 pytorch tensorflow

    一.tensorflow的numpy与tensor互转 1.数组(numpy)转tensor 利用tf.convert_to_tensor(numpy),将numpy转成tensor >> ...

  9. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

最新文章

  1. Linux访问Windows磁盘实现共享
  2. java stringbuilder 替换字符串_java中的经典问题StringBuilder替换String
  3. leetcode18
  4. vue 自定义指令实现,滚动条百分比进度条。
  5. mysql报错2_MySQL基于报错注入2
  6. Ado.net连接池 sp_reset_connection 概念
  7. python爬虫应用实战-如何爬取好看的小姐姐照片?
  8. 102、如何滚动更新 Service (Swarm09)
  9. How to Create a Development Package ?
  10. 韩国造智能手机时代走向终结:昔日巨头纷纷关闭生产线或削减产量
  11. HTML画笔移出画布停止,html5 canvas画布无法清除
  12. php输出下载地址,PHP实现的文件直接输出下载
  13. HSQL转换成MapReduce过程
  14. 12.TCP的成块数据流
  15. KiCad 5 版本体验记录
  16. 颜色的RGBnbsp;指数
  17. 开源组件安全漏洞检测主流工具对比
  18. CleanMyMac2023免费版系统清理优化工具
  19. c语言三重积分程序求法,D9_3三重积分[同济大学高等数学]..docx
  20. MATLAB的基本用法

热门文章

  1. python logging命令注入_整理后的手动注入脚本命令
  2. 华南师范大学计算机学院拟录取,华南师范大学各学院2015年硕士拟录取名单公示...
  3. 有序数组二分查找java_详解Java数据结构和算法(有序数组和二分查找)
  4. input file文件上传_微服务间的文件上传与下载-Feign
  5. 春意袭人,春装网店大比拼!
  6. emc celerra(一)--界面概览
  7. nginx的小总结(二)
  8. python解析response_python:解析requests返回的response(json格式)说明
  9. elk日志分析系统_部署ELK企业内部日志分析系统
  10. matlab中rat=1函数,matlab中的format rat是什么意思