好像还挺好玩的GAN重制版4——Pytorch搭建SRGAN平台进行图片超分辨率提升

  • 学习前言
  • 源码下载地址
  • 网络构建
    • 一、什么是SRGAN
    • 二、生成网络的构建
    • 三、判别网络的构建
  • 训练思路
    • 一、判别器的训练
    • 二、生成器的训练
  • 利用SRGAN生成图片
    • 一、数据集的准备
    • 二、数据集的处理
    • 三、模型训练

学习前言

我又死了我又死了我又死了!

源码下载地址

https://github.com/bubbliiiing/srgan-pytorch

喜欢的可以点个star噢。

网络构建

一、什么是SRGAN

SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

如果将SRGAN看作一个黑匣子,其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。

该文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节

SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感

二、生成网络的构建


生成网络的构成如上图所示,生成网络的作用是输入一张低分辨率图片,生成高分辨率图片。

SRGAN的生成网络由三个部分组成。
1、低分辨率图像进入后会经过一个卷积+RELU函数
2、然后经过B个残差网络结构,每个残差结构都包含两个卷积+标准化+RELU,还有一个残差边。
3、然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升

前两个部分用于特征提取,第三部分用于提高分辨率。

import math
import torch
from torch import nnclass ResidualBlock(nn.Module):def __init__(self, channels):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(channels)self.prelu = nn.PReLU(channels)self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(channels)def forward(self, x):short_cut = xx = self.conv1(x)x = self.bn1(x)x = self.prelu(x)x = self.conv2(x)x = self.bn2(x)return x + short_cutclass UpsampleBLock(nn.Module):def __init__(self, in_channels, up_scale):super(UpsampleBLock, self).__init__()self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)self.pixel_shuffle = nn.PixelShuffle(up_scale)self.prelu = nn.PReLU(in_channels)def forward(self, x):x = self.conv(x)x = self.pixel_shuffle(x)x = self.prelu(x)return xclass Generator(nn.Module):def __init__(self, scale_factor, num_residual=16):upsample_block_num = int(math.log(scale_factor, 2))super(Generator, self).__init__()self.block_in = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4),nn.PReLU(64))self.blocks = []for _ in range(num_residual):self.blocks.append(ResidualBlock(64))self.blocks = nn.Sequential(*self.blocks)self.block_out = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64))self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))self.upsample = nn.Sequential(*self.upsample)def forward(self, x):x = self.block_in(x)short_cut = xx = self.blocks(x)x = self.block_out(x)upsample = self.upsample(x + short_cut)return torch.tanh(upsample)

三、判别网络的构建


判别网络的构成如上图所示:

SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。
对于判断网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果

判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。

判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

实现代码如下:

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(64),nn.LeakyReLU(0.2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),nn.AdaptiveAvgPool2d(1),nn.Conv2d(512, 1024, kernel_size=1),nn.LeakyReLU(0.2),nn.Conv2d(1024, 1, kernel_size=1))def forward(self, x):batch_size = x.size(0)return torch.sigmoid(self.net(x).view(batch_size))

训练思路

SRGAN的训练可以分为生成器训练和判别器训练:
每一个step中一般先训练判别器,然后训练生成器。

一、判别器的训练

训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签

因此判别器的训练步骤如下:

1、随机选取batch_size个真实高分辨率图片。
2、利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。

二、生成器的训练

训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。

因此生成器的训练步骤如下:

1、将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。
2、将真实高分辨率图像和虚假高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss

利用SRGAN生成图片

SRGAN的库整体结构如下:

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。

训练过程中,可在results文件夹内查看训练效果:

好像还挺好玩的GAN重制版4——Pytorch搭建SRGAN平台进行图片超分辨率提升相关推荐

  1. 好像还挺好玩的GAN重制版2——Keras搭建SRGAN平台进行图片超分辨率提升

    好像还挺好玩的GAN重制版2--Keras搭建SRGAN平台进行图片超分辨率提升 学习前言 源码下载地址 网络构建 一.什么是SRGAN 二.生成网络的构建 三.判别网络的构建 训练思路 一.判别器的 ...

  2. 好像还挺好玩的GAN8——SRGAN实现图像的分辨率提升

    好像还挺好玩的GAN8--SRGAN实现图像的分辨率提升 注意事项 学习前言 什么是SRGAN 代码与训练数据的下载 神经网络组成 1.生成网络 2.判别网络 训练思路 1.对判别模型进行训练 2.对 ...

  3. 好像还挺好玩的GAN2——Keras搭建DCGAN利用深度卷积神经网络实现图片生成

    好像还挺好玩的GAN2--Keras搭建DCGAN利用深度卷积神经网络实现图片生成 注意事项 学习前言 什么是DCGAN 神经网络构建 1.Generator 2.Discriminator 训练思路 ...

  4. 憨批的语义分割重制版9——Pytorch 搭建自己的DeeplabV3+语义分割平台

    憨批的语义分割重制版9--Pytorch 搭建自己的DeeplabV3+语义分割平台 注意事项 学习前言 什么是DeeplabV3+模型 代码下载 DeeplabV3+实现思路 一.预测部分 1.主干 ...

  5. 憨批的语义分割重制版6——Pytorch 搭建自己的Unet语义分割平台

    憨批的语义分割重制版6--Pytorch 搭建自己的Unet语义分割平台 注意事项 学习前言 什么是Unet模型 代码下载 Unet实现思路 一.预测部分 1.主干网络介绍 2.加强特征提取结构 3. ...

  6. 语义分割重制版1——Pytorch 搭建自己的Unet语义分割平台

    转载:https://blog.csdn.net/weixin_44791964/article/details/108866828?spm=1001.2014.3001.5501 对应b站视频:ht ...

  7. python tkinter火柴人_用Python实现童年小游戏俄罗斯方块!别说还挺好玩!

    原标题:用Python实现童年小游戏俄罗斯方块!别说还挺好玩! 前言 大三上学期的程序设计实训大作业,挑了其中一个我认为最简单的的<图书管理系统>来写.用python写是因为py有自带的G ...

  8. 简陋无比的 Python 抠图方案,好像还挺像回事儿?

    Python编程学习点击免费领取 背景介绍 从某APP中截取了我的背单词曲线之后,我敏锐地发现了蕴藏在其中的数学规律. 每六个月达到一次峰值,峰值的高度不断减小.为了在图上画一条线来拟合这个折线,我打 ...

  9. 憨批的语义分割重制版11——Keras 搭建自己的HRNetV2语义分割平台

    憨批的语义分割重制版11--Keras 搭建自己的HRNetV2语义分割平台 学习前言 什么是HRNetV2模型 代码下载 HRNetV2实现思路 一.预测部分 1.主干网络介绍 a.Section- ...

最新文章

  1. C语言接收一个整数划分成5的倍数,整数划分为连续整数;整数划分
  2. 好程序员大数据笔记之:Hadoop集群搭建
  3. java scriptrunner_ScriptRunner.java
  4. .zip.001 -- .zip.003解压缩
  5. linux中pak命令,如何在Linux系统中安装Flatpak
  6. 乔布斯18岁求职信拍卖价22.24万美元,值吗?
  7. python 菜鸟-Python3 集合
  8. cdh 安装_使用Cloudera的CDH部署Hadoop:第二步,安装JDK
  9. 简单构建一个xmlhttp对象池合理创建和使用xmlhttp对象
  10. 银行行号和银行代码是一样的吗?区别是什么?
  11. toc如何判断 word_在WORD为什么点插入目录显示{TOC\o1
  12. axis2 异常OMElement
  13. Chrome安装程序遇到错误 0xe0000008解决办法
  14. 《 ERP高级计划》书的解读-APS算法分析之七分解技术(DT)(蔡颖)(转)
  15. 无法给变量添加属性导致出问题
  16. 从程序员到项目经理(21):谁都需要成就感
  17. 格力(GREE)家用移动空调免安装一体机空调KY-23NK 清灰拆装教程
  18. 柯尼卡/KonicaFTP扫描设置
  19. 电子商务思维导图精品荟萃:电子商务思维导图大全
  20. After Effects错误: CT generic: not ascii (83::2)(非原创)

热门文章

  1. 条件概率、全概率、先验概率、后验概率、类条件概率
  2. 怎样制作网关服务器,如何设计自己的网关(一)
  3. ubuntu 14.04开机出现错误“Error found when loading /root/.profile”解决(root用户登录时才会出现)
  4. Vue + Element UI 表格分页记忆选中
  5. 计算机网络:移动IP
  6. 国内各个界面库比较,告诉你怎么选择界面库?
  7. Mysql设置初始化密码和修改密码
  8. Arduino UNO控制带AB相磁通量式编码器电动推杆(测试阻尼)实录(L289N电机驱动)
  9. 绘制线性回归和多元线性回归
  10. 浅谈Spring定时任务