致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609

目录

致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609

1 本次要点

1.1 Python库语法

1.2 Pytorch框架语法

2 环境

3 网络结构

4 代码结构

4.1 model.py

4.2 utils.py

4.3 train.py

4.4 test.py


1 本次要点

1.1 Python库语法

  1. PIL 和 numpy 中维度顺序:H*W*C
  2. enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。如
  3. with:上下文管理器。with 语句适用于对资源进行访问的场合,相当于try….except….finlally,确保使用过程中不管是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭、线程中锁的自动获取和释放等。
  4. :解决多层继承中可能出现的一些问题。使用多继承时,一般要用此函数。

1.2 Pytorch框架语法

  1. pytorch 中 tensor 维度顺序:C*H*W
  2. optimizer.zero_grad():每计算一次batch后,要将历史梯度清零,防止累加。

  3. item():得到元素张量里面的元素值。(将张量值变为可计算的值?)
  4. #不计算损失和梯度。(节省内存和计算量)

2 环境

  • win10,GPU 1060 3G
  • pytorch 1.4
  • Python 3.6

3 网络结构

4 代码结构

  • model.py
  • utils.py
  • train.py
  • test.py
  • data(存放cifar数据集:需要解压,不能更改压缩包名字)
  • 1.jpg(测试图)

4.1 model.py

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self): #初始化函数super(LeNet, self).__init__() #super解决多层继承中可能出现的一些问题。使用多继承,一般要用此函数。self.conv1 = nn.Conv2d(3, 6, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16*5*5, 120) #输入要展平成1维向量(16通道,每通道5*5特征图)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x): #x代表输入的数据x = F.relu(self.conv1(x)) # input(3, 32, 32) output(6, 28, 28)x = self.pool1(x)         # output(6, 14, 14)x = F.relu(self.conv2(x)) # output(16, 10, 10)x = self.pool2(x)         # output(16, 5, 5)x = x.view(-1, 16*5*5)    # output(16*5*5 = 400)x = F.relu(self.fc1(x))   # output(120)x = F.relu(self.fc2(x))   # output(84)x = self.fc3(x)           # output(10)return x# # 测试网络输入输出维度是否写对
# import torch
# input1 = torch.rand([2, 3, 32, 32]) #B C H W
# print(input1)# model = LeNet()
# print(model)# output = model(input1)
# print(output)

4.2 utils.py

import torchvision.transforms as transformstransform_train = transforms.Compose([transforms.ToTensor(), #将数据转为tensor,维度顺序c*h*w, 值归一化[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 再对数据进行标准化]
)transform_test = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)clases_cifar10 = ('plane', 'car', 'bird', 'cat', 'deer''dog', 'frog', 'horse', 'ship', 'truck')

4.3 train.py

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transformsimport matplotlib.pyplot as plt
import numpy as np from utils import clases_cifar10, transform_train
from model import LeNet# 1 加载训练集(5万),预处理,打乱顺序并分割成一批批的batch
train_data = torchvision.datasets.CIFAR10(root='M:/CV_data/cifar-10/', train=True,download=False, transform=transform_train)
# win系统下num_work要设为0.
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32,shuffle=True, num_workers=0)# 2 加载验证集(1万),预处理,打乱顺序并分割成1个batch
val_data = torchvision.datasets.CIFAR10(root='M:/CV_data/cifar-10/', train=False,download=False, transform=transform_train)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=10000,shuffle=False, num_workers=0)
#创建迭代器对象(每次调用.next(),就会自动迭代集合中下一个元素,由于val集batch就1个,所以调用一次.next()就全部取完了)
val_data_iter = iter(val_loader)
val_images, val_labels = val_data_iter.next()# 3 初始化模型,损失函数,优化器
net = LeNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)# 4 开始训练
for epoch in range(5):running_loss = 0.0 #累加损失for step, data in enumerate(train_loader, start=0): #遍历训练数据,并返回当前indexinputs, labels = data# 每计算一次batch, 将历史梯度清零,防止累加。optimizer.zero_grad()# forward backward optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step() # 参数更新# 打印训练过程信息running_loss += loss.item() # item()得到元素张量里面的数值if step % 500 == 499:with torch.no_grad():#不计算损失和梯度。(节省内存和计算量)outputs = net(val_images) #[batch=10000, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == val_labels).sum().item() / val_labels.size(0)print('[%d, %5d] train_loss: %.3f val_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0# 5 保存模型
print("finished training")
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

输出

4.4 test.py

import torch
from PIL import Image
from model import LeNet
from utils import clases_cifar10, transform_testnet = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))im = Image.open('1.jpg')
im = transform_test(im) # [c, h, w]
im = torch.unsqueeze(im, dim=0) # [n, c, h, w]with torch.no_grad(): # 此句可以不要,但大批量测试时,必须加此句,节省内存和计算量。outputs = net(im) # 输出值谁最大,预测的就是谁predict = torch.softmax(outputs, dim=1) # 将值转换成预测概率print(predict)max_index = torch.max(predict, dim=1)[1].data.numpy() # 返回一个1*1数组。print(clases_cifar10[int(max_index)]) # 打印对应的便签

输出

CV算法复现(分类算法1/6):LeNet5(1998年 LeCun)相关推荐

  1. NLP专栏简介:数据增强、智能标注、意图识别算法|多分类算法、文本信息抽取、多模态信息抽取、可解释性分析、性能调优、模型压缩算法等

    NLP专栏简介:数据增强.智能标注.意图识别算法|多分类算法.文本信息抽取.多模态信息抽取.可解释性分析.性能调优.模型压缩算法等 专栏链接:NLP领域知识+项目+码源+方案设计 订阅本专栏你能获得什 ...

  2. 算法杂货铺——分类算法之决策树(Decision tree)

    算法杂货铺--分类算法之决策树(Decision tree) 2010-09-19 16:30 by T2噬菌体, 88978 阅读, 29 评论, 收藏, 编辑 3.1.摘要 在前面两篇文章中,分别 ...

  3. 算法杂货铺——分类算法之贝叶斯网络(Bayesian networks)

    算法杂货铺--分类算法之贝叶斯网络(Bayesian networks) 2010-09-18 22:50 by T2噬菌体, 66011 阅读, 25 评论, 收藏, 编辑 2.1.摘要 在上一篇文 ...

  4. k近邻算法(KNN)-分类算法

    k近邻算法(KNN)-分类算法 1 概念 定义:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. k-近邻算法采用测量不同特征值之间的 ...

  5. 数据挖掘算法——常用分类算法总结

    常用分类算法总结 分类算法 NBC算法 LR算法 SVM算法 ID3算法 C4.5 算法 C5.0算法 KNN 算法 ANN 算法 分类算法 分类是在一群已经知道类别标号的样本中,训练一种分类器,让其 ...

  6. 分类算法列一下有多少种?应用场景?分类算法介绍、常见分类算法优缺点、如何选择分类算法、分类算法评估

    分类算法 分类算法介绍 概念 分类算法 常见分类算法 NBS LR SVM算法 ID3算法 C4.5 算法 C5.0算法 KNN 算法 ANN 算法 选择分类算法 分类算法性能评估 分类算法介绍 概念 ...

  7. ML算法基础——分类算法-决策树、随机森林

    文章目录 1.决策树 1.1 认识决策树 1.2 信息论基础-银行贷款分析 1.2.1 信息论基础-信息熵 1.2.2 决策树的划分依据之一-信息增益 1.3 泰坦尼克号乘客生存分类 1.3.1 sk ...

  8. ML之监督学习算法之分类算法一 ———— k-近邻算法(最邻近算法)

    一.概述 最近邻规则分类(K-Nearest Neighbor)KNN算法 由Cover 和Hart在1968年提出了最初的邻近算法, 这是一个分类(classification)算法 输入基于实例的 ...

  9. 算法杂货铺——分类算法之朴素贝叶斯分类(Naive Bayesian classification)

    FROM: http://www.cnblogs.com/leoo2sk/archive/2010/09/17/1829190.html 0.写在前面的话 我个人一直很喜欢算法一类的东西,在我看来算法 ...

  10. 常见的分类算法及分类算法的评估方法

    文章目录 贝叶斯分类法(Bayes) 决策树(Decision Tree) 支持向量机(SVM) K近邻(K-NN) 逻辑回归(Logistics Regression) 线性回归和逻辑回归的区别 神 ...

最新文章

  1. 100天59万行代码_如何抽出100天的代码时间
  2. 4-2-串的堆存储结构-串-第4章-《数据结构》课本源码-严蔚敏吴伟民版
  3. ios 获取一个枚举的所有值_Java enum枚举在实际项目中的常用方法
  4. c语言中有关void,sizeof,结构体的一些问题
  5. Backbone学习日记第二集——Model
  6. SLAM之g2o安装
  7. HDOJ 1007(T_T)
  8. 为何 Emoji 能给产品设计(营销)带来如此大的数据增长?
  9. Linux C Serial串口编程
  10. GreenOpenPaint简介
  11. 计算机考试怎么切换到桌面,考试系统很多考试系统全屏无法切换桌面,只能 – 手机爱问...
  12. android 开发书签大全
  13. UCHome源码阅读
  14. Android App 启动优化全记录
  15. JavaEye论坛热点推荐-2009年1月
  16. count在python中是什么意思_python count返回什么
  17. matlab 矩阵与数比较,MATLAB 对矩阵中的数据进行大小比较
  18. 《可转债入门十讲》笔记
  19. wxpython中表格顶角怎么设置_当wxGrid中的某个单元格以编程方式更改时,突出显示该行中的一行(使用wxPython)...
  20. imu matlab,IMU姿态解算matlab

热门文章

  1. Python数据挖掘:数据转换-数据规范化
  2. Kotlin implements 的实现
  3. android Style(样式)的解析
  4. 吴裕雄--天生自然 高等数学学习:无穷级数
  5. 6.微信小程序的如何使用全局属性
  6. hadoop2.4.1集群搭建
  7. mac 使用nvm安装node
  8. 持续集成之戏说Check-in Dance
  9. 学习Modern UI for WPF
  10. ASP.NET文件的下载