彩色星球图片生成4:转置卷积层+插值缩放+卷积收缩(pytorch版)

  • 1. 改进方面
    • 1.1 优化器与优化步长
    • 1.2 交叉熵损失函数
    • 1.3 Patch判别器
    • 1.4 输入分辨率
    • 1.5 转置卷积+插值缩放+卷积收缩
    • 1.6 bias=False
    • 1.7 历史版本额外保存
  • 2. 代码
    • 2.1 训练集裁剪代码
    • 2.2 模型代码model.py
    • 2.3 训练代码train.py
  • 3. 最终效果
  • 4. 缺陷与下一步改进

上一集: 彩色星球图片生成3:代码改进(pytorch版)

在上一集代码的基础上,进行了更多的修改以改进生成效果。
主要针对转置卷积的棋盘效应进行了进一步的优化,也在其它方面做了一些工作。
参考文献:Deconvolution and Checkerboard Artifacts

训练集图片(共192张,来自于Space Engine):

1. 改进方面

1.1 优化器与优化步长

将原本的AdamW优化器重新更换为了Adam优化器,因为AdamW优化器的L2正则化效果似乎在GAN图像生成方面并没有表现出什么优势。
将两个判别器的步长设置为3e-4,而生成器的步长设置为1e-4,使得生成器用更小的步伐追赶判别器的步伐。

1.2 交叉熵损失函数

将损失函数从MSELoss均方差损失函数修改为二分类交叉熵,由于apex不支持BCELoss,因此使用了BCEWithLogitsLoss,因此也对Patch判别器进行了部分修改。

1.3 Patch判别器

不再通过一个label矩阵计算Patch判别器的输出的MSELoss,取而代之的是使用了一个全局平均池化层来统一Patch判别器的输出结果。
虽然从论文思想上来说,应该对Patch判别器的每一个输出结果使用sigmoid激活后再进行平均,但由于apex不支持BCELoss,只能使用BCEWithLogitsLoss,而它对减少显存占用和训练时间至关重要,因此最后的做法是去除了sigmoid直接做全局平均池化,相当于先平均池化再激活。

1.4 输入分辨率

训练集的输入图像使用了更大的512x512的图像作为输入,在将原图收缩到512x512的过程中,使用Image.ANTIALIAS缩放算法来保留更多高质量的细节信息,并且使用PNG格式保存以尽可能保留原图信息。
在之前版本的基础上,扩大训练集数量到192张图片。

1.5 转置卷积+插值缩放+卷积收缩

将图像的缩放改为逐层放大,由于资源与时间限制,暂时未使用LapGan的分段式训练结构,但使用了类似的二倍放大思想,每一层都是上一层长宽的两倍。
基于谷歌论文中对棋盘效应的研究描述,将所有kernal size改为stride的整数倍,并且使用最近邻插值进行二次缩放。
每个模块分为三步:
1、使用转置卷积将图像放大为2倍长宽。
2、使用最近邻插值将图像再次放大为2倍长宽。
3、使用卷积将图像压缩为0.5倍长宽。
整体结果相当于将图像放大为2倍长宽,最终输出为512x512。

1.6 bias=False

将所有层都设置为bias=False,不需要产生额外的数据偏移。

1.7 历史版本额外保存

添加了一个配置项,当电脑连接到一个大容量外界硬盘时,会将所有历史版本额外保存在指定路径中,方便追溯之前表现优秀的训练模型版本。

2. 代码

2.1 训练集裁剪代码

将训练集图片裁剪为正方形并使用高质量缩放算法进行缩放。

import cv2
import os
from PIL import Imageimg_size = 512# 数据集来源
img_path = "train_images/"for path, dirs, files in os.walk(img_path, topdown=False):file_list = list(files)
for file in file_list:image_path = img_path + fileimg = cv2.imread(image_path, 1)bias = (img.shape[1] - img.shape[0]) // 2img = img[:, bias:bias+img.shape[0], :](B, G, R) = cv2.split(img)# 颜色通道合并img = cv2.merge([R, G, B])# 缩放img = Image.fromarray(img)img = img.resize((img_size, img_size), Image.ANTIALIAS)os.remove(image_path)img.save(image_path.rstrip('.jpg') + '.png')

2.2 模型代码model.py

import torch
import torch.nn as nn# 通过最近邻插值的方式将长宽加倍
def amplify_img(imgs):return nn.functional.interpolate(imgs, torch.Size([imgs.shape[-2] * 2, imgs.shape[-1] * 2]), mode='nearest')# 生成器
class G_net(nn.Module):def __init__(self):super(G_net, self).__init__()self.expand = nn.Sequential(nn.Linear(128, 2048, bias=False),nn.BatchNorm1d(2048),nn.Dropout(0.5),nn.LeakyReLU(0.2, inplace=True),nn.Linear(2048, 4096, bias=False),nn.BatchNorm1d(4096),nn.Dropout(0.5),nn.LeakyReLU(0.2, inplace=True),)self.convTran1 = nn.Sequential(nn.ConvTranspose2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),)self.conv1 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),)self.convTran2 = nn.Sequential(nn.ConvTranspose2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),)self.conv2 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),)self.convTran3 = nn.Sequential(nn.ConvTranspose2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),)self.conv3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),)self.convTran4 = nn.Sequential(nn.ConvTranspose2d(512, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),)self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),)self.convTran5 = nn.Sequential(nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(32),nn.LeakyReLU(0.2, inplace=True),)self.conv5 = nn.Sequential(nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(32),nn.LeakyReLU(0.2, inplace=True),)self.convTran6 = nn.Sequential(nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(16),nn.LeakyReLU(0.2, inplace=True),)self.conv6 = nn.Sequential(nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(16),nn.LeakyReLU(0.2, inplace=True),)self.conv7 = nn.Sequential(nn.Conv2d(16, 3, kernel_size=1, stride=1, bias=False),# 将输出约束到[-1,1]nn.Tanh())def forward(self, img_seeds):img_seeds = self.expand(img_seeds)# 将线性数据重组为二维图片imgs = img_seeds.view(-1, 64, 8, 8)# 用转置卷积放大图片imgs = self.convTran1(imgs)# 用最近邻插值放大图片imgs = amplify_img(imgs)# 压缩图片为16x16imgs = self.conv1(imgs)# 用转置卷积放大图片imgs = self.convTran2(imgs)# 用最近邻插值放大图片imgs = amplify_img(imgs)# 压缩图片为32x32imgs = self.conv2(imgs)# 用转置卷积放大图片imgs = self.convTran3(imgs)# 用最近邻插值放大图片imgs = amplify_img(imgs)# 压缩图片为64x64imgs = self.conv3(imgs)# 用转置卷积放大图片imgs = self.convTran4(imgs)# 用最近邻插值放大图片imgs = amplify_img(imgs)# 压缩图片为128x128imgs = self.conv4(imgs)# 用转置卷积放大图片imgs = self.convTran5(imgs)# 用最近邻插值放大图片imgs = amplify_img(imgs)# 压缩图片为256x256imgs = self.conv5(imgs)# 用转置卷积放大图片imgs = self.convTran6(imgs)# 用最近邻插值放大图片imgs = amplify_img(imgs)# 压缩图片为512x512imgs = self.conv6(imgs)# 1x1维度整合卷积层,整合为3通道图片imgs = self.conv7(imgs)return imgs# 全局判别器,传统gan
class D_net_global(nn.Module):def __init__(self):super(D_net_global,self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, kernel_size=6, stride=3, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 16, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(16),nn.LeakyReLU(0.2, inplace=True),)self.classifier = nn.Sequential(nn.Linear(5184, 1),#nn.Sigmoid(),)def forward(self, img):features = self.features(img)features = features.view(features.shape[0], -1)output = self.classifier(features)return output# 局部判别器,patchgan
class D_net_patch(nn.Module):def __init__(self):super(D_net_patch,self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 32, kernel_size=4, stride=2, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(32, 64, kernel_size=4, stride=2, bias=False),nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, kernel_size=4, stride=2, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, kernel_size=4, stride=2, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, kernel_size=4, stride=2, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, kernel_size=4, stride=1, bias=False),#nn.Sigmoid(),)def forward(self, img):# 利用patch判别器输出矩阵features = self.features(img)# 全局平均池化,输出尺寸1x1features = nn.functional.adaptive_avg_pool2d(features, 1)# 展平features = features.view(features.shape[0], -1)#print("patch shape", features.shape)return features# 返回对应的生成器
def get_G_model(from_old_model, device, model_path):model = G_net()# 从磁盘加载之前保存的模型参数if from_old_model:model.load_state_dict(torch.load(model_path))# 将模型加载到用于运算的设备的内存model = model.to(device)return model# 返回全局判别器的模型
def get_D_model_global(from_old_model, device, model_path):model = D_net_global()# 从磁盘加载之前保存的模型参数if from_old_model:model.load_state_dict(torch.load(model_path))# 将模型加载到用于运算的设备的内存model = model.to(device)return model# 返回局部判别器的模型
def get_D_model_patch(from_old_model, device, model_path):model = D_net_patch()# 从磁盘加载之前保存的模型参数if from_old_model:model.load_state_dict(torch.load(model_path))# 将模型加载到用于运算的设备的内存model = model.to(device)return model

2.3 训练代码train.py

from torch.utils.data import Dataset, DataLoader
import time
from torch.optim import AdamW, RMSprop, SGD, Adam
from model import *
from torchvision.utils import save_image
import random
from torch.autograd import Variable
import os
import cv2
from albumentations import Normalize, Compose, Resize, IAAAdditiveGaussianNoise, GaussNoise, HorizontalFlip, VerticalFlip
from albumentations.pytorch import ToTensorV2
from apex import amp
import pickle# ------------------------------------config------------------------------------
class config:# 设置种子数,配置是否要固定种子数seed = 26use_seed = False# 配置是否要从磁盘加载之前保存的模型参数继续训练from_old_model = False# 使用apex加速训练use_apex = True# 运行多少个epoch之后停止epochs = 20000# 配置batch sizebatchSize = 8# 每次保存模型时输出多少张样图save_img_size = 64# 训练图片输入分辨率,在训练前都预处理完成缩放img_size = 512# 配置喂入生成器的随机正态分布种子数有多少维(如果改动,需要在model中修改网络对应参数)img_seed_dim = 128# 有多大概率在训练判别器D时交换正确图片的标签和伪造图片的标签D_train_label_exchange = 0.1# 将数据集保存在内存中还是磁盘中# 小型数据集可以整个载入内存加快速度read_from = "Memory"# read_from = "Disk"# 保存模型参数文件的路径G_model_path = "G_model.pth"D_model_global_path = "D_model_global.pth"D_model_patch_path = "D_model_patch.pth"# 保存优化器参数文件的路径G_optimizer_path = "G_optimizer.pth"D_optimizer_global_path = "D_optimizer_global.pth"D_optimizer_patch_path = "D_optimizer_patch.pth"# 保存当前保存模型的历史总计训练epoch数epoch_record_path = "epoch_count.pkl"# 当连接大容量移动硬盘时,对每个版本文件都进行单独备份,以方便回退历史版本# 如果这个路径不存在,则什么都不做extra_backup_path = 'F:/version12'# 损失函数# # 使用均方差损失函数# criterion = nn.MSELoss()# 使用二分类交叉熵损失函数criterion = nn.BCEWithLogitsLoss()# 多少个epoch之后保存一次模型save_step = 10# ------------------------------------路径配置------------------------------------# 数据集来源img_path = "train_images/"# 输出图片的文件夹路径output_path = "output_images/"# 如果继续训练,则读取之前进行过多少次epoch的训练if from_old_model:with open(epoch_record_path, "rb") as file:last_epoch_number = pickle.load(file)else:last_epoch_number = 0# 固定随机数种子
def seed_all(seed):random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueif config.use_seed:seed_all(seed=config.seed)# -----------------------------------transforms------------------------------------
def get_transforms():# 缩放分辨率并转换到0-1之间return Compose([ HorizontalFlip(p=0.5),VerticalFlip(p=0.5),Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, p=1.0),ToTensorV2(p=1.0)])# ------------------------------------dataset------------------------------------
# 从磁盘读取数据的dataset
if config.read_from == "Disk":class image_dataset(Dataset):def __init__(self, file_list, img_path, transform):# files listself.file_list = file_listself.img_path = img_pathself.transform = transformdef __getitem__(self, index):image_path = self.img_path + self.file_list[index]img = cv2.imread(image_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = self.transform(image=img)['image']return imgdef __len__(self):return len(self.file_list)# 从内存读取数据的dataset
elif config.read_from == "Memory":class image_dataset(Dataset):def __init__(self, file_list, img_path, transform):self.imgs = []for file in file_list:image_path = img_path + fileimg = cv2.imread(image_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)self.transform = transformself.imgs.append(img)def __getitem__(self, index):img = self.imgs[index]img = self.transform(image=img)['image']return imgdef __len__(self):return len(self.imgs)# ------------------------------------main------------------------------------
def main():# 如果可以使用GPU运算,则使用GPU,否则使用CPUdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print("Use " + str(device))# 创建输出文件夹if not os.path.exists(config.output_path):os.mkdir(config.output_path)# 如果检测到额外存储路径存在,发出通报if os.path.exists(config.extra_backup_path):print(f"Extra backup path [{config.extra_backup_path}] exists, extra backup version will be saved.")# 创建dataset# create datasetfile_list = Nonefor path, dirs, files in os.walk(config.img_path, topdown=False):file_list = list(files)train_dataset = image_dataset(file_list, config.img_path, transform=get_transforms())train_loader = DataLoader(dataset=train_dataset, batch_size=config.batchSize, shuffle=True)# 从model中获取判别器D和生成器G的网络模型# 判别器分为global全局判别器与patch局部判别器G_model = get_G_model(config.from_old_model, device, config.G_model_path)D_model_global = get_D_model_global(config.from_old_model, device, config.D_model_global_path)D_model_patch = get_D_model_patch(config.from_old_model, device, config.D_model_patch_path)G_model.train()D_model_global.train()D_model_patch.train()# 定义G和D的优化器,此处使用AdamW优化器G_optimizer = Adam(G_model.parameters(), lr=1e-4)D_optimizer_global = Adam(D_model_global.parameters(), lr=3e-4)D_optimizer_patch = Adam(D_model_patch.parameters(), lr=3e-4)# D_optimizer_global = AdamW(D_model_global.parameters(), lr=3e-4, weight_decay=1e-6)# D_optimizer_global = RMSprop(D_model_global.parameters(), lr=3e-4, alpha=0.9)# D_optimizer_global = SGD(D_model_global.parameters(), lr=3e-4)# D_optimizer_patch = AdamW(D_model_patch.parameters(), lr=3e-4, weight_decay=1e-6)# D_optimizer_patch = RMSprop(D_model_patch.parameters(), lr=3e-4, alpha=0.9)# D_optimizer_patch = SGD(D_model_patch.parameters(), lr=3e-4)# 如果是读取之前训练的数据,则加载保存的优化器参数if config.from_old_model:G_optimizer.load_state_dict(torch.load(config.G_optimizer_path))D_optimizer_global.load_state_dict(torch.load(config.D_optimizer_global_path))D_optimizer_patch.load_state_dict(torch.load(config.D_optimizer_patch_path))# 损失函数criterion = config.criterion# 混合精度加速if config.use_apex:G_model, G_optimizer = amp.initialize(G_model, G_optimizer, opt_level="O1")D_model_global, D_optimizer_global = amp.initialize(D_model_global, D_optimizer_global, opt_level="O1")D_model_patch, D_optimizer_patch = amp.initialize(D_model_patch, D_optimizer_patch, opt_level="O1")# 记录训练时间train_start = time.time()# 定义标签,单值标签用于传统判别器,多值标签用于patch判别器# 定义真标签,使用标签平滑的策略,全0.9real_labels = Variable(torch.ones(config.batchSize, 1)-0.1).to(device)# 定义假标签,单向平滑,因此不对生成器标签进行平滑处理,全0fake_labels = Variable(torch.zeros(config.batchSize, 1)).to(device)# 开始训练的每一个epochfor epoch in range(config.epochs):print("start epoch "+str(epoch+1)+":")# 定义一些变量用于记录进度和损失batch_num = len(train_loader)D_loss_sum_global = 0D_loss_sum_patch = 0G_loss_sum = 0count = 0# 从dataloader中提取数据for index, images in enumerate(train_loader):count += 1# 将图片放入运算设备的内存images = images.to(device)# 记录真假标签是否被交换过exchange_labels = False# 有一定概率在训练判别器时交换labelif random.uniform(0, 1) < config.D_train_label_exchange:real_labels, fake_labels = fake_labels, real_labelsexchange_labels = True# 训练判断器D_globalD_optimizer_global.zero_grad()# 将随机的初始数据喂入生成器生成假图像img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)fake_images = G_model(img_seeds)# 用真样本输入判别器real_output = D_model_global(images)# 对于数据集末尾的数据,长度不够一个batch size时需要去除过长的真实标签if len(real_labels) > len(real_output):D_loss_real = criterion(real_output, real_labels[:len(real_output)])else:D_loss_real = criterion(real_output, real_labels)# 用假样本输入判别器fake_output = D_model_global(fake_images)D_loss_fake = criterion(fake_output, fake_labels)# 将真样本与假样本损失相加,得到判别器的损失D_loss_global = D_loss_real + D_loss_fakeD_loss_sum_global += D_loss_global.item()# 重置优化器D_optimizer_global.zero_grad()# 用损失更新判别器if config.use_apex:with amp.scale_loss(D_loss_global, D_optimizer_global) as scaled_loss:scaled_loss.backward()else:D_loss_global.backward()D_optimizer_global.step()# 训练判断器D_patchD_optimizer_patch.zero_grad()# 将随机的初始数据喂入生成器生成假图像img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)fake_images = G_model(img_seeds)# 用真样本输入判别器real_output = D_model_patch(images)# # 对于数据集末尾的数据,长度不够一个batch size时需要去除过长的真实标签if len(real_labels) > len(real_output):D_loss_real = criterion(real_output, real_labels[:len(real_output)])else:D_loss_real = criterion(real_output, real_labels)# 用假样本输入判别器fake_output = D_model_patch(fake_images)D_loss_fake = criterion(fake_output, fake_labels)# 将真样本与假样本损失相加,得到判别器的损失D_loss_patch = D_loss_real + D_loss_fakeD_loss_sum_patch += D_loss_patch.item()# 重置优化器D_optimizer_patch.zero_grad()# 用损失更新判别器if config.use_apex:with amp.scale_loss(D_loss_patch, D_optimizer_patch) as scaled_loss:scaled_loss.backward()else:D_loss_patch.backward()D_optimizer_patch.step()# 如果之前交换过真假标签,此时再换回来if exchange_labels:real_labels, fake_labels = fake_labels, real_labels# 将随机种子数喂入生成器G生成假数据img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)fake_images = G_model(img_seeds)# 将假数据输入判别器fake_output_global = D_model_global(fake_images)fake_output_patch = D_model_patch(fake_images)# 将假数据的判别结果与真实标签对比得到损失G_loss_global = criterion(fake_output_global, real_labels)G_loss_patch = criterion(fake_output_patch, real_labels)G_loss = G_loss_global + G_loss_patchG_loss_sum += G_loss.item()# 重置优化器G_optimizer.zero_grad()# 利用损失更新生成器Gif config.use_apex:with amp.scale_loss(G_loss, G_optimizer) as scaled_loss:scaled_loss.backward()else:G_loss.backward()G_optimizer.step()# 打印程序工作进度print("\rEpoch: %2d, Batch: %4d / %4d" % (epoch + 1, index + 1, batch_num), end="")print()if (config.last_epoch_number+epoch+1) % config.save_step == 0:print("Start saving model files to current path...", end='')# 在每N个epoch结束时保存模型参数到磁盘文件torch.save(G_model.state_dict(), config.G_model_path)torch.save(D_model_global.state_dict(), config.D_model_global_path)torch.save(D_model_patch.state_dict(), config.D_model_patch_path)# 在每N个epoch结束时保存优化器参数到磁盘文件torch.save(G_optimizer.state_dict(), config.G_optimizer_path)torch.save(D_optimizer_global.state_dict(), config.D_optimizer_global_path)torch.save(D_optimizer_patch.state_dict(), config.D_optimizer_patch_path)# 保存历史训练总数with open(config.epoch_record_path, "wb") as file:pickle.dump(config.last_epoch_number + epoch + 1, file, 1)# 在每N个epoch结束时输出一组生成器产生的图片到输出文件夹,拼接出一张含有config.save_img_size张图的大图save_imgs = []with torch.no_grad():for _ in range(config.save_img_size // config.batchSize):img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)fake_images = G_model(img_seeds).cuda().data# 将假图像缩放到[0,1]的区间fake_images = 0.5 * (fake_images + 1)fake_images = fake_images.clamp(0, 1)# 连接所有生成的图片然后用自带的save_image()函数输出到磁盘文件fake_images = fake_images.view(-1, 3, config.img_size, config.img_size)save_imgs.append(fake_images)save_imgs = torch.cat(save_imgs, 0)save_image(save_imgs, config.output_path+str(config.last_epoch_number + epoch + 1)+'.png')print("Success.")# 当连接大容量移动硬盘时,对每个版本文件都进行单独备份,以方便取出历史版本if os.path.exists(config.extra_backup_path):print("Start saving model files to extra backup path...", end='')extra_backup_path = f'{config.extra_backup_path}/{config.last_epoch_number + epoch + 1}/'# 创建该版本的历史总epoch存放目录if not os.path.exists(extra_backup_path):os.mkdir(extra_backup_path)# 保存模型参数到磁盘文件torch.save(G_model.state_dict(), extra_backup_path + config.G_model_path)torch.save(D_model_global.state_dict(), extra_backup_path + config.D_model_global_path)torch.save(D_model_patch.state_dict(), extra_backup_path + config.D_model_patch_path)# 保存优化器参数到磁盘文件torch.save(G_optimizer.state_dict(), extra_backup_path + config.G_optimizer_path)torch.save(D_optimizer_global.state_dict(), extra_backup_path + config.D_optimizer_global_path)torch.save(D_optimizer_patch.state_dict(), extra_backup_path + config.D_optimizer_patch_path)# 保存历史训练总数with open(extra_backup_path + config.epoch_record_path, "wb") as file:pickle.dump(config.last_epoch_number + epoch + 1, file, 1)# 保存对应版本的预览图片save_image(save_imgs, extra_backup_path + str(config.last_epoch_number + epoch + 1) + '.png')print("Success.")# 打印该epoch的损失,时间等数据用于参考print("D_loss_global:", round(D_loss_sum_global / count, 3))print("D_loss_patch:", round(D_loss_sum_patch / count, 3))print("G_loss:", round(G_loss_sum / count, 3))current_time = time.time()pass_time = int(current_time - train_start)time_string = str(pass_time // 3600) + " hours, " + str((pass_time % 3600) // 60) + " minutes, " + str(pass_time % 60) + " seconds."print("Time pass:", time_string)print()# 运行结束print("Done.")if __name__ == '__main__':main()

3. 最终效果

最终生成的图片效果(与原本训练集完全不同的图像):







可以看出,画面与以往的版本比起来已经有了巨大的提升,能够看到星球表面非常清晰的细节和立体感。

4. 缺陷与下一步改进

在训练过程中,模型仍然存在着很容易产生模式坍塌的缺陷,而且在少数情况下仍然会产生少数有网格状的粗糙图像,也容易在训练过程中产生质量的震荡,部分改进方式已经完成代码的编写工作,等待实际运行验证后将发布新的版本。
真希望能拥有一张3090来跑模型啊QAQ


下一集:彩色星球图片生成5:先验条件约束与LapGAN(pytorch版)

彩色星球图片生成4:转置卷积+插值缩放+卷积收缩(pytorch版)相关推荐

  1. 彩色星球图片生成5:先验条件约束与LapGAN(pytorch版)

    彩色星球图片生成5:先验条件约束与LapGAN(pytorch版) 1. 改进方面 1.1 训练集信息的人工标注 1.2 先验信息的条件约束 1.3 分类器C 1.4 LapGAN的分层残差拟合 2. ...

  2. 彩色星球图片生成3:代码改进(pytorch版)

    彩色星球图片生成3:代码改进(pytorch版) 1. 修改 1.1 预处理缩放 1.2 随机翻转 1.3 修改全局判别器 1.4 修改进度打印 2. 效果 3. 总结 上一集: 彩色星球图片生成2: ...

  3. 彩色星球图片生成1:使用Gan实现(pytorch版)

    彩色星球图片生成1:使用Gan实现(pytorch版) 1. 描述 2. 代码 2.1 模型代码model.py 2.2 训练代码main.py 3. 效果 4. 趣图 上一集: 使用Gan实现MNI ...

  4. 图形图像处理-之-高质量的快速的图像缩放 中篇 二次线性插值和三次卷积插值

    from:http://blog.csdn.net/housisong/article/details/1452249 图形图像处理-之-高质量的快速的图像缩放 中篇 二次线性插值和三次卷积插值    ...

  5. 好像还挺好玩的GAN2——Keras搭建DCGAN利用深度卷积神经网络实现图片生成

    好像还挺好玩的GAN2--Keras搭建DCGAN利用深度卷积神经网络实现图片生成 注意事项 学习前言 什么是DCGAN 神经网络构建 1.Generator 2.Discriminator 训练思路 ...

  6. 使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上

    阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读. 相关资料链接为:使用PyTorch构建GAN生成对抗 本次训练代码使用了 ...

  7. 一文搞懂转置卷积(反卷积)

    ↑ 点击蓝字 关注极市平台 作者丨土豆@知乎 来源丨https://zhuanlan.zhihu.com/p/158933003 极市导读 转置卷积在一些文献中也被称为反卷积,人们如果希望网络学习到上 ...

  8. 分组卷积/转置卷积/空洞卷积/反卷积/可变形卷积/深度可分离卷积/DW卷积/Ghost卷积/

    文章目录 1. 常规卷积 2. 分组卷积 3. 转置卷积 4. 空洞卷积 5. 可变形卷积 6. 深度可分离卷积(Separable Convolution) 6.1 Depthwise Convol ...

  9. 深度学习中常见卷积(普通卷积、1×1卷积、转置卷积、可分离卷积、膨胀(空洞)卷积、3D卷积)

      总是在网络上看到各种名词的卷积,但是有搞不懂是什么含义,于是结合网上查阅的资料,总结一下.目前比较常用的卷积主要有常规的卷积.1×1卷积.转置卷积.可分离卷积.膨胀卷积.3D卷积.   以下是一些 ...

最新文章

  1. BF算法优化-------KMP算法
  2. C++ NULL nullptr和0的区别
  3. ADO.NET实用经验 转载
  4. Linux chmod命令
  5. 《Python编程从入门到实践》记录之函数编写指南
  6. android超级管理员权限作用,Android获取超级管理员权限的实现
  7. mybatis-plus中like的使用说明
  8. 学术 | 如何写一篇合格的NLP论文
  9. 在微信小程序中绘制图表(part1)
  10. MTK 6589暗码切换开机LOGO(不适应NAND 的FLASH)
  11. 南大衣哥、北袁长标,恭喜谷传民新歌准备报送央视春晚
  12. K-S检验两样本分布是否相同
  13. 按位与、按位异或、按位取反
  14. 操作系统内核Hack:(二)底层编程基础
  15. 怎样P漫画脸?这三个简单方法分享给你
  16. ICG-PEG-Biotin结构式,吲哚菁绿-聚乙二醇-生物素 荧光染料聚乙二醇衍生物
  17. uniapp-小程序点击底部导航跳转并刷新页面
  18. MFC打开一个文件方法汇总
  19. 自适应网站适合什么行业
  20. java并发编程实战(二)

热门文章

  1. 谈谈互联网企业的人员分工和角色管理
  2. 画一个奥利奥(python+opencv)
  3. 深度学习基础——简单了解meta learning(来自李宏毅课程笔记)
  4. LV和Dior所属集团推出区块链平台以验证奢侈品真伪
  5. 字典中setdefault()函数用法
  6. java构造方法与重载
  7. CHINAPLAS线上展会启动!足不出户发现全球热点橡塑科技
  8. 如何在Linux中“快速阅读”?
  9. C#复习之委托(Delegate)和事件(Event)
  10. automake生成静态库文件_Automake 详解