Mindspore实现手写字体识别

一、实验目的

加深对神经网络原理的理解
熟悉Minspore平台
掌握训练过程

二、实验环境

Windows + Python3+
一台装有集成开发环境(IDE)—— PyCharm的计算机

三、实验内容

1.下载数据集放置目录如下

四、代码填写

#encoding=utf-8
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import mindspore.dataset as ds
train_data_path = r"\datasets\MNIST_Data\train"
test_data_path = r"\datasets\MNIST_Data\test"
mnist_ds = ds.MnistDataset(train_data_path)#加载数据集
print('The type of mnist_ds:', type(mnist_ds))
print("Number of pictures contained in the mnist_ds:",mnist_ds.get_dataset_size())
#迭代器读取数据
dic_ds = mnist_ds.create_dict_iterator()
item = next(dic_ds)
img = item["image"].asnumpy()
label = item["label"].asnumpy()
#打印数据集信息 并可视化
print("The item of mnist_ds:", item.keys())
print("Tensor of image in item:", img.shape)
print("The label of item:", label)
plt.imshow(np.squeeze(img))
plt.title("number:%s"% item["label"].asnumpy())
plt.show()
"""
-------定义dataset(dataloader)-----
"""
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstypedef create_dataset(data_path, batch_size=32, repeat_size=1,num_parallel_workers=1):#调用API读取MNIST数据集合mnist_ds = ds.MnistDataset(data_path)
"""
-------对数据增强-----
"""resize_height, resize_width = 32, 32rescale = 1.0 / 255.0shift = 0.0rescale_nml = 1 / 0.3081shift_nml = -1 * 0.1307 / 0.3081#根据上面设置的参数阐释增强数据过程resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)rescale_op = CV.Rescale(rescale, shift)hwc2chw_op = CV.HWC2CHW()type_cast_op = C.TypeCast(mstype.int32)#使用map函数对数据集进行操作mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image",num_parallel_workers=num_parallel_workers)# 设置数据读取,比如是否随机,批次量多少,数据量加倍buffer_size = 10000mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds#初始化dataset并查看内容
ms_dataset = create_dataset(train_data_path)
print('Number of groups in the dataset:', ms_dataset.get_dataset_size())"""
-------利用next获取样本并查看单个样本格式------
"""
data =next(ms_dataset.create_dict_iterator(output_numpy=True))#填写
images = data['image']#填写
labels =data['label']#填写
print('Tensor of image:', images.shape)
print('Labels:', labels)"""
-------可视化数据集------
"""
count = 1
for i in images:plt.subplot(4, 8, count)plt.imshow(np.squeeze(i))plt.title('num:%s'%labels[count-1])plt.xticks([])count += 1plt.axis("off")
plt.show()
"""
-------定义LeNet5模型-----
"""
import mindspore.nn as nn
from mindspore.common.initializer import Normalclass LeNet5(nn.Cell):"""Lenet network structure."""# define the operator requireddef __init__(self, num_class=10, num_channel=1):super(LeNet5, self).__init__()self.conv1=nn.Conv2d(num_channel,6,5,pad_mode='valid')
self.conv2=nn.Conv2d(6,16,5,pad_mode='valid')
self.fc1=nn.Dense(16*5*5,120,weight_init=Normal(0.02))
self.fc2=nn.Dense(120,84,weight_init=Normal(0.02))
self.fc3=nn.Dense(84,num_class,weight_init=Normal(0.02))
self.relu=nn.ReLU()
self.max_pool2d=nn.MaxPool2d(kernel_size=2,stride=2)
self.flatten=nn.Flatten()# use the preceding operators to construct networks
def construct(self, x):x=self.max_pool2d(self.relu(self.conv1(x)))
x=self.max_pool2d(self.relu(self.conv2(x)))
x=self.flatten(x)
x=self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x=self.fc3(x)return x
network = LeNet5()"""
-------定义CALLBACK函数-----
"""
from mindspore.train.callback import Callback
#自定义CallBlack函数
# 记录损失和精度
class StepLossAccInfo(Callback):def __init__(self, model, eval_dataset, steps_loss, steps_eval):self.model = modelself.eval_dataset = eval_datasetself.steps_loss = steps_lossself.steps_eval = steps_evaldef step_end(self, run_context):cb_params = run_context.original_args()cur_epoch = cb_params.cur_epoch_numcur_step = (cur_epoch-1)*1875 + cb_params.cur_step_numself.steps_loss["loss_value"].append(str(cb_params.net_outputs))self.steps_loss["step"].append(str(cur_step))if cur_step % 125 == 0:acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)self.steps_eval["step"].append(cur_step)self.steps_eval["acc"].append(acc["Accuracy"])"""
-------开始训练-----
"""
import os
from mindspore import Tensor, Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.nn import Accuracy
network = LeNet5()
epoch_size = 1
momentum=0.9
lr=0.01
mnist_path =r"\datasets\MNIST_Data" #这里填写你的数据集路径
model_path =r"\datasets\models\ckpt\mindspore_quick_start"#模型保存路径
train_data_path = r"\datasets\MNIST_Data\train"
test_data_path = r"\datasets\MNIST_Data\test"
net_loss=SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean')
net_opt=nn.Momentum(network.trainable_params(),lr,momentum)repeat_size = 1
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
eval_dataset = create_dataset(os.path.join(mnist_path, "test"), 32)
# 使用Model定义模型,这个模型包括损失函数,优化器,网络结构,
model =Model(network,net_loss,net_opt,metrics={'Accuracy':Accuracy()})#填写
# 保存模型和参数
config_ck =CheckpointConfig(save_checkpoint_steps=375,keep_checkpoint_max=16)#使用Checkpoint设置保存模型
ckpoint_cb =ModelCheckpoint(prefix="checkpoint_lenet",directory=model_path,config=config_ck)#使用ModelCheckpoint设置保存模型的名称地址等信息
steps_loss = {"step": [], "loss_value": []}
steps_eval = {"step": [], "acc": []}
# 保存每一步step,以及对应的损失和准确率信息
step_loss_acc_info = StepLossAccInfo(model,eval_dataset,steps_loss,steps_eval)#使用StepLossAccInfo类
#填写训练模型
model.train(epoch_size,ds_train,callbacks=[ckpoint_cb,LossMonitor(125),step_loss_acc_info],dataset_sink_mode=False)
"""
-------打印想训练过程-----
"""steps = steps_loss["step"]
loss_value = steps_loss["loss_value"]
steps = list(map(int, steps))
loss_value = list(map(float, loss_value))
plt.plot(steps, loss_value, color="red")
plt.xlabel("Steps")
plt.ylabel("Loss_value")
plt.title("Change chart of model loss value")
plt.show()
"""
------在测试集上验证模型-----
"""
from mindspore import load_checkpoint, load_param_into_net
#定义验证函数
def test_net(network, model, mnist_path):print("============== Starting Testing ==============")#填写  加载保存的模型param_dict = load_checkpoint(mnist_path)#填写
load_param_into_net(network,param_dict)
ds_eval =rd.create_dataset(os.path.join(mnist_path,"test")) #填写  创建测试集dataloader
acc =model.eval(ds_eval,dataset_sink_mode=False)#填写  输入模型获取精度print("============== Accuracy:{} ==============".format(acc))
test_net(network, model, mnist_path)

五、实验结果
读取数据集

数据集测试查看

数据集训练

预测

苏苏

Mindspore实现手写字体识别相关推荐

  1. MindSpore实现手写数字识别代码

    MindSpore是华为自研的一套AI框架,最佳匹配昇腾处理器,最大程度地发挥硬件能力.作为AI入门的LeNet手写字体识别网络,网络大小和数据集都不大,可以在CPU上面进行训练和推理.下面是基于Mi ...

  2. pytorch CNN手写字体识别

    ## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...

  3. 第六讲 Keras实现手写字体识别分类

    一 本节课程介绍 1.1 知识点 1.图像识别分类相关介绍: 2.Mnist手写数据集介绍: 3.标准化数据预处理: 4.实验手写字体识别 二 课程内容 2.1 图像识别分类基本介绍 计算机的图像识别 ...

  4. Android Studio编写一个手写字体识别程序

    1.activity_main.xml 的代码 <?xml version="1.0" encoding="utf-8"?> <LinearL ...

  5. 【mindspore】mindspore实现手写数字识别

    mindspore实现手写数字识别 具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用 ...

  6. 人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)

    人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架) 直接上图,项目效果 1 ...

  7. python手写字体程序_深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是"hello world",在深度学习领域,这个"hello world"程序就是手写字体识别程序. 这次我们详细的分析下手 ...

  8. pytorch rnn 实现手写字体识别

    pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...

  9. 《MATLAB 神经网络43个案例分析》:第19章 基于SVM的手写字体识别

    <MATLAB 神经网络43个案例分析>:第19章 基于SVM的手写字体识别 1. 前言 2. MATLAB 仿真示例 3. 小结 1. 前言 <MATLAB 神经网络43个案例分析 ...

  10. 手写字体识别 --MNIST数据集

    Matlab 手写字体识别 忙过这段时间后,对于上次读取的Matlab内部数据实现的识别,我回味了一番,觉得那个实在太小.所以打算把数据换成[MNIST数据集][1]. 基础思想还是相同的,使用Tre ...

最新文章

  1. C#结构体中数组的分配
  2. STM32F10x_StdPeriph_Lib_V3.5.0库与系统滴答定时器(Systick)
  3. 机器人编程与python语言的区别_儿童编程和机器人编程有啥区别?
  4. Qt简单的解析Json数据例子(一)
  5. 详解 URLLC 前世今生,你 Get 了吗?
  6. ad导入pcd后网络标号消失_如何将后端BaaS化:业务逻辑的拆与合
  7. vue3.0项目创建
  8. Day02 目录和文件的管理(ADMIN02)
  9. 浅谈静态方法与静态变量
  10. [Offer收割]编程练习赛48
  11. scala ip转换器
  12. 线性代数学习笔记——矩阵主要公式
  13. iBase4J项目笔记
  14. KEIL4文件无法正常使用
  15. String的getBytes()方法
  16. org.postgresql.util.PSQLException: ERROR: column loginid of relation userinfo does not exist
  17. python爬取京东商品价格走势_用python编写的抓京东商品价格的爬虫
  18. 【刘晓燕长难句分析】2.并列句
  19. 2018年支付行业回顾
  20. 数学在计算机方面的应用论文参考文献,数学论文参考文献

热门文章

  1. 北京市通州区谷歌卫星地图下载
  2. JavaScript 学习-42.jQuery 提交表单 submit() 方法
  3. 基于SSM的企业人事人员管理系统
  4. 爬虫中国天气网数据并可视化
  5. Matlab经纬度坐标转换xy坐标,经纬度坐标系转换为UTM坐标系(matlab)
  6. c#中 utm坐标转换经纬度坐标
  7. html怎么设置华文行楷,css如何修改字体为华文行楷
  8. 拜耳2020年10个新植保制剂商业化,3个生物技术性状项目推进至上市阶段
  9. Python怎么安装jieba库?
  10. 陕西2020行政区划调整_陕西行政区划调整畅想:西安咸阳合并可行,但成立直辖市不太现实...