**

GAN生成哆啦A梦,亲测训练50000epoch

**
闲着没事学了一下GAN网络,感觉这个东西挺有趣的,所以就打算跟大家分享一下这个东西,其中的原理就不用跟大家说了,因为现在其他的博客介绍原理都很全面,大家可以去看一下其他对的博客,我这里就放上我参照其他博客进行修改的代码吧。

废话少说,先摆上我的训练结果,证明我的代码是可以运行的叭:

我的训练图片是从百度上面拿的,拿来的图片格式是.webp格式的图片
然后我进行训练的时候使用的图片分辨率是64X64,怎么说呢,就是尽量使用64X64的分辨率的图片进行训练叭,因为这样可以让你的电脑能够运行的起来,我刚开始进行训练的是医学的数据938X636结果就在电脑报错了:**RuntimeError: CUDA out of memory. Tried to allocate 6.71 GiB (GPU 0; 15.78 GiB total capacity; 13.49 GiB already allocated; 984.75 MiB free; 13.51 GiB reserved in total by PyTorch)**因为电脑的配置跟不上呀
所以最后我对收集到的数据进行分辨率的修改,编程64X64就能轻轻松松地进行训练了

先给大伙看看我整体的文件格式,因为路径那些都是配套,按照我这样才能进行运行,这个对小白比较友好,不然就自己进行修改代码了

解释一下叭

gan_A是总的文件夹

images是放置网页进行获取的图片,我拿的图片格式是.webp格式的图片,这个很重要,因为后面进行图片分辨率的转化的时候,我的代码就是针对这个格式的图片的,如果不是这种格式的话,在我的代码image_tool.py那里需要进行修改一些东西。#

img # 是每一万轮就对训练G产生的图片进行保存

result是对webp格式的图片进行转化后保存的一个文件夹,也就是gan_1进行数据的训练的数据

saved_models是每隔一万轮对G的模型进行保存

image_tool.py代码如下:

# -*- coding:utf-8 -*-
# @Time : 2022-04-18 15:07
# @Author : DaFuChen
# @File : image_tool.py
# @software: PyCharm# 进行分辨率的重新修改 进行值的变化修改# 导入需要的模块
from glob import glob
from PIL import Image
import os# 图片路径
# 使用 glob模块 获得文件夹内所有jpg图像
img_path = glob("./images/*.webp")
# img_path = glob("./images/*.jpg")
# img_path = glob("./images/*.jpeg")# 存储(输出)路径
path_save = "./result"for i, file in enumerate(img_path):name = os.path.join(path_save, "%d.jpg" % i)im = Image.open(file)# im.thumbnail((720,1280))reim = im.resize((64, 64))print(im.format, reim.size, reim.mode)reim.save(name, im.format)

gan_1.py代码如下:

# -*- coding:utf-8 -*-
# @Time : 2022-04-18 15:05
# @Author : DaFuChen
# @File : gan_1.py
# @software: PyCharmimport argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torch.autograd import Variable
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch# 输出图片保存路径 没有就会自动进行创建
os.makedirs("img", exist_ok=True)# 参数设置
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=1000001, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--gpu", type=int, default=0, help="number of cpu threads to use during batch generation")
# 输入噪声向量维度,默认100
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")# [enforce fail at .\c10 core\CPUAllocator.cpp:79]DefaultCPUAllocator:内存不足:您试图分配7203840000字节。 改改网络结构吧 tm
# 输入图片维度,默认64*64*3  但是进行修改之后两个值 938 625 但是这个设计的网络结构不稳定 很难搞  这个是撑不住了
parser.add_argument("--img_size1", type=int, default=64, help="size of each image dimension")
parser.add_argument("--img_size2", type=int, default=64, help="size of each image dimension")parser.add_argument("--channels", type=int, default=3, help="number of image channels")
# 每隔一个sample_interval的批次进行一次图片的保存
parser.add_argument("--sample_interval", type=int, default=10000, help="interval betwen image samples")# 其实是创建了一个对象 之后可以调用它其中的参数值 使用了它的这一个类里面刻画的一些属性
opt = parser.parse_args()
print(opt)# 图像的分辨率值
img_shape = (opt.channels, opt.img_size1, opt.img_size2)cuda = True if torch.cuda.is_available() else Falseclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(# 基本的一个操作   进行线性化的拟合 然后再relu一下取个好的效果nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# Loss function
adversarial_loss = torch.nn.BCELoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()# 将属性放进GPU进行训练
if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loader
img_transform = transforms.Compose([# transforms.ToPILImage(),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # (x-mean) / std
])class MyData(Dataset):  # 继承Datasetdef __init__(self, root_dir, transform=None):  # __init__是初始化该类的一些基础参数self.root_dir = root_dir  # 文件目录self.transform = transform  # 变换self.images = os.listdir(self.root_dir)  # 目录里的所有文件def __len__(self):  # 返回整个数据集的大小return len(self.images)def __getitem__(self, index):  # 根据索引index返回dataset[index]image_index = self.images[index]  # 根据索引index获取该图片img_path = os.path.join(self.root_dir, image_index)  # 获取索引为index的图片的路径名img = Image.open(img_path)  # 读取该图片if self.transform:img = self.transform(img)return img  # 返回该样本# 输入图片所在文件夹
mydataset = MyData(root_dir='./result/', transform=img_transform
)# data loader 数据载入
dataloader = DataLoader(dataset=mydataset, batch_size=opt.batch_size, shuffle=True
)# 下面这一块是可以省略
# os.makedirs("./data/MNIST", exist_ok=True)
# dataloader = torch.utils.data.DataLoader(
#     datasets.MNIST(
#         "./data/MNIST",
#         train=True,
#         download=True,
#         transform=transforms.Compose(
#             [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
#         ),
#     ),
#     batch_size=opt.batch_size,
#     shuffle=True,
# )# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, img in enumerate(dataloader):imgs = img# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(Tensor))# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "./img/%d.png" % batches_done, nrow=5, normalize=True)torch.save(generator, "saved_models/generator_%d.pth" % epoch)

可以使用了

GAN生成哆啦A梦,亲测训练50000epoch相关推荐

  1. LSGANs : Least Squares GAN(最小二乘GAN)--解决标准GAN生成的图片质量不高以及训练过程不稳定问题

    LSGANs基本思想 LSGANs的英文全称是Least Squares GANs.这篇文章针对的是标准GAN生成的图片质量不高以及训练过程不稳定这两个缺陷进行改进.改进方法就是将GAN的目标函数由交 ...

  2. 微信小程序码的生成(JAVA完整版) 亲测可用

    JAVA生成小程序码(太阳码) 首先准备工具类,这里我使用的是QrUtil;废话不多说,上工具类; 工具类是获取token使用; appid = 小程序appID secret = 小程序秘钥 /** ...

  3. windows下扩展yaf,并生成yaf框架文件(亲测)

    YAF中文文档:http://www.laruence.com/manual/index.html 1 YAF框架是用C开发的,属于PHP的扩展框架: 2 YAF的性能相对于源生PHP,性能只降低不到 ...

  4. java pdf模板填充生成pdf打印 (亲测有效)

    //先要制作好pdf模板(可以在word 里面画好,导出保存pdf文件), 下载Adobe Acrobat DC 工具 后打开 pdf 里面带格式的,然后 点击 准备表单按钮 你可以拖动 文本 和文本 ...

  5. 要让GAN生成想要的样本,可控生成对抗网络可能会成为你的好帮手

    如何让GAN生成带有指定特征的图像?这是一个极有潜力.极有应用前景的问题,然而目前都没有理想的方法.韩国大学电子工程学院Minhyeok Lee和Junhee Seok近期发表论文,就生成对抗网络的控 ...

  6. java 自带写日志包_jdk自带的日志工具实操总结(亲测有效)

    现在项目中,大多用log4j等第三方日志框架,用这些框架确实有原因,而且确实配置简单,好用.因为一个传统项目不想用第三方日志框架,想用jdk自带的日志来记录日志,所以总结了下经验,希望对大家有所帮助. ...

  7. 将训练好的pytorch模型的pth文件转换成onnx模型(亲测成功)

    将训练好的pytorch模型的pth文件转换成onnx模型(亲测成功) 模型转换 声明:本文原创,未经许可严禁转载,原文地址https://blog.csdn.net/hutao1030813002/ ...

  8. 中国博士生提出最先进AI训练优化器,收敛快精度高,网友亲测:Adam可以退休了...

    栗子 鱼羊 晓查 发自 凹非寺  量子位 报道 | 公众号 QbitAI 找到一种快速稳定的优化算法,是所有AI研究人员的目标. 但是鱼和熊掌不可兼得.Adam.RMSProp这些算法虽然收敛速度很快 ...

  9. mmdetection2.3.0版本安装过程,以及训练、测试、可视化等(亲测好用,很顺利)

    欢迎大家关注笔者,你的关注是我持续更博的最大动力 原创文章,转载告知,盗版必究 mmdetection2.3.0版本安装过程,以及训练.测试.可视化等(亲测好用,很顺利) 文章目录: 1 运行mmde ...

最新文章

  1. 《3D数学基础》系列视频 1.5 向量的夹角
  2. Oracle传输表空间
  3. C/C++/Java 的基本数据类型
  4. RPC框架几行代码就够了
  5. 查看linux的系统位数
  6. 禁止和恢复WIN7驱动强制签名
  7. 过年遇到前任借钱, 如何傲娇的拒绝?
  8. Beetlex服务网关1.8发布
  9. Java并发面试,幸亏有点道行,不然又被忽悠了 1
  10. DPDK收发包全景分析
  11. steam常用计算机,絮絮叨叨的繁星 篇二:新电脑必备——常用验机和跑分软件汇总...
  12. python 正则处理经纬度度分秒转换
  13. Excel: 批量去除空格的函数——trim函数, substitute函数,clean函数
  14. 收藏:国产服务器和处理器架构
  15. 主码,候选码,外码,全码,主属性,非主属性的区别
  16. python格式化千分位数字
  17. 液位系统c语言程序,基于STM32的液位控制系统设计
  18. 如何将GPS手持机航点数据导出、转换格式,并用不同软件Google Earth或者ArcGIS打开?
  19. 2020考研-王道数据结构-图-图的遍历
  20. 荣耀30s是否搭载鸿蒙系统,荣耀30S来袭 3月30日发布或配麒麟820芯片

热门文章

  1. su自带模型库怎么打开_草图大师Sketchup打不开3d模型库,该怎么解决?
  2. uni-app 语音播报-前台后台离线推送语音播报、到账xx元、收款播报、自定义推送铃(ios)
  3. 半桥LLC谐振工作原理及模态分析
  4. linux中如何解压.tgz
  5. 华为机试部分刷题记录
  6. C语言数组用到的动态内存分配
  7. cacti mysql 修复_cacti数据库修复
  8. 解决CentOS7 经HBA卡接sata硬盘不能启动系统
  9. 光通量与辐射通量之间的换算
  10. NLP5:NLTK词性标注