MindSpore实现手写数字识别代码
MindSpore是华为自研的一套AI框架,最佳匹配昇腾处理器,最大程度地发挥硬件能力。作为AI入门的LeNet手写字体识别网络,网络大小和数据集都不大,可以在CPU上面进行训练和推理。下面是基于MindSpore的LeNet手写字体识别代码,直接复制到ubuntu的Jupyter即可以运行,但是要确保安装了Mindspore包哦~
MNIST数据集需要提前准备好放在目录中。
import os
import argparse
from mindspore import context
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'])
args = parser.parse_known_args()[0]
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
"""!mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test
!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte --no-check-certificate
!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte --no-check-certificate
!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte --no-check-certificate
!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte --no-check-certificate
!tree ./datasets/MNIST_Data
"""
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# 进行shuffle、batch操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds
import mindspore.nn as nn
from mindspore.common.initializer import Normal
class LeNet5(nn.Cell):
"""
Lenet网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# 实例化网络
net = LeNet5()
# 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)
def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc))
#定义超参数
from mindspore import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
train_epoch = 2
mnist_path = "./datasets/MNIST_Data"
dataset_size = 1
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# save the network model and parameters for subsequence fine-tuning# 设置模型保存参数
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
# group layers into an object with training and evaluation features# 应用模型保存参数
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
#定义训练、测试方法
train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint_cb, False)
test_net(net, model, mnist_path)
PS:华为昇腾CANN训练营第三期正在进行中,感兴趣的小伙伴可以报名看一下,免费参加。
报名地址:昇腾CANN训练营第三期_开发者-华为云
MindSpore实现手写数字识别代码相关推荐
- 【mindspore】mindspore实现手写数字识别
mindspore实现手写数字识别 具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用 ...
- Pytorch入门——MNIST手写数字识别代码
MNIST手写数字识别教程 本文仅仅放出该教程的代码 具体教程请看 Pytorch入门--手把手教你MNIST手写数字识别 import torch import torchvision from t ...
- 卷积神经网络mnist手写数字识别代码_搭建经典LeNet5 CNN卷积神经网络对Mnist手写数字数据识别实例与注释讲解,准确率达到97%...
LeNet-5卷积神经网络是最经典的卷积网络之一,这篇文章就在LeNet-5的基础上加入了一些tensorflow的有趣函数,对LeNet-5做了改动,也是对一些tf函数的实例化笔记吧. 环境 Pyc ...
- 手写数字识别代码,可以跑通
来源: https://github.com/caicloud/tensorflow-tutorial/tree/master/Deep_Learning_with_TensorFlow/1.0.0/ ...
- CVNLP基础6之手写数字识别代码体验
文章目录 总流程(思路)预览 x是输入的图片y是图片对应的label 关于训练数据集的说明 搭建计算网络层 计算损失值loss 优化损失值loss(minimize loss) 手写数字初体验代码 代 ...
- Python神经网络手写数字识别代码解释
使用了数据集MNIST中的部分数据. 1.读取数据集内容 #打开文件并获取其中的内容 data_file=open("mnist_train.csv",'r') #open()函数 ...
- 【人工智能实验室】第三次培训之手写数字识别代码理解
感觉把每一行代码都理解过去特别爽!!! minist_train.py import torch from torch import nn from torch.nn import functiona ...
- MindSpore实现手写数字识别
具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用. 数据的流水线处理 defdata ...
- 手写数字识别代码函数解读(MATLAB实现)
1.tf = strcmp(s1,s2) 输入参数可以是字符串数组.字符向量和字符向量元胞数组的任何组合 比较 s1 和 s2,如果二者相同,则返回 1 (true),否则返回 0 (false).如 ...
最新文章
- 《等离子体所毕业生经验分享会》观后感 2020-07-03
- PHP微信支付没有收到微信的回调怎么修改订单状态:主动查询
- 【面经】蚂蚁金服一二三面的面经总结(内推实习方面)
- h.264视频文件封装
- 甲骨文中国裁员已定,补偿为N+6;VMware联手云平台合作伙伴AsiaPac,闪耀狮城;对标英伟达,寒武纪新货曝光……...
- 【C++】【一日一练】读写文件小实例【20140510】
- java 托盘开发_java托盘开发界面记录
- php flush 逐行显示_PHP逐行输出(ob_flush与flush的组合)
- 后期处理之一:雾蒙蒙风景照片处理技巧
- PaddlePaddle常用镜像
- 【Blender】UV贴图相关学习
- 4本建模必读的书籍,每天学一点,获益匪浅
- 深井泵房无人值守系统 泵站无人值守平台 智慧水务
- 【leetcode刷题】找到需补充粉笔的学生编号
- 了解什么是枚举(enumeration)
- 图标右上角的数字小圆圈 如图 在tabBarController中设置
- 跟着老陈学嵌入式-C语言入门之类Linux编译环境搭建
- Xilinx 7系列 FPGA CLB资源介绍
- 【日常学习】【二分】【单调队列优化线性DP】codevs3342 绿色通道题解
- @ControllerAdvice 用法