pixtopix(像素到像素)

原文连接:https://arxiv.org/pdf/1611.07004.pdf
输入一个域的图片转换为另一个域的图片(白天照片转成黑夜)
如下图,输入标记图片,输出真实图片缺点就是训练集两个域的图片要一一对应,所以叫pixtopix,

网络结构有点复杂,用到了语义分割的UNET网络结构

数据集:
地址忘了,也是官方的,想起来补
代码:这里是建筑物labels to facade的例子

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image# jpg是原始图片
images_path = glob.glob(r'base\*.jpg')
annos_path = glob.glob(r'base\*.png')
# png是分割的图片transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((256, 256)),transforms.Normalize(0.5, 0.5)
])class CMP_dataset(data.Dataset):def __init__(self, imgs_path, annos_path):self.imgs_path = imgs_pathself.annos_path = annos_pathdef __getitem__(self, item):img_path = self.imgs_path[item]anno_path = self.annos_path[item]pil_img = Image.open(img_path)pil_img = transform(pil_img)anno_img = Image.open(anno_path)anno_img = anno_img.convert('RGB')pil_anno = transform(anno_img)return pil_anno, pil_imgdef __len__(self):return len(self.imgs_path)dataset = CMP_dataset(images_path, annos_path)
batchsize = 32
dataloader = data.DataLoader(dataset,batch_size=batchsize,shuffle=True)annos_batch, images_batch = next(iter(dataloader))for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):anno = (anno.permute(1, 2, 0).numpy()+1)/2img = (img.permute(1, 2, 0).numpy()+1)/2plt.subplot(3, 2, i*2+1)plt.title('input_img')plt.imshow(anno)plt.subplot(3, 2, i*2+2)plt.title('output_img')plt.imshow(img)
plt.show()# 定义下采样模块
class Downsample(nn.Module):def __init__(self, in_channels, out_channels):super(Downsample, self).__init__()self.conv_relu = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 2, 1),nn.LeakyReLU(inplace=True))self.bn = nn.BatchNorm2d(out_channels)def forward(self, x, is_bn=True):x = self.conv_relu(x)if is_bn:x = self.bn(x)return x# 定义上采样模块
class Upsample(nn.Module):def __init__(self, in_channels, out_channels):super(Upsample, self).__init__()self.upconv_relu = nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,output_padding=1),nn.LeakyReLU(inplace=True))self.bn = nn.BatchNorm2d(out_channels)def forward(self, x, is_drop=False):x = self.upconv_relu(x)x = self.bn(x)if is_drop:x = F.dropout2d(x)return x# 定义生成器,包含6个下采样,5上采样,1输出
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.down1 = Downsample(3, 64)   # 64,128,128self.down2 = Downsample(64, 128)  # 128,64,64self.down3 = Downsample(128, 256)  # 256,32,32self.down4 = Downsample(256, 512)  # 512, 16,16self.down5 = Downsample(512, 512)  # 512,8,8self.down6 = Downsample(512, 512)  # 512, 4,4self.up1 = Upsample(512, 512)      # 512 ,8,8self.up2 = Upsample(1024, 512)    # 512, 16,16self.up3 = Upsample(1024, 256)   # 256, 32,32self.up4 = Upsample(512, 128)   # 128,64,64self.up5 = Upsample(256, 64)   # 64,128,128self.last = nn.ConvTranspose2d(128, 3,kernel_size=3,stride=2,padding=1,output_padding=1)def forward(self,x):x1 = self.down1(x)x2 = self.down2(x1)x3 = self.down3(x2)x4 = self.down4(x3)x5 = self.down5(x4)x6 = self.down6(x5)x6 = self.up1(x6, is_drop=True)x6 = torch.cat([x6, x5], dim=1)x6 = self.up2(x6, is_drop=True)x6 = torch.cat([x6, x4], dim=1)x6 = self.up3(x6, is_drop=True)x6 = torch.cat([x6, x3], dim=1)x6 = self.up4(x6, is_drop=True)x6 = torch.cat([x6, x2], dim=1)x6 = self.up5(x6)x6 = torch.cat([x6, x1], dim=1)x6 = torch.tanh(self.last(x6))return x6# 定义判别器 输入anno + img
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.down1 = Downsample(6, 64)  # 64*128*128self.down2 = Downsample(64, 128)  # 128*64*64self.conv1 = nn.Conv2d(128, 256, 3)self.bn1 = nn.BatchNorm2d(256)self.conv2 = nn.Conv2d(256, 1, 3)def forward(self, anno, img):x = torch.cat([anno, img], axis=1)  # batch*6*h*wx = self.down1(x, is_bn=False)x = self.down2(x)x = F.dropout2d(self.bn1(F.leaky_relu(self.conv1(x))))x = torch.sigmoid(self.conv2(x))   # batch*1* 60*60return xdevice = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':print('using cuda:', torch.cuda.get_device_name(0))
else:print(device)Gen = Generator().to(device)
Dis = Discriminator().to(device)d_optimizer = torch.optim.Adam(Dis.parameters(), lr=1e-3, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(Gen.parameters(), lr=1e-3, betas=(0.5, 0.999))
# loss
# cgan损失
loss_fn = torch.nn.BCELoss()
# L1-loss 后面计算,求差绝对值的求和
# 绘图
def generator_images(model, test_anno, test_real):prediction = model(test_anno).permute(0, 2, 3, 1).detach().cpu().numpy()test_anno = test_anno.permute(0, 2, 3, 1).detach().cpu().numpy()test_real = test_real.permute(0, 2, 3, 1).detach().cpu().numpy()plt.figure(figsize=(10, 10))display_list = [test_anno[0], test_real[0], prediction[0]]title = ['input', 'ground truth', 'output']for i in range(3):plt.subplot(1, 3, i+1)plt.title(title[i])plt.imshow(display_list[i])plt.axis('off')plt.show()# 加载extend为测试
test_imgs_path = glob.glob('extended/*.jpg')
test_annos_path = glob.glob('extended/*.png')test_dataset = CMP_dataset(test_imgs_path, test_annos_path)
test_daloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize
)
# 返回一个批次annos_batch, images_batch = next(iter(dataloader))plt.figure(figsize=(6, 10))
for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):anno = (anno.permute(1, 2, 0).numpy()+1)/2img = (img.permute(1, 2, 0).numpy()+1)/2plt.subplot(3, 2, i*2+1)plt.title('input_img')plt.imshow(anno)plt.subplot(3, 2, i*2+2)plt.title('output_img')plt.imshow(img)
plt.show()annos_batch, images_batch = annos_batch.to(device), images_batch.to(device)
LAMBDA = 7  # L1损失权重D_loss = []
G_loss = []
for epoch in range(300):D_epoch_loss = 0G_epoch_loss = 0count = len(dataloader)for step, (annos, imgs) in enumerate(dataloader):imgs = imgs.to(device)annos = annos.to(device)d_optimizer.zero_grad()disc_real_output = Dis(annos, imgs)  # 输入真实成对图片d_real_loss = loss_fn(disc_real_output, torch.ones_like(disc_real_output,device=device))d_real_loss.backward()gen_output = Gen(annos)dis_gen_output = Dis(annos, gen_output.detach())d_fake_loss = loss_fn(dis_gen_output, torch.zeros_like(dis_gen_output,device=device))d_fake_loss.backward()disc_loss = d_real_loss + d_fake_lossd_optimizer.step()disc_gen_out = Dis(annos, gen_output)gen_loss_crossentropyloss = loss_fn(disc_gen_out,torch.ones_like(disc_gen_out,device=device))gen_l1_loss = torch.mean(torch.abs(gen_output - imgs))gen_loss = LAMBDA * gen_l1_loss + gen_loss_crossentropylossgen_loss.backward()g_optimizer.step()with torch.no_grad():D_epoch_loss += disc_loss.item()G_epoch_loss += gen_loss.item()with torch.no_grad():D_epoch_loss /= countG_epoch_loss /= countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print('Epoch', epoch)generator_images(Gen, annos_batch, images_batch)

给动漫素描自动上色的(AI上色)移步我的kaggle
https://www.kaggle.com/code/jiyuanhai/pix2pix-test-pytorch

CycleGAN

这个厉害

GAN-生成对抗网络(Pytorch)合集(2)--pixtopix-CycleGAN相关推荐

  1. 使用SAGAN生成二次元人物头像(GAN生成对抗网络)--pytorch实现

    这是训练250epoch左右的成果. 之前的文章里面,我们使用了残差网络的形式实现生成器与辨别器,它理论上可以实现很不错的效果,但有一个很致命的缺点,就是训练太慢,很难见到成果. 这一次,我们实现了一 ...

  2. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

    文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...

  3. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

    文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...

  4. 深度学习(九) GAN 生成对抗网络 理论部分

    GAN 生成对抗网络 理论部分 前言 一.Pixel RNN 1.图片的生成模型 2.Pixel RNN 3.Pixel CNN 二.VAE(Variational Autoencoder) 1.VA ...

  5. 深度学习 GAN生成对抗网络-1010格式数据生成简单案例

    一.前言 本文不花费大量的篇幅来推导数学公式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络. 二.GAN概念 生成对抗网络(Generative Adversarial Networks ...

  6. 54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例

    1.54.GAN(生成对抗网络) 1.54.1.什么是GAN 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative ...

  7. GAN (生成对抗网络) 手写数字图片生成

    GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...

  8. GAN -- 生成对抗网络

    GAN -- 生成对抗网络 生成对抗网络(英语:Generative Adversarial Network,简称GAN)是非监督式学习的一种方法,通过让两个神经网络相互博弈的方式进行学习.该方法由伊 ...

  9. 【轩说AI】生成模型(2)—— GAN生成对抗网络 + WGAN + Conditional GAN + Cycle GAN

    文章目录 GAN生成对抗网络(Generative Adversarial Network) 神经网络的本质就是一个函数,一个用于拟合的函数 生成模型面临的前所未有的问题 GAN解决这一问题的思想 O ...

最新文章

  1. 7.1.1 [Enterprise Library]缓存应用程序块场景和目标
  2. 汇编实现地址对应值相加
  3. awk学习实战-原创
  4. 视觉slam发展史--从开始到未来
  5. 搭建H1ve-ctfd以及如何部署题目
  6. Nginx服务器版本升级需求分析
  7. DotNetty 实现 Modbus TCP 系列 (三) Codecs Handler
  8. guid判断是否有效_让我们一起啃算法----有效的括号
  9. 具有Tron效果的JavaFX 2 Form
  10. 解决SecureCRT登陆Linux显示中文乱码
  11. android手机分享app,Android Pie如何快捷分享文件至特定App
  12. 本年度读书计划-看几本必须好好琢磨的书
  13. VisualStudio2019 DLL生成并使用教程(C++)最详细Demo教程
  14. 后端服务接口都在测试什么?怎么测?
  15. mysql with 查询_mysql笔记(6)-多表查询之with
  16. 二进制文件转成文本保存,并可以读回
  17. 软件缺陷分析—软件测试之犯罪心理学
  18. android 调色板,所不了解的Android调色板
  19. python gtk_python 创建gtk应用程序
  20. 机器视觉测量原理及其优势

热门文章

  1. 【快速检索,稳定出版,强大委员会Speaker阵容】ICCCS 2022|第7届通信计算机大会
  2. python中的函数及面向对象的知识点
  3. python脱离环境运行_python 生成exe脱离python环境运行
  4. WEB前端开发,认认真真学4个月能学到初级吗?
  5. 中断、陷阱、软中断之间的异同
  6. 防火墙基础配置(二)
  7. MySQL创建数据库、创建数据表
  8. HTML5 —新增标签
  9. ExecuteNonQuery()返回-1的问题及解决
  10. 2.Enable ADB integration' to be enabled.