系列文章目录

李宏毅作业九 Anomaly Detection异常检测
李宏毅作业八unsupervised无监督聚类学习
李宏毅作业七其三 Network Compression (Network Pruning)
李宏毅作业七其二 Network Compression (Knowledge Distillation)
李宏毅作业七其一 Network Compression (Architecuture Design)
李宏毅作业六 Adversarial Attack对抗攻击

Generative Adversarial Network 生成对抗网络

  • 系列文章目录
  • 前言
  • 一、生成对抗网络
    • 1.生成对抗网络是什么
    • 2.数学公式
  • 二、代码
    • 1.下载数据
    • 2.数据预处理
    • 3.随机种子
    • 4.模型
    • 5.准备训练
    • 6.训练开始
    • 7.使用生成器生成图片
  • 总结

前言

本篇以代码为主,不过多涉及理论。
平台colab,语言python


一、生成对抗网络

1.生成对抗网络是什么

生成对抗网络中包含了两个模型,一个是生成模型G,另一个是判别模型D,下面通过一个生成图片的例子来解释两个模型的作用:

生成模型G:不断学习训练集中真实数据的概率分布,目标是将输入的随机噪声转化为可以以假乱真的图片(生成的图片与训练集中的图片越相似越好)
判别模型D:判断一个图片是否是真实的图片,目标是将生成模型G产生的“假”图片与训练集中的“真”图片分辨开。
GANs的实现方法是让D和G进行博弈,训练过程中通过相互竞争让这两个模型同时得到增强。由于判别模型D的存在,使得 G 在没有大量先验知识以及先验分布的前提下也能很好的去学习逼近真实数据,并最终让模型生成的数据达到以假乱真的效果

2.数学公式


判别模型D和生成模型G均采用多层感知机。GANs定义了一个噪声pz(x) 作为先验,用于学习生成模型G 在训练数据x上的概率分布pg,G(z)表示将输入的噪声z映射成数据(例如生成图片)。D(x)代表x 来自于真实数据分布pdata而不是pg的概率。
理论解释来源

二、代码

1.下载数据

除了下载数据集外,根据自身需要,来下载python套件。

""" Uncomment these lines to mount your own gdrive. """
# from google.colab import drive
# drive.mount('/content/drive')""" You can replace the workspace directory with your gdrive if you want. """
workspace_dir = '.'
# workspace_dir = './drive/My Drive/Machine Learning/hw11 - GAN/colab_tmp'""" Download the dataset. """
!gdown --id 1IGrTr308mGAaCKotpkkm8wTKlWs9Jq-p --output "{workspace_dir}/crypko_data.zip"
  • 解压文件
!unzip -q "{workspace_dir}/crypko_data.zip" -d "{workspace_dir}/"

2.数据预处理

  • 使用torchvision套件存图,使用cv2读取,并将cv2读取的图片(BGR)转换成torchvision格式(RGB)。
  • 将图片输入大小调整为(64,64),方便实验,并将value(0-1)线性转换成(-1~1)。
  • 代码里我注释的很清楚
from torch.utils.data import Dataset, DataLoader
import cv2
import osclass FaceDataset(Dataset):#定义类似私有变量的def __init__(self, fnames, transform):self.transform = transformself.fnames = fnamesself.num_samples = len(self.fnames)def __getitem__(self,idx):fname = self.fnames[idx]#图片索引赋给fnameimg = cv2.imread(fname)#用cv2读取图片img = self.BGR2RGB(img) #because "torchvision.utils.save_image" use RGBimg = self.transform(img)#转换格式,调用下方的transformreturn imgdef __len__(self):return self.num_samplesdef BGR2RGB(self,img):return cv2.cvtColor(img,cv2.COLOR_BGR2RGB)import glob
import torchvision.transforms as transformsdef get_dataset(root):fnames = glob.glob(os.path.join(root, '*'))# resize the image to (64, 64)# linearly map [0, 1] to [-1, 1]线性转换#transform这个我前面的作业解释的很清楚,#这就不再重复解释了,看过一遍就会觉得很简单transform = transforms.Compose([transforms.ToPILImage(),transforms.Resize((64, 64)),transforms.ToTensor(),transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3) ] )dataset = FaceDataset(fnames, transform)return dataset

3.随机种子

  • 使用random seed函数,方便我们固定随机值,方便多次reproduce
import random
import torch
import numpy as np#随机种子
def same_seeds(seed):torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.np.random.seed(seed)  # Numpy module.random.seed(seed)  # Python random module.torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = True

4.模型


这里 DCGAN 作为 baseline mode。 DCGAN 架构示意图,图中数字仅供参考。

  • 这里大致解释了如何在特征Z的基础上生成图片
  • 实质上就是逆卷积和上采样的过程,在上采样的过程中,会用到一些函数,弥补图像在下采样后的细节损失。
import torch.nn as nn
import torch.nn.functional as Fdef weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:#find() 返回字符串第一次出现的索引,如果没有匹配项则返回-1m.weight.data.normal_(0.0, 0.02)#归一化elif classname.find('BatchNorm') != -1:m.weight.data.normal_(1.0, 0.02)m.bias.data.fill_(0)#偏置填充0#生成器模型
class Generator(nn.Module):"""input (N, in_dim)output (N, 3, 64, 64)"""def __init__(self, in_dim, dim=64):super(Generator, self).__init__()#定义解码层的架构,后面会用到#逆卷积神经网络ConvTranspose2d是对图片进行上采样,图片越来越大,卷积与逆卷积是互相对应的。#其实就是还原图像def dconv_bn_relu(in_dim, out_dim):return nn.Sequential(nn.ConvTranspose2d(in_dim, out_dim, 5, 2,padding=2, output_padding=1, bias=False),nn.BatchNorm2d(out_dim),nn.ReLU())self.l1 = nn.Sequential(nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),nn.BatchNorm1d(dim * 8 * 4 * 4),nn.ReLU())self.l2_5 = nn.Sequential(dconv_bn_relu(dim * 8, dim * 4),dconv_bn_relu(dim * 4, dim * 2),dconv_bn_relu(dim * 2, dim),nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),nn.Tanh())#dataframe.apply(func,axis=0) #表示: 默认(默认axis=0)的情况下将dataframe表中的每一列的每一个元素分别作为实参传入函数func, 然后得出结果返回;self.apply(weights_init)def forward(self, x):y = self.l1(x)y = y.view(y.size(0), -1, 4, 4)y = self.l2_5(y)return y#判别器模型
class Discriminator(nn.Module):"""input (N, 3, 64, 64)output (N, )"""def __init__(self, in_dim, dim=64):super(Discriminator, self).__init__()#这里就是正常的编码器def conv_bn_lrelu(in_dim, out_dim):return nn.Sequential(nn.Conv2d(in_dim, out_dim, 5, 2, 2),nn.BatchNorm2d(out_dim),nn.LeakyReLU(0.2))self.ls = nn.Sequential(nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2),conv_bn_lrelu(dim, dim * 2),conv_bn_lrelu(dim * 2, dim * 4),conv_bn_lrelu(dim * 4, dim * 8),nn.Conv2d(dim * 8, 1, 4),nn.Sigmoid())self.apply(weights_init)        def forward(self, x):y = self.ls(x)y = y.view(-1)return y

5.准备训练

  • 设定超参数。准备数据加载,模型,损失标准,优化设定
    hyperparameters, dataloader, model, loss criterion, optimizer。
import torch
from torch import optim
from torch.autograd import Variable
import torchvision# hyperparameters超参数
batch_size = 64
z_dim = 120#特征维度
lr = 5e-5
n_epoch = 12
save_dir = os.path.join(workspace_dir, 'logs')
#os.path.join语法: os.path.join(path1[,path2[,……]])
#返回值:将多个路径组合后返回
os.makedirs(save_dir, exist_ok=True)#os.makedirs(path),
#他可以一次创建多级目录,哪怕中间目录不存在也能正常的(替你)创建# model
G = Generator(in_dim=z_dim).cuda()
D = Discriminator(3).cuda()
G.train()
D.train()# loss criterion
#BCELoss 是CrossEntropyLoss的一个特例,只用于二分类问题,
criterion = nn.BCELoss()# optimizer
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))same_seeds(0)
# dataloader (You might need to edit the dataset path if you use extra dataset.)
#这里可以加载自己想加载的数据
dataset = get_dataset(os.path.join(workspace_dir, 'faces'))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  • 查看一张
# 随便选一张查看
import matplotlib.pyplot as plt
plt.imshow(dataset[10].numpy().transpose(1,2,0))

6.训练开始

# for logging
z_sample = Variable(torch.randn(100, z_dim)).cuda()for e, epoch in enumerate(range(n_epoch)):#遍历for i, data in enumerate(dataloader):imgs = dataimgs = imgs.cuda()#图片格式改为适合GPU处理的张量bs = imgs.size(0)#batch.size大小是图片0维的数值""" Train D 训练判别器"""z = Variable(torch.randn(bs, z_dim)).cuda()r_imgs = Variable(imgs).cuda()f_imgs = G(z)#生成器通过特征Z生成的图片#判别器对原始图像和生成的图片进行判别# label        r_label = torch.ones((bs)).cuda()f_label = torch.zeros((bs)).cuda()# disr_logit = D(r_imgs.detach())f_logit = D(f_imgs.detach())# compute lossr_loss = criterion(r_logit, r_label)f_loss = criterion(f_logit, f_label)loss_D = (r_loss + f_loss) / 2# update modelD.zero_grad()loss_D.backward()opt_D.step()""" train G """# leafz = Variable(torch.randn(bs, z_dim)).cuda()f_imgs = G(z)# disf_logit = D(f_imgs)# compute lossloss_G = criterion(f_logit, r_label)# update modelG.zero_grad()loss_G.backward()opt_G.step()# logprint(f'\rEpoch [{epoch+1}/{n_epoch}] {i+1}/{len(dataloader)} Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}', end='')G.eval()f_imgs_sample = (G(z_sample).data + 1) / 2.0filename = os.path.join(save_dir, f'Epoch_{epoch+1:03d}.jpg')torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)print(f' | Save some samples to {filename}.')# show generated image图片的可视化grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)plt.figure(figsize=(10,10))plt.imshow(grid_img.permute(1, 2, 0))plt.show()G.train()if (e+1) % 5 == 0:torch.save(G.state_dict(), os.path.join(workspace_dir, f'dcgan_g.pth'))torch.save(D.state_dict(), os.path.join(workspace_dir, f'dcgan_d.pth'))

7.使用生成器生成图片

import torch
# load pretrained model加载预训练模型
G = Generator(z_dim)
G.load_state_dict(torch.load(os.path.join(workspace_dir, 'dcgan_g.pth')))
G.eval()
G.cuda()
# generate images and save the result
n_output = 20
z_sample = Variable(torch.randn(n_output, z_dim)).cuda()
imgs_sample = (G(z_sample).data + 1) / 2.0
save_dir = os.path.join(workspace_dir, 'logs')
filename = os.path.join(save_dir, f'result.jpg')
torchvision.utils.save_image(imgs_sample, filename, nrow=10)
# show image
grid_img = torchvision.utils.make_grid(imgs_sample.cpu(), nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

总结

这里的对抗生成网络还是简单一些的,代码还是很容易就看懂的,不过要想实现它,还是需要自己在多敲几下。看懂不代表会用。

李宏毅作业十 Generative Adversarial Network生成对抗网络(代码)相关推荐

  1. 【深度学习】Generative Adversarial Network 生成式对抗网络(GAN)

    文章目录 一.神经网络作为生成器 1.1 什么是生成器? 1.2 为什么需要输出一个分布? 1.3 什么时候需要生成器? 二.Generative Adversarial Network 生成式对抗网 ...

  2. Generative Adversarial Nets 生成对抗网络

    Generative Adversarial Nets 生成对抗网络 论文作者 Yan 跟随论文精读 (bilibili李沐) 同时会训练模型 G,生成模型要对整个数据的分布进行建模,就是想生成 尽量 ...

  3. GAN(Generative Adversarial Nets (生成对抗网络))

    一.GAN 1.应用 GAN的应用十分广泛,如图像生成.图像转换.风格迁移.图像修复等等. 2.简介 生成式对抗网络是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成 ...

  4. Generative Adversarial Networks 生成对抗网络的简单理解

    1. 引言 在对抗网络中,生成模型与判别相竞争,判别模型通过学习确定样本是来自生成模型分布还是原始数据分布.生成模型可以被认为是类似于一组伪造者,试图产生假币并在没有检测的情况下使用它,而判别模型类似 ...

  5. 【GAN ZOO阅读】Generative Adversarial Nets 生成对抗网络 原文翻译 by zk

    Ian J. Goodfellow, Jean Pouget-Abadie ∗ , Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair † ...

  6. 2020李宏毅机器学习笔记-Generative Adversarial Network - Conditional GAN

    目录 摘要 1. Text-to-Image 1.1 Traditional supervised approach 1.2 Conditional GAN 1.3 Conditional GAN - ...

  7. [Python人工智能] 二十九.什么是生成对抗网络GAN?基础原理和代码普及(1)

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CN ...

  8. BEGAN-边界均衡生成对抗网络-代码解读

    当前论文代码 首先注意: 不同点: 该论文的输入是噪音,鉴别器和生成器都是哑铃型结构, 相同点: 输出是一张图片,D都是用真实图像去比对. 已知信息 可见,是从main.py开始训练的.测试的时候,只 ...

  9. 李宏毅作业十二 Transfer Learning(迁移学习)

    系列文章目录 李宏毅作业十 Generative Adversarial Network生成对抗网络(代码) 李宏毅作业九 Anomaly Detection异常检测 李宏毅作业八unsupervis ...

  10. 生成对抗网络(GAN,Generative Adversarial Network)介绍

    生成对抗网络(GAN,Generative Adversarial Network)介绍 flyfish 在无监督学习中,最近的突破有哪些? 看一个GAN的应用 第一张图是用GAN将一副古代女子的画像 ...

最新文章

  1. FD.io/VPP — VPP 的安装部署
  2. 详解:UML类图符号、各种关系说明以及举例
  3. 语音特征提取: MFCC的理解
  4. CDH Kerberos 认证下Kafka 消费方式
  5. Android URI简介
  6. Git master branch has no upstream branch的解决
  7. 请假系统特例规则详细设计
  8. 下载链接|从CAD2004到CAD2022下载安装软件,提升CAD施工图大师一点儿也不难!
  9. C# 简单实现QQ截图功能
  10. 前端开发主流框架整理推荐
  11. 解决最新版搜狗输入法的软键盘快捷键Ctrl + Shift + K和Typora的热键冲突问题
  12. 优秀员工评审表 模板
  13. shader 获取法线_Unity Shader教程 三、法线方向
  14. 优秀logo,最基础的设计技巧(四)
  15. Java Heap - Percolate Up, Percolate Down, and Heapify
  16. java memorystream 包_C#使用MemoryStream类读写内存
  17. 手把手教你如何抵制法国货
  18. 程序员修炼之道(通俗版)——第一章
  19. Linux上ftp的安装
  20. 【无线通信】无线通信系统结构演进(1)

热门文章

  1. (毕业设计资料)基于单片机自行车码表系统设计
  2. java作品欣赏_[Java教程]推荐25个强大的 jQuery 网页布局设计作品欣赏
  3. php避免超卖,thinkphp防止超卖
  4. mybatis case when
  5. JS 的5个不良编码习惯,现在就改掉吧
  6. 左岸读书-你是想读书,还是想读完书?
  7. 左岸语不惊人死不休系列摘录
  8. dojo query 实现Ajax,Dojo Query 详解
  9. [语义分割]CTNet: Context-based Tandem Network for Semantic Segmentation
  10. 8、Flume 日志采集工具