原文:https://blog.csdn.net/qq_43360533/article/details/107448369

目录

1 Introduction

3 过渡层

4 DenseNet模型

5 获取数据并训练


1 Introduction

与ResNet的主要区别在于,DenseNet里模块B的输出不是像ResNet那样和模块A的输出相加,而是在通道维上连结。这样模块A的输出可以直接传入模块B后面的层。在这个设计里,模块A直接跟模块B后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。
如果用公式表示的话,传统的网络在 [公式] 层的输出为:


而对于ResNet,增加了来自上一层输入的identity函数:


在DenseNet中,会连接前面所有层作为输入:

 
DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。


在DenseBlock中,各个层的特征图大小一致,是如下图的结构,可以在channel维度上连接。DenseBlock中的非线性组合函数H(·)采用的是BN+ReLU+3x3 Conv的结构,如下图所示。另外值得注意的一点是,与ResNet不同,所有DenseBlock中各个层卷积之后均输出k个特征图,即得到的特征图的channel数为 k,或者说采用 k 个卷积核。k 在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的k (比如12),就可以得到较佳的性能。假定输入层的特征图的channel数为 k0,那么l 层输入的channel数为 k0+k(l-1) ,因此随着层数增加,尽管 k设定得较小,DenseBlock的输入会非常多,不过这是由于特征重用所造成的,每个层仅有 k个特征是自己独有的。

2 稠密块

DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构,我们首先在conv_block函数里实现这个结构。

import time
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def conv_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))return blk

# 稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,将每块的输入和
# 输出在通道维上连结。
class DenseBlock(nn.Module):def __init__(self, num_convs, in_channels, out_channels):super(DenseBlock, self).__init__()net = []for i in range(num_convs):in_c = in_channels + i * out_channelsnet.append(conv_block(in_c, out_channels))self.net = nn.ModuleList(net)self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数def forward(self, X):for blk in self.net:Y = blk(X)X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结return X

blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])

3 过渡层

# 由于每个稠密块连结都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用
# 来控制模型复杂度。它通过1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和
# 宽,从而进一步降低模型复杂度。
def transition_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.AvgPool2d(kernel_size=2, stride=2))return blk

对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。

blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])

4 DenseNet模型

DenseNet首先使用同ResNet一样的单卷积层和最大池化层。

类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,可以设置每个稠密块使用多少个卷积层,这里设成4,与ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

最后接上全局池化层和全连接层来输出。

net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]for i, num_convs in enumerate(num_convs_in_dense_blocks):DB = DenseBlock(num_convs, num_channels, growth_rate)net.add_module("DenseBlosk_%d" % i, DB)# 上一个稠密块的输出通道数num_channels = DB.out_channels# 在稠密块之间加入通道数减半的过渡层if i != len(num_convs_in_dense_blocks) - 1:net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))num_channels = num_channels // 2# 同ResNet一样,最后接上全局池化层和全连接层来输出。
net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) 

我们尝试打印每个子模块的输出维度确保网络无误:

X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():X = layer(X)print(name, ' output shape:\t', X.shape)

输出:

0  output shape:     torch.Size([1, 64, 48, 48])
1  output shape:     torch.Size([1, 64, 48, 48])
2  output shape:     torch.Size([1, 64, 48, 48])
3  output shape:     torch.Size([1, 64, 24, 24])
DenseBlosk_0  output shape:     torch.Size([1, 192, 24, 24])
transition_block_0  output shape:     torch.Size([1, 96, 12, 12])
DenseBlosk_1  output shape:     torch.Size([1, 224, 12, 12])
transition_block_1  output shape:     torch.Size([1, 112, 6, 6])
DenseBlosk_2  output shape:     torch.Size([1, 240, 6, 6])
transition_block_2  output shape:     torch.Size([1, 120, 3, 3])
DenseBlosk_3  output shape:     torch.Size([1, 248, 3, 3])
BN  output shape:     torch.Size([1, 248, 3, 3])
relu  output shape:     torch.Size([1, 248, 3, 3])
global_avg_pool  output shape:     torch.Size([1, 248, 1, 1])
fc  output shape:     torch.Size([1, 10])

5 获取数据并训练

由于这里使用了比较深的网络,本节里我们将输入高和宽从224降到96来简化计算。

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

参考原文:《动手学深度学习(pyTorch)》
为了深入性能方面的可以继续学习:

DenseNet稠密连接网络(pyTorch源码)相关推荐

  1. MXNet对DenseNet(稠密连接网络)的实现

    论文地址:Densely Connected Convolutional Networks DenseNet其实跟前面的ResNet是很相似的,我们知道ResNet的梯度可以直接通过身份函数(激活函数 ...

  2. (pytorch-深度学习)实现稠密连接网络(DenseNet)

    稠密连接网络(DenseNet) ResNet中的跨层连接设计引申出了数个后续工作.稠密连接网络(DenseNet)与ResNet的主要区别在于在跨层连接上的主要区别: ResNet使用相加 Dens ...

  3. 07.7. 稠密连接网络(DenseNet)

    文章目录 7.7. 稠密连接网络(DenseNet) 7.7.1. 从ResNet到DenseNet 7.7.2. 稠密块体 7.7.3. 过渡层 7.7.4. DenseNet模型 7.7.5. 训 ...

  4. pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)

    写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...

  5. ELMo解读(论文 + PyTorch源码)

    ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...

  6. Transformer-XL解读(论文 + PyTorch源码)

    前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...

  7. pytorch源码解析2——数据处理torch.utils.data

    迭代器 理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键. 在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 py ...

  8. autojs网络验证,权朗网络验证源码

    auto.js一个简单的网络验证源码 threads.start(function(){toastLog("开始查询...") //验证地址 r = http.get(" ...

  9. 基于Pytorch源码对SGD、momentum、Nesterov学习

    目前神经网络的监督学习过程通常为: 数据加载(load)进神经网络 经过网络参数对数据的计算,得出预测值(predict) 根据预测值与标注值(label)之间的差距,产生损失(loss) 通过反向传 ...

最新文章

  1. 零起点学算法01——第一个程序Hello World!
  2. C语言编程语言科技 c语言中的= 和= =有什么区别?(精华篇)
  3. linux异常 - 无法分配内存
  4. 基于 Bochs 的操作系统内核实现
  5. PHP预定义常量DIRECTORY_SEPARATOR
  6. JS事件冒泡与事件捕获
  7. 如何在线把网站html生成xml文件_快速抓取网站信息工具
  8. 限制input 内部字数
  9. python 打开excel并在屏幕上呈现_excel-检查文件是否在Python中打开
  10. jsTree工作笔记001---jsTree的基本使用_js实现树形结构
  11. 遗传算法的简介与应用详细过程
  12. Redis客户端与服务端
  13. 【Excel 教程系列第 15 篇】Excel 中的简单排序(升序 / 降序)、多条件排序、按颜色排序、自定义排序、以及巧用“升序“制作工资条
  14. 畅游陈德文:中国网游的发展与未来趋势
  15. 【python--爬虫】彼岸图网高清壁纸爬虫
  16. 服务器iis的作用,Web 服务器 (IIS) 概述
  17. java设计模式学习-代理模式
  18. ffmpeg flv转MP4
  19. c语言计算10以内之和,求一个C语言程序,随机产生50道10以内的加法算术题
  20. [半监督学习] Combining Labeled and Unlabeled Data with Co-Training

热门文章

  1. 【Oracle】使用Function计算去除周末及法定节假日天数
  2. 3D模型欣赏:猫耳少女 唯美 可爱
  3. oracle segment undo_Oracle undo管理详解
  4. rails pry使用_使用Rails Active Resource简化Web应用程序之间的互操作性
  5. json txt格式转换器_BIOM:生物观测矩阵——微生物组数据通用数据格式
  6. 三个分布式计算软件(Prime95、Folding@Home、BOINC)
  7. Python中Unicode字符串(字符串编码问题)
  8. 宁波大学2023年MBA招生考试初试成绩查询的通知
  9. 基于autojs的安卓免root脚本引擎编写的QQ文字换语言发送全自动脚本
  10. google reader分享计划(北邮制造)