GAN生成哆啦A梦,亲测训练50000epoch
**
GAN生成哆啦A梦,亲测训练50000epoch
**
闲着没事学了一下GAN网络,感觉这个东西挺有趣的,所以就打算跟大家分享一下这个东西,其中的原理就不用跟大家说了,因为现在其他的博客介绍原理都很全面,大家可以去看一下其他对的博客,我这里就放上我参照其他博客进行修改的代码吧。
废话少说,先摆上我的训练结果,证明我的代码是可以运行的叭:
我的训练图片是从百度上面拿的,拿来的图片格式是.webp格式的图片
然后我进行训练的时候使用的图片分辨率是64X64,怎么说呢,就是尽量使用64X64的分辨率的图片进行训练叭,因为这样可以让你的电脑能够运行的起来,我刚开始进行训练的是医学的数据938X636结果就在电脑报错了:**RuntimeError: CUDA out of memory. Tried to allocate 6.71 GiB (GPU 0; 15.78 GiB total capacity; 13.49 GiB already allocated; 984.75 MiB free; 13.51 GiB reserved in total by PyTorch)**因为电脑的配置跟不上呀
所以最后我对收集到的数据进行分辨率的修改,编程64X64就能轻轻松松地进行训练了
先给大伙看看我整体的文件格式,因为路径那些都是配套,按照我这样才能进行运行,这个对小白比较友好,不然就自己进行修改代码了
解释一下叭
gan_A是总的文件夹
images是放置网页进行获取的图片,我拿的图片格式是.webp格式的图片,这个很重要,因为后面进行图片分辨率的转化的时候,我的代码就是针对这个格式的图片的,如果不是这种格式的话,在我的代码image_tool.py那里需要进行修改一些东西。#
img # 是每一万轮就对训练G产生的图片进行保存
result是对webp格式的图片进行转化后保存的一个文件夹,也就是gan_1进行数据的训练的数据
saved_models是每隔一万轮对G的模型进行保存
image_tool.py代码如下:
# -*- coding:utf-8 -*-
# @Time : 2022-04-18 15:07
# @Author : DaFuChen
# @File : image_tool.py
# @software: PyCharm# 进行分辨率的重新修改 进行值的变化修改# 导入需要的模块
from glob import glob
from PIL import Image
import os# 图片路径
# 使用 glob模块 获得文件夹内所有jpg图像
img_path = glob("./images/*.webp")
# img_path = glob("./images/*.jpg")
# img_path = glob("./images/*.jpeg")# 存储(输出)路径
path_save = "./result"for i, file in enumerate(img_path):name = os.path.join(path_save, "%d.jpg" % i)im = Image.open(file)# im.thumbnail((720,1280))reim = im.resize((64, 64))print(im.format, reim.size, reim.mode)reim.save(name, im.format)
gan_1.py代码如下:
# -*- coding:utf-8 -*-
# @Time : 2022-04-18 15:05
# @Author : DaFuChen
# @File : gan_1.py
# @software: PyCharmimport argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torch.autograd import Variable
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch# 输出图片保存路径 没有就会自动进行创建
os.makedirs("img", exist_ok=True)# 参数设置
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=1000001, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--gpu", type=int, default=0, help="number of cpu threads to use during batch generation")
# 输入噪声向量维度,默认100
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")# [enforce fail at .\c10 core\CPUAllocator.cpp:79]DefaultCPUAllocator:内存不足:您试图分配7203840000字节。 改改网络结构吧 tm
# 输入图片维度,默认64*64*3 但是进行修改之后两个值 938 625 但是这个设计的网络结构不稳定 很难搞 这个是撑不住了
parser.add_argument("--img_size1", type=int, default=64, help="size of each image dimension")
parser.add_argument("--img_size2", type=int, default=64, help="size of each image dimension")parser.add_argument("--channels", type=int, default=3, help="number of image channels")
# 每隔一个sample_interval的批次进行一次图片的保存
parser.add_argument("--sample_interval", type=int, default=10000, help="interval betwen image samples")# 其实是创建了一个对象 之后可以调用它其中的参数值 使用了它的这一个类里面刻画的一些属性
opt = parser.parse_args()
print(opt)# 图像的分辨率值
img_shape = (opt.channels, opt.img_size1, opt.img_size2)cuda = True if torch.cuda.is_available() else Falseclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(# 基本的一个操作 进行线性化的拟合 然后再relu一下取个好的效果nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# Loss function
adversarial_loss = torch.nn.BCELoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()# 将属性放进GPU进行训练
if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loader
img_transform = transforms.Compose([# transforms.ToPILImage(),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)) # (x-mean) / std
])class MyData(Dataset): # 继承Datasetdef __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数self.root_dir = root_dir # 文件目录self.transform = transform # 变换self.images = os.listdir(self.root_dir) # 目录里的所有文件def __len__(self): # 返回整个数据集的大小return len(self.images)def __getitem__(self, index): # 根据索引index返回dataset[index]image_index = self.images[index] # 根据索引index获取该图片img_path = os.path.join(self.root_dir, image_index) # 获取索引为index的图片的路径名img = Image.open(img_path) # 读取该图片if self.transform:img = self.transform(img)return img # 返回该样本# 输入图片所在文件夹
mydataset = MyData(root_dir='./result/', transform=img_transform
)# data loader 数据载入
dataloader = DataLoader(dataset=mydataset, batch_size=opt.batch_size, shuffle=True
)# 下面这一块是可以省略
# os.makedirs("./data/MNIST", exist_ok=True)
# dataloader = torch.utils.data.DataLoader(
# datasets.MNIST(
# "./data/MNIST",
# train=True,
# download=True,
# transform=transforms.Compose(
# [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
# ),
# ),
# batch_size=opt.batch_size,
# shuffle=True,
# )# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------
# Training
# ----------for epoch in range(opt.n_epochs):for i, img in enumerate(dataloader):imgs = img# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(Tensor))# -----------------# Train Generator# -----------------optimizer_G.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------# Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "./img/%d.png" % batches_done, nrow=5, normalize=True)torch.save(generator, "saved_models/generator_%d.pth" % epoch)
可以使用了
GAN生成哆啦A梦,亲测训练50000epoch相关推荐
- LSGANs : Least Squares GAN(最小二乘GAN)--解决标准GAN生成的图片质量不高以及训练过程不稳定问题
LSGANs基本思想 LSGANs的英文全称是Least Squares GANs.这篇文章针对的是标准GAN生成的图片质量不高以及训练过程不稳定这两个缺陷进行改进.改进方法就是将GAN的目标函数由交 ...
- 微信小程序码的生成(JAVA完整版) 亲测可用
JAVA生成小程序码(太阳码) 首先准备工具类,这里我使用的是QrUtil;废话不多说,上工具类; 工具类是获取token使用; appid = 小程序appID secret = 小程序秘钥 /** ...
- windows下扩展yaf,并生成yaf框架文件(亲测)
YAF中文文档:http://www.laruence.com/manual/index.html 1 YAF框架是用C开发的,属于PHP的扩展框架: 2 YAF的性能相对于源生PHP,性能只降低不到 ...
- java pdf模板填充生成pdf打印 (亲测有效)
//先要制作好pdf模板(可以在word 里面画好,导出保存pdf文件), 下载Adobe Acrobat DC 工具 后打开 pdf 里面带格式的,然后 点击 准备表单按钮 你可以拖动 文本 和文本 ...
- 要让GAN生成想要的样本,可控生成对抗网络可能会成为你的好帮手
如何让GAN生成带有指定特征的图像?这是一个极有潜力.极有应用前景的问题,然而目前都没有理想的方法.韩国大学电子工程学院Minhyeok Lee和Junhee Seok近期发表论文,就生成对抗网络的控 ...
- java 自带写日志包_jdk自带的日志工具实操总结(亲测有效)
现在项目中,大多用log4j等第三方日志框架,用这些框架确实有原因,而且确实配置简单,好用.因为一个传统项目不想用第三方日志框架,想用jdk自带的日志来记录日志,所以总结了下经验,希望对大家有所帮助. ...
- 将训练好的pytorch模型的pth文件转换成onnx模型(亲测成功)
将训练好的pytorch模型的pth文件转换成onnx模型(亲测成功) 模型转换 声明:本文原创,未经许可严禁转载,原文地址https://blog.csdn.net/hutao1030813002/ ...
- 中国博士生提出最先进AI训练优化器,收敛快精度高,网友亲测:Adam可以退休了...
栗子 鱼羊 晓查 发自 凹非寺 量子位 报道 | 公众号 QbitAI 找到一种快速稳定的优化算法,是所有AI研究人员的目标. 但是鱼和熊掌不可兼得.Adam.RMSProp这些算法虽然收敛速度很快 ...
- mmdetection2.3.0版本安装过程,以及训练、测试、可视化等(亲测好用,很顺利)
欢迎大家关注笔者,你的关注是我持续更博的最大动力 原创文章,转载告知,盗版必究 mmdetection2.3.0版本安装过程,以及训练.测试.可视化等(亲测好用,很顺利) 文章目录: 1 运行mmde ...
最新文章
- 《3D数学基础》系列视频 1.5 向量的夹角
- Oracle传输表空间
- C/C++/Java 的基本数据类型
- RPC框架几行代码就够了
- 查看linux的系统位数
- 禁止和恢复WIN7驱动强制签名
- 过年遇到前任借钱, 如何傲娇的拒绝?
- Beetlex服务网关1.8发布
- Java并发面试,幸亏有点道行,不然又被忽悠了 1
- DPDK收发包全景分析
- steam常用计算机,絮絮叨叨的繁星 篇二:新电脑必备——常用验机和跑分软件汇总...
- python 正则处理经纬度度分秒转换
- Excel: 批量去除空格的函数——trim函数, substitute函数,clean函数
- 收藏:国产服务器和处理器架构
- 主码,候选码,外码,全码,主属性,非主属性的区别
- python格式化千分位数字
- 液位系统c语言程序,基于STM32的液位控制系统设计
- 如何将GPS手持机航点数据导出、转换格式,并用不同软件Google Earth或者ArcGIS打开?
- 2020考研-王道数据结构-图-图的遍历
- 荣耀30s是否搭载鸿蒙系统,荣耀30S来袭 3月30日发布或配麒麟820芯片