目录

DCGAN理论讲解

DCGAN的改进:

DCGAN的设计技巧

DCGAN纯代码实现

导入库

导入数据和归一化

定义生成器

定义鉴别器

初始化和 模型训练

运行结果


DCGAN理论讲解

DCGAN也叫深度卷积生成对抗网络,DCGAN就是将CNN与GAN结合在一起,生成模型和判别模型都运用了深度卷积神经网络的生成对抗网络。

DCGAN将GAN与CNN相结合,奠定了之后几乎所有GAN的基本网络架构。DCGAN极大地提升了原始GAN训练的稳定性以及生成结果的质量

DCGAN主要是在网络架构上改进了原始的GAN,DCGAN的生成器与判别器都利用CNN架构替换了原始GAN的全连接网络,主要改进之处有如下几个方面,

DCGAN的改进:

(1)DCGAN的生成器和判别器都舍弃了CNN的池化层,判别器保留CNN的整体架构,生成器则是将卷积层替换成了反卷积层。

(2)在判别器和生成器中使用了BatchNormalization(BN)层,这里有助于处理初始化不良导致的训练问题,加速模型训练提升训练的稳定性。要注意,在生成器的输出层和判别器的输入层不使用BN层。

(3)在生成器中除输出层使用Tanh()激活函数,其余层全部使用Relu激活函数,在判别器中,除输出层外所有层都使用LeakyRelu激活函数,防止梯度稀疏

自己画的,凑合着看吧/(*/ω\*)捂脸/

DCGAN的设计技巧

一,取消所有pooling层,G网络中使用转置卷积进行上采样,D网络中加入stride的卷积(为防止梯度稀疏)代替pooling

二,去掉FC层(全连接),使网络变成全卷积网络

三,G网络中使用Relu作为激活函数,最后一层用Tanh

四,D网络中使用LeakyRelu激活函数

五,在generator和discriminator上都使用batchnorm,解决初始化差的问题,帮助梯度传播到每一层,防止generator把所有的样本都收敛到同一点。直接将BN应用到所有层会导致样本震荡和模型不稳定,因此在生成器的输出层和判别器的输入层不使用BN层,可以防止这种现象。

六,使用Adam优化器

七,参数设置参考:LeakyRelu的斜率是0.2;Learing rate = 0.0002;batch size是128.

DCGAN纯代码实现

导入库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim #优化
import numpy as np
import matplotlib.pyplot as plt #绘图
import torchvision #加载图片
from torchvision import transforms #图片变换

导入数据和归一化

#对数据做归一化(-1,1)
transform=transforms.Compose([#将shanpe为(H,W,C)的数组或img转为shape为(C,H,W)的tensortransforms.ToTensor(), #转为张量并归一化到【0,1】;数据只是范围变了,并没有改变分布transforms.Normalize(mean=0.5,std=0.5)#数据归一化处理,将数据整理到[-1,1]之间;可让数据呈正态分布
])
#下载数据到指定的文件夹
train_ds = torchvision.datasets.MNIST('data/',train=True,transform=transform,download=True)
#数据的输入部分
train_dl=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)

定义生成器

使用长度为100的noise作为输入,也可以使用torch.randn(batchsize,100,1,1)

class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.linear1 = nn.Linear(100,256*7*7)self.bn1=nn.BatchNorm1d(256*7*7)self.deconv1 = nn.ConvTranspose2d(256,128,kernel_size=(3,3),stride=1,padding=1)  #生成(128,7,7)的二维图像self.bn2=nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128,64,kernel_size=(4,4),stride=2,padding=1)  #生成(64,14,14)的二维图像self.bn3=nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64,1,kernel_size=(4,4),stride=2,padding=1)  #生成(1,28,28)的二维图像def forward(self,x):x=F.relu(self.linear1(x))x=self.bn1(x)x=x.view(-1,256,7,7)x=F.relu(self.deconv1(x))x=self.bn2(x)x=F.relu(self.deconv2(x))x=self.bn3(x)x=torch.tanh(self.deconv3(x))return x

定义鉴别器

class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.conv1 = nn.Conv2d(1,64,kernel_size=3,stride=2)self.conv2 = nn.Conv2d(64,128,kernel_size=3,stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6,1)def forward(self,x):x= F.dropout2d(F.leaky_relu(self.conv1(x)))x= F.dropout2d(F.leaky_relu(self.conv2(x)) )  #(batch,128,6,6)x = self.bn(x)x = x.view(-1,128*6*6) #展平x = torch.sigmoid(self.fc(x))return x

初始化和 模型训练

#设备的配置
device='cuda' if torch.cuda.is_available() else 'cpu'
#初化生成器和判别器把他们放到相应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)
#交叉熵损失函数
loss_fn = torch.nn.BCELoss()
#训练器的优化器
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-5)
#训练生成器的优化器
g_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-4)
def generate_and_save_images(model,epoch,test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4,4))for i in range(prediction.shape[0]):plt.subplot(4,4,i+1)plt.imshow((prediction[i]+1)/2,cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.show()
test_input = torch.randn(16,100 ,device=device) #16个长度为100的随机数
D_loss = []
G_loss = []
#训练循环
for epoch in range(30):#初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(train_dl.dataset) #返回批次数#对数据集进行迭代for step,(img,_) in enumerate(train_dl):img =img.to(device) #把数据放到设备上size = img.shape[0] #img的第一位是size,获取批次的大小random_seed = torch.randn(size,100,device=device)#判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化d_optimizer.zero_grad()#梯度归零#判别器对于真实图片产生的损失real_output = dis(img) #判别器输入真实的图片,real_output对真实图片的预测结果d_real_loss = loss_fn(real_output,torch.ones_like(real_output,device=device))d_real_loss.backward()#计算梯度#在生成器上去计算生成器的损失,优化目标是判别器上的参数generated_img = gen(random_seed) #得到生成的图片#因为优化目标是判别器,所以对生成器上的优化目标进行截断fake_output = dis(generated_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了#判别器在生成图像上产生的损失d_fake_loss = loss_fn(fake_output,torch.zeros_like(fake_output,device=device))d_fake_loss.backward()#判别器损失disc_loss = d_real_loss + d_fake_loss#判别器优化d_optimizer.step()#生成器上损失的构建和优化g_optimizer.zero_grad() #先将生成器上的梯度置零fake_output = dis(generated_img)gen_loss = loss_fn(fake_output,torch.ones_like(fake_output,device=device))  #生成器损失gen_loss.backward()g_optimizer.step()#累计每一个批次的losswith torch.no_grad():D_epoch_loss +=disc_lossG_epoch_loss +=gen_loss#求平均损失with torch.no_grad():D_epoch_loss /=countG_epoch_loss /=countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)generate_and_save_images(gen,epoch,test_input)print('Epoch:',epoch)

运行结果

因篇幅有限,这里展示第一张和最后一张,这里我训练了30个epoch,有条件的可以多训练几次,训练越多效果越明显哦


希望我的文章能对你有所帮助。欢迎

DCGAN理论讲解及代码实现相关推荐

  1. 扩展卡尔曼滤波(EKF)理论讲解与实例(matlab、python和C++代码)

    扩展卡尔曼滤波(EKF)理论讲解与实例(matlab.python和C++代码) 文章目录 扩展卡尔曼滤波(EKF)理论讲解与实例(matlab.python和C++代码) 理论讲解 KF和EKF模型 ...

  2. MSCKF理论推导与代码解析

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 在SLAM后端中,主要有两种主流方法用于优化:基于滤波的方法和基于非线性的方法.基于滤波的方法主要有M ...

  3. 机器学习和特征工程理论与python代码实现 晓物智联

    文章来源于:http://www.52phm.cn/blog/detail/23 最初来源于本人的kesci专栏 课题:特征工程理论及代码实现 日期:2019.9.21 作者:小知同学 描述:本篇比较 ...

  4. 无迹(损)卡尔曼滤波(UKF)理论讲解与实例

    无迹(损)卡尔曼滤波(UKF)理论讲解与实例 文章目录 无迹(损)卡尔曼滤波(UKF)理论讲解与实例 理论讲解 模型对比 UT变换 UKF算法步骤 预测部分 更新部分 应用实例 CTRV模型 预测处理 ...

  5. C#冒泡排序原理讲解及代码块

    C#冒泡排序原理讲解及代码块 一.冒泡排序理论 (1)基本概念由于在排序过程中总是小数往前放,大数往后放,相当于气泡往上升,所以称作冒泡排序.冒泡排序的时间复杂度为O(n*n). (2)逻辑分析依次比 ...

  6. 【总结】关于玻尔兹曼机(BM)、受限玻尔兹曼机(RBM)、深度玻尔兹曼机(DBM)、深度置信网络(DBN)理论总结和代码实践

    近期学习总结 前言 玻尔兹曼机(BM) 波尔兹曼分布推导过程 吉布斯采样 受限玻尔兹曼机(RBM) 能量函数 CD学习算法 代码实现受限玻尔兹曼机 深度玻尔兹曼机(DBM) 代码实现深度玻尔兹曼机 深 ...

  7. ML之FE:特征工程中常用的五大数据集划分方法(特殊类型数据分割,如时间序列数据分割法)讲解及其代码

    ML之FE:特征工程中常用的五大数据集划分方法(特殊类型数据分割,如时间序列数据分割法)讲解及其代码 目录 特殊类型数据分割 5.1.时间序列数据分割TimeSeriesSplit 特殊类型数据分割 ...

  8. 基于python的随机森林回归实现_随机森林理论与python代码实现

    1,初品随机森林 随机森林,森林就是很多决策树放在一起一起叫森林,而随机体现在数据集的随机采样中和特征的随机选取中,具体下面再讲.通俗的说随机森林就是建立多颗决策树(CART),来做分类(回归),以多 ...

  9. 【MATLAB】无人驾驶车辆的模型预测控制技术(精简讲解和代码)【运动学轨迹规划】

    文章目录 0.友情链接 1.引言 2.预测模型 3.滚动优化 3.1.线性化 3.2.UrU_rUr​的求取 3.3.离散化与序列化 3.4.实现增量控制 4.仿真示例 0.友情链接 B站链接1-北京 ...

最新文章

  1. Java新手看招 常用开发工具介绍
  2. mysql max_allowed_packet 参数 限制接受的数据包大小
  3. Minio分布式集群部署注意事项
  4. MYSQL数据库索引设计的原则
  5. 教室信息管理系统mysql_教师信息管理系统(方式一:数据库为oracle数据库;方式二:存储在文件中)...
  6. CVPR 2019 | Adobe提出新型超分辨率方法:用神经网络迁移参照图像纹理
  7. 数据科学 IPython 笔记本 9.1 NumPy
  8. java fastjson vector转为json_java对象与Json字符串之间的转化(fastjson)
  9. 黑马程序员传智播客 python生成器 学习笔记
  10. 服务器配置文档模板,服务器配置模板
  11. 安装金山词霸2007
  12. houdini flowmap
  13. FMCW毫米波雷达中CFAR研究初探(附Python代码)
  14. 惠普服务器关机自动重启,HP笔记本关机自动重启的解决办法
  15. 音频加速 foobar_如何使用Foobar2000将音频CD翻录到FLAC
  16. 解决台式机前耳机插孔没有声音
  17. ps cc2019 安装教程
  18. 判断当前时间为本月的第几周,本周的第几天
  19. 《OD学Oozie》20160807Oozie
  20. 七种寻址方式(基址加变址寻址方式)

热门文章

  1. dedecms cookies泄漏导致SQL漏洞 article_add.php 的解决方法
  2. javascript实现生成并下载txt文件
  3. mysql序列表,自增序列生成合同编号
  4. php中isnumeric,$.isNumeric(value)
  5. sklearn.decomposition.LatentDirichletAllocation接口详解
  6. 商业模式得到肯定 投资ofo朱啸虎信心满满
  7. QT实时视频播放界面设计
  8. SpringBoot Dao层常用注解
  9. 计算两个人相爱的天数
  10. Flutter 新闻客户端 - 17 headless strapi centos 发布部署 + jmeter 压测