GAN介绍

理解GAN的直观方法是从博弈论的角度来理解它。GAN由两个参与者组成,即一个生成器和一个判别器,它们都试图击败对方。生成备从分巾中狄取一些随机噪声,并试图从中生成一些类似于输出的分布。生成器总是试图创建与真实分布没有区别的分布。也就是说,伪造的输出看起来应该是真实的图像。 然而,如果没有显式训练或标注,那么生成器将无法判别真实的图像,并且其唯一的来源就是随机浮点数的张量。

之后,GAN将在博弈中引入另一个参与者,即判别器。判别器仅负责通知生成器其生成的输出看起来不像真实图像,以便生成器更改其生成图像的方式以使判别器确信它是真实图像。 但是判别器总是可以告诉生成器其生成的图像不是真实的,因为判别器知道图像是从生成器生成的。为了解决这个事情,GAN将真实的图像引入博弈中,并将判别器与生成器隔离。现在,判别器从一组真实图像中获取一个图像,并从生成器中获取一个伪图像,而它必须找出每个图像的来源。

最初,判别器什么都不知道,而是随机预测结果。 但是,可以将判别器的任务修改为分类任务。判别器可以将输入图像分类为原始图像或生成图像,这是二元分类。同样,我们训练判别器网络以正确地对图像进行分类,最终,通过反向传播,判别器学会了区分真实图像和生成图像。

代码实例

数据集简介:
本次实验我们选用花卉数据集做图像的生成,本数据集共六类。

模型训练
训练判别器:
对于真图片,输出尽可能是1
对于假图片,输出尽可能是0
训练生成器:
对于假图片,输出尽可能是1
1、训练生成器时,无须调整判别器的参数;训练判别器时,无须调整生成器的参数。
2、在训练判别器时,需要对生成器生成的图片用detach操作进行计算图截断,避免反向传播将梯度传到生成器中。因为在训练判别器时我们不需要训练生成器,也就不需要生成器的梯度。
3、在训练判别器时,需要反向传播两次,一次是希望把真图片判为1,一次是希望把假图片判为0。也可以将这两者的数据放到一个batch中,进行一次前向传播和一次反向传播即可。
4、对于假图片,在训练判别器时,我们希望它输出0;而在训练生成器时,我们希望它输出1.因此可以看到一对看似矛盾的代码 error_d_fake = criterion(output, fake_labels)和error_g = criterion(output, true_labels)。判别器希望能够把假图片判别为fake_label,而生成器则希望能把他判别为true_label,判别器和生成器互相对抗提升。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from dataloader import MyDataset
from model import Generator, Discriminator
import torchvision
import numpy as np
import matplotlib.pyplot as plt
if __name__ == '__main__':LR = 0.0002EPOCH = 1000  # 50BATCH_SIZE = 40N_IDEAS = 100EPS = 1e-10TRAINED = False#path = r'./data/image'train_data = MyDataset(path=path, resize=96, Len=10000, img_type='jpg')train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)torch.cuda.empty_cache()if TRAINED:G = torch.load('G.pkl').cuda()D = torch.load('D.pkl').cuda()else:G = Generator(N_IDEAS).cuda()D = Discriminator(3).cuda()optimizerG = torch.optim.Adam(G.parameters(), lr=LR)optimizerD = torch.optim.Adam(D.parameters(), lr=LR)for epoch in range(EPOCH):tmpD, tmpG = 0, 0for step, x in enumerate(train_loader):x = x.cuda()rand_noise = torch.randn((x.shape[0], N_IDEAS, 1, 1)).cuda()G_imgs = G(rand_noise)D_fake_probs = D(G_imgs)D_real_probs = D(x)p_d_fake = torch.squeeze(D_fake_probs)p_d_real = torch.squeeze(D_real_probs)D_loss = -torch.mean(torch.log(p_d_real + EPS) + torch.log(1. - p_d_fake + EPS))G_loss = -torch.mean(torch.log(p_d_fake + EPS))# D_loss = -torch.mean(torch.log(D_real_probs) + torch.log(1. - D_fake_probs))# G_loss = torch.mean(torch.log(1. - D_fake_probs))optimizerD.zero_grad()D_loss.backward(retain_graph=True)optimizerD.step()optimizerG.zero_grad()G_loss.backward(retain_graph=True)optimizerG.step()tmpD_ = D_loss.cpu().detach().datatmpG_ = G_loss.cpu().detach().datatmpD += tmpD_tmpG += tmpG_tmpD /= (step + 1)tmpG /= (step + 1)print('epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG))# if (epoch+1) % 5 == 0:select_epoch = [1, 5, 10, 20, 50, 80, 100, 150, 200, 400, 500, 800, 999, 1500, 2000, 3000, 4000, 5000, 6000, 8000, 9999]if epoch in select_epoch:
plt.imshow(np.squeeze(G_imgs[0].cpu().detach().numpy().transpose((1, 2, 0))) * 0.5 + 0.5)plt.savefig('./result1/_%d.png' % epoch)torch.save(G, 'G.pkl')torch.save(D, 'D.pkl')

下面是训练多次的效果






完整代码如下:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True),      # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96real_image = Variable(image)real_image = real_image.cuda()if (i + 1) % d_every == 0:optimizer_d.zero_grad()output = Discriminator(real_image)      # 尽可能把真图片判为Trueerror_d_real = criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises).detach()       # 根据噪声生成假图fake_output = Discriminator(fake_img)       # 尽可能把假图片判为Falseerror_d_fake = criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i + 1) % g_every == 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises)        # 这里没有detachfake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为Trueerror_g = criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags = Generator(fix_noises)fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5# x = torch.rand(64, 3, 96, 96)fig = plt.figure(1)i = 1for image in fix_fake_imags:ax = fig.add_subplot(8, 8, eval('%d' % i))# plt.xticks([]), plt.yticks([])  # 去除坐标轴plt.axis('off')plt.imshow(image.permute(1, 2, 0))i += 1plt.subplots_adjust(left=None,  # the left side of the subplots of the figureright=None,  # the right side of the subplots of the figurebottom=None,  # the bottom of the subplots of the figuretop=None,  # the top of the subplots of the figurewspace=0.05,  # the amount of width reserved for blank space between subplotshspace=0.05)  # the amount of height reserved for white space between subplots)plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)plt.savefig("images/%dcgan.png" % num)if __name__ == '__main__':transform = tv.transforms.Compose([tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数])dataset = tv.datasets.ImageFolder(dir, transform=transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'print('数据加载完毕!')Generator = NetGenerator()Discriminator = NetDiscriminator()optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))criterion = torch.nn.BCELoss()true_labels = Variable(torch.ones(batch_size))     # batch_sizefake_labels = Variable(torch.zeros(batch_size))fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布if torch.cuda.is_available() == True:print('Cuda is available!')Generator.cuda()Discriminator.cuda()criterion.cuda()true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()fix_noises, noises = fix_noises.cuda(), noises.cuda()plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]for i in range(3000):        # 最大迭代次数train()print('迭代次数:{}'.format(i))if i in plot_epoch:show(i)

生成对抗网络(GAN)详解与实例相关推荐

  1. 生成对抗网络入门详解及TensorFlow源码实现--深度学习笔记

    生成对抗网络入门详解及TensorFlow源码实现–深度学习笔记 一.生成对抗网络(GANs) 生成对抗网络是一种生成模型(Generative Model),其背后最基本的思想就是从训练库里获取很多 ...

  2. 【深度学习】GAN生成对抗网络原理详解(1)

    一个 GAN 框架,最少(但不限于)拥有两个组成部分,一个是生成模型 G,一个是判别模型 D.在训练过程中,会把生成模型生成的样本和真实样本随机地传送一张(或者一个 batch)给判别模型 D.判别模 ...

  3. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  4. 万字详解什么是生成对抗网络GAN

    摘要:这篇文章将详细介绍生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN).发展历程.预备知识,并通过Keras搭建最简答的手写数字图片生成案. ...

  5. [Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:[Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleG ...

  6. [人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  7. 生成对抗网络gan原理_中国首个“芯片大学”即将落地;生成对抗网络(GAN)的数学原理全解...

    开发者社区技术周刊又和大家见面了,萌妹子主播为您带来第三期"开发者技术联播".让我们一起听听,过去一周有哪些值得我们开发者关注的重要新闻吧. 中国首个芯片大学,南京集成电路大学即将 ...

  8. pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)

    1.简介 这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码.数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G.生成器与判 ...

  9. 生成对抗网络(GAN)简单梳理

    作者:xg123321123 - 时光杂货店 出处:http://blog.csdn.net/xg123321123/article/details/78034859 声明:版权所有,转载请联系作者并 ...

  10. 【通知】《生成对抗网络GAN原理与实践》代码开源,勘误汇总!

    有三上个月出版了新书<生成对抗网络GAN:原理与实践>,Generative Adversarial Networks(中文名生成对抗网络,简称GAN)自从被提出来后,其发展就非常迅猛,几 ...

最新文章

  1. 让Bootstrap 3兼容IE8浏览器
  2. 【美文】接受生活的无力感,才能更好的出发
  3. .net 面试题系列文章二(附答案)
  4. 机器学习-聚类之K均值(K-means)算法原理及实战
  5. httpclient java多线程_Apache HttpClient4.5多个HTTP请求使用多线程执行
  6. Atitit 读取文本文件内容功能的实现 艾提拉 总结 attilax总结 1.1. FileUtilsAti.readFileToStringAutoDetectEncode(txtF); 1 1
  7. Uiautomator 2.0之UiObject2类学习小记
  8. “select count (1)”是什么意思?
  9. Graphql 初体验 第十一章 | #13 Hitting the API(实现了登录注册表单)
  10. 2022联想小新pro14和联想小新pro16 区别 哪个好
  11. 戴尔台式机修复计算机,dell电脑win10自动修复你的电脑未正确启动怎么修复
  12. Rocket的启动流程
  13. 最新综述 | 皮层内外无线神经信号记录系统为脑机接口技术注入全新血液
  14. “让你更值钱”的八个项目管理习惯
  15. nw362 linux 驱动下载,支持OpenGL 3.2 NVIDIA全新Linux驱动官方发布
  16. 让电脑在局域网中隐身
  17. 基于airtest的安卓ui自动化实践
  18. 计算机等级一考通2021,驾校驾考一点通2021最新版电脑版
  19. js返回计算机ip地址吗,js获取电脑IP地址???电脑连WIFI的
  20. 识别无效对象和不可用对象

热门文章

  1. 政府大数据之数据治理
  2. 一大波科研交流群出现,种类齐全,名额有限,请大家抓紧入坑!
  3. 数学建模之2016国赛A题程序(来源于cclplus)
  4. 每天提醒自己学习的软件?每天提醒学习任务怎么设置?
  5. 服务器创建和附加虚拟磁盘,详解Hyper-V创建虚拟磁盘存储配置攻略
  6. 关于数字货币到底是啥?而区块链又是什么呢
  7. 你知道ai文字绘画生成的软件有哪些吗?我来分享三个实用的软件
  8. Grafana重置admin密码方法
  9. Unity delegate
  10. 项目管理~会议成本=每小时平均工资的3倍*2*开会人数*会议小时数