1.概述

计算机视觉是当前深度学习研究最广泛、落地最成熟的技术领域,在手机拍照、智能安防、自动驾驶等场景有广泛应用。从2012年AlexNet在ImageNet比赛夺冠以来,深度学习深刻推动了计算机视觉领域的发展,当前最先进的计算机视觉算法几乎都是深度学习相关的。深度神经网络可以逐层提取图像特征,并保持局部不变性,被广泛应用于分类、检测、分割、跟踪、检索、识别、提升、重建等视觉任务中。 本次体验结合图像分类任务,介绍MindSpore如何应用于计算机视觉场景,如何训练模型,得出一个性能较优的模型。

2.图像分类

图像分类是最基础的计算机视觉应用,属于有监督学习类别。给定一张数字图像,判断图像所属的类别,如猫、狗、飞机、汽车等等。用函数来表示这个过程如下:

def classify(image):label = model(image)return label

定义的分类函数,以图片数据image为输入,通过model方法对image进行分类,最后返回分类结果。选择合适的model是关键。这里的model一般指的是深度卷积神经网络,如AlexNet、VGG、GoogLeNet、ResNet等等。
下面按照MindSpore的训练数据模型的正常步骤进行,当使用到MindSpore或者图像分类操作时,会增加相应的说明,本次体验的整体流程如下:

  1. 数据集的准备,这里使用的是CIFAR-10数据集。

  2. 构建一个卷积神经网络,这里使用ResNet-50网络。

  3. 定义损失函数和优化器。

  4. 调用Model高阶API进行训练和保存模型文件。

  5. 进行模型精度验证。

训练数据集下载

import mindspore

print(mindspore.__version__)

数据集准备

!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/cifar10.zip
!unzip -o cifar10.zip -d ./datasets
!tree ./datasets/cifar10

数据处理¶

数据集处理对于训练非常重要,好的数据集可以有效提高训练精度和效率。在加载数据集前,我们通常会对数据集进行一些处理。这里我们用到了数据增强,数据混洗和批处理。

数据增强主要是对数据进行归一化和丰富数据样本数量。常见的数据增强方式包括裁剪、翻转、色彩变化等等。MindSpore通过调用map方法在图片上执行增强操作。数据混洗和批处理主要是通过数据混洗shuffle随机打乱数据的顺序,并按batch读取数据,进行模型训练。

构建create_dataset函数,来创建数据集。通过设置 resize_heightresize_widthrescaleshift参数,定义map以及在图片上运用map实现数据增强。

import mindspore.nn as nn
from mindspore import dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore import context
import numpy as np
import matplotlib.pyplot as pltcontext.set_context(mode=context.GRAPH_MODE, device_target="GPU")def create_dataset(data_home, repeat_num=1, batch_size=32, do_train=True, device_target="GPU"):"""create data for next use such as training or inferring"""cifar_ds = ds.Cifar10Dataset(data_home,num_parallel_workers=8, shuffle=True)c_trans = []if do_train:c_trans += [C.RandomCrop((32, 32), (4, 4, 4, 4)),C.RandomHorizontalFlip(prob=0.5)]c_trans += [C.Resize((224, 224)),C.Rescale(1.0 / 255.0, 0.0),C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),C.HWC2CHW()]type_cast_op = C2.TypeCast(mstype.int32)cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=8)cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)cifar_ds = cifar_ds.repeat(repeat_num)return cifar_dsds_train_path = "./datasets/cifar10/train/"
dataset_show = create_dataset(ds_train_path)
with open(ds_train_path+"batches.meta.txt","r",encoding="utf-8") as f:all_name = [name.replace("\n","") for name in f.readlines()]iterator_show= dataset_show.create_dict_iterator()
dict_data = next(iterator_show)
images = dict_data["image"].asnumpy()
labels = dict_data["label"].asnumpy()
count = 1
%matplotlib inline
for i in images:plt.subplot(4, 8, count)# Images[0].shape is (3,224,224).We need transpose as (224,224,3) for using in plt.show().picture_show = np.transpose(i,(1,2,0))picture_show = picture_show/np.amax(picture_show)picture_show = np.clip(picture_show, 0, 1)plt.title(all_name[labels[count-1]])picture_show = np.array(picture_show,np.float32)plt.imshow(picture_show)count += 1plt.axis("off")print("The dataset size is:", dataset_show.get_dataset_size())
print("The batch tensor is:",images.shape)
plt.show()

定义卷积神经网络

卷积神经网络已经是图像分类任务的标准算法了。卷积神经网络采用分层的结构对图片进行特征提取,由一系列的网络层堆叠而成,比如卷积层、池化层、激活层等等。 ResNet-50通常是较好的选择。首先,它足够深,常见的有34层,50层,101层。通常层次越深,表征能力越强,分类准确率越高。其次,可学习,采用了残差结构,通过shortcut连接把低层直接跟高层相连,解决了反向传播过程中因为网络太深造成的梯度消失问题。此外,ResNet-50网络的性能很好,既表现为识别的准确率,也包括它本身模型的大小和参数量。

下载构建好的resnet50网络源码文件。

!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/source-codes/resnet.py

--2022-08-15 20:04:35--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/source-codes/resnet.py
Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172
Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.
Proxy request sent, awaiting response... 304 Not Modified
File ‘resnet.py’ not modified on server. Omitting download.

下载下来的resnet.py在当前目录,可以使用import方法将resnet50网络导出。

from resnet import resnet50
net = resnet50(batch_size=32, num_classes=10)

定义损失函数和优化器

接下来需要定义损失函数(Loss)和优化器(Optimizer)。损失函数是深度学习的训练目标,也叫目标函数,可以理解为神经网络的输出(Logits)和标签(Labels)之间的距离,是一个标量数据。 常见的损失函数包括均方误差、L2损失、Hinge损失、交叉熵等等。图像分类应用通常采用交叉熵损失(CrossEntropy)。 优化器用于神经网络求解(训练)。由于神经网络参数规模庞大,无法直接求解,因而深度学习中采用随机梯度下降算法(SGD)及其改进算法进行求解。MindSpore封装了常见的优化器,如SGD、ADAM、Momemtum等等。本例采用Momentum优化器,通常需要设定两个参数,动量(moment)和权重衰减项(weight decay)。

通过调用MindSpore中的API:MomentumSoftmaxCrossEntropyWithLogits,设置损失函数和优化器的参数。

import mindspore.nn as nn

from mindspore.nn import SoftmaxCrossEntropyWithLogits
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)

调用Model高阶API进行训练和保存模型文件

完成数据预处理、网络定义、损失函数和优化器定义之后,就可以进行模型训练了。模型训练包含两层迭代,数据集的多轮迭代(epoch)和一轮数据集内按分组(batch)大小进行的单步迭代。其中,单步迭代指的是按分组从数据集中抽取数据,输入到网络中计算得到损失函数,然后通过优化器计算和更新训练参数的梯度。

为了简化训练过程,MindSpore封装了Model高阶接口。用户输入网络、损失函数和优化器完成Model的初始化,然后调用train接口进行训练,train接口参数包括迭代次数epoch和数据集dataset

模型保存是对训练参数进行持久化的过程。Model类中通过回调函数的方式进行模型保存,如下面代码所示。用户通过CheckpointConfig设置回调函数的参数,其中,save_checkpoint_steps指每经过固定的单步迭代次数保存一次模型,keep_checkpoint_max指最多保存的模型个数。

本次体验选择epoch_size为10,一共迭代了10次,大约耗时25分钟,得到如下的运行结果。体验者可以自行设置不同的epoch_size,生成不同的模型,在下面的验证部分查看模型精确度。

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore import load_checkpoint, load_param_into_net
import os
from mindspore import Modelmodel = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
# As for train, users could use model.trainepoch_size = 10
ds_train_path = "./datasets/cifar10/train/"
model_path = "./models/ckpt/mindspore_vision_application/"
os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path))dataset = create_dataset(ds_train_path )
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory=model_path, config=config_ck)
loss_cb = LossMonitor(142)
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])

进行模型精度验证¶

调用model.eval得到最终精度超过0.80,准确度较高,验证得出模型是性能较优的。

# As for evaluation, users could use model.eval
ds_eval_path = "./datasets/cifar10/test/"
eval_dataset = create_dataset(ds_eval_path, do_train=False)
res = model.eval(eval_dataset)
print("result: ", res)

过程及结果

Resnet实现CIFAR-10图像分类相关推荐

  1. 基于SVM的思想做CIFAR 10图像分类

    #SVM 回顾一下之前的SVM,找到一个间隔最大的函数,使得正负样本离该函数是最远的,是否最远不是看哪个点离函数最远,而是找到一个离函数最近的点看他是不是和该分割函数离的最近的. 使用large ma ...

  2. 深度学习入门——利用卷积神经网络训练CIFAR—10数据集

    CIFAR-10数据集简介 CIFAR-10是由Hinton的学生Alex Krizhevsky和Ilya Sutskever整理的一个用于普适物体的小型数据集.它一共包含10个类别的RGB彩色图片: ...

  3. CIFAR彩色图像分类数据集

    一.CIFAR数据集介绍 1.1 CIFAR-10 数据集图像个数:60000张彩色图像:其中Train sets:50000:Test sets:10000,(测试批的数据里,取自10类中的每一类, ...

  4. cifar 10 最高正确率

    http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html 这个网站里有MNIST数据集 ...

  5. 【今日CV 计算机视觉论文速览 第111期】Fri, 3 May 2019

    今日CS.CV 计算机视觉论文速览 Fri, 3 May 2019 Totally 29 papers ?上期速览✈更多精彩请移步主页 Interesting: ?****Single Image P ...

  6. 元学习、迁移学习、对比学习、自监督学习与少样本学习的关系解读

    文章目录 前言 一.对比自监督学习与FSL 1.对比学习与自监督学习 2.自监督学习与FSL 二.元学习与FSL 1.元学习是什么 2.元学习与FSL 三.迁移学习与FSL 1.迁移学习 2.迁移学习 ...

  7. 「图像分类」从数据集和经典网络开始

    https://www.toutiao.com/i6715367170378826248/ 欢迎大家来到图像分类专栏,本篇简单介绍数据集和图像分类中的经典网络的进展. 作者 | 郭冰洋 编辑 言有三 ...

  8. 【图像分类】从数据集和经典网络开始

    欢迎大家来到图像分类专栏,本篇简单介绍数据集和图像分类中的经典网络的进展. 作者 | 郭冰洋 编辑  言有三 1 简介 一场完美的交响乐演出,指挥家需要充分结合每位演奏者和乐器的特点,根据演奏曲目把控 ...

  9. 【今日CV 计算机视觉论文速览 第128期】Mon, 10 Jun 2019

    今日CS.CV 计算机视觉论文速览 Mon, 10 Jun 2019 Totally 38 papers ?上期速览 ✈更多精彩请移步主页 Interesting: ?遮挡区域语义分割, 研究人员将语 ...

  10. 转载 深度学习---ResNet

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/qq_38906523/article/details/80098268 一. ResNet在2015 ...

最新文章

  1. UI调试神器 for ios:Reveal的使用与破解
  2. 牛客c语言数组,牛客网学习笔记 - C/C++
  3. 递归下降分析法(编译原理)
  4. 修正memcache.php中的错误示例
  5. Venkat 演讲翻译:你要清除代码中的异味
  6. 去医院看病如何开开心心出来? | 今日最佳
  7. c语言 将点同时保证x坐标从小到大,y坐标从小到大地排序,C语言第五六次作业.ppt...
  8. 【LeetCode】7. Reverse Integer
  9. 盘绕过苹果id方法_如何更换苹果ID?
  10. 解决:Eclipse SVN一直要求输出登陆密码
  11. 智慧环卫车辆监控管理系统方案
  12. 手游游戏资源提取 (破解、AssetStudio、VGMToolbox、disunity、Il2CppDumper、 .NET Reflector)...
  13. 光纤资料大全之光纤分类
  14. D. Binary Spiders(思维+字典树)
  15. 自定义流式布局的代码实现
  16. 如何开一场高效的迭代排期会 | 敏捷开发落地指南
  17. linux ip1180,canon ip1180驱动下载
  18. Zookeeper集群一致性原理(强一致性)
  19. 取反!和按位取反~的区别
  20. 关于人生中的第一篇博客

热门文章

  1. OSChina 周三乱弹 —— 泰迪转发了你的这条动弹
  2. 西门子博图功能指令(序列化)
  3. PyTorch:数据加载,数学原理,猫鱼分类,CNN,预训练,迁移学习
  4. 鸢尾花案例增加K值调优
  5. Java——BorderLayout(边界布局)
  6. 微信小程序实现分享至朋友圈的功能
  7. mac冒险游戏:死亡细胞Dead Cells 中文版
  8. mac 安装javaJDK教程
  9. Python-pyc文件
  10. python-》基于opencv2通过图片视觉处理+android adb tools 实现QQ自动点赞