论文传送门:https://arxiv.org/pdf/1609.04802.pdf

SRGAN模型目的:输入低分辨率图像,生成高分辨率图像。

生成网络由三部分构成:

①卷积+PReLU激活函数;

②(卷积+BN+PReLU+卷积+BN,连接残差边)x16+卷积+BN,连接残差边;

③(卷积+像素重组+PReLU)x2+卷积;

①②用于提取图像特征,③用于图像上采样,实现超分。

生成网络的目的:输入低分辨率图像,输出高分辨率图像。

鉴别网络类似VGG结构,由(卷积+BN+LeakyReLU)组成。

鉴别网络目的:输入高分辨图像,判断输入图像是真实图像还是生成图像。

class D_Block(nn.Module):  # 定义判别器中结构块(卷积+标准化+激活函数)def __init__(self, in_channel, out_channle, strid):  # 初始化方法,参数:输入通道数,输出通道数,卷积步长super(D_Block, self).__init__()  # 继承初始化方法self.block = nn.Sequential(  # 结构块nn.Conv2d(in_channel, out_channle, 3, strid, 1),  # convnn.BatchNorm2d(out_channle),  # bnnn.LeakyReLU(0.2)  # leakyrelu)def forward(self, x):  # 前传函数return self.block(x)class Discriminator(nn.Module):  # 定义判别器def __init__(self):  # 初始化方法super(Discriminator, self).__init__()  # 继承初始化方法self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)  # convself.leakyrelu = nn.LeakyReLU(0.2)  # leakyreluself.downsample = nn.Sequential(  # 下采样结构块,与VGG相同D_Block(64, 64, 2),  # 卷积+标准化+激活函数D_Block(64, 128, 1),  # 卷积+标准化+激活函数D_Block(128, 128, 2),  # 卷积+标准化+激活函数D_Block(128, 256, 1),  # 卷积+标准化+激活函数D_Block(256, 256, 2),  # 卷积+标准化+激活函数D_Block(256, 512, 1),  # 卷积+标准化+激活函数D_Block(512, 512, 2)  # 卷积+标准化+激活函数)self.linear = nn.Sequential(  # 线性映射结构块nn.AdaptiveAvgPool2d(1),  # 平均自适应池化nn.Conv2d(512, 1024, 1, 1, 0),  # conv,使用1x1卷积代替全连接nn.LeakyReLU(0.2),  # leakyrelunn.Conv2d(1024, 1, 1, 1, 0),  # conv,使用1x1卷积代替全连接nn.Sigmoid()  # sigmoid)def forward(self, x):  # 前传函数,输入高分辨率图像x = self.leakyrelu(self.conv1(x))  # conv+leakyrelu,(n,3,256,256)-->(n,64,256,256)x = self.downsample(x)  # 下采样,(n,64,256,256)-->(n,64,128,128)-->(n,128,128,128)-->(n,128,64,64)-->(n,256,64,64)-->(n,256,32,32)-->(n,512,32,32)-->(n,512,16,16)x = self.linear(x)  # 线性映射,(n,512,16,16)-->(n,512,1,1)-->(n,1024,1,1)-->(n,1,1,1)x = x.squeeze()  # 删除多余的维度,(n,1,1,1)-->(n)return x  # 返回图片真假的得分class G_Block(nn.Module):  # 定义生成器中结构块(残差结构)def __init__(self, channel):  # 初始化方法,参数:通道数,残差结构前后通道数不变super(G_Block, self).__init__()  # 继承初始化方法self.block = nn.Sequential(  # 结构块nn.Conv2d(channel, channel, 3, 1, 1),  # convnn.BatchNorm2d(channel),  # bnnn.PReLU(channel),  # prelu,带参数的relu激活函数nn.Conv2d(channel, channel, 3, 1, 1),  # convnn.BatchNorm2d(channel)  # bn)def forward(self, x):  # 前传函数return x + self.block(x)  # F(x) + xclass Generator(nn.Module):  # 定义生成器def __init__(self):  # 初始化方法super(Generator, self).__init__()  # 继承初始化方法self.conv1 = nn.Conv2d(3, 64, 9, 1, 4)  # convself.prelu1 = nn.PReLU(64)  # preluself.blocks = []  # 存放残差块的列表for _ in range(16):  # 共16个残差块self.blocks.append(G_Block(64))  # 添加残差块self.blocks = nn.Sequential(*self.blocks)  # 列表转化为模型结构序列self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)  # convself.bn2 = nn.BatchNorm2d(64)  # bnself.upsample = nn.Sequential(  # 上采样块nn.Conv2d(64, 256, 3, 1, 1),  # convnn.PixelShuffle(2),  # pixelshuffle,像素重组,将通道拆分重组至(H,W)nn.PReLU(64),  # prelunn.Conv2d(64, 256, 3, 1, 1),  # convnn.PixelShuffle(2),  # pixelshufflenn.PReLU(64),  # prelunn.Conv2d(64, 3, 9, 1, 4)  # conv)def forward(self, x):  # 前传函数,输入低分辨率图像x = self.prelu1(self.conv1(x))  # conv+prelu,(n,3,64,64)-->(n,64,64,64)x += self.bn2(self.conv2(self.blocks(x)))  # F(x)+x,F(x):16层残差结构+conv+bn,(n,64,64,64)-->(n,64,64,64)x = self.upsample(x)  # 上采样,(n,64,64,64)-->(n,256,64,64)-->(n,64,128,128)-->(n,256,128,128)-->(n,64,256,256)-->(n,3,256,256)return x  # 返回高分辨率图像

SRGAN模型——pytorch实现相关推荐

  1. DIN模型pytorch代码逐行细讲

    DIN模型pytorch代码逐行细讲 文章目录 DIN模型pytorch代码逐行细讲 一.DIN模型的结构 二.代码介绍 三.导入包 四.导入数据 五.数据处理 六.模型定义 七.封装训练集,测试集 ...

  2. python人脸识别训练模型生产_深度学习-人脸识别DFACE模型pytorch训练(二)

    首先介绍一下MTCNN的网络结构,MTCNN有三种网络,训练网络的时候需要通过三部分分别进行,每一层网络都依赖前一层网络产生训练数据供当前训练网络,这样也推动了两个网络之间的最小损耗. Pnet Rn ...

  3. 轴承故障诊断经典模型pytorch复现(一)——WDCNN

    论文地址:<A New Deep Learning Model for Fault Diagnosis with Good Anti-Noise and Domain Adaptation Ab ...

  4. 车牌识别 远距离监控视角 自创简化模型 Pytorch

    甲方一拍脑门,让我去实现车牌识别,还是远距离监控视角的,真开心. 数据?呵~ 不会有人期待甲方提供数据吧?? 先逛逛某宝,一万张车辆图片,0.4元/张. 甲方:阿巴阿巴- 嗯,那没事了. 再逛逛全球同 ...

  5. Seq2Seq模型PyTorch版本

    Seq2Seq模型介绍以及Pytorch版本代码详解 一.Seq2Seq模型的概述 Seq2Seq是一种循环神经网络的变种,是一种端到端的模型,包括 Encoder编码器和 Decoder解码器部分, ...

  6. resnet18到resnet152模型pytorch实现

    resnet在深度学习领域的重要性不言而喻,自从15年resnet提出后,被各种深度学习模型大量引用.得益于其残差结构的设计,使得深度学习模型可以训练更深层的网络.常见的resnet有resnet18 ...

  7. python调用yolov3模型,pytorch版yolov3训练自己的数据(数据,代码,预训练模型下载链接)...

    1.数据预处理 准备图片数据(JPEGImages),标注文件(Annotations),以及划分好测试集训练集的索引号(ImageSets) 修改代码中voc_label.py文件中的路径以及类别, ...

  8. WGAN模型——pytorch实现

    论文传送门:https://arxiv.org/pdf/1701.07875.pdf 参考文章:令人拍案叫绝的Wasserstein GAN - 知乎​​​​​​ WGAN的目的:解决GAN的梯度不稳 ...

  9. WGAN-gp模型——pytorch实现

    论文传送门:https://arxiv.org/pdf/1704.00028.pdf WGAN存在的问题:在WGAN中,为使得判别器D(x)满足Lipschitz连续条件,从而对网络参数进行了[-c, ...

最新文章

  1. zabbix3.0安装
  2. 博士申请 | 美国罗格斯大学王灏助理教授招收机器学习方向博士生
  3. ssms 缺少索引信息_MySQL3:索引
  4. 一个memset函数使用时的坑
  5. SAP License:系统退货处理流程
  6. php7 viewmodel,【初念科技】| php框架实例: Laravel之Model Observer模型
  7. T-SQL Enhancement in SQL Server 2005[下篇]
  8. java提取(获取)博客信息(内容)
  9. 购物商城Web开发第二十二天
  10. PSP3000破解原理——缓冲区溢出漏洞随谈
  11. 鸿蒙音响怎么调,好音质是调出来的 大师教你咋调试音响
  12. 用Python输出100以内的质数
  13. WebStorm+Chrome插件JetBrains IDE Support进行实时调试
  14. 5.5 Go语言项目实战:多人聊天室
  15. Tomcat启动,提示 The JRE_HOME environment variable is not defined correctly 问题。
  16. AIOT:基于智能家居谈AIOT
  17. The projiect you are opening contains compilation errors
  18. RTE2020首批嘉宾公布 技术+行业专家携手顶级投资人参会
  19. 解决matplotlib绘制图片时plt.savefig()后图片全黑的问题
  20. 面向对象的3个基本对象

热门文章

  1. java四则运算考试系统_小学生四则运算出题软件-基于java控制台的实现
  2. Scratch软件编程等级考试四级——20200913
  3. 腾讯人力资源体系全曝光 附下载
  4. ADSP-21489的图形化编程详解(3:音效开发例程-直通三个例程讲清楚)
  5. 彻底解决浏览器被ie.75011.net搜狗导航劫持的问题
  6. 2022CTF培训(一)脱壳技术Hook入门
  7. fm25cl64 linux,FM25CL64铁电存储器的问题
  8. java file 工具_JAVA文件类工具
  9. Ant Design of Vue 中 日期时间控件 禁止选中的(日期——)设置
  10. MacOS 查看硬盘分区参数