DenseNet¶

因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 cvpr 2017 的 best paper,DenseNet。

DenseNet 和 ResNet 不同在于 ResNet 是跨层求和,而 DenseNet 是跨层将特征在通道维度进行拼接, DenseNet因为是在通道维度进行特征的拼接,所以底层的输出会保留进入所有后面的层,这能够更好的保证梯度的传播,同时能够使用低维的特征和高维的特征进行联合训练,能够得到更好的结果。

DenseNet 主要由 dense block 构成,下面我们来实现一个 densen block

import sys
sys.path.append('..')import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10首先定义一个卷积块,这个卷积块的顺序是 bn -> relu -> convdef conv_block(in_channel, out_channel):layer = nn.Sequential(nn.BatchNorm2d(in_channel),nn.ReLU(True),nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False))return layerdense block 将每次的卷积的输出称为 growth_rate,因为如果输入是 in_channel,有 n 层,那么输出就是 in_channel + n * growh_rateclass dense_block(nn.Module):def __init__(self, in_channel, growth_rate, num_layers):super(dense_block, self).__init__()block = []channel = in_channelfor i in range(num_layers):block.append(conv_block(channel, growth_rate))channel += growth_rateself.net = nn.Sequential(*block)def forward(self, x):for layer in self.net:out = layer(x)x = torch.cat((out, x), dim=1)return x我们验证一下输出的 channel 是否正确test_net = dense_block(3, 12, 3)
test_x = Variable(torch.zeros(1, 3, 96, 96))
print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
test_y = test_net(test_x)
print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))input shape: 3 x 96 x 96
output shape: 39 x 96 x 96

除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet 会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用 1 x 1 的卷积

def transition(in_channel, out_channel):trans_layer = nn.Sequential(nn.BatchNorm2d(in_channel),nn.ReLU(True),nn.Conv2d(in_channel, out_channel, 1),nn.AvgPool2d(2, 2))return trans_layer

验证一下过渡层是否正确

test_net = transition(3, 12)
test_x = Variable(torch.zeros(1, 3, 96, 96))
print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
test_y = test_net(test_x)
print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))input shape: 3 x 96 x 96
output shape: 12 x 48 x 48

最后我们定义 DenseNet

class densenet(nn.Module):def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):super(densenet, self).__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channel, 64, 7, 2, 3),nn.BatchNorm2d(64),nn.ReLU(True),nn.MaxPool2d(3, 2, padding=1))channels = 64block = []for i, layers in enumerate(block_layers):block.append(dense_block(channels, growth_rate, layers))channels += layers * growth_rateif i != len(block_layers) - 1:block.append(transition(channels, channels // 2)) # 通过 transition 层将大小减半,通道数减半channels = channels // 2self.block2 = nn.Sequential(*block)self.block2.add_module('bn', nn.BatchNorm2d(channels))self.block2.add_module('relu', nn.ReLU(True))self.block2.add_module('avg_pool', nn.AvgPool2d(3))self.classifier = nn.Linear(channels, num_classes)def forward(self, x):x = self.block1(x)x = self.block2(x)x = x.view(x.shape[0], -1)x = self.classifier(x)return xtest_net = densenet(3, 10)
test_x = Variable(torch.zeros(1, 3, 96, 96))
test_y = test_net(test_x)
print('output: {}'.format(test_y.shape))output: torch.Size([1, 10])
from utils import traindef data_tf(x):x = x.resize((96, 96), 2) # 将图片放大到 96 x 96x = np.array(x, dtype='float32') / 255x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到x = x.transpose((2, 0, 1)) # 将 channel 放到第一维,只是 pytorch 要求的输入方式x = torch.from_numpy(x)return xtrain_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)net = densenet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()train(net, train_data, test_data, 20, optimizer, criterion)Epoch 0. Train Loss: 1.374316, Train Acc: 0.507972, Valid Loss: 1.203217, Valid Acc: 0.572884, Time 00:01:44
Epoch 1. Train Loss: 0.912924, Train Acc: 0.681506, Valid Loss: 1.555908, Valid Acc: 0.492286, Time 00:01:50
Epoch 2. Train Loss: 0.701387, Train Acc: 0.755794, Valid Loss: 0.815147, Valid Acc: 0.718354, Time 00:01:49
Epoch 3. Train Loss: 0.575985, Train Acc: 0.800911, Valid Loss: 0.696013, Valid Acc: 0.759494, Time 00:01:50
Epoch 4. Train Loss: 0.479812, Train Acc: 0.836957, Valid Loss: 1.013879, Valid Acc: 0.676226, Time 00:01:51
Epoch 5. Train Loss: 0.402165, Train Acc: 0.861413, Valid Loss: 0.674512, Valid Acc: 0.778481, Time 00:01:50
Epoch 6. Train Loss: 0.334593, Train Acc: 0.888247, Valid Loss: 0.647112, Valid Acc: 0.791634, Time 00:01:50
Epoch 7. Train Loss: 0.278181, Train Acc: 0.907149, Valid Loss: 0.773517, Valid Acc: 0.756527, Time 00:01:51
Epoch 8. Train Loss: 0.227948, Train Acc: 0.922714, Valid Loss: 0.654399, Valid Acc: 0.800237, Time 00:01:49
Epoch 9. Train Loss: 0.181156, Train Acc: 0.940157, Valid Loss: 1.179013, Valid Acc: 0.685225, Time 00:01:50
Epoch 10. Train Loss: 0.151305, Train Acc: 0.950208, Valid Loss: 0.630000, Valid Acc: 0.807951, Time 00:01:50
Epoch 11. Train Loss: 0.118433, Train Acc: 0.961077, Valid Loss: 1.247253, Valid Acc: 0.703323, Time 00:01:52
Epoch 12. Train Loss: 0.094127, Train Acc: 0.969789, Valid Loss: 1.230697, Valid Acc: 0.723101, Time 00:01:51
Epoch 13. Train Loss: 0.086181, Train Acc: 0.972047, Valid Loss: 0.904135, Valid Acc: 0.769284, Time 00:01:50
Epoch 14. Train Loss: 0.064248, Train Acc: 0.980359, Valid Loss: 1.665002, Valid Acc: 0.624209, Time 00:01:51
Epoch 15. Train Loss: 0.054932, Train Acc: 0.982996, Valid Loss: 0.927216, Valid Acc: 0.774723, Time 00:01:51
Epoch 16. Train Loss: 0.043503, Train Acc: 0.987272, Valid Loss: 1.574383, Valid Acc: 0.707377, Time 00:01:52
Epoch 17. Train Loss: 0.047615, Train Acc: 0.985154, Valid Loss: 0.987781, Valid Acc: 0.770471, Time 00:01:51
Epoch 18. Train Loss: 0.039813, Train Acc: 0.988012, Valid Loss: 2.248944, Valid Acc: 0.631824, Time 00:01:50
Epoch 19. Train Loss: 0.030183, Train Acc: 0.991168, Valid Loss: 0.887785, Valid Acc: 0.795392, Time 00:01:51

DenseNet 将残差连接改为了特征拼接,使得网络有了更稠密的连接

PyTorch 深度学习:32分钟快速入门——DenseNet相关推荐

  1. PyTorch 深度学习:32分钟快速入门——ResNet

    ResNet 当大家还在惊叹 GoogLeNet 的 inception 结构的时候,微软亚洲研究院的研究员已经在设计更深但结构更加简单的网络 ResNet,并且凭借这个网络子在 2015 年 Ima ...

  2. PyTorch 深度学习:36分钟快速入门——GAN

    自动编码器和变分自动编码器,不管是哪一个,都是通过计算生成图像和输入图像在每个像素点的误差来生成 loss,这一点是特别不好的,因为不同的像素点可能造成不同的视觉结果,但是可能他们的 loss 是相同 ...

  3. PyTorch 深度学习:34分钟快速入门——自动编码器

    自动编码器最开始是作为一种数据压缩方法,同时还可以在卷积网络中进行逐层预训练,但是随后更多结构复杂的网络,比如 resnet 的出现使得我们能够训练任意深度的网络,自动编码器就不再使用在这个方面,下面 ...

  4. PyTorch 深度学习:33分钟快速入门——VGG

    CIFAR 10¶ cifar 10 这个数据集一共有 50000 张训练集,10000 张测试集,两个数据集里面的图片都是 png 彩色图片,图片大小是 32 x 32 x 3,一共是 10 分类问 ...

  5. PyTorch 深度学习:30分钟快速入门

    卷积¶ 卷积在 pytorch 中有两种方式,一种是 torch.nn.Conv2d(),一种是 torch.nn.functional.conv2d(),这两种形式本质都是使用一个卷积操作 这两种形 ...

  6. PyTorch 深度学习:37分钟快速入门——FCN 做语义分割

    语义分割是一种像素级别的处理图像方式,对比于目标检测其更加精确,能够自动从图像中划分出对象区域并识别对象区域中的类别 在 2015 年 CVPR 的一篇论文 Fully Convolutional N ...

  7. PyTorch 深度学习:38分钟快速入门——RNN 做图像分类

    RNN 特别适合做序列类型的数据,那么 RNN 能不能想 CNN 一样用来做图像分类呢?下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类,但是这种方法并不是主流,这里我们只是 ...

  8. PyTorch 深度学习:35分钟快速入门——变分自动编码器

    变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成. 回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编 ...

  9. PyTorch 深度学习:31分钟快速入门——Batch Normalization

    Batch Normalization¶ 前面在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好.但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相 ...

最新文章

  1. mysql 绑定参数_MySQL 使用 Perl 绑定参数和列
  2. LSOF 安装与使用(功能强大)
  3. c++实现,对象池 object_pool
  4. 第九节: EF的性能篇(二) 之 Z.EntityFramework.Extensions程序集解决EF的性能问题
  5. latex表格名的引用问题
  6. Go Web编程--应用ORM
  7. 对AngularJS的编译和链接过程讲解一步到位的文章
  8. __VA_ARGS__和##__VA_ARGS__的区别(转载)
  9. Kafka权威指南,Kafka生产者
  10. vs 2005 sp1 安装失败的解决方案 安装VS2005 sp1的方法
  11. 为什么我们应该使用 HTML5 开发网站
  12. 巨人综合音源优化版 – East West Quantum Leap Colossus Kontakt
  13. Vue项目设置浏览器小图标
  14. 【成功解决】Ubuntu下U盘文件夹不存在
  15. 羽毛球·印尼赛 | 国羽男双新高塔组合惊喜进决赛
  16. 2022年资料员-岗位技能(资料员)操作证考试题模拟考试平台操作
  17. python中矩阵的表示方法_在python中创建数字的二进制表示形式的矩阵 - python
  18. 达梦8 DCA培训总结
  19. STG游戏中瞄具的基本原理
  20. 能温柔的时候,请别尖锐

热门文章

  1. Wootrade宣布加入SushiSwa旗下Mirin协议和子池计划
  2. 数据:灰度比特币信托基金溢价达41%创近一年新高
  3. SAP License:ERP系统管理软件该有的“魅力”
  4. scrapy 基本操作
  5. 关于hive中的reduce个数的设置。
  6. YAML书写规则与数据结构
  7. selenium的定位方式
  8. python基础3之文件操作、字符编码解码、函数介绍
  9. PHP模板引擎smarty详细介绍
  10. linux线程相关函数接口