http://arxiv.org/abs/1611.07004

2017年的一篇CVPR。是非常经典的一个模型。

pix2pix是基于Conditional-GAN,也就是CGAN。相比于一般的CGAN输入为一个较为常见的label(one-hot编码的标签)。这里将输入控制为一个图片。

CGAN的模型思路如下:

如果是图片作为输入的话,其实要求就会高了很多了。

同时,pix2pix也是之前提到的DualGAN,还有还没有提到的CycleGAN这些模型的基石。

不同于后续的模型,在要求上更加宽松,不需要成对的数据,pix2pix其实对于数据集做了要求的,必须是成对的数据来用于训练。

pix2pix的主要贡献:

提出PatchGAN的思路:简单来讲就是,D的输出不是一个scale(标量),而是一个矩阵Patch * Patch。然后来计算这个矩阵和real data(全一矩阵),以及fake data(全0矩阵)之间的距离(这里常用L2)。

为了捕捉高频的信息(这里使用PatchGAN的模型);低频的信息用L1norm来保证。

使用L1范数,而不是L2范数:这里是指衡量生成数据和真实数据之间的距离的时候给G添加的一个损失。这个损失的距离计算方式不是我们常用的L2范数,而是L1范数,目的就是为了捕获低频的信息。(使用L1的模糊程度会小很多)

不用z做G的输入,而是添加Dropout:这个也是DualGAN在这学。实验结果显示,这样效果更加好。

G使用U-Net结构而不是Encoder-Decoder结构:DualGAN关于G的设计就是学这个。也就是需要把encoder的信息concat到对称的Decoder的部分。避免低维的信息在计算的过程中消失掉,使得能更好的保存图像的原始特征。(有点像风格迁移的时候,需要保存初始图像该有的信息)


恰饭


相关阅读

  • CGAN模型理论以及Python实现

  • DualGAN模型理论以及Python实现

实验

实验部分基本上是在DualGAN的代码上改改。但是实际上,是DualGAN学习pix2pix。

第一列是,模型生成的素描图;

第二列是,真实数据中对应的素描图;

第三列是,真实数据中的输入图(实拍照片)。

dataloader.py

import torch.utils.data as dataimport globimport osimport torchvision.transforms as transformsfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npimport torchimport piexifimport imghdrimport numberstry:    import accimageexcept ImportError:    accimage = Noneclass MyCrop(object):    """Crops the given PIL Image at the center.    Args:        size (sequence or int): Desired output size of the crop. If size is an            int instead of sequence like (h, w), a square crop (size, size) is            made.    """    def __init__(self, i, j, size):        if isinstance(size, numbers.Number):            self.size = (int(size), int(size))        else:            self.size = size        self.i, self.j = i, j    def __call__(self, img):        """        Args:            img (PIL Image): Image to be cropped.        Returns:            PIL Image: Cropped image.        """        if not isinstance(img, Image.Image):            raise TypeError('img should be PIL Image. Got {}'.format(type(img)))        th, tw = self.size        return img.crop((self.j, self.i, self.j + tw, self.i + th))    def __repr__(self):        return self.__class__.__name__ + '(size={0})'.format(self.size)class MyDataset(data.Dataset):    def __init__(self, path_sketch, path_photo, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False,                 default=False):        self.Train = Train        self.sketch_dataset = self.init_dataset(path_sketch, Len=Len, resize=resize, img_type=img_type,                                                remove_exif=remove_exif, sketch=True, default=default)        self.photo_dataset = self.init_dataset(path_photo, Len=Len, resize=resize, img_type=img_type,                                               remove_exif=remove_exif, sketch=False, default=default)    def init_dataset(self, path, Len=-1, resize=-1, img_type='png', remove_exif=False, sketch=True, default=False):        if resize != -1:            if default:                transform = transforms.Compose([                    transforms.Resize(resize),                    transforms.CenterCrop(resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                ])            elif sketch:                transform = transforms.Compose([                    transforms.Resize(resize),                    MyCrop(30, 0, resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))                ])            else:                transform = transforms.Compose([                    transforms.Resize(resize + 20),                    MyCrop(15, 26, resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))                ])        else:            transform = transforms.Compose([                transforms.ToTensor(),            ])        img_format = '*.%s' % img_type        if remove_exif:            for name in glob.glob(os.path.join(path, img_format)):                try:                    piexif.remove(name)  # 去除exif                except Exception:                    continue        # imghdr.what(img_path) 判断是否为损坏图片        if Len == -1:            dataset = [np.array(transform(Image.open(name).convert("L"))) for name in                       glob.glob(os.path.join(path, img_format)) if imghdr.what(name)]        else:            dataset = [np.array(transform(Image.open(name).convert("L"))) for name in                       glob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]        dataset = np.array(dataset)        dataset = torch.Tensor(dataset)        return dataset    def __len__(self):        return len(self.photo_dataset)    def __getitem__(self, idx):        return self.sketch_dataset[idx], self.photo_dataset[idx]if __name__ == '__main__':    sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=10, img_type='jpg')    print(len(dataset))    for i in range(5):        plt.imshow(np.squeeze(dataset[i][0].numpy()) * 0.5 + 0.5, cmap='gray')        plt.show()        print(dataset[i][0].max(), dataset[i][0].min())        plt.imshow(np.squeeze(dataset[i][1].numpy()) * 0.5 + 0.5, cmap='gray')        plt.show()        print(dataset[i][1].max(), dataset[i][1].min())

main.py

import osimport torchfrom torch.utils.data import Dataset, DataLoaderimport torch.nn as nnfrom model import Generator, Discriminator, gp_loss# from model import gp_loss# from github_model import Generator, Discriminatorimport torchvisionfrom dataloader import MyDatasetimport matplotlib.pyplot as pltimport itertoolsimport numpy as npimport torchvision.utils as vutilsif __name__ == '__main__':    LR = 0.0002    EPOCH = 100  # 50    BATCH_SIZE = 4    drop_rate = 0.5    lam = 10    TRAINED = False    sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=88, img_type='jpg')    data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)    torch.cuda.empty_cache()    if not TRAINED:        G = Generator(1, drop_rate=drop_rate).cuda()        D = Discriminator(1).cuda()    else:        G = torch.load("G.pkl").cuda()        D = torch.load("D.pkl").cuda()    optimizerG = torch.optim.Adam(G.parameters(), lr=LR)    optimizerD = torch.optim.Adam(D.parameters(), lr=LR)    l1_c = nn.L1Loss()    mse_c = nn.MSELoss()    # PATCH SHAPE IS (1, 12, 12)    real_label = torch.ones((BATCH_SIZE, 1, 12, 12)).cuda()    fake_label = torch.zeros((BATCH_SIZE, 1, 12, 12)).cuda()    for epoch in range(EPOCH):        tmpD, tmpG = 0, 0        for step, (x, y) in enumerate(data_loader):            x = x.cuda()            y = y.cuda()            G_x = G(y)            D_xy = D(x, y)  # PatchGAN            D_gxy = D(G_x, y)            # print(D_xy.shape, D_gxy.shape)            D_loss = mse_c(D_xy, real_label) + mse_c(D_gxy, fake_label)            G_loss = mse_c(D_gxy, real_label) + lam * l1_c(G_x, x)            optimizerG.zero_grad()            G_loss.backward(retain_graph=True)            optimizerG.step()            tmpD_ = D_loss.cpu().detach().data            tmpG_ = G_loss.cpu().detach().data            tmpD += tmpD_            tmpG += tmpG_        tmpD /= (step + 1)        tmpG /= (step + 1)        print(            'epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG)        )        if (epoch + 1) % 5 == 0:            fig = plt.figure(figsize=(10, 10))            plt.axis("off")            plt.imshow(np.transpose(                vutils.make_grid(torch.stack([G_x[0].cpu().detach(), x[0].cpu().detach(), y[0].cpu().detach(), ]),                                 nrow=3, padding=0, normalize=True, scale_each=True), (1, 2, 0)), cmap='gray')            plt.show()    torch.save(G, 'G.pkl')    torch.save(D, 'D.pkl')

model.py

import osimport torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionfrom torch.utils.data import DataLoaderfrom dataloader import MyDatasetimport torch.autograd as autogradclass ResidualBlock(nn.Module):    def __init__(self, in_channel=1, out_channel=1, stride=1):        super(ResidualBlock, self).__init__()        self.weight_layer = nn.Sequential(            nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1),            nn.BatchNorm2d(out_channel),            nn.ReLU(),            nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1),        )        self.active_layer = nn.Sequential(            nn.BatchNorm2d(out_channel),            nn.ReLU()        )    def forward(self, x):        residual = x        x = self.weight_layer(x)        x += residual        return self.active_layer(x)class Generator(nn.Module):    def __init__(self, input_channel=1, drop_rate=0.5):        super(Generator, self).__init__()        self.c_e1 = nn.Sequential(            nn.Conv2d(in_channels=input_channel, out_channels=64, kernel_size=4, stride=2, padding=1),            nn.LeakyReLU(0.2),            ResidualBlock(in_channel=64, out_channel=64))        self.c_e2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2),                                  nn.BatchNorm2d(128), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=128, out_channel=128))        self.c_e3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=256, out_channel=256), nn.Dropout2d(drop_rate))        self.c_e4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=256, out_channel=256), nn.Dropout2d(drop_rate))        self.d_e1 = nn.Sequential(            nn.ConvTranspose2d(in_channels=128, out_channels=input_channel, kernel_size=4, stride=2, padding=1),            nn.Tanh())        self.d_e2 = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=64, kernel_size=4, stride=2),                                  nn.BatchNorm2d(64), nn.ReLU(), )        self.d_e3 = nn.Sequential(nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=5, stride=2),                                  nn.BatchNorm2d(128), nn.Dropout2d(drop_rate))        self.d_e4 = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.Dropout2d(drop_rate))    def forward(self, x):        e1 = self.c_e1(x)        e2 = self.c_e2(e1)        e3 = self.c_e3(e2)        e4 = self.c_e4(e3)        d4 = self.d_e4(e4)        # print(d4.shape, e3.shape)        d4 = torch.cat([d4, e3], dim=1)        # d4 = d4 + e3        d3 = self.d_e3(d4)        # print(d3.shape, e2.shape)        d3 = torch.cat([d3, e2], dim=1)        # d3 = d3 + e2        d2 = self.d_e2(d3)        # print(d2.shape, e1.shape)        d2 = torch.cat([d2, e1], dim=1)        # d2 = d2 + e1        # print(d2.shape)        d1 = self.d_e1(d2)        # print(d1.shape)        return d1class Discriminator(nn.Module):    def __init__(self, input_size):        super(Discriminator, self).__init__()        strides = [1, 2, 2, 2]        padding = [0, 1, 1, 1]        channels = [input_size * 2,                    64, 128, 256, 1]  # 1表示一维        kernels = [4, 4, 4, 3]        model = []        for i, stride in enumerate(strides):            model.append(                nn.Conv2d(                    in_channels=channels[i],                    out_channels=channels[i + 1],                    stride=stride,                    kernel_size=kernels[i],                    padding=padding[i]                )            )            model.append(nn.BatchNorm2d(channels[i + 1]))            model.append(                nn.LeakyReLU(0.2)            )        self.main = nn.Sequential(*model)    def forward(self, fake_x, real_x):        x = torch.cat([fake_x, real_x], dim=1)        x = self.main(x)        return x  # .view(x.shape[0], -1)        # return self.fc(x)def gp_loss(D, real_x, fake_x, cuda=False):    if cuda:        alpha = torch.rand((real_x.shape[0], 1, 1, 1)).cuda()    else:        alpha = torch.rand((real_x.shape[0], 1, 1, 1))    x_ = (alpha * real_x + (1 - alpha) * fake_x).requires_grad_(True)    y_ = D(x_)    # cal f'(x)    grad = autograd.grad(        outputs=y_,        inputs=x_,        grad_outputs=torch.ones_like(y_),        create_graph=True,        retain_graph=True,        only_inputs=True,    )[0]    grad = grad.view(x_.shape[0], -1)    gp = ((grad.norm(2, dim=1) - 1) ** 2).mean()    return gpif __name__ == '__main__':    drop_rate = 0.5    G = Generator(1, drop_rate)    sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=10, img_type='jpg')    train_loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)    D = Discriminator(1)    rs = ResidualBlock(1, 1, stride=1)  # only for stride 1    for step, (x, y) in enumerate(train_loader):        print(x.shape)        print(G(x).shape)        print(D(x, x).shape)        print(rs(x).shape)        break

judge.py

import numpy as npimport torchimport matplotlib.pyplot as pltfrom model import Generator, Discriminatorfrom dataloader import MyDatasetfrom torch.utils.data import Dataset, DataLoaderimport itertoolsimport torchvision.utils as vutilsif __name__ == '__main__':    BATCH_SIZE = 3    TIMES = 5    img_shape = (1, 28, 28)    G = torch.load("G.pkl").cuda()    D = torch.load("D.pkl").cuda()    sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=BATCH_SIZE * TIMES, img_type='jpg')    data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)    for step, (x, y) in enumerate(data_loader):        x = x.cuda()        y = y.cuda()        G_x = G(y)        fig = plt.figure(figsize=(10, 10))        plt.axis("off")        plt.imshow(np.transpose(            vutils.make_grid(                torch.stack([G_x.cpu().detach(), x.cpu().detach(), y.cpu().detach()]).transpose(1, 0).contiguous().view(                    BATCH_SIZE * 3, 1, 96, 96), nrow=3, padding=0,                normalize=True, scale_each=True), (1, 2, 0)), cmap='gray')        plt.savefig(str(step) + '.png', dpi=300)        plt.show()

ds证据理论python实现_pix2pix模型理论以及Python实现相关推荐

  1. ds证据理论python实现_ALI模型理论以及Python实现

    https://openreview.net/forum?id=B1ElR4cgg 模型结构和明天要发BiGAN模型一模一样,但是两篇论文的作者都是独立完成自己的内容的.而且从写作的风格来看emmm完 ...

  2. python医学图像读取_对python读取CT医学图像的实例详解

    需要安装OpenCV和SimpleItk. SimpleItk比较简单,直接pip install SimpleItk即可. 代码如下: #coding:utf-8 import SimpleITK ...

  3. 小猪的Python学习之旅 —— 19.Python微信自动好友验证,自动回复,发送群聊链接

    小猪的Python学习之旅 -- 19.Python微信自动好友验证,自动回复,发送群聊链接 标签:Python 一句话概括本文: 上一节利用itchat这个库,做了小宇宙早报的监测与转发, 本节新增 ...

  4. 使用python预测基金_使用python先知3 1创建预测

    使用python预测基金 This tutorial was created to democratize data science for business users (i.e., minimiz ...

  5. python学习之路:python连接阿里云ODPS

    python学习之路:python连接阿里云ODPS 前言 本人最近在学习使用ODPS,希望把学习过程记录下来,方便自己查阅. 1.安装ODPS pip install ODPS 2.连接阿里云odp ...

  6. 用python计算个人所得税税率表,Python 小案例 计算个人所得税

    Python 小案例 计算个人所得税 Python 小案例 计算个人所得税 #coding=utf-8 monthMoney=input("请输入月收入:") ds=3500 #扣 ...

  7. Python学习细节总结以及python与c语言区别比较(1)

    本文python学习基于廖雪峰老师的学习网站:字符串和编码 - 廖雪峰的官方网站 (liaoxuefeng.com),其内容相对完整,适合初学者学习.由于楼主之前有c语言的学习经验,在此本文仅对其中与 ...

  8. 面试前赶紧看了5道Python Web面试题,Python面试题No17

    目录 本面试题题库,由公号:非本科程序员 整理发布 第1题: Flask中的请求上下文和应用上下文是什么? 第2题:django中间件的使用? 第3题: django开发中数据做过什么优化? 第4题: ...

  9. python queue 调试_学Python不是盲目的,是有做过功课认真去了解的

    有多少伙伴是因为一句'人生苦短,我用Python'萌生想法学Python的!我跟大家更新过很多Python学习教程普及过多次的Python相关知识,不过大家还是还得计划一下Python学习路线!Pyt ...

最新文章

  1. java 以什么开头_判断字符串以什么开头
  2. 用Python编写博客导出工具
  3. 多行单列CV小技能----Alt加鼠标滚轮
  4. Spring Framework中的作用域代理
  5. 【转】细说.NET中的多线程 (二 线程池)
  6. 漆桂林 | 知识图谱的应用
  7. 2020年Q3最具社交影响力KOL盘点报告
  8. Hadoop大数据平台环境搭建注意事项,分布式数据采集,武汉数道云科技
  9. qt调用外部程序(exe)
  10. 瑞星对Windows7捆绑杀毒软件等消息的回应
  11. [luogu2414 NOI2011]阿狸的打字机 (AC自动机)
  12. php array函数 array_sum 求数组所有值和
  13. ARM开发7.3.2 基础实训( 2 ) 单个按键的输入系统设计( 2)--LPC21XX
  14. 【原创】微信最新表情js代码
  15. F28335的ePWM模块
  16. dp hp oracle 备份软件_HPDP备份软件设置
  17. 【Python】模拟登陆并抓取拉勾网信息(selenium+phantomjs)
  18. 【律联云知产课堂】商标注册需要什么条件?
  19. AI会玩魔方了!全是自学,比任何人都快,包括机器人
  20. Tensorflow使用object_detetcion安装教程

热门文章

  1. 人生需要积极勇敢的去面对
  2. CN DBMove 过程中一些最常见最需要注意的问题
  3. 【转】一个基于Ajax的通用(组合)查询(ASP.NET)
  4. VMware与 Device/Credential Guard 不兼容.
  5. FFmpeg之获取yuv分量(二十二)
  6. Fuchsia编译及运行
  7. 国内的Android SDK镜像
  8. 安卓ashmem(匿名共享内存映射)学习native篇
  9. 人生的四大天规,越早明白,越有福气
  10. Java基础教程【第四章:Java流程控制】