DenseNet稠密连接网络(pyTorch源码)
原文:https://blog.csdn.net/qq_43360533/article/details/107448369
目录
1 Introduction
3 过渡层
4 DenseNet模型
5 获取数据并训练
1 Introduction
与ResNet的主要区别在于,DenseNet里模块B的输出不是像ResNet那样和模块A的输出相加,而是在通道维上连结。这样模块A的输出可以直接传入模块B后面的层。在这个设计里,模块A直接跟模块B后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。
如果用公式表示的话,传统的网络在 [公式] 层的输出为:
而对于ResNet,增加了来自上一层输入的identity函数:
在DenseNet中,会连接前面所有层作为输入:
DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。
在DenseBlock中,各个层的特征图大小一致,是如下图的结构,可以在channel维度上连接。DenseBlock中的非线性组合函数H(·)采用的是BN+ReLU+3x3 Conv的结构,如下图所示。另外值得注意的一点是,与ResNet不同,所有DenseBlock中各个层卷积之后均输出k个特征图,即得到的特征图的channel数为 k,或者说采用 k 个卷积核。k 在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的k (比如12),就可以得到较佳的性能。假定输入层的特征图的channel数为 k0,那么l 层输入的channel数为 k0+k(l-1) ,因此随着层数增加,尽管 k设定得较小,DenseBlock的输入会非常多,不过这是由于特征重用所造成的,每个层仅有 k个特征是自己独有的。
2 稠密块
DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构,我们首先在conv_block
函数里实现这个结构。
import time
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def conv_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))return blk
# 稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,将每块的输入和
# 输出在通道维上连结。
class DenseBlock(nn.Module):def __init__(self, num_convs, in_channels, out_channels):super(DenseBlock, self).__init__()net = []for i in range(num_convs):in_c = in_channels + i * out_channelsnet.append(conv_block(in_c, out_channels))self.net = nn.ModuleList(net)self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数def forward(self, X):for blk in self.net:Y = blk(X)X = torch.cat((X, Y), dim=1) # 在通道维上将输入和输出连结return X
blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])
3 过渡层
# 由于每个稠密块连结都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用
# 来控制模型复杂度。它通过1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和
# 宽,从而进一步降低模型复杂度。
def transition_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.AvgPool2d(kernel_size=2, stride=2))return blk
对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。
blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])
4 DenseNet模型
DenseNet首先使用同ResNet一样的单卷积层和最大池化层。
类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,可以设置每个稠密块使用多少个卷积层,这里设成4,与ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。
最后接上全局池化层和全连接层来输出。
net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))num_channels, growth_rate = 64, 32 # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]for i, num_convs in enumerate(num_convs_in_dense_blocks):DB = DenseBlock(num_convs, num_channels, growth_rate)net.add_module("DenseBlosk_%d" % i, DB)# 上一个稠密块的输出通道数num_channels = DB.out_channels# 在稠密块之间加入通道数减半的过渡层if i != len(num_convs_in_dense_blocks) - 1:net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))num_channels = num_channels // 2# 同ResNet一样,最后接上全局池化层和全连接层来输出。
net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10)))
我们尝试打印每个子模块的输出维度确保网络无误:
X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():X = layer(X)print(name, ' output shape:\t', X.shape)
输出:
0 output shape: torch.Size([1, 64, 48, 48])
1 output shape: torch.Size([1, 64, 48, 48])
2 output shape: torch.Size([1, 64, 48, 48])
3 output shape: torch.Size([1, 64, 24, 24])
DenseBlosk_0 output shape: torch.Size([1, 192, 24, 24])
transition_block_0 output shape: torch.Size([1, 96, 12, 12])
DenseBlosk_1 output shape: torch.Size([1, 224, 12, 12])
transition_block_1 output shape: torch.Size([1, 112, 6, 6])
DenseBlosk_2 output shape: torch.Size([1, 240, 6, 6])
transition_block_2 output shape: torch.Size([1, 120, 3, 3])
DenseBlosk_3 output shape: torch.Size([1, 248, 3, 3])
BN output shape: torch.Size([1, 248, 3, 3])
relu output shape: torch.Size([1, 248, 3, 3])
global_avg_pool output shape: torch.Size([1, 248, 1, 1])
fc output shape: torch.Size([1, 10])
5 获取数据并训练
由于这里使用了比较深的网络,本节里我们将输入高和宽从224降到96来简化计算。
batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
参考原文:《动手学深度学习(pyTorch)》
为了深入性能方面的可以继续学习:
DenseNet稠密连接网络(pyTorch源码)相关推荐
- MXNet对DenseNet(稠密连接网络)的实现
论文地址:Densely Connected Convolutional Networks DenseNet其实跟前面的ResNet是很相似的,我们知道ResNet的梯度可以直接通过身份函数(激活函数 ...
- (pytorch-深度学习)实现稠密连接网络(DenseNet)
稠密连接网络(DenseNet) ResNet中的跨层连接设计引申出了数个后续工作.稠密连接网络(DenseNet)与ResNet的主要区别在于在跨层连接上的主要区别: ResNet使用相加 Dens ...
- 07.7. 稠密连接网络(DenseNet)
文章目录 7.7. 稠密连接网络(DenseNet) 7.7.1. 从ResNet到DenseNet 7.7.2. 稠密块体 7.7.3. 过渡层 7.7.4. DenseNet模型 7.7.5. 训 ...
- pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)
写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...
- ELMo解读(论文 + PyTorch源码)
ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...
- Transformer-XL解读(论文 + PyTorch源码)
前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...
- pytorch源码解析2——数据处理torch.utils.data
迭代器 理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键. 在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 py ...
- autojs网络验证,权朗网络验证源码
auto.js一个简单的网络验证源码 threads.start(function(){toastLog("开始查询...") //验证地址 r = http.get(" ...
- 基于Pytorch源码对SGD、momentum、Nesterov学习
目前神经网络的监督学习过程通常为: 数据加载(load)进神经网络 经过网络参数对数据的计算,得出预测值(predict) 根据预测值与标注值(label)之间的差距,产生损失(loss) 通过反向传 ...
最新文章
- 零起点学算法01——第一个程序Hello World!
- C语言编程语言科技 c语言中的= 和= =有什么区别?(精华篇)
- linux异常 - 无法分配内存
- 基于 Bochs 的操作系统内核实现
- PHP预定义常量DIRECTORY_SEPARATOR
- JS事件冒泡与事件捕获
- 如何在线把网站html生成xml文件_快速抓取网站信息工具
- 限制input 内部字数
- python 打开excel并在屏幕上呈现_excel-检查文件是否在Python中打开
- jsTree工作笔记001---jsTree的基本使用_js实现树形结构
- 遗传算法的简介与应用详细过程
- Redis客户端与服务端
- 【Excel 教程系列第 15 篇】Excel 中的简单排序(升序 / 降序)、多条件排序、按颜色排序、自定义排序、以及巧用“升序“制作工资条
- 畅游陈德文:中国网游的发展与未来趋势
- 【python--爬虫】彼岸图网高清壁纸爬虫
- 服务器iis的作用,Web 服务器 (IIS) 概述
- java设计模式学习-代理模式
- ffmpeg flv转MP4
- c语言计算10以内之和,求一个C语言程序,随机产生50道10以内的加法算术题
- [半监督学习] Combining Labeled and Unlabeled Data with Co-Training
热门文章
- 【Oracle】使用Function计算去除周末及法定节假日天数
- 3D模型欣赏:猫耳少女 唯美 可爱
- oracle segment undo_Oracle undo管理详解
- rails pry使用_使用Rails Active Resource简化Web应用程序之间的互操作性
- json txt格式转换器_BIOM:生物观测矩阵——微生物组数据通用数据格式
- 三个分布式计算软件(Prime95、Folding@Home、BOINC)
- Python中Unicode字符串(字符串编码问题)
- 宁波大学2023年MBA招生考试初试成绩查询的通知
- 基于autojs的安卓免root脚本引擎编写的QQ文字换语言发送全自动脚本
- google reader分享计划(北邮制造)