1. VGG块

VGG块的组成规律是:连续使用数个相同的填充为1、窗口形状为3×33\times 33×3的卷积层后接上一个步幅为2、窗口形状为2×22\times 22×2的最大池化层。卷积层保持输入的高和宽不变,而池化层则对其减半。我们使用vgg_block函数来实现这个基础的VGG块,它可以指定卷积层的数量和输入输出通道数。

import time
import torch
from torch import nn, optimimport sys
sys.path.append("..") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def vgg_block(num_convs, in_channels, out_channels):blk = []for i in range(num_convs):if i == 0:blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))else:blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))blk.append(nn.ReLU())blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 这里会使宽高减半return nn.Sequential(*blk)

2. VGG网络

与AlexNet和LeNet一样,VGG网络由卷积层模块后接全连接层模块构成。卷积层模块串联数个vgg_block,其超参数由变量conv_arch定义。该变量指定了每个VGG块里卷积层个数和输入输出通道数。全连接模块则跟AlexNet中的一样。

现在我们构造一个VGG网络。它有5个卷积块,前2块使用单卷积层,而后3块使用双卷积层。第一块的输入输出通道分别是1(因为下面要使用的Fashion-MNIST数据的通道数为1)和64,之后每次对输出通道数翻倍,直到变为512。因为这个网络使用了8个卷积层和3个全连接层,所以经常被称为VGG-11。

conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))
# 经过5个vgg_block, 宽高会减半5次, 变成 224/32 = 7
fc_features = 512 * 7 * 7 # c * w * h
fc_hidden_units = 64 # 任意

下面我们实现VGG-11。

def vgg(conv_arch, fc_features, fc_hidden_units=4096):net = nn.Sequential()# 卷积层部分for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):# 每经过一个vgg_block都会使宽高减半net.add_module("vgg_block_" + str(i+1), vgg_block(num_convs, in_channels, out_channels))# 全连接层部分net.add_module("fc", nn.Sequential(d2l.FlattenLayer(),nn.Linear(fc_features, fc_hidden_units),nn.ReLU(),nn.Dropout(0.5),nn.Linear(fc_hidden_units, fc_hidden_units),nn.ReLU(),nn.Dropout(0.5),nn.Linear(fc_hidden_units, 10)))return net

下面构造一个高和宽均为224的单通道数据样本来观察每一层的输出形状。

net = vgg(conv_arch, fc_features, fc_hidden_units)
X = torch.rand(1, 1, 224, 224)# named_children获取一级子模块及其名字(named_modules会返回所有子模块,包括子模块的子模块)
for name, blk in net.named_children(): X = blk(X)print(name, 'output shape: ', X.shape)

输出:

vgg_block_1 output shape:  torch.Size([1, 64, 112, 112])
vgg_block_2 output shape:  torch.Size([1, 128, 56, 56])
vgg_block_3 output shape:  torch.Size([1, 256, 28, 28])
vgg_block_4 output shape:  torch.Size([1, 512, 14, 14])
vgg_block_5 output shape:  torch.Size([1, 512, 7, 7])
fc output shape:  torch.Size([1, 10])

可以看到,每次我们将输入的高和宽减半,直到最终高和宽变成7后传入全连接层。与此同时,输出通道数每次翻倍,直到变成512。因为每个卷积层的窗口大小一样,所以每层的模型参数尺寸和计算复杂度与输入高、输入宽、输入通道数和输出通道数的乘积成正比。VGG这种高和宽减半以及通道翻倍的设计使得多数卷积层都有相同的模型参数尺寸和计算复杂度。

3. 获取数据和训练模型

因为VGG-11计算上比AlexNet更加复杂,出于测试的目的我们构造一个通道数更小,或者说更窄的网络在Fashion-MNIST数据集上进行训练。

ratio = 8
small_conv_arch = [(1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio)]
net = vgg(small_conv_arch, fc_features // ratio, fc_hidden_units // ratio)
print(net)

输出:

Sequential((vgg_block_1): Sequential((0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(vgg_block_2): Sequential((0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(vgg_block_3): Sequential((0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(vgg_block_4): Sequential((0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(vgg_block_5): Sequential((0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(fc): Sequential((0): FlattenLayer()(1): Linear(in_features=3136, out_features=512, bias=True)(2): ReLU()(3): Dropout(p=0.5)(4): Linear(in_features=512, out_features=512, bias=True)(5): ReLU()(6): Dropout(p=0.5)(7): Linear(in_features=512, out_features=10, bias=True))
)

模型训练过程与上一节的AlexNet中的类似。

def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):"""Download the fashion mnist dataset and then load into memory."""trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iter
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
batch_size = 64
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0101, train acc 0.755, test acc 0.859, time 255.9 sec
epoch 2, loss 0.0051, train acc 0.882, test acc 0.902, time 238.1 sec
epoch 3, loss 0.0043, train acc 0.900, test acc 0.908, time 225.5 sec
epoch 4, loss 0.0038, train acc 0.913, test acc 0.914, time 230.3 sec
epoch 5, loss 0.0035, train acc 0.919, test acc 0.918, time 153.9 sec

小结

  • VGG-11通过5个可以重复使用的卷积块来构造网络。根据每块里卷积层个数和输出通道数的不同可以定义出不同的VGG模型。

pytorch学习笔记(二十五):VGG相关推荐

  1. JVM 学习笔记二十五、JVM监控及诊断工具-命令行篇

    二十五.JVM监控及诊断工具-命令行篇 1.概述 性能诊断是软件工程师在日常工作中经常面对和解决的问题,在用户体验至上的今天,解决好应用软件的性能问题能带来非常大的收益. Java作为最流行的编程语言 ...

  2. Java学习笔记二十五:Java面向对象的三大特性之多态

    Java面向对象的三大特性之多态 一:什么是多态: 多态是同一个行为具有多个不同表现形式或形态的能力. 多态就是同一个接口,使用不同的实例而执行不同操作. 多态性是对象多种表现形式的体现. 现实中,比 ...

  3. pytorch学习笔记(十五):模型构造

    文章目录 1. 继承Module类来构造模型 2. Module的子类 2.1 Sequential类 2.2 ModuleList类 2.3 ModuleDict类 3. 构造复杂的模型 小结 这里 ...

  4. angular学习笔记(二十五)-$http(3)-转换请求和响应格式

    本篇主要讲解$http(config)的config中的tranformRequest项和transformResponse项 1. transformRequest: $http({transfor ...

  5. Mr.J-- jQuery学习笔记(二十五)--监听DOM加载

    页面元素 <body> <div></div> <div></div> <div></div> <div> ...

  6. java沙盒模式_JavaScript学习笔记(二十五) 沙箱模式

    沙箱模式(Sandbox Pattern) 沙箱模式可以避免命名空间的一些缺点(namespacing pattern),比如: 依赖一个唯一全局的变量作为程序的全局符号.在命名空间模式中,没有办法存 ...

  7. JavaScript学习笔记(十五)

    JavaScript学习笔记(十五) 事件 事件是DOM(文档对象模型)的一部分.事件流就是事件发生顺序,这是IE和其他浏览器在事件支持上的主要差别. 一.事件流 1.冒泡型事件 IE上的解决方案就是 ...

  8. OpenCV学习笔记(十五):图像仿射变换:warpAffine(),getRotationMatrix2D()

    OpenCV学习笔记(十五):图像仿射变换:warpAffine(),getRotationMatrix2D() 一个任意的仿射变换都能表示为乘以一个矩阵(线性变换)接着再加上一个向量(平移)的形式. ...

  9. 学习笔记(十五)——镜像的知识点与注意事项

    学习笔记(十五)--镜像的知识点与注意事项 一.基础知识 1.SQL Server镜像只有两种模式:高安全模式和高性能模式.两种模式的主要区别在于在事务提交后的操作. 在高性能模式下,主体服务器不需要 ...

  10. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

最新文章

  1. Java TreeMap 源码解析
  2. 2019 Power BI最Top50面试题,助你面试脱颖而出系列中
  3. 如何截取滚动的页面,窗口
  4. 不好意思昨天断更了,今天聊聊创业
  5. 比Redis快50倍的中间件,为啥这么快?
  6. hive删除EXTERNAL外表
  7. RocketMQ独孤九剑-总纲
  8. 线性方法求欧拉数-POJ2478
  9. mysql5.6.1安装步骤_mysql5.6安装步骤
  10. Ubuntu图形界面升级方法
  11. creator网页调试工具(ccc-devtools v3.0.1)
  12. 算法器之AVR的ISP烧录
  13. 齐鲁医药学院计算机考试,齐鲁医药学院2020年单独招生和综合评价招生考试时间及考试科目...
  14. 程序员自制游戏:超级玛丽100%真实版,能把你玩哭了~【附源码】
  15. 6端口车载以太网交换机
  16. 【无标题】离婚起诉状范文17篇
  17. one body.one heart.一个人,一颗心
  18. 用Java写春联:一键自动发送微信祝福给喜欢的人【撩】
  19. html5 p2p直播源,websocket – 使用HTML5或Javascript进行P2P视频配置
  20. 删除重复数据只保留一条

热门文章

  1. Audio Session Programming Guide
  2. 试验IFTTT同步发微博
  3. OLTP与OLAP介绍
  4. Hadoop源代码分析之Configuration
  5. Java中的StringBuffer、StringBuilder和包装器类型
  6. Struts2的Action中访问servletAPI方式
  7. sqoop 命令在crontab 不能自定执行
  8. hibernate4中主要的配置文件配置
  9. 【数据库实验】《小型MIS的开发》— JavaFx 开发 民航票务管理系统
  10. Linux操作Oracle(3)——Oracle OPatch打补丁遇到问题详细汇总详细记录