(pytorch-深度学习)实现稠密连接网络(DenseNet)
稠密连接网络(DenseNet)
ResNet中的跨层连接设计引申出了数个后续工作。稠密连接网络(DenseNet)与ResNet的主要区别在于在跨层连接上的主要区别:
- ResNet使用相加
- DenseNet使用连结
ResNet(左)与DenseNet(右):
图中将部分前后相邻的运算抽象为模块AAA和模块BBB。
- DenseNet里模块BBB的输出不是像ResNet那样和模块AAA的输出相加,而是在通道维上连结。
- 这样模块AAA的输出可以直接传入模块BBB后面的层。在这个设计里,模块AAA相当于直接跟模块BBB后面的所有层直接连接在了一起。这也是它被称为“稠密连接”的原因。
DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。
- 稠密块定义了输入和输出是如何连结的
- 过渡层用来控制通道数,控制其大小
稠密块
DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构:
import time
import torch
from torch import nn, optim
import torch.nn.functional as Fdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def conv_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))return blk
- 稠密块由多个conv_block组成,每块使用相同的输出通道数。
- 在前向计算时,我们将每块的输入和输出在通道维上连结。
class DenseBlock(nn.Module):def __init__(self, num_convs, in_channels, out_channels):super(DenseBlock, self).__init__()net = []for i in range(num_convs):in_c = in_channels + i * out_channelsnet.append(conv_block(in_c, out_channels))self.net = nn.ModuleList(net)self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数def forward(self, X):for blk in self.net:Y = blk(X)X = torch.cat((X, Y), dim=1) # 在通道维上将输入和输出连结return X
定义一个有2个输出通道数为10的卷积块。
- 使用通道数为3的输入时,我们会得到通道数为3+2×10=233+2\times 10=233+2×10=23的输出。
- 卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。
blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])
过渡层
- 每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。
- 过渡层用来控制模型复杂度。它通过1×11\times11×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。
def transition_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.AvgPool2d(kernel_size=2, stride=2))return blk
对上例中稠密块的输出,使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。
blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])
DenseNet模型
DenseNet首先使用和ResNet一样的单卷积层和最大池化层。
net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
- 接着使用4个稠密块。
- 同ResNet一样,我们可以设置每个稠密块使用多少个卷积层(这里设成4)。
- 稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。
ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。DenseNet则使用过渡层来减半高和宽,并减半通道数。
num_channels, growth_rate = 64, 32 # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]for i, num_convs in enumerate(num_convs_in_dense_blocks):DB = DenseBlock(num_convs, num_channels, growth_rate)net.add_module("DenseBlosk_%d" % i, DB)# 上一个稠密块的输出通道数num_channels = DB.out_channels# 在稠密块之间加入通道数减半的过渡层if i != len(num_convs_in_dense_blocks) - 1:net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))num_channels = num_channels // 2
- 最后接上全局池化层和全连接层来输出。
class GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)
net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(num_channels, 10)))
- 打印每个子模块的输出维度
X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():X = layer(X)print(name, ' output shape:\t', X.shape)
0 output shape: torch.Size([1, 64, 48, 48])
1 output shape: torch.Size([1, 64, 48, 48])
2 output shape: torch.Size([1, 64, 48, 48])
3 output shape: torch.Size([1, 64, 24, 24])
DenseBlosk_0 output shape: torch.Size([1, 192, 24, 24])
transition_block_0 output shape: torch.Size([1, 96, 12, 12])
DenseBlosk_1 output shape: torch.Size([1, 224, 12, 12])
transition_block_1 output shape: torch.Size([1, 112, 6, 6])
DenseBlosk_2 output shape: torch.Size([1, 240, 6, 6])
transition_block_2 output shape: torch.Size([1, 120, 3, 3])
DenseBlosk_3 output shape: torch.Size([1, 248, 3, 3])
BN output shape: torch.Size([1, 248, 3, 3])
relu output shape: torch.Size([1, 248, 3, 3])
global_avg_pool output shape: torch.Size([1, 248, 1, 1])
fc output shape: torch.Size([1, 10])
- 获取数据
def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):"""Download the fashion mnist dataset and then load into memory."""trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)if sys.platform.startswith('win'):num_workers = 0 # 0表示不用额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iter
batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=96)
训练模型
def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
《动手学深度学习》
(pytorch-深度学习)实现稠密连接网络(DenseNet)相关推荐
- pytorch深度学习实战——预训练网络
来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...
- PyTorch 深度学习:32分钟快速入门——DenseNet
DenseNet¶ 因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 cvpr 2017 的 best paper,DenseNet. DenseNet ...
- pytorch | 深度学习分割网络U-net的pytorch模型实现
原文:https://blog.csdn.net/u014722627/article/details/60883185 pytorch | 深度学习分割网络U-net的pytorch模型实现 这个是 ...
- Pytorch深度学习实战教程:UNet语义分割网络
1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 本文的开发环境如下: 开发环境:Windows 开发语言:Python3. ...
- Pytorch 深度学习实战教程(二):UNet语义分割网络
本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...
- Pytorch深度学习实战教程(二):UNet语义分割网络
1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...
- 手把手教你搭建pytorch深度学习网络
总有人在后台问我,如今 TensorFlow 和 PyTorch 两个深度学习框架,哪个更流行? 就这么说吧,今年面试的实习生,问到常用的深度学习框架时,他们清一色的选择了「PyTorch」. 这并不 ...
- 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别
深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...
- 07.7. 稠密连接网络(DenseNet)
文章目录 7.7. 稠密连接网络(DenseNet) 7.7.1. 从ResNet到DenseNet 7.7.2. 稠密块体 7.7.3. 过渡层 7.7.4. DenseNet模型 7.7.5. 训 ...
最新文章
- [转]我们为什么需要工作流
- VS2010配置OpenCV
- 数据结构——队列(queue)
- 如何在word写小论文在正文分栏后第一页左下角添加 项目 基金 作者简介 (添加通栏脚注)
- S4HANA里至关重要的建模方式CDS view架构介绍
- java中如何调出字体对话框_java 字体对话框
- Collections.synchronizedList使用
- MDaemon邮件服务器解决方案之应急恢复解决方案
- Sendmail 邮件服务器安装和优化
- Microsoft SQL Server Protocols
- Linux系统管理---LVM分区管理
- 分享一份软件测试面试指南
- GRE词汇统计大全(二)
- 案例4——52周存钱挑战
- Python的学习笔记案例4--52周存钱挑战3.0
- 译:25个面试中最常问的问题和答案
- 【评测】照胶的仪器选购
- 运用Python爬虫爬取一个美女网址,爬取美女图
- 案例 | 巴别鸟为弘睿构建企业知识库
- 如何正确使用SIM卡呢?
热门文章
- python组合数据分类_Python解决数据样本类别分布不均衡问题
- python 除法取模_跟我一起学python | 探究05
- 【错误记录】Invalid character found in method name. HTTP method names must be tokens
- 回头看看NSURLConnection
- 选择排序算法流程图_常用排序算法之选择排序
- 照片识别出错_AI跨年龄人脸识别技术在跨年龄寻亲的应用简析
- jsp项目开发案例_Laravel中使用swoole项目实战开发案例一 (建立swoole和前端通信)
- 华硕 x86 android,【华硕X79评测】学不会不收费 几步教你安装Android x86-中关村在线...
- matlab二维数组最小值出错,矩阵求最小值问题 问题是: 错误使用空矩形矩阵进行赋值...
- 【LeetCode笔记】78. 子集(Java、dfs)