import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import timetime_start = time.time()# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([transforms.ToTensor(),  # 会做0-1归一化,也会channels, height, widthtransforms.Normalize((0.5,), (0.5,))
])train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform)
dataLoader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28*28),nn.Tanh())def forward(self, x):img = self.main(x)img = img.view(-1, 28, 28, 1)return img# 判别器网络定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.main(x)return xdevice = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失函数
loss_fn = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i] + 1)/2)plt.axis('off')plt.show()test_input = torch.randn(16, 100, device=device)# GAN训练
D_loss = []
G_loss = []# 训练循环
for epoch in range(20):d_epoch_loss = 0g_epoch_loss = 0count = len(dataLoader)  # 返回批次数for step, (img, _) in enumerate(dataLoader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)# 判别器的损失与优化d_optimizer.zero_grad()real_output = dis(img)  # 对判别器输入真实图片, real_output是对真实图片的判断结果d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 判别器在真实图像上的损失d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,fake_output对生成图片的预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 判别器在生成图像上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 生成器的损失与优化g_optimizer.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 生成器的损失g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print("Epoch:", epoch)gen_img_plot(gen, test_input)time_end = time.time()
print("花费总时间为:", time_end - time_start)

GAN实战——生成手写字体相关推荐

  1. 用c语言实现knn算法要有训练集和测试集,KNN算法实战:手写字体识别

    我们已经知道手写字体数据集是一个8×8的矩阵,共有64个特征.让我们看一下K最近邻算法对手写字体数据集处理的效果. 1) 导入相关包 这里我们将用到 datasets 中的手写字体数据,使用 trai ...

  2. #21天学习挑战赛—深度学习实战100例#——生成手写字体

    ​ ​ 活动地址:CSDN21天学习挑战赛 本文为

  3. 生成式对抗网络GAN之实现手写字体的生成(基于keras Tensorflow2.0实现)详细分析训练过程和代码

  4. 使用WGAN生成手写字体

    import sys; sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages") import nump ...

  5. Pytorch实现GAN之生成手写数字图片

    1.导入所需库 import torch import torch.optim as optim import torch.nn as nn import torch.nn.functional as ...

  6. gan网络原理(通俗)+minist手写字体实战

     gan网络原理如下: mnist手写字体实战: import torch import torchvision from torchvision import transforms from tor ...

  7. 《Gans in Action》第三章 用GAN生成手写数字

    此为<Gans in Action>(对抗神经网络实战)第三章读书笔记 Chapter 3. Your first GAN: Generating handwritten digits 用 ...

  8. GAN变种ACGAN利用手写数字识别mnist生成手写数字

    1.摘要 本文主要讲解:GAN变种ACGAN利用手写数字识别mnist数据集进行训练,最终生成手写数字图片 主要思路: Initialize generator and discriminator I ...

  9. 「zi2zi」:用AI生成自己的手写字体

    导读 如果想要自己做一套字体,无论是电脑软件FontCreator还是网站flexifont都为我们带来了极大的便利. 但是最低的国标字体数量近7000个,若采用传统的方法则需要手写相同数量的汉字,这 ...

最新文章

  1. TensorFlow基础5-可训练变量和自动求导机制
  2. 简单删除我的电脑里的wps云文档图标
  3. python替换缺失值_详解Pandas 处理缺失值指令大全
  4. 获取SQL SERVER某个数据库中所有存储过程的参数
  5. Mac下和Windows下UnrealEngine 4体验对比
  6. 在危机中呈现转机的网络管理
  7. RocketMQ 源码之 异步和同步请求 以及异步的回调 是怎么做到的
  8. e服务器系统可以用PE来装吗,U盘安装系统有哪些方式可以启动WinPE
  9. 软件著作权申请流程和费用
  10. 软件测试周刊(第15期):将军赶路 不追小兔
  11. 西雅图本地创业公司大盘点
  12. 思科模拟器 --- 路由器RIP动态路由配置
  13. Latent semantic analysis (LSA)
  14. anaconda安装、部署、卸载——Mac
  15. STM32之HAL库的Bootloader跳转到APP
  16. 网易企业邮箱删除的的邮件服务器,网易企业邮箱进行全选删除移动操作步骤说明...
  17. 2022年国内运营商最全号段,联通、移动、电信、广电四大运营商
  18. Kuick:创业大军中脱颖而出的少数派
  19. 深度学习框架Keras的安装
  20. 手机java短片_多媒体、JAVA游戏

热门文章

  1. 仿魅族手机消息通知效果
  2. bwz是什么缩写_bwz是什么意思,bwz缩写代表什么意思,bwz是什么含义
  3. 疯狂英语之突破功能-视频
  4. 【图像分割】UNet 和 UNet++
  5. 云虚拟主机mysql下载_云虚拟主机如何安装mysql
  6. 游戏介绍——钢琴块2
  7. firebase到底怎么用android
  8. SSM框架-spirng、springboot、mybatis
  9. 微信新消息,任务栏一闪一闪问题处理
  10. 让AI 作画更快一点