【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)
简述
之前认真学习了网上的一份,代码做了很详细的笔记。
【Gans入门】Pytorch实现Gans代码详解【70+代码】
但是上面的任务只是画一条在一定区间下的曲线。
这里对这个进行迁移,到可以进行图像的生成。
图像的很多数据都没有,但是突然想到在sklearn上的digits是一个非常简单的图片。
这里我想到之前的一份笔记
sklearn学习(一)
这里会使用sklearn自带的小数据来做训练
目标是让神经网络自己学会生成数字。
任务描述
为了让神经网络操作更简单。这里的输入数据只会选择特定数值的数字图片数据。然后丢给对抗生成神经网络学习。让其中的生成器学会如何生成手写数字。
下面是选择用数值1的生成过程
其实可以发现其实是有点这样的感觉了。
下面的这个是让它学习数字0的效果
可能是由于数字0的细节更粗糙一点,所以,可以发现,我们认为这个0生成的更好。(数字1和数字4其实是有点像的,所以会有点问题,还有这是因为图片像素有点低)
代码详解
导入包
torch,numpy
这些都是数据处理过程中需要的包matplotlib
为了画图sklearn
主要是为了它本身带的数据random
主要是为了选择标准数据更具有随机性os,shutil,imagei
o这三个库是为了画出gif动态图
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageio
创建临时文件夹
PNGFILE = './png/'
if not os.path.exists(PNGFILE):os.mkdir(PNGFILE)
else:shutil.rmtree(PNGFILE)os.mkdir(PNGFILE)
这里会创建一个临时的文件夹png,会把中途生成的那些图片都存在这,然后我就可以用这些png来生成gif文件
模型参数
BATCH_SIZE
这个参数表示每次用多少的数据来进行考量。(数值多的话模型进化的会稍微快点)LR_G
跟LR_D
表示两个模型的学习率N_IDEAS
:启发式因子(生成函数的初始层的节点数)。因为我们要操作的节点数量会特别大(特别是图像问题,但是如果输入节点过于大的话,会需要大量的计算资源。所以用小一点的这个基本够用就行了)target_num
:表示的是想要生成的数字。由于数据集中只有(0到9)所以,这里也只能取0到9。image_max
表示图片像素点的最大值,这个一开始我用到了,但是后来我修改了代码之后,就用不到了。ART_COMPONENTS
:像素点数量(其实本质上跟前一个版本的参考节点数都是一样的)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001 # learning rate for generator
LR_D = 0.00001 # learning rate for discriminator
N_IDEAS = 6 # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Numberdigits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1] # it could be total point G can draw in the canvas
标准数据
这个函数本质上,这个区间上选BATCH_SIZE个标准数据。
但是,random.sample只能输入的是list所以需要先把data转成list,但是转出来的list又不能直接变成torch中的Tensor,这里需要再转成ndarray,之后再转成Tensor,但是要注意在后面加一个.float()
函数的操作。
def artist_works(): # painting from the famous artist (real target)return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()
构建模型
生成器模型,但是Linear转成的数据是有可能有负数的数据的,但是作为图片肯定是不可以有这样的数据的。因为数据一定是需要为大于等于0的数据。
所以搭建的这个模型最后一定要加一个ReLU()这样的类似的,来保证没有0的情况。
G = nn.Sequential( # Generatornn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideasnn.ReLU(),
)D = nn.Sequential( # Discriminatornn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(), # tell the probability that the art work is made by artist
)
构建最优化的模型
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
迭代优化
这跟之前的是类似的。
for step in range(10000):artist_paintings = artist_works() # real painting from artistG_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideasG_paintings = G(G_ideas) # fake painting from G (random ideas)prob_artist0 = D(artist_paintings) # D try to increase this probprob_artist1 = D(G_paintings) # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True) # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward(retain_graph=True)opt_G.step()
画图并保存
if step % 100 == 0: # plottingplt.cla()tempdata = G_paintings[0].detach().numpy()tempdata = tempdata.reshape((8, 8))plt.imshow(tempdata, cmap=plt.cm.gray_r)# plt.draw()plt.savefig(PNGFILE + '%d.png' % times)filedatalist.append(PNGFILE + '%d.png' % times)times += 1plt.pause(0.01)
生成gif
generated_images = []
for png_path in filedatalist:generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)
全部代码
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageioPNGFILE = './png/'
if not os.path.exists(PNGFILE):os.mkdir(PNGFILE)
else:shutil.rmtree(PNGFILE)os.mkdir(PNGFILE)# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001 # learning rate for generator
LR_D = 0.00001 # learning rate for discriminator
N_IDEAS = 6 # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Numberdigits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1] # it could be total point G can draw in the canvasdef artist_works(): # painting from the famous artist (real target)return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()G = nn.Sequential( # Generatornn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideasnn.ReLU(),
)D = nn.Sequential( # Discriminatornn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(), # tell the probability that the art work is made by artist
)opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
times = 0filedatalist = []for step in range(10000):artist_paintings = artist_works() # real painting from artistG_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideasG_paintings = G(G_ideas) # fake painting from G (random ideas)prob_artist0 = D(artist_paintings) # D try to increase this probprob_artist1 = D(G_paintings) # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True) # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward(retain_graph=True)opt_G.step()if step % 100 == 0: # plottingplt.cla()tempdata = G_paintings[0].detach().numpy()tempdata = tempdata.reshape((8, 8))plt.imshow(tempdata, cmap=plt.cm.gray_r)# plt.draw()plt.savefig(PNGFILE + '%d.png' % times)filedatalist.append(PNGFILE + '%d.png' % times)times += 1plt.pause(0.01)generated_images = []
for png_path in filedatalist:generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)
【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)相关推荐
- 可下载:60分钟入门PyTorch(中文翻译全集)
来源:机器学习初学者本文约9500字,建议阅读20分钟官网教程翻译:60分钟入门PyTorch(全集) 前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute ...
- 【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器
前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...
- 60分钟快速入门PyTorch
点击上方"算法猿的成长",关注公众号,选择加"星标"或"置顶" 总第 136 篇文章,本文大约 26000 字,阅读大约需要 60 分钟 P ...
- 60分钟快速入门 PyTorch
PyTorch 是由 Facebook 开发,基于 Torch 开发,从并不常用的 Lua 语言转为 Python 语言开发的深度学习框架,Torch 是 TensorFlow 开源前非常出名的一个深 ...
- 快速入门PyTorch(2)--如何构建一个神经网络
2019 第 43 篇,总第 67 篇文章 本文大约 4600 字,阅读大约需要 10 分钟 快速入门 PyTorch 教程第二篇,这篇介绍如何构建一个神经网络.上一篇文章: 快速入门Pytorch( ...
- 使用Pytorch实现手写数字识别
使用Pytorch实现手写数字识别 1. 思路和流程分析 流程: 准备数据,这些需要准备DataLoader 构建模型,这里可以使用torch构造一个深层的神经网络 模型的训练 模型的保存,保存模型, ...
- 60分钟入门PyTorch,官方教程手把手教你训练第一个深度学习模型(附链接)
来源:机器之心 本文约800字,建议阅读5分钟. 本文介绍了官方教程入门PyTorch的技巧训练. 近期的一份调查报告显示:PyTorch 已经力压 TensorFlow 成为各大顶会的主流深度学习框 ...
- 【深度学习】新人如何入门Pytorch的路线?有哪些资源推荐?
初学者如何入门Pytorch,看看过来人的建议 作者:范星.xfanplus 链接:https://www.zhihu.com/question/55720139/answer/294449487 高 ...
- 【深度学习】翻译:60分钟入门PyTorch(三)——神经网络
前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...
最新文章
- 四种JOIN简单实例
- Bash shell
- 数据结构特性解析 (四)LinkedList
- Java面试中与源码有关的问题分享
- 运行中的Nginx进程间的关系
- 优秀学生专栏——李浩然
- 关于我曾经做过的一个商业社区的ui框架
- 创业版上市与SAP管理软件系统的关系
- IOS学习笔记07---C语言函数-scanf函数
- c语言答辩中期报告,安徽工程大学毕业设计(论文)中期检查总结
- 前端埋点方法解析及优缺点分析
- jquery easyui 输入框 禁止输入负数 设置属性data-options=min:0,required:true
- Edwin 的基本使用
- STM32:使用外部中断控制对射式红外传感器并计次
- python的多元数据类型(上)
- java hasfocus_说说Flutter中的无名英雄 —— Focus
- 软件工程c语言2000行代码,C语言教务管理系统(2000行代码)
- 拿到别人的vue项目之后如何运行
- ZCash零知识证明
- Android 千变万化 TextView:神奇的 SpannableString
热门文章
- WINCE下实现基于USB的camera
- Java程序员涨薪必备技能
- 超强、超详细Redis入门教程【转】
- LinkedIn公司采用超大规模数据中心设计
- 原创哈希数据导出算法
- windows下qt5 kinect 2.0开发与环境配置
- ArcEngine编辑功能的实现(二)
- ora-12514: tns: 监听程序当前无法识别连接描述符中请求的服务 问题解决
- 我的世界服务器反作弊不起作用,我的世界服务器反作弊怎么搞 | 手游网游页游攻略大全...
- 脑动力:C语言函数速查效率手册(附DVD光盘1张) [平