原文链接:https://arxiv.org/pdf/1604.07379.pdf

简介

背景:从“Conditinal GAN”到“InfoGAN”,我们都在致力于解决一个问题就是如何通过人为控制的范式,来控制GAN网络生成我们想要的数据,但是之前的几种方法大多是在针对模型的输入与输出在做文章。而本文是通过构思了一种双(或多)GAN的架构,通过不同域的相互约束来达到控制生成数据的样式。

核心思想:使用一对GAN通过权重共享的方式,是的两个数据空间域同时约束生成数据。

由上图来看,我们从中间划开来看的话就是两个独立的GAN,但是有所不同的是这两个GAN在生成器前几层与生成器后几层是共享像网络权重的。为什么我们要共享网络权重呢?可以感性地这样理解,我们将GAN网络应用到了两个任务中,于是乎GAN会受到两个任务的约束,而约束越多就越方便我们控制GAN的优化方向。

基础结构

生成器

我们将生成器的每一层网络拆分开,即可得到上面的式子,由于两个生成器是从相同的随机变量映射到不同的数据空间,因此共享的网络层只能是接近噪声z输入端的,即m较小的层。

判别器

我们将判别器的每一层网络拆分开,即可得到上面的式子,由于两个判别器是从不同的数据空间映射到真假的判断结果,因此共享的网络层只能是接近结果输出端的,即n较大的层。

LOSS

用一张简单的示意图表示:

上图两个大的矩形就是联合分布域形成的约束。

代码与实践结果

参考链接:https://github.com/WingsofFAN/PyTorch-GAN/blob/master/implementations/cogan/cogan.py

import argparse
import os
import numpy as np
import math
import scipy
import itertoolsimport mnistmimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else Falsedef weights_init_normal(m):classname = m.__class__.__name__if classname.find("Linear") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)class CoupledGenerators(nn.Module):def __init__(self):super(CoupledGenerators, self).__init__()self.init_size = opt.img_size // 4self.fc = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.shared_conv = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),)self.G1 = nn.Sequential(nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)self.G2 = nn.Sequential(nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise):out = self.fc(noise)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img_emb = self.shared_conv(out)img1 = self.G1(img_emb)img2 = self.G2(img_emb)return img1, img2class CoupledDiscriminators(nn.Module):def __init__(self):super(CoupledDiscriminators, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))block.extend([nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)])return blockself.shared_conv = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# The height and width of downsampled imageds_size = opt.img_size // 2 ** 4self.D1 = nn.Linear(128 * ds_size ** 2, 1)self.D2 = nn.Linear(128 * ds_size ** 2, 1)def forward(self, img1, img2):# Determine validity of first imageout = self.shared_conv(img1)out = out.view(out.shape[0], -1)validity1 = self.D1(out)# Determine validity of second imageout = self.shared_conv(img2)out = out.view(out.shape[0], -1)validity2 = self.D2(out)return validity1, validity2# Loss function
adversarial_loss = torch.nn.MSELoss()# Initialize models
coupled_generators = CoupledGenerators()
coupled_discriminators = CoupledDiscriminators()if cuda:coupled_generators.cuda()coupled_discriminators.cuda()# Initialize weights
coupled_generators.apply(weights_init_normal)
coupled_discriminators.apply(weights_init_normal)# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader1 = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)os.makedirs("../../data/mnistm", exist_ok=True)
dataloader2 = torch.utils.data.DataLoader(mnistm.MNISTM("../../data/mnistm",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(coupled_generators.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(coupled_discriminators.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, ((imgs1, _), (imgs2, _)) in enumerate(zip(dataloader1, dataloader2)):batch_size = imgs1.shape[0]# Adversarial ground truthsvalid = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False)# Configure inputimgs1 = Variable(imgs1.type(Tensor).expand(imgs1.size(0), 3, opt.img_size, opt.img_size))imgs2 = Variable(imgs2.type(Tensor))# ------------------#  Train Generators# ------------------optimizer_G.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))# Generate a batch of imagesgen_imgs1, gen_imgs2 = coupled_generators(z)# Determine validity of generated imagesvalidity1, validity2 = coupled_discriminators(gen_imgs1, gen_imgs2)g_loss = (adversarial_loss(validity1, valid) + adversarial_loss(validity2, valid)) / 2g_loss.backward()optimizer_G.step()# ----------------------#  Train Discriminators# ----------------------optimizer_D.zero_grad()# Determine validity of real and generated imagesvalidity1_real, validity2_real = coupled_discriminators(imgs1, imgs2)validity1_fake, validity2_fake = coupled_discriminators(gen_imgs1.detach(), gen_imgs2.detach())#真实图片输入对应两个loss#生成图片输入对应两个lossd_loss = (adversarial_loss(validity1_real, valid)+ adversarial_loss(validity1_fake, fake)+ adversarial_loss(validity2_real, valid)+ adversarial_loss(validity2_fake, fake)) / 4d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader1), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader1) + iif batches_done % opt.sample_interval == 0:gen_imgs = torch.cat((gen_imgs1.data, gen_imgs2.data), 0)save_image(gen_imgs, "images/%d.png" % batches_done, nrow=8, normalize=True)

mnist与mnistm测试结果

   

mnist                                                         mnistm

Coupled Generative Adversarial Networks(小白学GAN 九)相关推荐

  1. Coupled Generative Adversarial Networks

    https://blog.csdn.net/carrierlxksuper/article/details/60479883 这篇文章(NIPS2016)是基于Generative Adversari ...

  2. GAN系列:代码阅读——Generative Adversarial Networks 李宏毅老师GAN课程P1+P4

    看了一上午简直要头疼死.GAN之前没接触过,学习的时候产生了很多乱七八糟的联想.从上篇文章开始,很多内容都是自己的理解,估计有很多错误,以后学习中发现了可能会回来修改的. 找的是机器之心i的代码:ht ...

  3. [解读] Coupled Generative Adversarial Networks

    论文链接: https://arxiv.org/abs/1606.07536v2 Github 项目地址: https://github.com/mingyuliutw/cogan 本文提出耦合的生成 ...

  4. Generative Adversarial Networks(CGAN、CycleGAN、CoGAN)

    很久前整理了GAN和DCGAN,主要是GAN的基本原理和训练方法,以及DCGAN在图像上的应用,模式崩溃问题等.其核心思想就是通过训练两个神经网络,一个用来生成数据,另一个用于在假数据中分类出真数据, ...

  5. (SAGAN)Self-Attention Generative Adversarial Networks

    core idea:将self-attention机制引入到GANs的图像生成当中,来建模像素间的远距离关系,用于图像生成任务 CGAN的缺点: 1.依赖卷积建模图像不同区域的依赖关系,由于卷积核比较 ...

  6. 生成对抗网络(Generative Adversarial Networks,GAN)初探

    1. 从纳什均衡(Nash equilibrium)说起 我们先来 看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句 ...

  7. GAN网络学习入门之:A Beginner's Guide to Generative Adversarial Networks (GANs)-翻译

    译自:https://wiki.pathmind.com/generative-adversarial-network-gan 你可能认为编码者不是艺术家,但是编程是一个极具创意的职业.它是基于逻辑的 ...

  8. 史上最全GAN综述2020版:算法、理论及应用(A Review on Generative Adversarial Networks: Algorithms, Theory, and Applic)

    ** ** 史上最全GAN综述2020版:算法.理论及应用** 论文地址:https://arxiv.org/pdf/2001.06937.pdf ** 摘要:生成对抗网络(GANs)是近年来的一个研 ...

  9. GAN Dissection: Visualizing and Understanding Generative Adversarial Networks

    GAN Dissection: Visualizing and Understanding Generative Adversarial Networks 该论文介绍了一个可视化和理解生成网络学得结构 ...

最新文章

  1. Linux环境变量说明与配置
  2. 解决SQL Server里sp_helptext输出格式错行问题
  3. 数组运用_1-19 编程练习
  4. 人工智能产品化的关键是基础架构和数据,而非算法
  5. 硬盘引导安装windows7系统的方法
  6. python梦幻西游鼠标偏移_PYTHONPYGAME如何向鼠标位置移动和旋转多边形?
  7. 二叉树两个节点的公共节点
  8. 史上最简单的SpringCloud教程 | 第十三篇: 断路器聚合监控(Hystrix Turbine)
  9. openfeign调用 HttpServletRequest作为参数 报错..
  10. Sharepoint学习笔记—ECM系列--3 从.CSV文件导入术语集(Term Sets)
  11. 每天进步一点点013
  12. 对无序的边界点排序(顺时针绘制边界)
  13. 苹果好用的测试软件,Mac上有什么好用的Mac内存检测软件?
  14. win7更新错误代码80072efe怎么解决?
  15. 使用switch的注意
  16. Debian修改时区
  17. 计算机老师开学第一堂课,开学第一堂课作文
  18. 联想微型计算机开机密码忘记了,lenovo台式电脑忘了开机密码简单解决的方法,小孩子就能搞定的...
  19. 数据分析常用知识体系
  20. 数据库设计-逻辑设计

热门文章

  1. 心内科临床护理管理中护士分层培训的应用分析
  2. ABOUT ME/OI回忆录
  3. 报错:exception is java.io.InvalidClassException
  4. iReport3.0.0属性说明
  5. vue拖拽组件插件 vue-draggable-resizable-gorkys
  6. JVM垃圾回收-记忆集和卡表
  7. Go 区块链 Input Data 解析
  8. ATM取款机_存款功能
  9. 编程中的燕尾槽刀具该如何加工设置?
  10. javascript原生实现图片相似度识别算法