MindSpore是华为自研的一套AI框架,最佳匹配昇腾处理器,最大程度地发挥硬件能力。作为AI入门的LeNet手写字体识别网络,网络大小和数据集都不大,可以在CPU上面进行训练和推理。下面是基于MindSpore的LeNet手写字体识别代码,直接复制到ubuntu的Jupyter即可以运行,但是要确保安装了Mindspore包哦~

MNIST数据集需要提前准备好放在目录中。

import os
import argparse
from mindspore import context

parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'])
args = parser.parse_known_args()[0]
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

"""!mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test
!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte --no-check-certificate
!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte --no-check-certificate
!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte --no-check-certificate
!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte --no-check-certificate
!tree ./datasets/MNIST_Data
"""

import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype

def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    # 定义数据集
    mnist_ds = ds.MnistDataset(data_path)
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

# 定义所需要操作的map映射
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

# 使用map映射函数,将数据操作应用到数据集
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)

# 进行shuffle、batch操作
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)

return mnist_ds

import mindspore.nn as nn
from mindspore.common.initializer import Normal

class LeNet5(nn.Cell):
    """
    Lenet网络结构
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        # 定义所需要的运算
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

def construct(self, x):
        # 使用定义好的运算构建前向网络
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

# 实例化网络
net = LeNet5()

# 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)

def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
    """定义训练的方法"""
    # 加载训练数据集
    ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
    model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)

def test_net(network, model, data_path):
    """定义验证的方法"""
    ds_eval = create_dataset(os.path.join(data_path, "test"))
    acc = model.eval(ds_eval, dataset_sink_mode=False)
    print("{}".format(acc))

#定义超参数
from mindspore import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor

train_epoch = 2
mnist_path = "./datasets/MNIST_Data"
dataset_size = 1

config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# save the network model and parameters for subsequence fine-tuning# 设置模型保存参数
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
# group layers into an object with training and evaluation features# 应用模型保存参数
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

#定义训练、测试方法
train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint_cb, False)
test_net(net, model, mnist_path)

PS:华为昇腾CANN训练营第三期正在进行中,感兴趣的小伙伴可以报名看一下,免费参加。

报名地址:昇腾CANN训练营第三期_开发者-华为云

MindSpore实现手写数字识别代码相关推荐

  1. 【mindspore】mindspore实现手写数字识别

    mindspore实现手写数字识别 具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用 ...

  2. Pytorch入门——MNIST手写数字识别代码

    MNIST手写数字识别教程 本文仅仅放出该教程的代码 具体教程请看 Pytorch入门--手把手教你MNIST手写数字识别 import torch import torchvision from t ...

  3. 卷积神经网络mnist手写数字识别代码_搭建经典LeNet5 CNN卷积神经网络对Mnist手写数字数据识别实例与注释讲解,准确率达到97%...

    LeNet-5卷积神经网络是最经典的卷积网络之一,这篇文章就在LeNet-5的基础上加入了一些tensorflow的有趣函数,对LeNet-5做了改动,也是对一些tf函数的实例化笔记吧. 环境 Pyc ...

  4. 手写数字识别代码,可以跑通

    来源: https://github.com/caicloud/tensorflow-tutorial/tree/master/Deep_Learning_with_TensorFlow/1.0.0/ ...

  5. CVNLP基础6之手写数字识别代码体验

    文章目录 总流程(思路)预览 x是输入的图片y是图片对应的label 关于训练数据集的说明 搭建计算网络层 计算损失值loss 优化损失值loss(minimize loss) 手写数字初体验代码 代 ...

  6. Python神经网络手写数字识别代码解释

    使用了数据集MNIST中的部分数据. 1.读取数据集内容 #打开文件并获取其中的内容 data_file=open("mnist_train.csv",'r') #open()函数 ...

  7. 【人工智能实验室】第三次培训之手写数字识别代码理解

    感觉把每一行代码都理解过去特别爽!!! minist_train.py import torch from torch import nn from torch.nn import functiona ...

  8. MindSpore实现手写数字识别

    具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用. 数据的流水线处理 defdata ...

  9. 手写数字识别代码函数解读(MATLAB实现)

    1.tf = strcmp(s1,s2) 输入参数可以是字符串数组.字符向量和字符向量元胞数组的任何组合 比较 s1 和 s2,如果二者相同,则返回 1 (true),否则返回 0 (false).如 ...

最新文章

  1. 《等离子体所毕业生经验分享会》观后感 2020-07-03
  2. PHP微信支付没有收到微信的回调怎么修改订单状态:主动查询
  3. 【面经】蚂蚁金服一二三面的面经总结(内推实习方面)
  4. h.264视频文件封装
  5. 甲骨文中国裁员已定,补偿为N+6;VMware联手云平台合作伙伴AsiaPac,闪耀狮城;对标英伟达,寒武纪新货曝光……...
  6. 【C++】【一日一练】读写文件小实例【20140510】
  7. java 托盘开发_java托盘开发界面记录
  8. php flush 逐行显示_PHP逐行输出(ob_flush与flush的组合)
  9. 后期处理之一:雾蒙蒙风景照片处理技巧
  10. PaddlePaddle常用镜像
  11. 【Blender】UV贴图相关学习
  12. 4本建模必读的书籍,每天学一点,获益匪浅
  13. 深井泵房无人值守系统 泵站无人值守平台 智慧水务
  14. 【leetcode刷题】找到需补充粉笔的学生编号
  15. 了解什么是枚举(enumeration)
  16. 图标右上角的数字小圆圈 如图 在tabBarController中设置
  17. 跟着老陈学嵌入式-C语言入门之类Linux编译环境搭建
  18. Xilinx 7系列 FPGA CLB资源介绍
  19. 【日常学习】【二分】【单调队列优化线性DP】codevs3342 绿色通道题解
  20. @ControllerAdvice 用法

热门文章

  1. XenServer虚拟化—介绍、部署、测试
  2. 明修栈道,暗渡陈仓----之私募一哥徐翔新玩法 z
  3. 日语动词活用之连用形
  4. PDF编辑器怎么使用?PDF编辑器的操作方法
  5. 禅宗思想追求以有为求无为
  6. 注册宝网络验证系统,安全免费的网络验证系统
  7. 一个网工获得CCNP认证后的成功求职记
  8. Android实时模糊
  9. Zemax 2023安装教程
  10. Qt按键值与Windows Virtual-Key Codes映射表