MindSpore实现手写数字识别
具体流程参考教程:MindSpore快速入门 MindSpore 接口文档
注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用。
数据的流水线处理
defdatapipe(dataset, batch_size):'''数据处理流水线'''image_transform = [vision.Rescale(1.0 / 255.0, 0), # 缩放 output = image * rescale + shift.vision.Normalize(mean=(0.1307,), std=(0.3081,)), # 根据平均值和标准偏差对输入图像进行归一化vision.HWC2CHW() # 转换为NCHW格式]label_transform = transforms.TypeCast(mindspore.int32) # 转为mindspore的int32格式dataset = dataset.map(image_transform, 'image') # 对各个图像按照流水线处理dataset = dataset.map(label_transform, 'label') # 对各个标签转换为int32dataset = dataset.batch(batch_size)return dataset
这段代码中对输入图片进行了缩放、归一化和格式转换三个操作,按照流水线运行。
流水线操作
数据流水线处理的介绍:【AI设计模式】03-数据处理-流水线(Pipeline)模式
总结而言,海量数据下,流水线模式可以实现高效的数据处理,当然也会占用更多的CPU和内存资源。
map操作
MindSpore下dataset的map操作:第一个参数是处理函数列表,第二个参数是需要处理的列。
map函数会将数据集中第二个参数的指定的列作为输入,调用第一个参数的处理函数执行处理,如果有多个处理函数,上一个函数的输出作为下一个函数的输入。
NCHW和NHWC格式的优缺点
NCHW
缺点:必须等所有通道输入准备好才能得到最终输出结果,需要占用较大的临时空间。
优点:是 Nvidia cuDNN 默认格式,使用 GPU 加速时用 NCHW 格式速度会更快。(这个是什么原因呢?没找到资料_(:з」∠)_)
NHWC
缺点:GPU 加速较NCHW更慢
优点:访存局部性更好(每三个输入像素即可得到一个输出像素)
参考文章:【深度学习框架输入格式】NCHW还是NHWC?
为什么pytorch中transforms.ToTorch要把(H,W,C)的矩阵转为(C,H,W)?
模型
classNetwork(nn.Cell):'''Network model'''def__init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))defconstruct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logits
基类
MindSpore的模型基类是mindspore.nn.Cell
Pytorch的模型基类是torch.nn.Module
全连接层
MindSpore的全连接层是mindspore.nn.Dense
Pytorch的全连接层是torch.nn.Linear
模型连接
MindSpore的顺序容器是mindspore.nn.SequentialCell
Pytorch的顺序容器是torch.nn.Sequential
前向传播
MindSpore的前向传播函数(要执行的计算逻辑)基类函数为construct(self, xxx)
Pytorch的前向传播函数基类函数为forward(self, xxx)
损失函数和优化策略
my_loss_fn = nn.CrossEntropyLoss()
my_optimizer = nn.SGD(model.trainable_params(), 1e-2)
交叉熵:把来自一个分布q的消息使用另一个分布p的最佳代码传达方式计算得到的平均消息长度,即为交叉熵。针对交叉熵,这个文章讲的较好:损失函数:交叉熵详解
MindSpore的交叉熵函数和Pytorch类似:
前者是mindspore.nn.CrossEntropyLoss(),后者是torch.nn.CrossEntropyLoss()
训练
deftrain(model_train, dataset, loss_fn, optimizer):'''训练函数'''# Define forward functiondefforward_fn(data, label):logits = model_train(data)loss = loss_fn(logits, label)return loss, logits# Get gradient functiongrad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step trainingdeftrain_step(data, label):(loss, _), grads = grad_fn(data, label)loss = ops.depend(loss, optimizer(grads))return losssize = dataset.get_dataset_size()model_train.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
value_and_grad
官网对value_and_grad函数的介绍如下:mindspore.ops.value_and_grad
按照官网的描述,这个函数的作用是:生成求导函数,用于计算给定函数的正向计算结果和梯度。
我们需要给这个函数传入模型的正向传输函数和待求导的参数
其中模型的正向传输函数需要封装一下,返回loss的计算, 用于后续优化器的梯度计算;
待求导的参数可以写为model.trainable_params(),也可以由优化器提供(optimizer.parameters),因为优化器初始化时已经传入需要求导的参数。
总之,这个接口返回的是一个函数,函数的作用是把正向传播、反向传播的整个流程走一遍,最后的输出为正向传输函数的返回值+待求导参数的梯度值
depend算子
在训练时使用到了depend算子,官网对Depend函数的介绍如下:mindspore.ops.Depend
# Define function of one-step trainingdeftrain_step(data, label):(loss, _), grads = grad_fn(data, label)loss = ops.depend(loss, optimizer(grads))return loss
经询问分析,使用depend算子的原因是,在静态图模式下,函数执行的先后顺序可能会被优化,这就可能存在loss在grad_fn(value_and_grad)之前就被返回使用的情况,导致返回的loss不正确。
因此通过使用depend算子,来保证loss的返回动作在optimizer之后执行,而optimizer的输入依赖grad_fn,因此optimizer一定在grad_fn之后执行,这就保证了depend返回的loss确实是经过grad_fn计算的最新结果。
当然,mindspore也是支持动态图模式的,只需加一行代码:
# 设置为动态图模式
mindspore.set_context(mode=mindspore.PYNATIVE_MODE)
# 设置为静态图模式# mindspore.set_context(mode=mindspore.GRAPH_MODE)model = Network()
print(model)
然后训练函数就可以这么写:
deftrain_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return loss
但是实测,动态图模式下,训练速度相比静态图慢了很多。
关于mindspore动态图和静态图模式的介绍,可看这个官方文档:动静态图
训练尺寸
各个文章在介绍梯度下降法时,通常介绍的是批量梯度下降法,但是训练模型时用的最多的是小批量梯度下降法。这里先讲下批量梯度下降、随机梯度下降和小批量梯度下降的区别。
批量梯度下降
批量梯度下降法的流程是:假设有1000个数据,经过正向计算,得到1000个计算结果,误差函数的计算公式依赖这1000个计算结果;再对误差函数进行反向传播求导,得到模型里参数的梯度值;同样地,对误差函数求导得梯度,也依赖这1000个计算结果;最后基于学习率更新参数,然后进入下一轮训练。
因此,标准的批量梯度下降,需要每次计算出1000个数据的正向传播结果,才可以得到参数梯度值,然后下一轮训练,重新计算1000个计算结果…这就存在大量的运算量,使得训练容易变得非常耗时。
随机梯度下降
随机梯度下降法的流程是,假设有1000个数据,我们随机取1个数据,经过正向计算,得到1个计算结果,误差函数的计算公式就只依赖这1个计算结果;然后反向传播求导,得到基于1个计算结果的梯度值,最后基于学习率更新参数,然后进入下一轮训练。下一轮训练时,随机取另1个数据,重复上述操作…
这种方法下,极大地降低了计算量,而且理论上,只要数据量够大,数据足够随机,最后也总会下降到所需极值点,毕竟计算数据量小了很多,算得更快了,下降速度也会快很多。但是每次只依赖1个数据,就使得梯度的下降方向在整体方向上不稳定,容易到处飘,最后的结果可能不会是全局最优。
小批量梯度下降
小批量梯度下降法的流程是:假设有1000个数据,我们随机取100个数据,经过正向计算,得到100个计算结果,误差函数的计算公式依赖这100个计算结果;然后反向传播求导,得到基于100个计算结果的梯度值,最后基于学习率更新参数,然后进入下一轮训练。下一轮训练时,随机取另100个数据,重复上述操作…
可以看出,小批量梯度下降 结合了 批量梯度下降 和 随机梯度下降 的优缺点,使得计算即不那么耗时,又保证参数更新路径和结果相对稳定。
实例
mindspore的这个例子用的是小批量梯度下降,train_step每次输入64个数据,然后前向传播、计算梯度、更新参数,再进入下一个epoch,随机取新的64个数据,重复训练…
size = dataset.get_dataset_size()
model_train.set_train()
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
在将数据集进行datapipe后,返回的train_dataset和test_dataset都是以batch_size=64个为一组进行输出的,此处dataset.get_dataset_size()返回的size是有多少组数据。测试集返回的size为938,表示一共有938组,每组64个图片数据。实际上MNIST只有60000个测试集图片,因此最后一组只有32个图片。
运行结果
Epoch 1
-------------------------------
loss: 2.303684 [ 0/938]
loss: 2.291476 [100/938]
loss: 2.273411 [200/938]
loss: 2.212310 [300/938]
loss: 1.969760 [400/938]
loss: 1.600426 [500/938]
loss: 1.004380 [600/938]
loss: 0.735266 [700/938]
loss: 0.672223 [800/938]
loss: 0.578563 [900/938]
Test: Accuracy: 85.3%, Avg loss: 0.528851Epoch 2
-------------------------------
loss: 0.384008 [ 0/938]
loss: 0.453575 [100/938]
loss: 0.277697 [200/938]
loss: 0.317674 [300/938]
loss: 0.294471 [400/938]
loss: 0.519272 [500/938]
loss: 0.253794 [600/938]
loss: 0.389252 [700/938]
loss: 0.383196 [800/938]
loss: 0.334877 [900/938]
Test: Accuracy: 90.2%, Avg loss: 0.334850
此处跑了两轮训练,可以看出,第一轮的938组数据的训练过程中,参数快速调整至合理范围(loss从2.3降低到0.5),但是第二轮的938组数据的训练过程中,loss出现了上下波动(0.3->0.4->0.2->0.3…),即模型参数向当前数据组的梯度下降的方向走了一小步后,新的数据组算出的loss反而比之前还提高了。
这主要是因为当前数据组的梯度下降方向 无法代表 替他数据组/所有数据的梯度下降方向,当然也可能是学习率(步长)太大导致跨过了最低点,这个就具体问题具体分析了。
总结
MindSpore和Pytorch在接口命名上存在区别,但是实际使用过程中,开发思路还是一致的。因此最关键的还是要熟悉深度学习的思路和流程,至于思路和代码实现的映射,这就唯手熟尔。
MindSpore实现手写数字识别相关推荐
- 【mindspore】mindspore实现手写数字识别
mindspore实现手写数字识别 具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用 ...
- MindSpore实现手写数字识别代码
MindSpore是华为自研的一套AI框架,最佳匹配昇腾处理器,最大程度地发挥硬件能力.作为AI入门的LeNet手写字体识别网络,网络大小和数据集都不大,可以在CPU上面进行训练和推理.下面是基于Mi ...
- MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
- 基于LeNet5的手写数字识别,在ModelArts和GPU上复现
基于LeNet5的手写数字识别 实验介绍 LeNet5 + MNIST被誉为深度学习领域的"Hello world".本实验主要介绍使用MindSpore在MNIST手写数字数据集 ...
- MindSpore手写数字识别体验
文章目录 1. 环境准备 2. 安装minspore及其套件 3. 程序撰写 4. 总结 今天带大家体验一下 MindSpore 这个 AI 框架来完成手写数字识别的任务 1. 环境准备 使用Anac ...
- MNIST手写数字识别 —— ResNet-经典卷积神经网络
了解ResNet18的网络结构:掌握模型的保存和加载方法:掌握批量测试图片的方法. 结合图像分类任务,使用典型的图像分类网络ResNet18,实现手写数字识别. ResNet作为经典的图像分类网络有其 ...
- MNIST手写数字识别 —— 图像分析法实现二分类
手写数字任务识别简介 MNIST 数据集来自美国国家标准与技术研究所(National Institute of Standards and Technology,简称 NIST ),总共有7万张图, ...
- 深蓝学院第三章:基于卷积神经网络(CNN)的手写数字识别实践
参看之前篇章的用全连接神经网络去做手写识别:https://blog.csdn.net/m0_37957160/article/details/114105389?spm=1001.2014.3001 ...
- 深蓝学院第二章:基于全连接神经网络(FCNN)的手写数字识别
如何用全连接神经网络去做手写识别??? 使用的是jupyter notebook这个插件进行代码演示.(首先先装一个Anaconda的虚拟环境,然后自己构建一个自己的虚拟环境,然后在虚拟环境中安装ju ...
最新文章
- oracle 判断11位数字,45个非常有用的 Oracle 查询语句小结
- python学习之路二
- tsf php,TSF:腾讯推出的 PHP 协程方案
- 正在使用.NET Framework 2.0 Beta 2的开发者要注意了!
- python企业微信回调_python 微信企业号-回调模式接收微信端客户端发送消息并被动返回消息...
- 搭建iscsi存储系统
- MyISAM 和 InnoDB 讲解
- 他毕业两年,博客一年,时间
- 微软开源故事 | 开启 .NET 开源革命
- 好消息,关于2005的default provider
- 宏块与宏块对(附图)
- 进制转换练习-其它进制转换为十进制
- Vue-Cli4笔记
- 建筑工程师的转行学计算机科学与技术的抉择
- linux pam鉴定令牌错误,linux – chsh:PAM身份验证失败
- 【网络】能远程电脑,但ping不通
- 三极管为什么可以放大电流?
- globalsign代码签名最新步骤
- Linux 之软连接
- android studio 56 下载网络歌曲 代码
热门文章
- DVD-Video 解谜 - VOB文件
- nodejs安装及环境配置
- 2014年7月份第3周51Aspx源码发布详情
- 韦根w34是多少位_Levi's裤子尺码中的W34和L34各是多少厘米?
- ​在商还得言商 | 【常垒·常识】
- 苹果付费app共享公众号_【苹果iOS付费游戏应用帐号分享】新增一款40元iOS游戏应用共享帐号...
- which must be escaped when used within the value
- java object比较排序
- 虹膜识别1.opencv3同心圆的提取
- 针对德尔塔等变异株!国产皮卡新冠疫苗开启临床实验;重庆成都互为人才外流第一目标城市 | 美通社头条...