稠密连接网络(DenseNet)

ResNet中的跨层连接设计引申出了数个后续工作。稠密连接网络(DenseNet)与ResNet的主要区别在于在跨层连接上的主要区别:

  • ResNet使用相加
  • DenseNet使用连结

ResNet(左)与DenseNet(右):

图中将部分前后相邻的运算抽象为模块AAA和模块BBB。

  • DenseNet里模块BBB的输出不是像ResNet那样和模块AAA的输出相加,而是在通道维上连结。
  • 这样模块AAA的输出可以直接传入模块BBB后面的层。在这个设计里,模块AAA相当于直接跟模块BBB后面的所有层直接连接在了一起。这也是它被称为“稠密连接”的原因。

DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。

  • 稠密块定义了输入和输出是如何连结的
  • 过渡层用来控制通道数,控制其大小

稠密块

DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构:

import time
import torch
from torch import nn, optim
import torch.nn.functional as Fdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def conv_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))return blk
  • 稠密块由多个conv_block组成,每块使用相同的输出通道数
  • 在前向计算时,我们将每块的输入和输出在通道维上连结。
class DenseBlock(nn.Module):def __init__(self, num_convs, in_channels, out_channels):super(DenseBlock, self).__init__()net = []for i in range(num_convs):in_c = in_channels + i * out_channelsnet.append(conv_block(in_c, out_channels))self.net = nn.ModuleList(net)self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数def forward(self, X):for blk in self.net:Y = blk(X)X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结return X

定义一个有2个输出通道数为10的卷积块。

  • 使用通道数为3的输入时,我们会得到通道数为3+2×10=233+2\times 10=233+2×10=23的输出。
  • 卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。
blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])

过渡层

  • 每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。
  • 过渡层用来控制模型复杂度。它通过1×11\times11×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。
def transition_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.AvgPool2d(kernel_size=2, stride=2))return blk

对上例中稠密块的输出,使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。

blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])

DenseNet模型

DenseNet首先使用和ResNet一样的单卷积层和最大池化层。

net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
  • 接着使用4个稠密块。
  • 同ResNet一样,我们可以设置每个稠密块使用多少个卷积层(这里设成4)。
  • 稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。DenseNet则使用过渡层来减半高和宽,并减半通道数。

num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]for i, num_convs in enumerate(num_convs_in_dense_blocks):DB = DenseBlock(num_convs, num_channels, growth_rate)net.add_module("DenseBlosk_%d" % i, DB)# 上一个稠密块的输出通道数num_channels = DB.out_channels# 在稠密块之间加入通道数减半的过渡层if i != len(num_convs_in_dense_blocks) - 1:net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))num_channels = num_channels // 2
  • 最后接上全局池化层和全连接层来输出。
class GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)
net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(num_channels, 10)))
  • 打印每个子模块的输出维度
X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():X = layer(X)print(name, ' output shape:\t', X.shape)
0  output shape:  torch.Size([1, 64, 48, 48])
1  output shape:     torch.Size([1, 64, 48, 48])
2  output shape:     torch.Size([1, 64, 48, 48])
3  output shape:     torch.Size([1, 64, 24, 24])
DenseBlosk_0  output shape:  torch.Size([1, 192, 24, 24])
transition_block_0  output shape:    torch.Size([1, 96, 12, 12])
DenseBlosk_1  output shape:  torch.Size([1, 224, 12, 12])
transition_block_1  output shape:    torch.Size([1, 112, 6, 6])
DenseBlosk_2  output shape:  torch.Size([1, 240, 6, 6])
transition_block_2  output shape:    torch.Size([1, 120, 3, 3])
DenseBlosk_3  output shape:  torch.Size([1, 248, 3, 3])
BN  output shape:    torch.Size([1, 248, 3, 3])
relu  output shape:  torch.Size([1, 248, 3, 3])
global_avg_pool  output shape:   torch.Size([1, 248, 1, 1])
fc  output shape:    torch.Size([1, 10])
  • 获取数据
def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):"""Download the fashion mnist dataset and then load into memory."""trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iter
batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=96)

训练模型

def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

《动手学深度学习》

(pytorch-深度学习)实现稠密连接网络(DenseNet)相关推荐

  1. pytorch深度学习实战——预训练网络

    来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...

  2. PyTorch 深度学习:32分钟快速入门——DenseNet

    DenseNet¶ 因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 cvpr 2017 的 best paper,DenseNet. DenseNet ...

  3. pytorch | 深度学习分割网络U-net的pytorch模型实现

    原文:https://blog.csdn.net/u014722627/article/details/60883185 pytorch | 深度学习分割网络U-net的pytorch模型实现 这个是 ...

  4. Pytorch深度学习实战教程:UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 本文的开发环境如下: 开发环境:Windows 开发语言:Python3. ...

  5. Pytorch 深度学习实战教程(二):UNet语义分割网络

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  6. Pytorch深度学习实战教程(二):UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...

  7. 手把手教你搭建pytorch深度学习网络

    总有人在后台问我,如今 TensorFlow 和 PyTorch 两个深度学习框架,哪个更流行? 就这么说吧,今年面试的实习生,问到常用的深度学习框架时,他们清一色的选择了「PyTorch」. 这并不 ...

  8. 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

    深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...

  9. 07.7. 稠密连接网络(DenseNet)

    文章目录 7.7. 稠密连接网络(DenseNet) 7.7.1. 从ResNet到DenseNet 7.7.2. 稠密块体 7.7.3. 过渡层 7.7.4. DenseNet模型 7.7.5. 训 ...

最新文章

  1. [转]我们为什么需要工作流
  2. VS2010配置OpenCV
  3. 数据结构——队列(queue)
  4. 如何在word写小论文在正文分栏后第一页左下角添加 项目 基金 作者简介 (添加通栏脚注)
  5. S4HANA里至关重要的建模方式CDS view架构介绍
  6. java中如何调出字体对话框_java 字体对话框
  7. Collections.synchronizedList使用
  8. MDaemon邮件服务器解决方案之应急恢复解决方案
  9. Sendmail 邮件服务器安装和优化
  10. Microsoft SQL Server Protocols
  11. Linux系统管理---LVM分区管理
  12. 分享一份软件测试面试指南
  13. GRE词汇统计大全(二)
  14. 案例4——52周存钱挑战
  15. Python的学习笔记案例4--52周存钱挑战3.0
  16. 译:25个面试中最常问的问题和答案
  17. 【评测】照胶的仪器选购
  18. 运用Python爬虫爬取一个美女网址,爬取美女图
  19. 案例 | 巴别鸟为弘睿构建企业知识库
  20. 如何正确使用SIM卡呢?

热门文章

  1. python组合数据分类_Python解决数据样本类别分布不均衡问题
  2. python 除法取模_跟我一起学python | 探究05
  3. 【错误记录】Invalid character found in method name. HTTP method names must be tokens
  4. 回头看看NSURLConnection
  5. 选择排序算法流程图_常用排序算法之选择排序
  6. 照片识别出错_AI跨年龄人脸识别技术在跨年龄寻亲的应用简析
  7. jsp项目开发案例_Laravel中使用swoole项目实战开发案例一 (建立swoole和前端通信)
  8. 华硕 x86 android,【华硕X79评测】学不会不收费 几步教你安装Android x86-中关村在线...
  9. matlab二维数组最小值出错,矩阵求最小值问题 问题是: 错误使用空矩形矩阵进行赋值...
  10. 【LeetCode笔记】78. 子集(Java、dfs)