动手学深度学习(PyTorch实现)(七)--LeNet模型
LeNet模型
- 1. LeNet模型
- 2. PyTorch实现
- 2.1 模型实现
- 2.2 获取数据与训练
1. LeNet模型
LeNet分为卷积层块和全连接层块两个部分。下面我们分别介绍这两个模块。
卷积层块里的基本单位是卷积层后接平均池化层:卷积层用来识别图像里的空间模式,如线条和物体局部,之后的平均池化层则用来降低卷积层对位置的敏感性。
卷积层块由两个这样的基本单位重复堆叠构成。在卷积层块中,每个卷积层都使用5×55 \times 55×5的窗口,并在输出上使用sigmoid激活函数。第一个卷积层输出通道数为6,第二个卷积层输出通道数则增加到16。
全连接层块含3个全连接层。它们的输出个数分别是120、84和10,其中10为输出的类别个数。
2. PyTorch实现
2.1 模型实现
面我们通过Sequential类来实现LeNet模型。
# 导入相应的包
import sys
sys.path.append("/home/kesci/input")
import d2lzh1981 as d2l
import torch
import torch.nn as nn
import torch.optim as optim
import time
# 展平操作,更改维度
class Flatten(torch.nn.Module): def forward(self, x):return x.view(x.shape[0], -1)
# 将图像大小重定型
class Reshape(torch.nn.Module): def forward(self, x):return x.view(-1,1,28,28) #(B x C x H x W)# LeNet的实现
net = torch.nn.Sequential( # 重新定型图像大小 Reshape(),# 第一层卷积层,输入通道数1,输出通道数6,卷积核尺寸5*5,填充为2nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2), #b*1*28*28 =>b*6*28*28# 经过sigmoid激活函数nn.Sigmoid(), # 平均池化层,核尺寸为2*2,步幅为2 nn.AvgPool2d(kernel_size=2, stride=2), #b*6*28*28 =>b*6*14*14# 第二层卷积层,输入通道为6,输出通道为16,卷积核为5*5nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), #b*6*14*14 =>b*16*10*10# 经过sigmoid激活函数nn.Sigmoid(),# 平均池化层,核尺寸为2*2,步幅为2 nn.AvgPool2d(kernel_size=2, stride=2), #b*16*10*10 => b*16*5*5# 展平操作Flatten(), #b*16*5*5 => b*400# 第一层全连接隐藏层,输入维度为16*5*5,输出维度为120nn.Linear(in_features=16*5*5, out_features=120),# 经过sigmoid激活函数nn.Sigmoid(),# 第二层全连接隐藏层,输入维度为120,输出维度为84nn.Linear(120, 84),# 经过sigmoid激活函数nn.Sigmoid(),# 第三层全连接输出层,输入维度为84,输出维度为10nn.Linear(84, 10)
)
在LeNet中,在卷积层块中输入的高和宽在逐层减小。卷积层由于使用高和宽均为5的卷积核,从而将高和宽分别减小4,而池化层则将高和宽减半,但通道数则从1增加到16。全连接层则逐层减少输出个数,直到变成图像的类别数10。
2.2 获取数据与训练
下面我们来实现LeNet模型。我们仍然使用Fashion-MNIST作为训练数据集。
# 数据批量数为256
batch_size = 256
# 获取训练数据与测试数据
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size, root='/home/kesci/input/FashionMNIST2065')
选择GPU进行训练,如果没有GPU,仍然采用CPU进行训练
def try_gpu():"""If GPU is available, return torch.device as cuda:0; else return torch.device as cpu."""if torch.cuda.is_available():device = torch.device('cuda:0')else:device = torch.device('cpu')return devicedevice = try_gpu()
我们实现evaluate_accuracy
函数,该函数用于计算模型net
在数据集data_iter
上的准确率。
#计算准确率
'''
(1). net.train()启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为True
(2). net.eval()
不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False
'''def evaluate_accuracy(data_iter, net,device=torch.device('cpu')):"""Evaluate accuracy of a model on the given data set."""acc_sum,n = torch.tensor([0],dtype=torch.float32,device=device),0for X,y in data_iter:# If device is the GPU, copy the data to the GPU.X,y = X.to(device),y.to(device)net.eval()with torch.no_grad():y = y.long()acc_sum += torch.sum((torch.argmax(net(X), dim=1) == y)) #[[0.2 ,0.4 ,0.5 ,0.6 ,0.8] ,[ 0.1,0.2 ,0.4 ,0.3 ,0.1]] => [ 4 , 2 ]n += y.shape[0]return acc_sum.item()/n
我们定义函数train_ch5
,用于训练模型。
'''
参数含义:
net: 要训练的网络
train_iter: 训练集
test_iter: 测试集
criterion: 损失函数
num_epochs: 训练周期
batch_size: 训练的小批量样本数
device: 训练使用的装置CPU或者GPU
lr: 学习率
'''
def train_ch5(net, train_iter, test_iter,criterion, num_epochs, batch_size, device,lr=None):"""Train and evaluate a model with CPU or GPU."""print('training on', device)net.to(device)# 随机梯度下降为优化函数optimizer = optim.SGD(net.parameters(), lr=lr)for epoch in range(num_epochs):# 初始化各种变量train_l_sum = torch.tensor([0.0],dtype=torch.float32,device=device)train_acc_sum = torch.tensor([0.0],dtype=torch.float32,device=device)n, start = 0, time.time()# 开始训练for X, y in train_iter:net.train()# 梯度参数清零optimizer.zero_grad()X,y = X.to(device),y.to(device) # y_hat为网络的输出值y_hat = net(X)# 计算损失loss = criterion(y_hat, y)# 反向传播loss.backward()# 更新参数optimizer.step()with torch.no_grad():# 转化为long型y = y.long()# 计算损失的和train_l_sum += loss.float()# 计算预测正确的个数train_acc_sum += (torch.sum((torch.argmax(y_hat, dim=1) == y))).float()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net,device)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, ''time %.1f sec'% (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc,time.time() - start))
我们重新将模型参数初始化到对应的设备device
(cpu
or cuda:0
)之上,并使用Xavier随机初始化。损失函数和训练算法则依然使用交叉熵损失函数和小批量随机梯度下降。
# 训练
lr, num_epochs = 0.9, 10def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:torch.nn.init.xavier_uniform_(m.weight)net.apply(init_weights)
net = net.to(device)
#交叉熵描述了两个概率分布之间的距离,交叉熵越小说明两者之间越接近
criterion = nn.CrossEntropyLoss()
train_ch5(net, train_iter, test_iter, criterion,num_epochs, batch_size,device, lr)
输出结果为:
对训练好的网络进行测试:
# test
for testdata,testlabe in test_iter:testdata,testlabe = testdata.to(device),testlabe.to(device)break
print(testdata.shape,testlabe.shape)
net.eval()
y_pre = net(testdata)
print(torch.argmax(y_pre,dim=1)[:10])
print(testlabe[:10])
测试结果为:
动手学深度学习(PyTorch实现)(七)--LeNet模型相关推荐
- 动手学深度学习(PyTorch实现)(八)--AlexNet模型
AlexNet模型 1. AlexNet模型介绍 1.1 AlexNet的特点 1.2 AlexNet的结构 1.3 AlexNet参数数量 2. AlexNet的PyTorch实现 2.1 导入相应 ...
- 动手学深度学习(PyTorch实现)(十一)--GoogLeNet模型
GoogLeNet模型 1. GoogLeNet介绍 1.1 背景 1.2 GoogLeNet网络结构 2. PyTorch实现 2.1 导入相应的包 2.2 定义Inception块结构 2.3 定 ...
- 动手学深度学习(PyTorch实现)(十三)--ResNet模型
ResNet模型 1. ResNet介绍 2. ResNet结构 3. ResNet的PyTorch实现 3.1 导入所需要的包 3.2 构建ResNet网络 3.3 开始训练 注:本文部分内容参考博 ...
- 动手学深度学习(PyTorch实现)(十)--NiN模型
NiN模型 1. NiN模型介绍 1.1 NiN模型结构 1.2 NiN结构与VGG结构的对比 2. PyTorch实现 2.1 导入相应的包 2.2 定义NiN block 2.3 全局最大池化层 ...
- 动手学深度学习(PyTorch实现)(九)--VGGNet模型
VGGNet模型 1. VGGNet模型介绍 1.1 VGGNet的结构 1.2 VGGNet结构举例 2. VGGNet的PyTorch实现 2.1 导入相应的包 2.2 基本网络单元block 2 ...
- 【动手学深度学习PyTorch版】6 权重衰退
上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...
- 伯禹公益AI《动手学深度学习PyTorch版》Task 05 学习笔记
伯禹公益AI<动手学深度学习PyTorch版>Task 05 学习笔记 Task 05:卷积神经网络基础:LeNet:卷积神经网络进阶 微信昵称:WarmIce 昨天打了一天的<大革 ...
- 【动手学深度学习PyTorch版】19 网络中的网络 NiN
上一篇请移步[动手学深度学习PyTorch版]18 使用块的网络 VGG_水w的博客-CSDN博客 目录 一.网络中的网络 NiN 1.1 NiN ◼ 全连接层的问题 ◼ 大量的参数会带来很多问题 ◼ ...
- 动手学深度学习Pytorch Task01
深度学习目前以及未来都有良好的发展前景.正值疫情期间,报名参加了动手学深度学习pytorch版的公开课,希望在以后的学习生活中能够灵活运用学到的这些知识. 第一次课主要包含三个部分:线性回归.soft ...
最新文章
- mysql中concat函数的使用相关总结
- Binary Tree Maximum Path Sum
- 破译手势在对话中的意义
- cinder配置多ceph储存池[Ceph and Cinder multi-backend]
- boost::python::upcast的测试程序
- ITK:两幅图像之差的绝对值
- 架构模式_Index
- Xcode代码提示联想功能失效,按command键点不进去类库,提示“?”
- LeetCode 217. 存在重复元素(哈希)
- oppo 手机侧滑快捷菜单_关于oppo手机菜单键调出的方法,原来是这样的
- 【转】雷军自曝创业第一年:掏自己的钱创业成功率最高
- OverIQ 中文系列教程【翻译完成】
- IIS7下 【请求被中止: 未能创建 SSL/TLS 安全通道 】 解决方法
- 【读fastclick源码有感】彻底解决tap“点透”,提升移动端点击响应速度
- 项目功能介绍 非常有用
- sql中exist()的用法
- Windows中使用http-server搭建一个本地服务
- aac蓝牙编解码协议_蓝牙协议总结
- 天若OCR专业版软件,现可无需联网本地使用了~
- rp文件,怎么用浏览器预览