GANs

GANs(生成对抗网络),顾名思义,这个网络第一部分是生成网络,第二部分对抗模型严格来讲是一个判别器;简单来说,就是让两个网络相互竞争,生成网络来生成假的数据,对抗网络通过判别器去判别真伪,最后希望生成器生成的数据能够以假乱真。

可以用下图来简单的看一看这两个过程。

下面我们就来依次介绍。

Discriminator Network

首先我们来讲一下对抗过程,因为这个过程更加简单。

对抗过程简单来说就是一个判断真假的判别器,相当于一个二分类问题,我们输入一张真的图片希望判别器输出的结果是1,输入一张假的图片希望判别器输出的结果是0。这其实已经和原图片的label没有关系了,不管原图片到底是一个多少类别的图片,他们都统一称为真的图片,label是1表示真实的;而生成的假的图片的label是0表示假的。

我们训练的过程就是希望这个判别器能够正确的判出真的图片和假的图片,这其实就是一个简单的二分类问题,对于这个问题可以用我们前面讲过的很多方法去处理,比如logistic回归,深层网络,卷积神经网络,循环神经网络都可以。

Generative Network

接着我们要看看如何生成一张假的图片。首先给出一个简单的高维的正态分布的噪声向量,如上图所示的D-dimensional noise vector,这个时候我们可以通过仿射变换,也就是xw+b将其映射到一个更高的维度,然后将他重新排列成一个矩形,这样看着更像一张图片,接着进行一些卷积、池化、激活函数处理,最后得到了一个与我们输入图片大小一模一样的噪音矩阵,这就是我们所说的假的图片,这个时候我们如何去训练这个生成器呢?就是通过判别器来得到结果,然后希望增大判别器判别这个结果为真的概率,在这一步我们不会更新判别器的参数,只会更新生成器的参数。

如下图所示

以上的过程已经简单的阐述了生成对抗网络的学习过程,如果仍然不太清楚这个过程,下面我们会通过代码来更清晰地展示整个过程。

Code

我们会使用mnist手写数字来做数据集,通过生成对抗网络我们希望生成一些“以假乱真”的手写字体。为了加快训练过程,我们不使用卷积网络来做判别器,我们使用简单的多层网络来进行判别。

Discriminator Network

  1. class discriminator(nn.Module):

  2. def __init__(self):

  3. super(discriminator, self).__init__()

  4. self.dis = nn.Sequential(

  5. nn.Linear(784, 256),

  6. nn.LeakyReLU(0.2),

  7. nn.Linear(256, 256),

  8. nn.LeakyReLU(0.2),

  9. nn.Linear(256, 1),

  10. nn.Sigmoid()

  11. )

  12. def forward(self, x):

  13. x = self.dis(x)

  14. return x

以上这个网络是一个简单的多层神经网络,将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。之所以使用LeakyRelu而不是用ReLU激活函数是因为经过实验LeakyReLU的表现更好。

Generative Network

  1. class generator(nn.Module):

  2. def __init__(self, input_size):

  3. super(generator, self).__init__()

  4. self.gen = nn.Sequential(

  5. nn.Linear(input_size, 256),

  6. nn.ReLU(True),

  7. nn.Linear(256, 256),

  8. nn.ReLU(True),

  9. nn.Linear(256, 784),

  10. nn.Tanh()

  11. )

  12. def forward(self, x):

  13. x = self.gen(x)

  14. return x

输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,然后通过ReLU激活函数,接着进行一个线性变换,再经过一个ReLU激活函数,然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。

Discriminator Train

判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。

首先我们需要定义loss的度量方式和优化器,loss度量使用二分类的交叉熵,优化器注意使用的学习率是0.0003

  1. criterion = nn.BCELoss()

  2. d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)

  3. g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

接着进入训练

  1. img = img.view(num_img, -1) # 将图片展开乘28x28=784

  2. real_img = Variable(img).cuda() # 将tensor变成Variable放入计算图中

  3. real_label = Variable(torch.ones(num_img)).cuda() # 定义真实label为1

  4. fake_label = Variable(torch.zeros(num_img)).cuda() # 定义假的label为0

  5. # compute loss of real_img

  6. real_out = D(real_img) # 将真实的图片放入判别器中

  7. d_loss_real = criterion(real_out, real_label) # 得到真实图片的loss

  8. real_scores = real_out # 真实图片放入判别器输出越接近1越好

  9. # compute loss of fake_img

  10. z = Variable(torch.randn(num_img, z_dimension)).cuda() # 随机生成一些噪声

  11. fake_img = G(z) # 放入生成网络生成一张假的图片

  12. fake_out = D(fake_img) # 判别器判断假的图片

  13. d_loss_fake = criterion(fake_out, fake_label) # 得到假的图片的loss

  14. fake_scores = fake_out # 假的图片放入判别器越接近0越好

  15. # bp and optimize

  16. d_loss = d_loss_real + d_loss_fake # 将真假图片的loss加起来

  17. d_optimizer.zero_grad() # 归0梯度

  18. d_loss.backward() # 反向传播

  19. d_optimizer.step() # 更新参数

我已经把每一步都注释在了代码上,这样更加便于大家阅读,这是一个判别器的训练过程,我们希望判别器能够正确辨别出真假图片。

Generative Train

在生成网络的训练中,我们希望生成一张假的图片,然后经过判别器之后希望他能够判断为真的图片,在这个过程中,我们将判别器固定,将假的图片传入判别器的结果与真实label对应,反向传播更新的参数是生成网络里面的参数,这样我们就可以通过跟新生成网络里面的参数来使得判别器判断生成的假的图片为真,这样就达到了生成对抗的作用。

  1. # compute loss of fake_img

  2. z = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到随机噪声

  3. fake_img = G(z) # 生成假的图片

  4. output = D(fake_img) # 经过判别器得到结果

  5. g_loss = criterion(output, real_label) # 得到假的图片与真实图片label的loss

  6. # bp and optimize

  7. g_optimizer.zero_grad() # 归0梯度

  8. g_loss.backward() # 反向传播

  9. g_optimizer.step() # 更新生成网络的参数

这样我们就写好了一个简单的生成网络,通过不断地训练我们希望能够生成很真的图片。

Result

通过不断训练,我们可以得到下面的图片

这是真实图片

第1幅为第一次生成的噪声图片,之后分别是跑完15次生成的图片,跑完30次,跑完50次,跑完70次,最后一个是跑完100次生成的图片

怎么样,是不是特别神奇,我们居然可以生成一副看着很真的图片,这里我们只是用了简单的多层感知器来生成和判别模型,我们可以用更复杂的卷积神经网络来做同样的事情,代码将和本文的代码放在一起,有兴趣的同学可以自己去看看,然后放几张卷积网络生成的图片

可以发现产生的噪声更少了,训练也更加稳定,主要是里面引入了Batchnormalization,另外gan的训练过程是特别困难的,两个对偶网络相互学习,这个时候有一些训练技巧可以使得训练生成更加稳定。

最后我们来说一下为何Gans能够成为最近20年来机器学习以及深度学习界革命性的发现。这是因为不管是深度学习还是机器学习仍然很大一部分是监督学习,但是创建这么多有label的数据集所需要的人力物力是极大的,同时遇到的新的任务时我们很容易得到原始的没有label的数据集,这是我们需要花大量的时间去给其标定label,所以很多人都认为无监督学习才是机器学习的未来,这个时候Gans的出现为无监督学习提供了有力的支持,这当然引起了学界的大量关注,同时基于Gans的应用也越来越多,业界对其也非常狂热。

最后引用Yan Lecun的话:”它(Gans)为创建无监督学习模型提供了强有力的算法框架,有望帮助我们为 AI 加入常识(common sense)。我们认为,沿着这条路走下去,有不小的成功机会能开发出更智慧的 AI 。”

以上我们简单的介绍了Gans,通过网络实现了手写字体的生成,当然还有更多的变形和应用,有兴趣的同学可以自己阅读相关论文深入了解。

全部代码

简单网络(非卷积),训练快

  1. import torch

  2. import torchvision

  3. import torch.nn as nn

  4. import torch.nn.functional as F

  5. from torchvision import datasets

  6. from torchvision import transforms

  7. from torchvision.utils import save_image

  8. from torch.autograd import Variable

  9. import os

  10. if not os.path.exists('./img'):

  11. os.mkdir('./img')

  12. def to_img(x):

  13. out = 0.5 * (x + 1)

  14. out = out.clamp(0, 1)

  15. out = out.view(-1, 1, 28, 28)

  16. return out

  17. batch_size = 128

  18. num_epoch = 100

  19. z_dimension = 100

  20. # Image processing

  21. img_transform = transforms.Compose([

  22. transforms.ToTensor(),

  23. transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

  24. ])

  25. # MNIST dataset

  26. mnist = datasets.MNIST(

  27. root='./data/', train=True, transform=img_transform, download=True)

  28. # Data loader

  29. dataloader = torch.utils.data.DataLoader(

  30. dataset=mnist, batch_size=batch_size, shuffle=True)

  31. # Discriminator

  32. class discriminator(nn.Module):

  33. def __init__(self):

  34. super(discriminator, self).__init__()

  35. self.dis = nn.Sequential(

  36. nn.Linear(784, 256),

  37. nn.LeakyReLU(0.2),

  38. nn.Linear(256, 256),

  39. nn.LeakyReLU(0.2),

  40.             nn.Linear(256, 1),

  41.             nn.Sigmoid())

  42. def forward(self, x):

  43. x = self.dis(x)

  44. return x

  45. # Generator

  46. class generator(nn.Module):

  47. def __init__(self):

  48. super(generator, self).__init__()

  49. self.gen = nn.Sequential(

  50. nn.Linear(100, 256),

  51. nn.ReLU(True),

  52. nn.Linear(256, 256),

  53.             nn.ReLU(True),

  54.             nn.Linear(256, 784),

  55.             nn.Tanh())

  56. def forward(self, x):

  57. x = self.gen(x)

  58. return x

  59. D = discriminator()

  60. G = generator()

  61. if torch.cuda.is_available():

  62. D = D.cuda()

  63. G = G.cuda()

  64. # Binary cross entropy loss and optimizer

  65. criterion = nn.BCELoss()

  66. d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)

  67. g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

  68. # Start training

  69. for epoch in range(num_epoch):

  70. for i, (img, _) in enumerate(dataloader):

  71. num_img = img.size(0)

  72. # =================train discriminator

  73. img = img.view(num_img, -1)

  74. real_img = Variable(img).cuda()

  75. real_label = Variable(torch.ones(num_img)).cuda()

  76. fake_label = Variable(torch.zeros(num_img)).cuda()

  77. # compute loss of real_img

  78. real_out = D(real_img)

  79. d_loss_real = criterion(real_out, real_label)

  80. real_scores = real_out # closer to 1 means better

  81. # compute loss of fake_img

  82. z = Variable(torch.randn(num_img, z_dimension)).cuda()

  83. fake_img = G(z)

  84. fake_out = D(fake_img)

  85. d_loss_fake = criterion(fake_out, fake_label)

  86. fake_scores = fake_out # closer to 0 means better

  87. # bp and optimize

  88. d_loss = d_loss_real + d_loss_fake

  89. d_optimizer.zero_grad()

  90. d_loss.backward()

  91. d_optimizer.step()

  92. # ===============train generator

  93. # compute loss of fake_img

  94. z = Variable(torch.randn(num_img, z_dimension)).cuda()

  95. fake_img = G(z)

  96. output = D(fake_img)

  97. g_loss = criterion(output, real_label)

  98. # bp and optimize

  99. g_optimizer.zero_grad()

  100. g_loss.backward()

  101. g_optimizer.step()

  102. if (i + 1) % 100 == 0:

  103. print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '

  104. 'D real: {:.6f}, D fake: {:.6f}'.format(

  105. epoch, num_epoch, d_loss.data[0], g_loss.data[0],

  106. real_scores.data.mean(), fake_scores.data.mean()))

  107. if epoch == 0:

  108. real_images = to_img(real_img.cpu().data)

  109. save_image(real_images, './img/real_images.png')

  110. fake_images = to_img(fake_img.cpu().data)

  111. save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))

  112. torch.save(G.state_dict(), './generator.pth')

  113. torch.save(D.state_dict(), './discriminator.pth')

卷积网络版

  1. import torch

  2. import torch.nn as nn

  3. from torch.autograd import Variable

  4. from torch.utils.data import DataLoader

  5. from torchvision import transforms

  6. from torchvision import datasets

  7. from torchvision.utils import save_image

  8. import os

  9. if not os.path.exists('./dc_img'):

  10. os.mkdir('./dc_img')

  11. def to_img(x):

  12. out = 0.5 * (x + 1)

  13. out = out.clamp(0, 1)

  14. out = out.view(-1, 1, 28, 28)

  15. return out

  16. batch_size = 128

  17. num_epoch = 100

  18. z_dimension = 100 # noise dimension

  19. img_transform = transforms.Compose([

  20. transforms.ToTensor(),

  21. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

  22. ])

  23. mnist = datasets.MNIST('./data', transform=img_transform)

  24. dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True,

  25. num_workers=4)

  26. class discriminator(nn.Module):

  27. def __init__(self):

  28. super(discriminator, self).__init__()

  29. self.conv1 = nn.Sequential(

  30. nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28

  31. nn.LeakyReLU(0.2, True),

  32. nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14

  33. )

  34. self.conv2 = nn.Sequential(

  35. nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14

  36. nn.LeakyReLU(0.2, True),

  37. nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7

  38. )

  39. self.fc = nn.Sequential(

  40. nn.Linear(64*7*7, 1024),

  41. nn.LeakyReLU(0.2, True),

  42. nn.Linear(1024, 1),

  43. nn.Sigmoid()

  44. )

  45. def forward(self, x):

  46. '''

  47. x: batch, width, height, channel=1

  48. '''

  49. x = self.conv1(x)

  50. x = self.conv2(x)

  51. x = x.view(x.size(0), -1)

  52. x = self.fc(x)

  53. return x

  54. class generator(nn.Module):

  55. def __init__(self, input_size, num_feature):

  56. super(generator, self).__init__()

  57. self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56

  58. self.br = nn.Sequential(

  59. nn.BatchNorm2d(1),

  60. nn.ReLU(True)

  61. )

  62. self.downsample1 = nn.Sequential(

  63. nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56

  64. nn.BatchNorm2d(50),

  65. nn.ReLU(True)

  66. )

  67. self.downsample2 = nn.Sequential(

  68. nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56

  69. nn.BatchNorm2d(25),

  70. nn.ReLU(True)

  71. )

  72. self.downsample3 = nn.Sequential(

  73. nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28

  74. nn.Tanh()

  75. )

  76. def forward(self, x):

  77. x = self.fc(x)

  78. x = x.view(x.size(0), 1, 56, 56)

  79. x = self.br(x)

  80. x = self.downsample1(x)

  81. x = self.downsample2(x)

  82. x = self.downsample3(x)

  83. return x

  84. D = discriminator().cuda() # discriminator model

  85. G = generator(z_dimension, 3136).cuda() # generator model

  86. criterion = nn.BCELoss() # binary cross entropy

  87. d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)

  88. g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

  89. # train

  90. for epoch in range(num_epoch):

  91. for i, (img, _) in enumerate(dataloader):

  92. num_img = img.size(0)

  93. # =================train discriminator

  94. real_img = Variable(img).cuda()

  95. real_label = Variable(torch.ones(num_img)).cuda()

  96. fake_label = Variable(torch.zeros(num_img)).cuda()

  97. # compute loss of real_img

  98. real_out = D(real_img)

  99. d_loss_real = criterion(real_out, real_label)

  100. real_scores = real_out # closer to 1 means better

  101. # compute loss of fake_img

  102. z = Variable(torch.randn(num_img, z_dimension)).cuda()

  103. fake_img = G(z)

  104. fake_out = D(fake_img)

  105. d_loss_fake = criterion(fake_out, fake_label)

  106. fake_scores = fake_out # closer to 0 means better

  107. # bp and optimize

  108. d_loss = d_loss_real + d_loss_fake

  109. d_optimizer.zero_grad()

  110. d_loss.backward()

  111. d_optimizer.step()

  112. # ===============train generator

  113. # compute loss of fake_img

  114. z = Variable(torch.randn(num_img, z_dimension)).cuda()

  115. fake_img = G(z)

  116. output = D(fake_img)

  117. g_loss = criterion(output, real_label)

  118. # bp and optimize

  119. g_optimizer.zero_grad()

  120. g_loss.backward()

  121. g_optimizer.step()

  122. if (i+1) % 100 == 0:

  123. print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '

  124. 'D real: {:.6f}, D fake: {:.6f}'

  125. .format(epoch, num_epoch, d_loss.data[0], g_loss.data[0],

  126. real_scores.data.mean(), fake_scores.data.mean()))

  127. if epoch == 0:

  128. real_images = to_img(real_img.cpu().data)

  129. save_image(real_images, './dc_img/real_images.png')

  130. fake_images = to_img(fake_img.cpu().data)

  131. save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch+1))

  132. torch.save(G.state_dict(), './generator.pth')

  133. torch.save(D.state_dict(), './discriminator.pth')

参考:

1.https://zhuanlan.zhihu.com/p/27386749

pytorch gans相关推荐

  1. 基于MNIST的GANs实现【Pytorch】

    简述 其实是根据我之前写的两个代码改的.(之前已经有过非常详细的解释了,可以去看看) [GANs入门]pytorch-GANs任务迁移-单个目标(数字的生成) [Gans入门]Pytorch实现Gan ...

  2. 【Gans入门】Pytorch实现Gans代码详解【70+代码】

    简述 由于科技论文老师要求阅读Gans论文并在网上找到类似的代码来学习. 文章目录 简述 代码来源 代码含义概览 代码分段解释 导入包: 设置参数: 给出标准数据: 构建模型: 构建优化器 迭代细节 ...

  3. 理解和创建GANs|使用PyTorch来做深度学习

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者:Venkatesh Tata 编译:ronghuaiyang ...

  4. PyTorch实战GANs

    GANs简介 GANs(Generative Adversarial Networks ),全名又叫做生成式对抗网络,设计者使用的是一种类似于"左右手互博"的思想,所以GANs的作 ...

  5. pytorch深度学习_了解如何使用PyTorch进行深度学习

    pytorch深度学习 PyTorch is an open source machine learning library for Python that facilitates building ...

  6. 2018热点总结:BERT最热,GANs最活跃,每20分钟就有一篇论文...

    作者 | Ross Taylor 译者 | linstancy 整理 | Jane 出品 | AI科技大本营 [导读]本文的作者 Ross Taylor 和 Robert Stojnic 在今年一起启 ...

  7. Pytorch Lightning 完全攻略!

    Datawhale干货 作者:Takanashi@知乎,编辑:极市平台 来源 | https://zhuanlan.zhihu.com/p/353985363 极市导读 作者实践中发现Pytorch- ...

  8. 13个你一定要知道的PyTorch特性

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 编译 | ronghuaiyang 来源 | 人工智能前沿讲习 编辑 ...

  9. PyTorch入门与代码模板

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 PyTorch1.0之后,越来越多的人选择使用PyTorch ...

最新文章

  1. 某厂:有微服务经验者优先!
  2. groovy 兼容 java,升级Groovy 1.7 - 2.1不兼容
  3. 必须要改变这样的生活
  4. 劣势者效应下,B站等短视频平台UP主“卖惨”吸睛又吸金?
  5. 合成小丹(dp+二进制按位或+结论)
  6. Qt 设置当前窗口出现在左右窗口的最前面
  7. 三星电子通信撤出中国!
  8. 4核a5中断linux,Cortex A5 MPcore寄存器TPIDRPRW复位值不为零,造成Linux Kernel不能启动的问题...
  9. Atitit 建设自己的财政体系 attilax总结 1.1. 收入理论 2 1.2. 收入分类 2 1.3. 2 1.4. 非货币收入 2 1.5. 2 1.6. 降低期望 2 1.7.
  10. Rust: trim(),trim_matches()等江南六怪......
  11. TideSec远控免杀学习四(BackDoor-Factory+Avet+TheFatRat)
  12. vivox50支持鸿蒙,【vivoX50Pro评测】轻薄机身内大有玄妙 深挖vivo X50系列技术创新-中关村在线...
  13. java编程规范换行_Java源代码的换行规则
  14. 通过python操作GeoLite2-City.mmdb库将nginx日志访问IP转换为城市写入数据库
  15. 计算机病毒学,计算机病毒学.doc
  16. Python格式控制之九九乘法表打印
  17. 《HarmonyOS开发 - 小凌派-RK2206开发笔记》第3章 应用开发
  18. Android开发之使用贝塞尔曲线实现黏性水珠下拉效果
  19. 上海黄金交易所交易操作基本知识
  20. 实战Nagios NSCA方式监控Linux系统资源使用情况 -- Nagios配置篇 -- 被监控端

热门文章

  1. html广告位代码,一段CSS代码让你的广告位“立起来”
  2. PTA 基础编程题目集 7-17 爬动的蠕虫 C语言
  3. 一步步教你下载centos镜像
  4. 7-3 逆序的三位数(C语言)
  5. 汇编转c语言,如何把汇编语言转换成C语言
  6. java 开发工具_Java开发工具和环境,你了解多少?
  7. HashMap 详解七
  8. js中push和pop的用法
  9. 第一款支持容器和云部署的开源数据库Neo4j 3.0
  10. nginx配置文件说明