SRGAN模型——pytorch实现
论文传送门:https://arxiv.org/pdf/1609.04802.pdf
SRGAN模型目的:输入低分辨率图像,生成高分辨率图像。
生成网络由三部分构成:
①卷积+PReLU激活函数;
②(卷积+BN+PReLU+卷积+BN,连接残差边)x16+卷积+BN,连接残差边;
③(卷积+像素重组+PReLU)x2+卷积;
①②用于提取图像特征,③用于图像上采样,实现超分。
生成网络的目的:输入低分辨率图像,输出高分辨率图像。
鉴别网络类似VGG结构,由(卷积+BN+LeakyReLU)组成。
鉴别网络目的:输入高分辨图像,判断输入图像是真实图像还是生成图像。
class D_Block(nn.Module): # 定义判别器中结构块(卷积+标准化+激活函数)def __init__(self, in_channel, out_channle, strid): # 初始化方法,参数:输入通道数,输出通道数,卷积步长super(D_Block, self).__init__() # 继承初始化方法self.block = nn.Sequential( # 结构块nn.Conv2d(in_channel, out_channle, 3, strid, 1), # convnn.BatchNorm2d(out_channle), # bnnn.LeakyReLU(0.2) # leakyrelu)def forward(self, x): # 前传函数return self.block(x)class Discriminator(nn.Module): # 定义判别器def __init__(self): # 初始化方法super(Discriminator, self).__init__() # 继承初始化方法self.conv1 = nn.Conv2d(3, 64, 3, 1, 1) # convself.leakyrelu = nn.LeakyReLU(0.2) # leakyreluself.downsample = nn.Sequential( # 下采样结构块,与VGG相同D_Block(64, 64, 2), # 卷积+标准化+激活函数D_Block(64, 128, 1), # 卷积+标准化+激活函数D_Block(128, 128, 2), # 卷积+标准化+激活函数D_Block(128, 256, 1), # 卷积+标准化+激活函数D_Block(256, 256, 2), # 卷积+标准化+激活函数D_Block(256, 512, 1), # 卷积+标准化+激活函数D_Block(512, 512, 2) # 卷积+标准化+激活函数)self.linear = nn.Sequential( # 线性映射结构块nn.AdaptiveAvgPool2d(1), # 平均自适应池化nn.Conv2d(512, 1024, 1, 1, 0), # conv,使用1x1卷积代替全连接nn.LeakyReLU(0.2), # leakyrelunn.Conv2d(1024, 1, 1, 1, 0), # conv,使用1x1卷积代替全连接nn.Sigmoid() # sigmoid)def forward(self, x): # 前传函数,输入高分辨率图像x = self.leakyrelu(self.conv1(x)) # conv+leakyrelu,(n,3,256,256)-->(n,64,256,256)x = self.downsample(x) # 下采样,(n,64,256,256)-->(n,64,128,128)-->(n,128,128,128)-->(n,128,64,64)-->(n,256,64,64)-->(n,256,32,32)-->(n,512,32,32)-->(n,512,16,16)x = self.linear(x) # 线性映射,(n,512,16,16)-->(n,512,1,1)-->(n,1024,1,1)-->(n,1,1,1)x = x.squeeze() # 删除多余的维度,(n,1,1,1)-->(n)return x # 返回图片真假的得分class G_Block(nn.Module): # 定义生成器中结构块(残差结构)def __init__(self, channel): # 初始化方法,参数:通道数,残差结构前后通道数不变super(G_Block, self).__init__() # 继承初始化方法self.block = nn.Sequential( # 结构块nn.Conv2d(channel, channel, 3, 1, 1), # convnn.BatchNorm2d(channel), # bnnn.PReLU(channel), # prelu,带参数的relu激活函数nn.Conv2d(channel, channel, 3, 1, 1), # convnn.BatchNorm2d(channel) # bn)def forward(self, x): # 前传函数return x + self.block(x) # F(x) + xclass Generator(nn.Module): # 定义生成器def __init__(self): # 初始化方法super(Generator, self).__init__() # 继承初始化方法self.conv1 = nn.Conv2d(3, 64, 9, 1, 4) # convself.prelu1 = nn.PReLU(64) # preluself.blocks = [] # 存放残差块的列表for _ in range(16): # 共16个残差块self.blocks.append(G_Block(64)) # 添加残差块self.blocks = nn.Sequential(*self.blocks) # 列表转化为模型结构序列self.conv2 = nn.Conv2d(64, 64, 3, 1, 1) # convself.bn2 = nn.BatchNorm2d(64) # bnself.upsample = nn.Sequential( # 上采样块nn.Conv2d(64, 256, 3, 1, 1), # convnn.PixelShuffle(2), # pixelshuffle,像素重组,将通道拆分重组至(H,W)nn.PReLU(64), # prelunn.Conv2d(64, 256, 3, 1, 1), # convnn.PixelShuffle(2), # pixelshufflenn.PReLU(64), # prelunn.Conv2d(64, 3, 9, 1, 4) # conv)def forward(self, x): # 前传函数,输入低分辨率图像x = self.prelu1(self.conv1(x)) # conv+prelu,(n,3,64,64)-->(n,64,64,64)x += self.bn2(self.conv2(self.blocks(x))) # F(x)+x,F(x):16层残差结构+conv+bn,(n,64,64,64)-->(n,64,64,64)x = self.upsample(x) # 上采样,(n,64,64,64)-->(n,256,64,64)-->(n,64,128,128)-->(n,256,128,128)-->(n,64,256,256)-->(n,3,256,256)return x # 返回高分辨率图像
SRGAN模型——pytorch实现相关推荐
- DIN模型pytorch代码逐行细讲
DIN模型pytorch代码逐行细讲 文章目录 DIN模型pytorch代码逐行细讲 一.DIN模型的结构 二.代码介绍 三.导入包 四.导入数据 五.数据处理 六.模型定义 七.封装训练集,测试集 ...
- python人脸识别训练模型生产_深度学习-人脸识别DFACE模型pytorch训练(二)
首先介绍一下MTCNN的网络结构,MTCNN有三种网络,训练网络的时候需要通过三部分分别进行,每一层网络都依赖前一层网络产生训练数据供当前训练网络,这样也推动了两个网络之间的最小损耗. Pnet Rn ...
- 轴承故障诊断经典模型pytorch复现(一)——WDCNN
论文地址:<A New Deep Learning Model for Fault Diagnosis with Good Anti-Noise and Domain Adaptation Ab ...
- 车牌识别 远距离监控视角 自创简化模型 Pytorch
甲方一拍脑门,让我去实现车牌识别,还是远距离监控视角的,真开心. 数据?呵~ 不会有人期待甲方提供数据吧?? 先逛逛某宝,一万张车辆图片,0.4元/张. 甲方:阿巴阿巴- 嗯,那没事了. 再逛逛全球同 ...
- Seq2Seq模型PyTorch版本
Seq2Seq模型介绍以及Pytorch版本代码详解 一.Seq2Seq模型的概述 Seq2Seq是一种循环神经网络的变种,是一种端到端的模型,包括 Encoder编码器和 Decoder解码器部分, ...
- resnet18到resnet152模型pytorch实现
resnet在深度学习领域的重要性不言而喻,自从15年resnet提出后,被各种深度学习模型大量引用.得益于其残差结构的设计,使得深度学习模型可以训练更深层的网络.常见的resnet有resnet18 ...
- python调用yolov3模型,pytorch版yolov3训练自己的数据(数据,代码,预训练模型下载链接)...
1.数据预处理 准备图片数据(JPEGImages),标注文件(Annotations),以及划分好测试集训练集的索引号(ImageSets) 修改代码中voc_label.py文件中的路径以及类别, ...
- WGAN模型——pytorch实现
论文传送门:https://arxiv.org/pdf/1701.07875.pdf 参考文章:令人拍案叫绝的Wasserstein GAN - 知乎 WGAN的目的:解决GAN的梯度不稳 ...
- WGAN-gp模型——pytorch实现
论文传送门:https://arxiv.org/pdf/1704.00028.pdf WGAN存在的问题:在WGAN中,为使得判别器D(x)满足Lipschitz连续条件,从而对网络参数进行了[-c, ...
最新文章
- zabbix3.0安装
- 博士申请 | 美国罗格斯大学王灏助理教授招收机器学习方向博士生
- ssms 缺少索引信息_MySQL3:索引
- 一个memset函数使用时的坑
- SAP License:系统退货处理流程
- php7 viewmodel,【初念科技】| php框架实例: Laravel之Model Observer模型
- T-SQL Enhancement in SQL Server 2005[下篇]
- java提取(获取)博客信息(内容)
- 购物商城Web开发第二十二天
- PSP3000破解原理——缓冲区溢出漏洞随谈
- 鸿蒙音响怎么调,好音质是调出来的 大师教你咋调试音响
- 用Python输出100以内的质数
- WebStorm+Chrome插件JetBrains IDE Support进行实时调试
- 5.5 Go语言项目实战:多人聊天室
- Tomcat启动,提示 The JRE_HOME environment variable is not defined correctly 问题。
- AIOT:基于智能家居谈AIOT
- The projiect you are opening contains compilation errors
- RTE2020首批嘉宾公布 技术+行业专家携手顶级投资人参会
- 解决matplotlib绘制图片时plt.savefig()后图片全黑的问题
- 面向对象的3个基本对象
热门文章
- java四则运算考试系统_小学生四则运算出题软件-基于java控制台的实现
- Scratch软件编程等级考试四级——20200913
- 腾讯人力资源体系全曝光 附下载
- ADSP-21489的图形化编程详解(3:音效开发例程-直通三个例程讲清楚)
- 彻底解决浏览器被ie.75011.net搜狗导航劫持的问题
- 2022CTF培训(一)脱壳技术Hook入门
- fm25cl64 linux,FM25CL64铁电存储器的问题
- java file 工具_JAVA文件类工具
- Ant Design of Vue 中 日期时间控件 禁止选中的(日期——)设置
- MacOS 查看硬盘分区参数