首先我们要先了解深度学习的概念和AI计算框架的角色(https://zhuanlan.zhihu.com/p/463019160),本篇文章将演示怎么利用MindSpore来训练一个AI模型。和上一章的场景一致,我们要训练的模型是用来对手写数字图片进行分类的LeNet5模型

请参考(http://yann.lecun.com/exdb/lenet/)。

图1 MindSpore使用流程

安装MindSpore

MindSpore提供给用户使用的是Python接口(什么是Python,请参考:

https://zhuanlan.zhihu.com/p/462756985),所以我们首先需要安装MindSpore的whl包,安装之后就可以导入(import)MindSpore提供的方法接口了。安装whl包有两种方式:

方式一:进入MindSpore官网,根据自己的设备和Python版本选择安装命令。比如我的Python版本是3.7.5,我的设备是笔记本(CPU),那么我就复制下图红框中的命令进行安装:

图2 MindSpore安装界面

安装过程如下:

图3 MindSpore安装过程

注意:由于MindSpore还依赖于其他的Python三方库,所以在安装过程中,系统还会自动下载、安装其他的Python三方库,如numpy、pillow、scipy等等,安装结束后,如果能 import mindspore 成功,说明MindSpore安装成功了:

图4 MindSpore安装成功

方式二:可以在版本列表中找到对应的whl包,点击就能下载:

图5 MindSpore版本下载列表

下载完成后,把whl包放到自己的目录下,执行 pip install xxx.whl:

图6 MindSpore第二种安装方式

定义模型

安装好MindSpore之后,我们就可以导入MindSpore提供的算子(卷积、全连接、池化等函数:https://zhuanlan.zhihu.com/p/463019160)来构建我们的模型了。可以这么比喻:我们构建一个AI模型就像建一个房子,而MindSpore提供给我们的算子就像是砖块、窗户、地板等基本组件。

图7 定义LeNet5模型

如上图所示,我们用到的“砖块”都是mindspore.nn模块提供的。注意:这里用到了Python的类(class),由②和③两部分组成。我们这里定义的类是class LeNet5,它由初始化函数 __init__(self) 和构造函数construct(self, x)组成。初始化函数定义了我们构造模型所需要用到的算子,比如conv算子、relu算子、flatten算子等等,这些算子都是从mindspore.nn获取的;构造函数就是把我们在初始化函数中导入的算子按顺序排放,构成我们最终的模型。construct()函数的输入就是我们这个模型预测的对象,比如第一章讲的黑白图片像素矩阵;而“return y”中的就是预测的结果,对应于第一章讲到的10分类手写数字数据集,就是一个行10列的数组(这里的是指输入图片的数量,AI模型支持多张图片同时推理)。

导入训练数据集

什么是训练数据集?刚刚定义好的模型是不能对图片进行正确分类的,我们要通过“训练”过程来调整模型的参数矩阵的值。训练过程就需要用到训练样本,也就是打上了正确标签的图片。这就好比我们教小孩儿认识动物,需要拿几张图片给他们看,然后告诉他们这是什么、那是什么,教了几遍之后,小孩儿就能认识了。那么我们训练LeNet5模型就需要用到MNIST数据集,请参考(http://yann.lecun.com/exdb/mnist/)。这个数据集由两部分组成:训练集(6万张图片)和测试集(1万张图片),都是0~9的黑白手写数字图片。训练集是用来训练AI模型的,测试集是用来测试训练后的模型分类准确率的。

下载得到的数据集最初是压缩文件,还不能直接传给MindSpore的训练接口使用,我们要先用MindSpore提供的数据处理接口把他们读进来:

import mindspore.dataset as ds
mnist_ds = ds.MnistDataset(data_path)  # 导入下载的MNIST数据集

然后进行数据增强(比如把图片大小转化成相同的尺寸、像素值标准化、归一化等操作),提升训练效率:

import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype# 定义数据增强函数
def create_dataset(data_path, batch_size=32):  # batch_size是每一步训练使用的图片数量,一般取32"""create dataset for train or testArgs:data_path (str): Data pathbatch_size (int): The number of data records in each group"""# define datasetmnist_ds = ds.MnistDataset(data_path)  # 导入下载的MNIST数据集# define some parameters needed for data enhancement and rough justificationresize_height, resize_width = 32, 32rescale = 1.0 / 255.0shift = 0.0rescale_nml = 1 / 0.3081shift_nml = -1 * 0.1307 / 0.3081# according to the parameters, generate the corresponding data enhancement methodresize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)rescale_op = CV.Rescale(rescale, shift)hwc2chw_op = CV.HWC2CHW()type_cast_op = C.TypeCast(mstype.int32)# using map to apply operations to a datasetmnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label")mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image")mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image")mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image")mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image")# process the generated datasetbuffer_size = 10000mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)return mnist_ds

训练模型

训练数据集和模型定义完成之后呢,我们就可以开始训练模型了。但是在训练之前,我们还需要从MindSpore导入两个函数:

  • 损失函数,也就是衡量预测结果和真实标签之间的差距的函数。看过上一章的同学可能会记得,我们之前用的损失函数是真实值与预测值之差的2-范数:

图8 2-范数损失

在这里,我们使用业界最常用的交叉熵损失函数SoftmaxCrossEntropyWithLogits,对于真实标签

和预测值,它们之间的交叉熵损失计算公式为:

其中J代表数组的下标,。从MindSpore导入损失函数:

from mindspore.nn import SoftmaxCrossEntropyWithLogits
# define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 
  • 优化器,优化器就是用来求解损失函数关于模型参数的更新梯度的,它是整个训练过程中最重要的工具!我们这里用MindSpore提供的Momentum优化器:

import mindspore.nn as nnlr = 0.01  # 定义学习率
momentum = 0.9  # 定义Momentum优化器的超参
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)  # 导入mindspore提供

准备好损失函数和优化器之后我们就可以开始训练模型了,也非常简单,我们先把前面定义好的模型、损失函数、优化器封装成一个Model:

from mindspore import Model
net = LeNet5()
model = Model(net, net_loss , net_opt , metrics={'acc', 'loss'})

然后使用model.train接口就可以训练我们定义的LeNet5模型了:

loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())  # 用于监控训练过程中损失函数值的变化
ds_train = create_dataset(train_data_dir)  # 传入下载的训练集的路径
model.train(num_epochs, ds_train, callbacks=[loss_cb])  # num_epochs是训练的轮数,往往训练多轮才能使模型收敛

测试训练后的模型准确率

训练结束后,调用model.eval()计算训练后的模型在测试集上面的分类准确率:

ds_eval = create_dataset(test_data_dir)  # 传入下载的训练集的路径
metrics = model.eval(ds_eval)

小结

祝贺你耐心看完了MindSpore训练模型的完整过程,如果你想动手操作一遍,但是又没有现成的环境,那么你可以使用官网提供的“在线运行”来体验一番:

图9 MindSpore官网提供的免费体验入口

这是体验过程的实操视频:

https://zhuanlan.zhihu.com/p/463229660

欢迎投稿

欢迎大家踊跃投稿,有想投稿技术干货、项目经验等分享的同学,可以添加MindSpore官方小助手:小猫子(mindspore0328)的微信,告诉猫哥哦!

昇思MindSpore官方交流QQ群 : 486831414群里有很多技术大咖助力答疑!

MindSpore官方资料

GitHub : https://github.com/mindspore-ai/mindspore

Gitee : https : //gitee.com/mindspore/mindspore

官方QQ群 : 486831

手把手教你用MindSpore训练一个AI模型!相关推荐

  1. 手把手教你用自己训练的AI模型玩王者荣耀

    击上方"Python爬虫与数据挖掘",进行关注 回复"书籍"即可获赠Python从入门到进阶共10本电子书 今 日 鸡 汤 浮云一别后,流水十年间. 大家好,我 ...

  2. 手把手教你搭建机器学习+深度学习AI模型

    说起现在搞什么最挣钱,10 人里 11 个都要回答人工智能! 早在几年前,华为就开出百万年薪招聘 AI 专家,当是很多人认为噱头大于实际价值.但如果今天还有谁质疑人工智能的前景,那显然已经和时代脱轨了 ...

  3. 手把手教你使用 YOLOV5 训练目标检测模型

    作者 | 肆十二 来源 | CSDN博客 这次要使用YOLOV5来训练一个口罩检测模型,比较契合当下的疫情,并且目标检测涉及到的知识点也比较多. 先来看看我们要实现的效果,我们将会通过数据来训练一个口 ...

  4. 手把手教你用Python搭建一个AI智能问答系统

    导读:智能问答系统是自然语言处理的一个重要分支.今天我们将利用分词处理以及搜索引擎搭建一个智能问答系统. 本文经授权转自公众号CSDN(ID:CSDNnews) 作者:李秋键 具体的效果如下所示: 下 ...

  5. 手把手教你用Python搭建一个AI智能问答系统!

    导读:智能问答系统是自然语言处理的一个重要分支.今天我们将利用分词处理以及搜索引擎搭建一个智能问答系统. 具体的效果如下所示: 私信小编01  领取完整代码! 下面简单了解下智能问答系统和自然语言处理 ...

  6. 手把手教你用fairseq训练一个NMT机器翻译系统

    以构建英-中NMT为例,在linux上运行,fairseq版本为0.8.0 环境准备 Requirements: fairseq:pytorch包,包括许多网络结构,https://github.co ...

  7. python做一个问答系统_手把手教你用Python搭建一个AI智能问答系统

    导读:智能问答系统是自然语言处理的一个重要分支.今天我们将利用分词处理以及搜索引擎搭建一个智能问答系统. 本文经授权转自公众号CSDN(ID:CSDNnews) 作者:李秋键 具体的效果如下所示: 下 ...

  8. python界面设计-手把手教你用Python设计一个简单的命令行界面

    原标题:手把手教你用Python设计一个简单的命令行界面 对 Python 程序来说,完备的命令行界面可以提升团队的工作效率,减少调用时可能碰到的困扰.今天,我们就来教大家如何设计功能完整的 Pyth ...

  9. 手把手教你使用TensorFlow训练出自己的模型

    手把手教你使用TensorFlow训练出自己的模型 一.前言 搭建TensorFlow开发环境一直是初学者头疼的问题,为了帮忙初学者快速使用TensorFlow框架训练出自己的模型,作者开发了一款基于 ...

最新文章

  1. python 检测文件或文件夹是否存在
  2. python字典的键可以用列表吗_python字典多键值及重复键值的使用方法(详解)
  3. 炒冷饭系列:设计模式 装饰模式
  4. 我们的2009 梦想照进了现实
  5. 如何利用C/C++逐行读取txt文件中的字符串(可以顺便实现文本文件的复制)
  6. 7.Spring Cloud Alibaba教程:整合Dubbo实现RPC调用
  7. 小微商户申请php,微信小微商户申请入驻 - osc_r8q2esik的个人空间 - OSCHINA - 中文开源技术交流社区...
  8. Ajax联动下拉框的实现例子
  9. android mvvm流程图,MVVM框架模式详解
  10. C++ 实现分块查找(顺序存储结构)(完整代码)
  11. ps分辨率像素英寸和厘米的区别_南南带你免费学习超级强大的做图软件-PS(第一章:第二节)...
  12. Mac蓝牙无法使用怎么办?教你7个修复蓝牙的技巧
  13. 【15】蓝桥杯之史丰收速算(程序填空题)
  14. 利用three建立一个3d园区
  15. ACP.敏捷概念梳理1
  16. 海康服务器装win7系统,详解win7旗舰版系统必须重装的四种情况
  17. 微信公众号的缩略图/封面图下载方法详细介绍
  18. 计算机软件知识产权保护主要保护哪些内容,计算机软件知识产权保护制度.pptx...
  19. 面试经验-(1)宁波银行
  20. 信息抽取在知识图谱构建中的实践与应用

热门文章

  1. Godaddy打不开和支付时没有支付宝选项的解决方法
  2. 物联网网关程序设计-3
  3. 1skp素材和草图溜溜是不是一样的免费?我看到了建模的乐趣!
  4. 六、Matlab 批量保存多格式图片
  5. 基本函数依赖和候选键_白话详解数据库函数依赖和Armstrong公理及其引理
  6. 允许网站使用相机和麦克风_Windows 10 相机、麦克风和隐私
  7. vs运行出现无法启动IIS Express服务器
  8. php去除最后一个空格,php去除头尾空格的2种方法
  9. 使用visio创建跨职能流程图。
  10. 别人的研究生宿舍男女混住,而我没地方住?