前言

本学习笔记参考自B站up主霹雳吧啦Wz

代码均来自其github开源项目WZMIAOMIAO/deep-learning-for-image-processing: deep learning for image processing including classification and object-detection etc. (github.com)

视频链接在这里:2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili


LeNet——训练和预测篇

2、trian

老样子,先给出代码

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = val_data_iter.next()# classes = ('plane', 'car', 'bird', 'cat',#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)for epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

下面来开始讲一下写一个训练模块的步骤

首先我们需要导入数据集

这里用到的是pytorch官网提供的[CIFAR10数据集](CIFAR-10 and CIFAR-100 datasets (toronto.edu))

该数据集包含60000张32*32的图片,10个分类,每个分类6000张,其中50000张训练集,10000张测试集

导包我们就不讲了,在上面的代码自己复制就好

在下载数据集之前我们先设置一下标准化的代码

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

关于为什么要使用到标准化,可以看一下这个博主的解释(2条消息) 数据标准化:减去均值,除以方差的理解_仙女修炼史的博客-CSDN博客_减均值除方差的意义

其中transforms.ToTensor()是讲图片的格式改为[C, H, W],前面的文章我们提到pytorch支持的张量通道为[batch, C, H, W],而一般图片的通道为[H, W, C]。Normalize函数实现了数据的标准化,再由Compose函数将这两个步骤组合起来。

这里使用到pytorch提供的下载方式,注意trian=True才会下载训练集,否则就是测试集

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)

接着我们把数据集进行一下划分,先取batch_size=4拿四张图片出来看一下数据集的内容

train_loader = torch.utils.data.DataLoader(train_set, batch_size=4,shuffle=True, num_workers=0)

这里用到一个显示图片的函数,就不详细介绍了,官网有

import matplotlib.pyplot as plt
import numpy as npclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

可以看到拿出来的四张图片和对应的标签,由于像素是32*32的,所以糊

接着我们把前面的写在一个主函数里面,方便与后面的函数衔接调用,顺便把测试集也写上,其中用到了一个iter函数把测试集的图片和标签分开,方便后面测试精度的时候用

def main():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)var_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)var_loader = torch.utils.data.DataLoader(var_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(var_loader)val_image, val_label = val_data_iter.next()

数据集我们就弄好了

接下来开始训练阶段

首先把之前写的model实例化,确定损失函数,使用Adam优化器进行参数优化

 net = LeNet()loss_function = nn.CrossEntropyLoss()        # 这里就解释了为什么之前网络层结构搭建时最后没有用交叉熵函数,因为                                                     nn.CorssEntropyLoss()这个函数里面已经内置了交叉熵函数optimizer = optim.Adam(net.parameters(), lr=0.001)

net.parameters为网络层结构中所有可以优化的参数,同时设置学习率为0.001

接着我们设置5次训练次数,同时赋值一个loss来记录误差

 for epoch in range(5):running_loss = 0.0

我们对每个批次的训练集进行训练,同时对每个批次的数据进行一个编号,enumerate可以在拿到批次的同时在前面建立序号,注意这里用一个start参数让step从1开始

     for step, data in enumerate(train_loader, start=1):#这里data是一个[img, laber]的listinputs, lables = data

每一次开始的时候将梯度进行清零,防止梯度爆炸

         optimizer.zero_grad()

然后将我们的数据集进行一个输入,和标签计算误差

         outputs = net(inputs)loss = loss_function(outputs, labels)

误差反向转播,参数更新

         loss.backward()optimizer.step()

这样,一个完整的训练过程就完成了!

然后把每批次训练的误差进行一个累加,因为loss是一个张量,所以这里用到item这个函数拿到里面的元素

         running_loss += loss.item()

我们每500批次训练进行一次精度检验,这里用到了val_set,所以前面说的测试集其实不准确,应该是验证集(在训练过程中对精度进行检验)

因为是验证,所以不需要用到梯度更新,如果不禁用梯度更新,torch会自动计算占用大量内存

         if step % 500 == 0:with torch.no_grad():outputs = net(val_image)

可以看到outputs的形状,5000为我们设置的验证集的batch_size,10是一个输入对应的10个类别的概率,很显然,我们要的是后面这个数据

                 predict_y = torch.max(outputs, dim=1)[1]

这里用到一个max函数取出10个概率的最大值,就是我们预测的值,后面的【1】意思是我们取得是max函数返回值的序号,用序号来获取最后对应的标签。max函数具体的用法大家可以看这个博客torch.max()使用讲解 - 简书 (jianshu.com)

                 accuracy = torch.eq(predict_y, val_label).sum().item()/val_label.size(0)

这里通过torch.eq对预测值和标签进行比较,然后总和正确的标签数量,因为这里求和后还是tenser格式,所以我们要通过一个item()函数拿到里面的数字,然后除以总标签数就是精度

                 print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step, running_loss / 500, accuracy))running_loss = 0.0

最后进行一个参数保存

    print('Finished Training')save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)

普天同庆!训练模块就搭建完成了!!


3、预测

代码

import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))im = Image.open('Plan.jpg')im = transform(im)  # [C, H, W]im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()print(classes[int(predict)])if __name__ == '__main__':main()

这里就不详讲了,跟训练大差不差


然后就是希望下一次的AlexNet这个星期能做出来,干巴爹噢!

LeNet——训练和预测篇相关推荐

  1. Pythonnbsp;实现LeNet网络模型的训练及预测

    更多python教程请到友情连接: 菜鸟教程https://www.piaodoo.com 初中毕业读什么技校 http://cntkd.net 茂名一技http://www.enechn.com p ...

  2. 基于mlr3工具包的机器学习(1)——数据、模型、训练、预测

    专注系列化.高质量的R语言教程 (查看推文索引) mlr3是一个关于机器学习的工具包,关于它的详细介绍可参见: 网页版:https://mlr3book.mlr-org.com/intro.html ...

  3. pytorch建立mobilenetV3-ssd网络并进行训练与预测

    pytorch建立mobilenetV3-ssd网络并进行训练与预测 前言 Step1:搭建mobilenetV3-ssd网络框架 需要提前准备的函数和类. mobilenetV3_large 调用m ...

  4. ResNet网络的训练和预测

    ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...

  5. 三两下实现NLP训练和预测,这四个框架你要知道

    作者 | 狄东林 刘元兴 朱庆福 胡景雯 编辑 | 刘元兴,崔一鸣 来源 | 哈工大SCIR(ID:HIT_SCIR) 引言 随着人工智能的发展,越来越多深度学习框架如雨后春笋般涌现,例如PyTorc ...

  6. Spark技术在京东智能供应链预测的应用——按照业务进行划分,然后利用scikit learn进行单机训练并预测...

    3.3 Spark在预测核心层的应用 我们使用Spark SQL和Spark RDD相结合的方式来编写程序,对于一般的数据处理,我们使用Spark的方式与其他无异,但是对于模型训练.预测这些需要调用算 ...

  7. LeNet训练MNIST

    jupyter notebook: https://github.com/Penn000/NN/blob/master/notebook/LeNet/LeNet.ipynb LeNet训练MNIST ...

  8. 特征训练、预测一致性管理工具:开源项目Feast

    在机器学习的流程大体可以分成模型训练和模型服务两个阶段.无论是训练和服务阶段,其实都需要进行特征工程相关的工作,这块的技术挑战就是如何保证训练和预测过程中使用的特征是一致的.这个问题困扰了很多机器学习 ...

  9. MAT之GUI:GUI的方式创建/训练/仿真/预测神经网络

    MAT之GUI:GUI的方式创建/训练/仿真/预测神经网络 目录 操作步骤 (0).打开 (1).导入数据 (2)创建模型network_Jason_niu (3)设置参数并训练 (4)仿真预测 操作 ...

最新文章

  1. 如何在Linux中使用Shell脚本终止用户会话?
  2. CPU,MPU,MCU,SOC,SOPC联系与差别
  3. mysql更新多条数据6_mysql语句:批量更新多条记录的不同值
  4. 力扣——字符串转换整数 (atoi)
  5. jQuery 异步上传插件 Uploadify 使用 (Java平台)
  6. argparse模块
  7. 第六届中国电子信息博览会今日正式开幕,智享新时代!
  8. HTML 限制文本框只能输入特定字符(比如数字 onkeyup+onafterpaste)
  9. Android仿自如客APP裸眼3D效果
  10. 分子动力学模拟学习3-Gromacs数据处理
  11. 文件及文件夹管理规范
  12. docker-compose 部署shipyard
  13. 微积分 —— 有限覆盖定理
  14. FPGA学习心得分享——交通灯(EGO1)
  15. 外卖匹配系统_外卖平台派单规则浅析
  16. matlab语音识别系统(源代码),matlab语音识别系统(源代码)最新版DOC.doc
  17. 释放数据生产力 我们该如何思考、如何行动?
  18. root权限获取排行榜,root权限软件排行榜
  19. Java计算机毕业设计小区物业管理系统
  20. 很漂亮的字体闪烁效果

热门文章

  1. python 之 第一次亲密接触
  2. 乘风破浪的5G,与隐藏在深海的EMC暗礁
  3. 图像局部特征(七)--SURF原理总结
  4. 阿里负责人揭秘面试潜规则【转】
  5. Java学习笔记 | 尚硅谷项目三详解
  6. c语言编程解释,c语言编程,请高手一字一句解释
  7. 一亩三分地新手上路答案
  8. 地理信息系统复习摘要
  9. java基础第二十五天 数据库
  10. 将扩散模型用于目标检测任务,从随机框中直接检测!