Ian Goodfellow在2014年的《Generative Adversarial Nets》中提出了生成对抗网络的概念,具体的思想大家恐怕都看烂了~整个模型的架构可以表示为

目标函数为 min ⁡ G max ⁡ D V ( G , D ) = E x − p data  ( x ) log ⁡ D ( x i ) + E x ∼ p z ( z ) log ⁡ ( 1 − D ( G ( z i ) ) ) \min _{G} \max _{D} V(G, D)=E_{x-p_{\text { data }}(x)} \log D\left(x_{i}\right)+E_{x \sim p_{z}(z)} \log \left(1-D\left(G\left(z_{i}\right)\right)\right) Gmin​Dmax​V(G,D)=Ex−p data ​(x)​logD(xi​)+Ex∼pz​(z)​log(1−D(G(zi​)))

下面来看一下我们用pytorch如何在MNIST数据集上实现GAN,以下的代码来源于pytorch-GAN。

  • 首先引入所需的库文件
## argparse是python用于解析命令行参数和选项的标准模块
# 使用步骤:
# 1 import argparse
# 2 parser = argparse.ArgumentParser()
# 3 parser.add_argument()
# 4 parser.parse_args()
import argparse
import os
import numpy as np
import math# 用于data augmentation
import torchvision.transforms as transforms
# 保存生成图像
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torch
  • 然后设置模型中的某些参数,这里使用了argparse来集中操作
# 如果根目录下不存在images文件夹,则创建images存放生成图像结果
os.makedirs("images", exist_ok=True)# 创建解析对象
parser = argparse.ArgumentParser()
# 向解析对象中添加命令行参数和选项
# epoch = 200,批大小 = 64,学习率 = 0.0002,衰减率 = 0.5/0.999,线程数 = 8,隐码维数 = 100,样本尺寸 = 28 * 28,通道数 = 1,样本间隔 = 400
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
# 解析参数
opt = parser.parse_args()
print(opt)
  • 创建生成器G
#-------------------------
#        生成器
#-------------------------
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):# 这里简单的只对输入数据做线性转换layers = [nn.Linear(in_feat, out_feat)]# 使用BNif normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))# 添加LeakyReLU非线性激活层layers.append(nn.LeakyReLU(0.2, inplace=True))return layers# 创建生成器网络模型self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())# 前向def forward(self, z):# 生成假样本img = self.model(z)img = img.view(img.size(0), *img_shape)# 返回生成图像return img
  • 创建判别器D
#-------------------------
#        判别器
#-------------------------
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),# 因需判别真假,这里使用Sigmoid函数给出标量的判别结果nn.Sigmoid(),)# 判别def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)# 判别结果return validity
  • 损失函数和优化器
# 损失函数:二分类交叉熵函数
adversarial_loss = torch.nn.BCELoss()# 优化器,G和D都使用Adam
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
  • 加载数据集
os.makedirs("../../data/mnist", exist_ok=True)
#------------------------------------------
#      torch.utils.data.DataLoader
#------------------------------------------
# 数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化
#
#torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
#                            batch_sampler=None, num_workers=0, collate_fn=<function default_collate>,
#                            pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
# dataset:加载数据的数据集
# batch_size:每批次加载的数据量
# shuffle:默认false,若为True,表示在每个epoch打乱数据
# sampler:定义从数据集中绘制示例的策略,如果指定,shuffle必须为False
# ...
# 更多可参考: https://pytorch.org/docs/stable/data.html # 设置数据加载器,这里使用MNIST数据集
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)
  • 训练模型
#-----------------------
#      训练模型
#-----------------------
for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)# 输入real_imgs = Variable(imgs.type(Tensor))# -----------------#  训练 G# -----------------optimizer_G.zero_grad()# 采样随机噪声向量z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# 训练得到一批次生成样本gen_imgs = generator(z)# 计算G的损失函数值g_loss = adversarial_loss(discriminator(gen_imgs), valid)# 更新Gg_loss.backward()optimizer_G.step()# ---------------------#  训练 D# ---------------------optimizer_D.zero_grad()# 评估D的判别能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2# 更新Dd_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))#保存结果batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

实验结果

因为只有一块1080Ti,所以这里设置epoch = 50 跑了一下实现,结果如下所示:



pytorch - GAN相关推荐

  1. 图像超分辨率进ASC19超算大赛,PyTorch+GAN受关注

    近日,2019 ASC世界大学生超级计算机竞赛(ASC19)公布了初赛赛题. 来自全球200余所高校的300多支大学生队伍,将在长达两个月的初赛阶段,尝试挑战一项当前热门的人工智能技术--单张图像超分 ...

  2. PyTorch - GAN与WGAN及其实战

    目录 GAN 基本结构 训练 对于生成器 对于判别器 训练流程 训练理论 min max公式 Where will D converge, given fixed G Where will G con ...

  3. Pytorch:GAN生成对抗网络实现二次元人脸的生成

    github:https://github.com/SPECTRELWF/pytorch-GAN-study 网络结构 最近在疯狂补深度学习一些基本架构的基础,看了一下大佬的GAN的原始论文,说实话一 ...

  4. Pytorch:GAN生成对抗网络实现MNIST手写数字的生成

    github:https://github.com/SPECTRELWF/pytorch-GAN-study 个人主页:liuweifeng.top:8090 网络结构 最近在疯狂补深度学习一些基本架 ...

  5. 涵盖18+ SOTA GAN实现,这个图像生成领域的PyTorch库火了

    视学算法报道 转载自:机器之心 作者:杜伟.陈萍 GAN 自从被提出后,便迅速受到广泛关注.我们可以将 GAN 分为两类,一类是无条件下的生成:另一类是基于条件信息的生成.近日,来自韩国浦项科技大学的 ...

  6. 这个图像生成领域的PyTorch库火了,涵盖18+ SOTA GAN实现

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 机器之心报道 近日,在 GitHub 上看到了一个非常有意义的项目 PyTorch-S ...

  7. 赞!图像生成PyTorch库火了,涵盖18+ SOTA GAN实现

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 近 ...

  8. api存在csrf攻击吗_使用rest api防止单页应用上的csrf攻击

    api存在csrf攻击吗 tl;dr - If your SPA uses a private REST API, use CORS and a CSRF Token header. If your ...

  9. 5天玩转PyTorch深度学习,从GAN到词嵌入都有实例丨教程资源

    郭一璞 发自 凹非寺  量子位 报道 | 公众号 QbitAI 学PyTorch深度学习,可能5天就够了. 法国深度学习研究者Marc Lelarge出品的这套名为<Hands-on tour ...

最新文章

  1. 年度盛宴——2012年最精彩的15个 CSS3 教程
  2. mysql 终端模拟_mysql客户端模拟脏读、幻读和可重复读
  3. 微课|中学生可以这样学Python(8.4节):递归算法例题讲解3
  4. 如何使用IEDA连接数据库
  5. SQL那些事儿(十四)--C#调用oracle存储过程(查询非查询and有参无参)深度好文
  6. 《从零开始学Swift》学习笔记(Day 53)——do-try-catch错误处理模式
  7. Visual Studio 2008 Designer.cs不能更新/自动添加控件声明的解决办法
  8. linux pptp服务器安装
  9. 水电远程预付费管理系统
  10. img标签图片自适应
  11. 动态代理实例——增强Waiter接口
  12. 2月9日,30秒知全网,精选7个热点
  13. 2020-12-25 PMP 群内练习题 - 光环
  14. U盘杀毒后文件不见/找回
  15. Yahoo! User Interface Library,YUI,YUI下载,YUI学习,YUI是什么,YUI浅谈,YUI研究(2)
  16. Apache的配置与应用【Apache访问控制】以及apache日志管理【日志分割、awstats日志分析】
  17. 深圳荣耀Java后端一面
  18. 算法与分析-实验一 算法设计基础
  19. Spring Boot整合JWT实现用户认证(附源码)
  20. 华硕驱动问题,(触摸板可以用,可是手势不可以用,很多方法都试了,还是本机自带的东西好用)

热门文章

  1. 《武则天正传》读后感
  2. 好家伙!上天入地混血儿料箱机器人
  3. 最热门的国人开发开源软件 TOP 50
  4. davinci平台Uboot移植
  5. io包下 文件类 字节流 字符流 缓冲流 转换流 序列化流 打印流
  6. hbase 使用lzo_装配HBase LZO
  7. eclipse egit_EGit迁移成为Eclipse的要求吗?
  8. 抢抓双城发展机遇 新川代表团赴渝交流
  9. Word自动生成目录页码靠右对齐
  10. crm组织服务中的xRM消息