gan网络原理如下:

mnist手写字体实战:

import torch
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from torch import nn
from torch.autograd import Variable
from torch import optim
import ostransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])mnist = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=100, shuffle=True)class Dnet(nn.Module):def __init__(self):super(Dnet, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, padding=2),  # batch, 6, 30,30nn.LeakyReLU(0.2, True),nn.MaxPool2d(2, stride=2),  # batch, 6, 15, 15)self.conv2 = nn.Sequential(nn.Conv2d(6, 12, 3, padding=2),  # batch, 12, 17, 17nn.LeakyReLU(0.2, True),nn.MaxPool2d(2, stride=2)  # batch, 12, 8, 8)self.fc = nn.Sequential(nn.Linear(12 * 8 * 8, 1024),nn.LeakyReLU(0.2, True),nn.Linear(1024, 1),nn.Sigmoid())# x.shape:[100,1,28,28]def forward(self, x):'''x: batch, width, height, channel=1'''x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)               # 将第二次卷积的输出拉伸为一行x = self.fc(x)x = x.squeeze(-1) # x.shape:[100,]return xclass Gnet(nn.Module):def __init__(self):super(Gnet, self).__init__()self.fc = nn.Linear(128, 784)  # batch, 1,28,28self.br = nn.Sequential(nn.BatchNorm2d(1),nn.ReLU(True))self.downsample1 = nn.Sequential(nn.Conv2d(1, 12, 3, stride=1, padding=1),  # batch, 12, 28, 28nn.BatchNorm2d(12),nn.ReLU(True))self.downsample2 = nn.Sequential(nn.Conv2d(12, 6, 3, stride=1, padding=1),  # batch, 6, 28, 28nn.BatchNorm2d(6),nn.ReLU(True))self.downsample3 = nn.Sequential(nn.Conv2d(6, 1, 3, stride=1, padding=1),  # batch, 1, 28, 28nn.Tanh())# x.shape:[100,128]def forward(self, x):x = self.fc(x) # # x.shape:[100,784]x = x.view(x.size(0), 1, 28, 28) # x.shape:[100,1,28,28]x = self.br(x) # x.shape:[100,1,28,28]x = self.downsample1(x) # x.shape:[100,12,28,28]x = self.downsample2(x) # x.shape:[100,6,28,28]x = self.downsample3(x) # x.shape:[100,1,28,28]return xdef to_img(x):y = (x + 1) * 0.5y = y.clamp(0, 1)y = y.view(-1, 1, 28, 28)return yclass Net:def __init__(self):self.dnet = Dnet()self.gnet = Gnet()self.dnet = self.dnet.cuda()self.gnet = self.gnet.cuda()self.Loss = nn.BCELoss()self.d_optimizer = optim.Adam(self.dnet.parameters(), lr=0.0002)self.g_optimizer = optim.Adam(self.gnet.parameters(), lr=0.0002)def forward(self, real_x, fack_x):self.real_d_out = self.dnet(real_x) # 将真样本输入到判别器,得到判别结果self.real_d_outg_out = self.gnet(fack_x) # 将噪声输入生成器产生假样本,g_out.shape:[100,1,28,28]self.g_d_out = net.dnet(g_out.detach()) #将假样本输入判别器,得到判别结果self.g_d_outdef backward(self, pos_y, nega_y, fack_xs):# 以下几行的目地是训练判别器,使得判别器遇到真实样本就给1,遇到假样本就给0;d_out_loss = self.Loss(self.real_d_out, pos_y) # 将真样本判别结果self.real_d_out和真实标签求lossg_d_loss = self.Loss(self.g_d_out, nega_y) # 将假样本判别结果self.g_d_out和假标签求lossself.d_loss = d_out_loss + g_d_lossself.d_optimizer.zero_grad()self.d_loss.backward(retain_graph = True)self.d_optimizer.step()# 以下几行的目的是训练生成器,使得生成器产生的假样本越来越接近真实;self.fack_g_out = self.gnet(fack_xs) # 将噪声输入生成器产生假样本self.fack_g_d_out = self.dnet(self.fack_g_out) #将假样本输入判别器,得到判别结果self.fack_g_d_outself.g_loss = self.Loss(self.fack_g_d_out, pos_y) # 将假样本的判别结果与正标签对比求loss,意思是让假样本越来越接近真实;self.g_optimizer.zero_grad()self.g_loss.backward()self.g_optimizer.step()if __name__ == '__main__':if not os.path.exists('img'):os.mkdir('img')net = Net()for i in range(100):for x, y in dataloader:# x = x.view(x.size(0),-1)real_x = Variable(x).cuda() # x [100,1,28,28] ,real_x:真样本输入fack_x = Variable(torch.randn(100, 128)).cuda() #fack_x:噪声  torch.randn(100, 128)标准正态分布pos_y = Variable(torch.ones(100)).cuda() # 真样本标签nega_y = Variable(torch.zeros(100)).cuda() # 假样本标签fack_xs = Variable(torch.randn(100, 128)).cuda() # fack_xs:噪声net.forward(real_x, fack_x) # 前向推理,输入real_x, fack_x,即真样本输入和噪声net.backward(pos_y, nega_y, fack_xs) #反向传播过程img = to_img(net.fack_g_out.data)D_Accuracy = ((net.real_d_out.mean() + 1 - net.fack_g_d_out.mean()) / 2).item()print(net.d_loss.item(), net.g_loss.item(), D_Accuracy)save_image(img, './img/fake_images-{}.png'.format(i + 1))

gan网络原理(通俗)+minist手写字体实战相关推荐

  1. 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

    深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...

  2. 《TensorFlow实例一 MINIST手写字体识别》

    Ubuntu python3 TensorFlow实例:使用RNN算法实现对MINST-data数字集识别,最终识别准确率达96.875% PS:小白一个,初级阶段,从调试到实现,step by st ...

  3. pytorch深度学习神经网络实现手写字体识别

    利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下: 其具体实现代码如下所示:import torch ...

  4. 基于CNN的MINIST手写数字识别项目代码以及原理详解

    文章目录 项目简介 项目下载地址 项目开发软件环境 项目开发硬件环境 前言 一.数据加载的作用 二.Pytorch进行数据加载所需工具 2.1 Dataset 2.2 Dataloader 2.3 T ...

  5. 计算机视觉ch8 基于LeNet的手写字体识别

    文章目录 原理 LeNet的简单介绍 Minist数据集的特点 Python代码实现 原理 卷积神经网络参考:https://www.cnblogs.com/chensheng-zhou/p/6380 ...

  6. 【干货】JDK动态代理的实现原理以及如何手写一个JDK动态代理

    动态代理 代理模式是设计模式中非常重要的一种类型,而设计模式又是编程中非常重要的知识点,特别是在业务系统的重构中,更是有举足轻重的地位.代理模式从类型上来说,可以分为静态代理和动态代理两种类型. 在解 ...

  7. 「zi2zi」:用AI生成自己的手写字体

    导读 如果想要自己做一套字体,无论是电脑软件FontCreator还是网站flexifont都为我们带来了极大的便利. 但是最低的国标字体数量近7000个,若采用传统的方法则需要手写相同数量的汉字,这 ...

  8. 识别手写字体app_我如何构建手写识别器并将其运送到App Store

    识别手写字体app 从构建卷积神经网络到将OCR部署到iOS (From constructing a Convolutional Neural Network to deploying an OCR ...

  9. 基于kNN的手写字体识别——《机器学习实战》笔记

    看完一节<机器学习实战>,算是踏入ML的大门了吧!这里就详细讲一下一个demo:使用kNN算法实现手写字体的简单识别 kNN 先简单介绍一下kNN,就是所谓的K-近邻算法: [作用原理]: ...

最新文章

  1. 多线程:Callable
  2. IDEA 创建 SpringCloud项目-多项目方式
  3. LeetCode 295. 数据流的中位数(大小堆)
  4. 软件测试面试 (二) 如何测试网页的登录页面
  5. ExtJs2.0学习系列(1)--Ext.MessageBox
  6. SpringBoot集成Redis缓存
  7. 图两点间的最短路径,所有路径算法C语言实现
  8. 计算机ip保留地址,ip地址显示为保留地址怎么解决
  9. BoundsChecker用法
  10. java发展观_科学发展观的第一要义是以人为本。
  11. 多图片拼图怎么操作?这个方法不要错过
  12. 解决COVID-19的7个开放硬件项目
  13. 机器学习作业-FOGS: 基于学习图的一阶梯度监督交通流预测
  14. Linux粘滞位(粘着位)
  15. 蓝墨云班课与中职计算机课,蓝墨云环境下中职《计算机应用基础》的对分课堂教学研究...
  16. vue axios传递FormData填坑,headers不显示,后台报错等等问题
  17. 普通用户可以申请华为鸿蒙系统吗,鸿蒙OS系统普通用户有申请成功的吗?
  18. Github标星5.3K,进阶学习工作最全指南
  19. GPU版本安装Pytorch教程最新方法
  20. MacOS Ventura 13.2.1 (22D68) 正式版带 OC 0.8.9 and winPE 双分区原版黑苹果镜像

热门文章

  1. 测试覆盖率是软件测试的重要组成部分?当然是,必须是啊!
  2. opencv 常用脚本合集
  3. win10系统/用anaconda安装pytorch/再把虚拟环境配到pycharm的流程
  4. 云计算流派战争:硬件出身终失意,他们只是太老了
  5. CKEditor4.0行距插件
  6. WebRTC 混音分析
  7. 计算机的由来、组成及其功能
  8. Android中的GridView反复调用getView和getCount,并且getView中的position的值几乎都是0
  9. 【闪电侠学netty】第4章 服务端启动流程
  10. await和then的区别详解