pytorch复现经典生成对抗式的超分辨率网络
论文原文:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
论文的中文翻译:翻译:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
网络结构如下图所示:
上面和下面分别是生成网络和判别网络:
废话不多说,直接看代码。比较不喜欢一堆废话的博客,懂得都懂,最有用的就是代码!
代码的实现参考pytorch torchvision中的网络实现优点:模块化、简洁易读、而且容易修改网络宽度和深度(方便调整网络架构做对比实验,消融实验)。
实现生成网络:
# -*- coding: utf-8 -*-
# @Use :
# @Time : 2022/8/17 21:32
# @FileName: models.py
# @Software: PyCharm
# @inference:https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.pyimport torch
from torch import nn
import torchvision
from torch import Tensorclass GeneratorBasicBlock(nn.Module):"""生成器重复的部分"""def __init__(self, channel, kernel_size) -> None:super(GeneratorBasicBlock, self).__init__()self.channel = channelself.conv1 = nn.Conv2d(in_channels=channel, out_channels=channel,kernel_size=(kernel_size, kernel_size),stride=(1, 1), padding=(1, 1))self.bn1 = nn.BatchNorm2d(num_features=channel)self.p_relu1 = nn.PReLU()self.conv2 = nn.Conv2d(in_channels=channel, out_channels=channel,kernel_size=(kernel_size, kernel_size),stride=(1, 1), padding=(1, 1))self.bn2 = nn.BatchNorm2d(num_features=channel)def forward(self, x: Tensor) -> Tensor:"""前向推断:param x::return:"""identity = xout = self.conv1(x)out = self.bn1(out)out = self.p_relu1(out)out = self.conv2(out)out = self.bn2(out)out += identityreturn outclass PixelShufflerBlock(nn.Module):"""生成器最后的pixelshuffler"""def __init__(self, in_channel, out_channel) -> None:super(PixelShufflerBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.pixels_shuffle = nn.PixelShuffle(upscale_factor=2)self.prelu = nn.PReLU()def forward(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.pixels_shuffle(out)out = self.prelu(out)return outclass Generator(nn.Module):"""生成器"""def __init__(self, config) -> None:# Generator parameterssuper(Generator, self).__init__()large_kernel_size = config.G.large_kernel_size # = 9small_kernel_size = config.G.small_kernel_size # = 3n_channels = config.G.n_channels # = 64n_blocks = config.G.n_blocks # = 16base_block_type = config.G.base_block_type # 'depthwise_conv_residual' # 'conv_residual' or 'depthwise_conv_residual'# base blockif base_block_type == 'depthwise_conv_residual':self.repeat_block = GeneratorDepthwiseBlockif base_block_type == 'conv_residual':self.repeat_block = GeneratorBasicBlockself.conv1 = nn.Conv2d(in_channels=3, out_channels=n_channels,kernel_size=(large_kernel_size, large_kernel_size),stride=(1, 1), padding=(4, 4))self.prelu1 = nn.PReLU()self.B_residul_block = self._make_layer(self.repeat_block, n_channels,n_blocks, small_kernel_size)self.conv2 = nn.Conv2d(in_channels=n_channels, out_channels=n_channels,kernel_size=(small_kernel_size, small_kernel_size),stride=(1, 1), padding=(1, 1))self.bn1 = nn.BatchNorm2d(n_channels)self.pixel_shuffle_block1 = PixelShufflerBlock(n_channels, 4 * n_channels)self.pixel_shuffle_block2 = PixelShufflerBlock(n_channels, 4 * n_channels)self.conv3 = nn.Conv2d(in_channels=n_channels, out_channels=3,kernel_size=(large_kernel_size, large_kernel_size),stride=(1, 1), padding=(4, 4))def _make_layer(self, base_block, n_channels, n_block, kernel_size) -> nn.Sequential:"""构建重复的B个基本块:param base_block: 基本块:param n_channels: 块里面的通道数:param n_block: 块数:return:"""layers = []self.base_block = base_blockfor _ in range(n_block):layers.append(self.base_block(n_channels, kernel_size))return nn.Sequential(*layers)def _forward_impl(self, x: Tensor) -> Tensor:"""前向的实现"""out = self.conv1(x)out = self.prelu1(out)identity = outout = self.B_residul_block(out)out = self.conv2(out)out = self.bn1(out)out += identityout = self.pixel_shuffle_block1(out)out = self.pixel_shuffle_block2(out)out = self.conv3(out)return outdef forward(self, x: Tensor) -> Tensor:"""前向"""return self._forward_impl(x)
判别网络实现:
class DiscriminatorBaseblock(nn.Module):"""判别器的基本块"""def __init__(self, inchannel, out_chanel, kernel_size, stride):super(DiscriminatorBaseblock, self).__init__()self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=out_chanel,kernel_size=kernel_size, stride=stride, padding=(1, 1))self.bn1 = nn.BatchNorm2d(out_chanel)self.act1 = nn.LeakyReLU(0.2)def forward(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.bn1(out)out = self.act1(out)return outclass Discriminator(nn.Module):"""判别器"""
def __init__(self, config):super(Discriminator, self).__init__()# Discriminator parameterskernel_size = config.D.kernel_size = 3n_channels = config.D.n_channels = 64n_blocks = config.D.n_blocks = 8fc_size = config.D.fc_size = 1024self.conv1 = nn.Conv2d(in_channels=3, out_channels=n_channels,kernel_size=(kernel_size, kernel_size), stride=(1, 1), padding=(1, 1))self.leaky_relu1 = nn.LeakyReLU(0.2)conv_configs = [[3, 64, 2], # 配置二维数组[3, 128, 1],[3, 128, 2],[3, 256, 1],[3, 256, 2],[3, 512, 1],[3, 512, 2]]self.base_blocks = self._make_layer(n_channels, DiscriminatorBaseblock, conv_configs)self.dense1 = nn.Linear(512 * 6 * 6, 1024)self.leaky_relu2 = nn.LeakyReLU(0.2)self.dense2 = nn.Linear(1024, 1)self.sigmod1 = nn.Sigmoid()def _make_layer(self, in_channel, base_block, conv_configs: list) -> nn.Sequential:""":param base_block: DiscriminatorBaseblock:param conv_configs: (kernel, channel, stride):return:"""layers = []self.base_block = base_blockself.in_channel = in_channelfor i in range(len(conv_configs)):# in_channel, out_chanel, kernel_size, stridekernel_size = (conv_configs[i][0], conv_configs[i][0])stride = (conv_configs[i][2], conv_configs[i][2])out_channel = conv_configs[i][1]layers.append(self.base_block(self.in_channel, out_channel, kernel_size, stride))self.in_channel = out_channelreturn nn.Sequential(*layers)def _forward_impl(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.leaky_relu1(out)out = self.base_blocks(out)out = torch.flatten(out, 1)out = self.dense1(out)out = self.leaky_relu2(out)out = self.dense2(out)out = self.sigmod1(out)return outdef forward(self, x: Tensor) -> Tensor:"""前向"""return self._forward_impl(x)
pytorch复现经典生成对抗式的超分辨率网络相关推荐
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下
文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上
文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...
- 分离潜变量自动编码器超分辨率网络 SLAESR
分离潜变量自动编码器超分辨率网络 Separating latent automatic encoder super-resolution network. 简称:SLAESR 关于训练部分,请看仓库 ...
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别
文章目录 1 生成对抗网络基本概念 2 生成对抗网络建模 2.1 建立MnistDataset类 2.2 建立鉴别器 2.3 测试鉴别器 2.4 Mnist生成器制作 3 模型的训练 4 模型表现的判 ...
- 经典论文复现 | 基于深度学习的图像超分辨率重建
过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含"伪代码".这是今年 AAAI ...
- 生成对抗式网络 GAN及其衍生CGAN、DCGAN、WGAN、LSGAN、BEGAN等原理介绍、应用介绍及简单Tensorflow实现
生成式对抗网络(GAN,Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.学界大牛Yann Lecun 曾说,令他最激 ...
- 【深度学习】GAN生成对抗式网络原理
生成模型和判别模型 理解对抗网络,首先要了解生成模型和判别模型.判别模型比较好理解,就像分类一样,有一个判别界限,通过这个判别界限去区分样本.从概率角度分析就是获得样本x属于类别y的概率,是一个条件概 ...
- 对抗性神经网络百度百科,生成对抗式神经网络
深度学习什么是对抗式神经网络? 对抗式神经网络GAN让机器学会"左右互搏"GAN网络的原理本质上就是这两篇小说中主人公练功的人工智能或机器学习版本. 一个网络中有两个角色,修炼的过 ...
- 【PyTorch】12 生成对抗网络实战——用GAN生成动漫头像
GAN 生成动漫头像 1. 获取数据 2. 用GAN生成 2.1 Generator 2.2 Discriminator 2.3 其它细节 2.4 训练思路 3. 全部代码 4. 结果展示与分析 小结 ...
最新文章
- 敏捷软件开发实践——估算与计划(01)
- 词性标注,实体识别,ICTCLAS分析系统的学习
- OK,让我们开始吧!
- mysql dba系统学习(3)mysql的启动停止
- ImportError:cannot import name ‘WordCloud’的解决办法
- System.Text.Json 中的字符编码
- 架构必备「RESTful API」设计技巧经验总结
- linux日志服务器配置在哪个文件,Linux中日志的基本配置(syslog)
- JS 停留几秒后返回上一页
- C站最全Python库总结丨标准库+高级库
- 104. 二叉树的最大深度【LeetCode】
- 我的Java之路(7)
- 通向码农的道路(enet开源翻译计划 二)
- Label高度根据内容变化SnapKi
- 193.有效电话号码
- 毫米波雷达探测技术,雷达人体存在感应器,实时检测静止存在应用
- JAVA中输出姓王的姓名,没出过国的人,不配姓王
- android简繁体切换快捷键,我的Android进阶之旅------Android中如何高效率的进行简繁体转换...
- Excel表格中无法中间插入新行列! 提示:在当前工作表的最后一行或列中,存在非空单元格,解决方案
- 超千万人同时在线,抖音快手,是怎么抗住高并发?
热门文章
- 【API调用】人脸检测+人脸属性(旷视 / 百度)
- 周训练计划之(韦德分化训练法:胸、肩、背、腿、腹)
- asterisk连接sip139网络电话
- 经济专业需要学c语言吗,学c语言要什么基础?
- jquery判断元素内容是否为空的方法
- 桂林理工大学 程序设计实践课程 实习报告
- 通信行业英文缩写整理(待更新)
- 小程序商城和社区团购小程序,商家应该选哪个?
- SPDA-CNN:Unifying Semantic Part Detection and Abstraction for Fine-grained Recognition
- [HTML] HTML简单实现网络测速