简述

之前认真学习了网上的一份,代码做了很详细的笔记。
【Gans入门】Pytorch实现Gans代码详解【70+代码】

但是上面的任务只是画一条在一定区间下的曲线。
这里对这个进行迁移,到可以进行图像的生成。

图像的很多数据都没有,但是突然想到在sklearn上的digits是一个非常简单的图片。
这里我想到之前的一份笔记
sklearn学习(一)

这里会使用sklearn自带的小数据来做训练
目标是让神经网络自己学会生成数字。

任务描述

为了让神经网络操作更简单。这里的输入数据只会选择特定数值的数字图片数据。然后丢给对抗生成神经网络学习。让其中的生成器学会如何生成手写数字。

下面是选择用数值1的生成过程

其实可以发现其实是有点这样的感觉了。

下面的这个是让它学习数字0的效果

可能是由于数字0的细节更粗糙一点,所以,可以发现,我们认为这个0生成的更好。(数字1和数字4其实是有点像的,所以会有点问题,还有这是因为图片像素有点低

代码详解

导入包

  • torch,numpy这些都是数据处理过程中需要的包
  • matplotlib为了画图
  • sklearn主要是为了它本身带的数据
  • random主要是为了选择标准数据更具有随机性
  • os,shutil,imageio这三个库是为了画出gif动态图
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageio

创建临时文件夹

PNGFILE = './png/'
if not os.path.exists(PNGFILE):os.mkdir(PNGFILE)
else:shutil.rmtree(PNGFILE)os.mkdir(PNGFILE)

这里会创建一个临时的文件夹png,会把中途生成的那些图片都存在这,然后我就可以用这些png来生成gif文件

模型参数

  • BATCH_SIZE这个参数表示每次用多少的数据来进行考量。(数值多的话模型进化的会稍微快点)
  • LR_GLR_D表示两个模型的学习率
  • N_IDEAS:启发式因子(生成函数的初始层的节点数)。因为我们要操作的节点数量会特别大(特别是图像问题,但是如果输入节点过于大的话,会需要大量的计算资源。所以用小一点的这个基本够用就行了)
  • target_num :表示的是想要生成的数字。由于数据集中只有(0到9)所以,这里也只能取0到9。
  • image_max表示图片像素点的最大值,这个一开始我用到了,但是后来我修改了代码之后,就用不到了。
  • ART_COMPONENTS:像素点数量(其实本质上跟前一个版本的参考节点数都是一样的)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001  # learning rate for generator
LR_D = 0.00001  # learning rate for discriminator
N_IDEAS = 6  # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Numberdigits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1]  # it could be total point G can draw in the canvas

标准数据

这个函数本质上,这个区间上选BATCH_SIZE个标准数据。
但是,random.sample只能输入的是list所以需要先把data转成list,但是转出来的list又不能直接变成torch中的Tensor,这里需要再转成ndarray,之后再转成Tensor,但是要注意在后面加一个.float()函数的操作。

def artist_works():  # painting from the famous artist (real target)return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()

构建模型

生成器模型,但是Linear转成的数据是有可能有负数的数据的,但是作为图片肯定是不可以有这样的数据的。因为数据一定是需要为大于等于0的数据。

所以搭建的这个模型最后一定要加一个ReLU()这样的类似的,来保证没有0的情况。

G = nn.Sequential(  # Generatornn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideasnn.ReLU(),
)D = nn.Sequential(  # Discriminatornn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(),  # tell the probability that the art work is made by artist
)

构建最优化的模型

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

迭代优化

这跟之前的是类似的。

for step in range(10000):artist_paintings = artist_works()  # real painting from artistG_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideasG_paintings = G(G_ideas)  # fake painting from G (random ideas)prob_artist0 = D(artist_paintings)  # D try to increase this probprob_artist1 = D(G_paintings)  # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True)  # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward(retain_graph=True)opt_G.step()

画图并保存

    if step % 100 == 0:  # plottingplt.cla()tempdata = G_paintings[0].detach().numpy()tempdata = tempdata.reshape((8, 8))plt.imshow(tempdata, cmap=plt.cm.gray_r)# plt.draw()plt.savefig(PNGFILE + '%d.png' % times)filedatalist.append(PNGFILE + '%d.png' % times)times += 1plt.pause(0.01)

生成gif

generated_images = []
for png_path in filedatalist:generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)

全部代码

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageioPNGFILE = './png/'
if not os.path.exists(PNGFILE):os.mkdir(PNGFILE)
else:shutil.rmtree(PNGFILE)os.mkdir(PNGFILE)# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001  # learning rate for generator
LR_D = 0.00001  # learning rate for discriminator
N_IDEAS = 6  # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Numberdigits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1]  # it could be total point G can draw in the canvasdef artist_works():  # painting from the famous artist (real target)return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()G = nn.Sequential(  # Generatornn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideasnn.ReLU(),
)D = nn.Sequential(  # Discriminatornn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(),  # tell the probability that the art work is made by artist
)opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
times = 0filedatalist = []for step in range(10000):artist_paintings = artist_works()  # real painting from artistG_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideasG_paintings = G(G_ideas)  # fake painting from G (random ideas)prob_artist0 = D(artist_paintings)  # D try to increase this probprob_artist1 = D(G_paintings)  # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True)  # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward(retain_graph=True)opt_G.step()if step % 100 == 0:  # plottingplt.cla()tempdata = G_paintings[0].detach().numpy()tempdata = tempdata.reshape((8, 8))plt.imshow(tempdata, cmap=plt.cm.gray_r)# plt.draw()plt.savefig(PNGFILE + '%d.png' % times)filedatalist.append(PNGFILE + '%d.png' % times)times += 1plt.pause(0.01)generated_images = []
for png_path in filedatalist:generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)

【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)相关推荐

  1. 可下载:60分钟入门PyTorch(中文翻译全集)

    来源:机器学习初学者本文约9500字,建议阅读20分钟官网教程翻译:60分钟入门PyTorch(全集) 前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute ...

  2. 【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

  3. 60分钟快速入门PyTorch

    点击上方"算法猿的成长",关注公众号,选择加"星标"或"置顶" 总第 136 篇文章,本文大约 26000 字,阅读大约需要 60 分钟 P ...

  4. 60分钟快速入门 PyTorch

    PyTorch 是由 Facebook 开发,基于 Torch 开发,从并不常用的 Lua 语言转为 Python 语言开发的深度学习框架,Torch 是 TensorFlow 开源前非常出名的一个深 ...

  5. 快速入门PyTorch(2)--如何构建一个神经网络

    2019 第 43 篇,总第 67 篇文章 本文大约 4600 字,阅读大约需要 10 分钟 快速入门 PyTorch 教程第二篇,这篇介绍如何构建一个神经网络.上一篇文章: 快速入门Pytorch( ...

  6. 使用Pytorch实现手写数字识别

    使用Pytorch实现手写数字识别 1. 思路和流程分析 流程: 准备数据,这些需要准备DataLoader 构建模型,这里可以使用torch构造一个深层的神经网络 模型的训练 模型的保存,保存模型, ...

  7. 60分钟入门PyTorch,官方教程手把手教你训练第一个深度学习模型(附链接)

    来源:机器之心 本文约800字,建议阅读5分钟. 本文介绍了官方教程入门PyTorch的技巧训练. 近期的一份调查报告显示:PyTorch 已经力压 TensorFlow 成为各大顶会的主流深度学习框 ...

  8. 【深度学习】新人如何入门Pytorch的路线?有哪些资源推荐?

    初学者如何入门Pytorch,看看过来人的建议 作者:范星.xfanplus 链接:https://www.zhihu.com/question/55720139/answer/294449487 高 ...

  9. 【深度学习】翻译:60分钟入门PyTorch(三)——神经网络

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

最新文章

  1. 四种JOIN简单实例
  2. Bash shell
  3. 数据结构特性解析 (四)LinkedList
  4. Java面试中与源码有关的问题分享
  5. 运行中的Nginx进程间的关系
  6. 优秀学生专栏——李浩然
  7. 关于我曾经做过的一个商业社区的ui框架
  8. 创业版上市与SAP管理软件系统的关系
  9. IOS学习笔记07---C语言函数-scanf函数
  10. c语言答辩中期报告,安徽工程大学毕业设计(论文)中期检查总结
  11. 前端埋点方法解析及优缺点分析
  12. jquery easyui 输入框 禁止输入负数 设置属性data-options=min:0,required:true
  13. Edwin 的基本使用
  14. STM32:使用外部中断控制对射式红外传感器并计次
  15. python的多元数据类型(上)
  16. java hasfocus_说说Flutter中的无名英雄 —— Focus
  17. 软件工程c语言2000行代码,C语言教务管理系统(2000行代码)
  18. 拿到别人的vue项目之后如何运行
  19. ZCash零知识证明
  20. Android 千变万化 TextView:神奇的 SpannableString

热门文章

  1. WINCE下实现基于USB的camera
  2. Java程序员涨薪必备技能
  3. 超强、超详细Redis入门教程【转】
  4. LinkedIn公司采用超大规模数据中心设计
  5. 原创哈希数据导出算法
  6. windows下qt5 kinect 2.0开发与环境配置
  7. ArcEngine编辑功能的实现(二)
  8. ora-12514: tns: 监听程序当前无法识别连接描述符中请求的服务 问题解决
  9. 我的世界服务器反作弊不起作用,我的世界服务器反作弊怎么搞 | 手游网游页游攻略大全...
  10. 脑动力:C语言函数速查效率手册(附DVD光盘1张) [平