CV算法复现(分类算法1/6):LeNet5(1998年 LeCun)
致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609
目录
致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609
1 本次要点
1.1 Python库语法
1.2 Pytorch框架语法
2 环境
3 网络结构
4 代码结构
4.1 model.py
4.2 utils.py
4.3 train.py
4.4 test.py
1 本次要点
1.1 Python库语法
- PIL 和 numpy 中维度顺序:H*W*C
- enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。如
- with:上下文管理器。with 语句适用于对资源进行访问的场合,相当于try….except….finlally,确保使用过程中不管是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭、线程中锁的自动获取和释放等。
:解决多层继承中可能出现的一些问题。使用多继承时,一般要用此函数。
1.2 Pytorch框架语法
- pytorch 中 tensor 维度顺序:C*H*W
optimizer.zero_grad():每计算一次batch后,要将历史梯度清零,防止累加。
- item():得到元素张量里面的元素值。(将张量值变为可计算的值?)
#不计算损失和梯度。(节省内存和计算量)
2 环境
- win10,GPU 1060 3G
- pytorch 1.4
- Python 3.6
3 网络结构
4 代码结构
- model.py
- utils.py
- train.py
- test.py
- data(存放cifar数据集:需要解压,不能更改压缩包名字)
- 1.jpg(测试图)
4.1 model.py
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self): #初始化函数super(LeNet, self).__init__() #super解决多层继承中可能出现的一些问题。使用多继承,一般要用此函数。self.conv1 = nn.Conv2d(3, 6, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16*5*5, 120) #输入要展平成1维向量(16通道,每通道5*5特征图)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x): #x代表输入的数据x = F.relu(self.conv1(x)) # input(3, 32, 32) output(6, 28, 28)x = self.pool1(x) # output(6, 14, 14)x = F.relu(self.conv2(x)) # output(16, 10, 10)x = self.pool2(x) # output(16, 5, 5)x = x.view(-1, 16*5*5) # output(16*5*5 = 400)x = F.relu(self.fc1(x)) # output(120)x = F.relu(self.fc2(x)) # output(84)x = self.fc3(x) # output(10)return x# # 测试网络输入输出维度是否写对
# import torch
# input1 = torch.rand([2, 3, 32, 32]) #B C H W
# print(input1)# model = LeNet()
# print(model)# output = model(input1)
# print(output)
4.2 utils.py
import torchvision.transforms as transformstransform_train = transforms.Compose([transforms.ToTensor(), #将数据转为tensor,维度顺序c*h*w, 值归一化[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 再对数据进行标准化]
)transform_test = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)clases_cifar10 = ('plane', 'car', 'bird', 'cat', 'deer''dog', 'frog', 'horse', 'ship', 'truck')
4.3 train.py
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transformsimport matplotlib.pyplot as plt
import numpy as np from utils import clases_cifar10, transform_train
from model import LeNet# 1 加载训练集(5万),预处理,打乱顺序并分割成一批批的batch
train_data = torchvision.datasets.CIFAR10(root='M:/CV_data/cifar-10/', train=True,download=False, transform=transform_train)
# win系统下num_work要设为0.
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32,shuffle=True, num_workers=0)# 2 加载验证集(1万),预处理,打乱顺序并分割成1个batch
val_data = torchvision.datasets.CIFAR10(root='M:/CV_data/cifar-10/', train=False,download=False, transform=transform_train)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=10000,shuffle=False, num_workers=0)
#创建迭代器对象(每次调用.next(),就会自动迭代集合中下一个元素,由于val集batch就1个,所以调用一次.next()就全部取完了)
val_data_iter = iter(val_loader)
val_images, val_labels = val_data_iter.next()# 3 初始化模型,损失函数,优化器
net = LeNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)# 4 开始训练
for epoch in range(5):running_loss = 0.0 #累加损失for step, data in enumerate(train_loader, start=0): #遍历训练数据,并返回当前indexinputs, labels = data# 每计算一次batch, 将历史梯度清零,防止累加。optimizer.zero_grad()# forward backward optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step() # 参数更新# 打印训练过程信息running_loss += loss.item() # item()得到元素张量里面的数值if step % 500 == 499:with torch.no_grad():#不计算损失和梯度。(节省内存和计算量)outputs = net(val_images) #[batch=10000, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == val_labels).sum().item() / val_labels.size(0)print('[%d, %5d] train_loss: %.3f val_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0# 5 保存模型
print("finished training")
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)
输出
4.4 test.py
import torch
from PIL import Image
from model import LeNet
from utils import clases_cifar10, transform_testnet = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))im = Image.open('1.jpg')
im = transform_test(im) # [c, h, w]
im = torch.unsqueeze(im, dim=0) # [n, c, h, w]with torch.no_grad(): # 此句可以不要,但大批量测试时,必须加此句,节省内存和计算量。outputs = net(im) # 输出值谁最大,预测的就是谁predict = torch.softmax(outputs, dim=1) # 将值转换成预测概率print(predict)max_index = torch.max(predict, dim=1)[1].data.numpy() # 返回一个1*1数组。print(clases_cifar10[int(max_index)]) # 打印对应的便签
输出
CV算法复现(分类算法1/6):LeNet5(1998年 LeCun)相关推荐
- NLP专栏简介:数据增强、智能标注、意图识别算法|多分类算法、文本信息抽取、多模态信息抽取、可解释性分析、性能调优、模型压缩算法等
NLP专栏简介:数据增强.智能标注.意图识别算法|多分类算法.文本信息抽取.多模态信息抽取.可解释性分析.性能调优.模型压缩算法等 专栏链接:NLP领域知识+项目+码源+方案设计 订阅本专栏你能获得什 ...
- 算法杂货铺——分类算法之决策树(Decision tree)
算法杂货铺--分类算法之决策树(Decision tree) 2010-09-19 16:30 by T2噬菌体, 88978 阅读, 29 评论, 收藏, 编辑 3.1.摘要 在前面两篇文章中,分别 ...
- 算法杂货铺——分类算法之贝叶斯网络(Bayesian networks)
算法杂货铺--分类算法之贝叶斯网络(Bayesian networks) 2010-09-18 22:50 by T2噬菌体, 66011 阅读, 25 评论, 收藏, 编辑 2.1.摘要 在上一篇文 ...
- k近邻算法(KNN)-分类算法
k近邻算法(KNN)-分类算法 1 概念 定义:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. k-近邻算法采用测量不同特征值之间的 ...
- 数据挖掘算法——常用分类算法总结
常用分类算法总结 分类算法 NBC算法 LR算法 SVM算法 ID3算法 C4.5 算法 C5.0算法 KNN 算法 ANN 算法 分类算法 分类是在一群已经知道类别标号的样本中,训练一种分类器,让其 ...
- 分类算法列一下有多少种?应用场景?分类算法介绍、常见分类算法优缺点、如何选择分类算法、分类算法评估
分类算法 分类算法介绍 概念 分类算法 常见分类算法 NBS LR SVM算法 ID3算法 C4.5 算法 C5.0算法 KNN 算法 ANN 算法 选择分类算法 分类算法性能评估 分类算法介绍 概念 ...
- ML算法基础——分类算法-决策树、随机森林
文章目录 1.决策树 1.1 认识决策树 1.2 信息论基础-银行贷款分析 1.2.1 信息论基础-信息熵 1.2.2 决策树的划分依据之一-信息增益 1.3 泰坦尼克号乘客生存分类 1.3.1 sk ...
- ML之监督学习算法之分类算法一 ———— k-近邻算法(最邻近算法)
一.概述 最近邻规则分类(K-Nearest Neighbor)KNN算法 由Cover 和Hart在1968年提出了最初的邻近算法, 这是一个分类(classification)算法 输入基于实例的 ...
- 算法杂货铺——分类算法之朴素贝叶斯分类(Naive Bayesian classification)
FROM: http://www.cnblogs.com/leoo2sk/archive/2010/09/17/1829190.html 0.写在前面的话 我个人一直很喜欢算法一类的东西,在我看来算法 ...
- 常见的分类算法及分类算法的评估方法
文章目录 贝叶斯分类法(Bayes) 决策树(Decision Tree) 支持向量机(SVM) K近邻(K-NN) 逻辑回归(Logistics Regression) 线性回归和逻辑回归的区别 神 ...
最新文章
- 100天59万行代码_如何抽出100天的代码时间
- 4-2-串的堆存储结构-串-第4章-《数据结构》课本源码-严蔚敏吴伟民版
- ios 获取一个枚举的所有值_Java enum枚举在实际项目中的常用方法
- c语言中有关void,sizeof,结构体的一些问题
- Backbone学习日记第二集——Model
- SLAM之g2o安装
- HDOJ 1007(T_T)
- 为何 Emoji 能给产品设计(营销)带来如此大的数据增长?
- Linux C Serial串口编程
- GreenOpenPaint简介
- 计算机考试怎么切换到桌面,考试系统很多考试系统全屏无法切换桌面,只能 – 手机爱问...
- android 开发书签大全
- UCHome源码阅读
- Android App 启动优化全记录
- JavaEye论坛热点推荐-2009年1月
- count在python中是什么意思_python count返回什么
- matlab 矩阵与数比较,MATLAB 对矩阵中的数据进行大小比较
- 《可转债入门十讲》笔记
- wxpython中表格顶角怎么设置_当wxGrid中的某个单元格以编程方式更改时,突出显示该行中的一行(使用wxPython)...
- imu matlab,IMU姿态解算matlab