【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

  • 概述
  • GAN 网络结构
  • GAN 训练流程
  • 模型详解
    • 生成器
    • 判别器

概述

GAN (Generative Adversarial Network) 即生成对抗网络. GAN 网络包括一个生成器 (Generator) 和一个判别器 (Discriminator). GAN 可以自动提取特征, 并判断和优化.

GAN 网络结构

生成器 (Generator) 输入一个向量, 输出手写数字大小的像素图像.

判别器 (Discriminator) 输入图片, 判断图片是来自数据集还是来自生成器的, 输出标签 (Real / Fake)

GAN 训练流程


第一阶段:

  • 固定判别器, 训练生成器: 使得生成器的技能不断提升, 骗过判别器

第二阶段:

  • 固定生成器, 训练判别器: 使得判别器的技能不断提升, 生成器无法骗过判别器

然后:

  • 循环第一阶段和第二阶段, 使得生成器和判别器都越来越强

模型详解

生成器

class Generator(nn.Module):"""生成器"""def __init__(self, latent_dim, img_shape):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):"""block:param in_feat: 输入的特征维度:param out_feat: 输出的特征维度:param normalize: 归一化:return: block"""layers = [nn.Linear(in_feat, out_feat)]# 归一化if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))# 激活layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(# [b, 100] => [b, 128]*block(latent_dim, 128, normalize=False),# [b, 128] => [b, 256]*block(128, 256),# [b, 256] => [b, 512]*block(256, 512),# [b, 512] => [b, 1024]*block(512, 1024),# [b, 1024] => [b, 28 * 28 * 1] => [b, 784]nn.Linear(1024, int(np.prod(img_shape))),# 激活nn.Tanh())def forward(self, z, img_shape):# [b, 100] => [b, 784]img = self.model(z)# [b, 784] => [b, 1, 28, 28]img = img.view(img.size(0), *img_shape)# 返回生成的图片return img

网络结构:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1                  [-1, 128]          12,928LeakyReLU-2                  [-1, 128]               0Linear-3                  [-1, 256]          33,024BatchNorm1d-4                  [-1, 256]             512LeakyReLU-5                  [-1, 256]               0Linear-6                  [-1, 512]         131,584BatchNorm1d-7                  [-1, 512]           1,024LeakyReLU-8                  [-1, 512]               0Linear-9                 [-1, 1024]         525,312BatchNorm1d-10                 [-1, 1024]           2,048LeakyReLU-11                 [-1, 1024]               0Linear-12                  [-1, 784]         803,600Tanh-13                  [-1, 784]               0
================================================================
Total params: 1,510,032
Trainable params: 1,510,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 5.76
Estimated Total Size (MB): 5.82
----------------------------------------------------------------

判别器

class Discriminator(nn.Module):"""判断器"""def __init__(self, img_shape):super(Discriminator, self).__init__()self.model = nn.Sequential(# 就是个线性回归nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):# 压平img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity

网络结构:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1                  [-1, 512]         401,920LeakyReLU-2                  [-1, 512]               0Linear-3                  [-1, 256]         131,328LeakyReLU-4                  [-1, 256]               0Linear-5                    [-1, 1]             257Sigmoid-6                    [-1, 1]               0
================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.04
Estimated Total Size (MB): 2.05

【万物皆可 GAN】生成对抗网络生成手写数字 Part 1相关推荐

  1. 【DCGAN】生成对抗网络,手写数字识别

    基于paddle,aistudio的DCGAN 主要用于记录自己学习经历. 1   导入必要的包 import os import random import paddle import paddle ...

  2. 生成对抗网络生成多维数据集_生成没有数据集的新颖内容

    生成对抗网络生成多维数据集 介绍(Introduction) GAN architecture has been the standard for generating content through ...

  3. 利用生成对抗网络生成海洋塑料合成图像

    问题陈述 过去十年来,海洋塑料污染一直是气候问题的首要问题.海洋中的塑料不仅能够通过勒死或饥饿杀死海洋生物,而且也是通过捕获二氧化碳使海洋变暖的一个主要因素. 近年来,非营利组织海洋清洁组织(Ocea ...

  4. pytorch生成对抗网络生成动漫图像

    代码地址:pytorch实战,使用生成对抗网络生成动漫图像 dataset from torchvision import transforms from torch.utils.data impor ...

  5. CNN网络实现手写数字(MNIST)识别 代码分析

    CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import num ...

  6. 基于Python的BP网络实现手写数字识别

    资源下载地址:https://download.csdn.net/download/sheziqiong/86790047 资源下载地址:https://download.csdn.net/downl ...

  7. 基于改进型生成对抗网络生成异构故障样本的方法

    文章地址:A Modified Generative Adversarial Network for Fault Diagnosis in High-Speed Train Components wi ...

  8. 生物神经网络与机器学习的碰撞,Nature论文提出DNA试管网络识别手写数字

    近日,来自加州理工学院的研究人员开发出一种由 DNA 制成的新型人工神经网络.该网络解决了一个经典的机器学习问题:正确识别手写数字.该项研究中,研究者用了 36 个手写数字 6 和 7 作为测试例子, ...

  9. “万物皆可Seq2Seq” | 忠于原文的T5手写论文翻译

    <Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer> 摘要 / Abstr ...

最新文章

  1. Unable to instantiate Action,
  2. 非常好的JavaScript学习资源推荐
  3. 在MAC系统的eclipse里打开android sdk manager
  4. 响应式系统reactive system初探
  5. 什么是 SAP Spartacus FacadeFactoryService 中的 Resolver
  6. Hibernate bean 对象配制文件
  7. 如何简单区分web前后端与MVC框架
  8. java mysql websocket_javaweb-ajax-websocket-mysql
  9. Linux make menuconfig打开失败
  10. Linux下编译(安装)程序、编译库整理
  11. jquery 添加可操作,编辑不可操作
  12. hadoop ha环境下的datanode启动报错java.lang.NumberFormatException: For input string: 10m
  13. CCF NOI1097 数列
  14. 全球虚拟化服务器排行榜,全球云服务器厂商排名
  15. vue 加headers_vue上传图片设置headers表头信息
  16. 计算机英语六级时间,计算机一级考试_6月英语六级报名时间
  17. 冰冻三尺,非一日之寒。数据解析——bs4
  18. 19年程序员薪酬报告:平均年薪超70万,40岁后普遍遭遇收入天花板
  19. 梯度下降算法_梯度下降算法的工作原理
  20. HTML onmouseover, onmouseout , onmousemove 事件属性

热门文章

  1. ue4 改变枢轴位置_houdini+ue4道路(2):思路
  2. 下载文件时,文件名称乱码问题解决方法
  3. 一个女孩从十楼跳下所看到的...
  4. 2023 404 收音机动画HTML源码
  5. 【C++】忽略逗号或者自定义符号输入
  6. Python对字典列表多维数组排序
  7. Verilog中generate语句的用法
  8. 机械转行哪些行业容易上手?机械转行it是个坑嘛?
  9. MD5加密概述,原理以及实现
  10. c语言crypt,通过C中的文件通过crypt函数进行身份验证(分段错误)