Mindspore实现手写字体识别
Mindspore实现手写字体识别
一、实验目的
加深对神经网络原理的理解
熟悉Minspore平台
掌握训练过程
二、实验环境
Windows + Python3+
一台装有集成开发环境(IDE)—— PyCharm的计算机
三、实验内容
1.下载数据集放置目录如下
四、代码填写
#encoding=utf-8
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import mindspore.dataset as ds
train_data_path = r"\datasets\MNIST_Data\train"
test_data_path = r"\datasets\MNIST_Data\test"
mnist_ds = ds.MnistDataset(train_data_path)#加载数据集
print('The type of mnist_ds:', type(mnist_ds))
print("Number of pictures contained in the mnist_ds:",mnist_ds.get_dataset_size())
#迭代器读取数据
dic_ds = mnist_ds.create_dict_iterator()
item = next(dic_ds)
img = item["image"].asnumpy()
label = item["label"].asnumpy()
#打印数据集信息 并可视化
print("The item of mnist_ds:", item.keys())
print("Tensor of image in item:", img.shape)
print("The label of item:", label)
plt.imshow(np.squeeze(img))
plt.title("number:%s"% item["label"].asnumpy())
plt.show()
"""
-------定义dataset(dataloader)-----
"""
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstypedef create_dataset(data_path, batch_size=32, repeat_size=1,num_parallel_workers=1):#调用API读取MNIST数据集合mnist_ds = ds.MnistDataset(data_path)
"""
-------对数据增强-----
"""resize_height, resize_width = 32, 32rescale = 1.0 / 255.0shift = 0.0rescale_nml = 1 / 0.3081shift_nml = -1 * 0.1307 / 0.3081#根据上面设置的参数阐释增强数据过程resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)rescale_op = CV.Rescale(rescale, shift)hwc2chw_op = CV.HWC2CHW()type_cast_op = C.TypeCast(mstype.int32)#使用map函数对数据集进行操作mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image",num_parallel_workers=num_parallel_workers)# 设置数据读取,比如是否随机,批次量多少,数据量加倍buffer_size = 10000mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds#初始化dataset并查看内容
ms_dataset = create_dataset(train_data_path)
print('Number of groups in the dataset:', ms_dataset.get_dataset_size())"""
-------利用next获取样本并查看单个样本格式------
"""
data =next(ms_dataset.create_dict_iterator(output_numpy=True))#填写
images = data['image']#填写
labels =data['label']#填写
print('Tensor of image:', images.shape)
print('Labels:', labels)"""
-------可视化数据集------
"""
count = 1
for i in images:plt.subplot(4, 8, count)plt.imshow(np.squeeze(i))plt.title('num:%s'%labels[count-1])plt.xticks([])count += 1plt.axis("off")
plt.show()
"""
-------定义LeNet5模型-----
"""
import mindspore.nn as nn
from mindspore.common.initializer import Normalclass LeNet5(nn.Cell):"""Lenet network structure."""# define the operator requireddef __init__(self, num_class=10, num_channel=1):super(LeNet5, self).__init__()self.conv1=nn.Conv2d(num_channel,6,5,pad_mode='valid')
self.conv2=nn.Conv2d(6,16,5,pad_mode='valid')
self.fc1=nn.Dense(16*5*5,120,weight_init=Normal(0.02))
self.fc2=nn.Dense(120,84,weight_init=Normal(0.02))
self.fc3=nn.Dense(84,num_class,weight_init=Normal(0.02))
self.relu=nn.ReLU()
self.max_pool2d=nn.MaxPool2d(kernel_size=2,stride=2)
self.flatten=nn.Flatten()# use the preceding operators to construct networks
def construct(self, x):x=self.max_pool2d(self.relu(self.conv1(x)))
x=self.max_pool2d(self.relu(self.conv2(x)))
x=self.flatten(x)
x=self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x=self.fc3(x)return x
network = LeNet5()"""
-------定义CALLBACK函数-----
"""
from mindspore.train.callback import Callback
#自定义CallBlack函数
# 记录损失和精度
class StepLossAccInfo(Callback):def __init__(self, model, eval_dataset, steps_loss, steps_eval):self.model = modelself.eval_dataset = eval_datasetself.steps_loss = steps_lossself.steps_eval = steps_evaldef step_end(self, run_context):cb_params = run_context.original_args()cur_epoch = cb_params.cur_epoch_numcur_step = (cur_epoch-1)*1875 + cb_params.cur_step_numself.steps_loss["loss_value"].append(str(cb_params.net_outputs))self.steps_loss["step"].append(str(cur_step))if cur_step % 125 == 0:acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)self.steps_eval["step"].append(cur_step)self.steps_eval["acc"].append(acc["Accuracy"])"""
-------开始训练-----
"""
import os
from mindspore import Tensor, Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.nn import Accuracy
network = LeNet5()
epoch_size = 1
momentum=0.9
lr=0.01
mnist_path =r"\datasets\MNIST_Data" #这里填写你的数据集路径
model_path =r"\datasets\models\ckpt\mindspore_quick_start"#模型保存路径
train_data_path = r"\datasets\MNIST_Data\train"
test_data_path = r"\datasets\MNIST_Data\test"
net_loss=SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean')
net_opt=nn.Momentum(network.trainable_params(),lr,momentum)repeat_size = 1
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
eval_dataset = create_dataset(os.path.join(mnist_path, "test"), 32)
# 使用Model定义模型,这个模型包括损失函数,优化器,网络结构,
model =Model(network,net_loss,net_opt,metrics={'Accuracy':Accuracy()})#填写
# 保存模型和参数
config_ck =CheckpointConfig(save_checkpoint_steps=375,keep_checkpoint_max=16)#使用Checkpoint设置保存模型
ckpoint_cb =ModelCheckpoint(prefix="checkpoint_lenet",directory=model_path,config=config_ck)#使用ModelCheckpoint设置保存模型的名称地址等信息
steps_loss = {"step": [], "loss_value": []}
steps_eval = {"step": [], "acc": []}
# 保存每一步step,以及对应的损失和准确率信息
step_loss_acc_info = StepLossAccInfo(model,eval_dataset,steps_loss,steps_eval)#使用StepLossAccInfo类
#填写训练模型
model.train(epoch_size,ds_train,callbacks=[ckpoint_cb,LossMonitor(125),step_loss_acc_info],dataset_sink_mode=False)
"""
-------打印想训练过程-----
"""steps = steps_loss["step"]
loss_value = steps_loss["loss_value"]
steps = list(map(int, steps))
loss_value = list(map(float, loss_value))
plt.plot(steps, loss_value, color="red")
plt.xlabel("Steps")
plt.ylabel("Loss_value")
plt.title("Change chart of model loss value")
plt.show()
"""
------在测试集上验证模型-----
"""
from mindspore import load_checkpoint, load_param_into_net
#定义验证函数
def test_net(network, model, mnist_path):print("============== Starting Testing ==============")#填写 加载保存的模型param_dict = load_checkpoint(mnist_path)#填写
load_param_into_net(network,param_dict)
ds_eval =rd.create_dataset(os.path.join(mnist_path,"test")) #填写 创建测试集dataloader
acc =model.eval(ds_eval,dataset_sink_mode=False)#填写 输入模型获取精度print("============== Accuracy:{} ==============".format(acc))
test_net(network, model, mnist_path)
五、实验结果
读取数据集
数据集测试查看
数据集训练
预测
苏苏
Mindspore实现手写字体识别相关推荐
- MindSpore实现手写数字识别代码
MindSpore是华为自研的一套AI框架,最佳匹配昇腾处理器,最大程度地发挥硬件能力.作为AI入门的LeNet手写字体识别网络,网络大小和数据集都不大,可以在CPU上面进行训练和推理.下面是基于Mi ...
- pytorch CNN手写字体识别
## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...
- 第六讲 Keras实现手写字体识别分类
一 本节课程介绍 1.1 知识点 1.图像识别分类相关介绍: 2.Mnist手写数据集介绍: 3.标准化数据预处理: 4.实验手写字体识别 二 课程内容 2.1 图像识别分类基本介绍 计算机的图像识别 ...
- Android Studio编写一个手写字体识别程序
1.activity_main.xml 的代码 <?xml version="1.0" encoding="utf-8"?> <LinearL ...
- 【mindspore】mindspore实现手写数字识别
mindspore实现手写数字识别 具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用 ...
- 人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)
人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架) 直接上图,项目效果 1 ...
- python手写字体程序_深度学习---手写字体识别程序分析(python)
我想大部分程序员的第一个程序应该都是"hello world",在深度学习领域,这个"hello world"程序就是手写字体识别程序. 这次我们详细的分析下手 ...
- pytorch rnn 实现手写字体识别
pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...
- 《MATLAB 神经网络43个案例分析》:第19章 基于SVM的手写字体识别
<MATLAB 神经网络43个案例分析>:第19章 基于SVM的手写字体识别 1. 前言 2. MATLAB 仿真示例 3. 小结 1. 前言 <MATLAB 神经网络43个案例分析 ...
- 手写字体识别 --MNIST数据集
Matlab 手写字体识别 忙过这段时间后,对于上次读取的Matlab内部数据实现的识别,我回味了一番,觉得那个实在太小.所以打算把数据换成[MNIST数据集][1]. 基础思想还是相同的,使用Tre ...
最新文章
- C#结构体中数组的分配
- STM32F10x_StdPeriph_Lib_V3.5.0库与系统滴答定时器(Systick)
- 机器人编程与python语言的区别_儿童编程和机器人编程有啥区别?
- Qt简单的解析Json数据例子(一)
- 详解 URLLC 前世今生,你 Get 了吗?
- ad导入pcd后网络标号消失_如何将后端BaaS化:业务逻辑的拆与合
- vue3.0项目创建
- Day02 目录和文件的管理(ADMIN02)
- 浅谈静态方法与静态变量
- [Offer收割]编程练习赛48
- scala ip转换器
- 线性代数学习笔记——矩阵主要公式
- iBase4J项目笔记
- KEIL4文件无法正常使用
- String的getBytes()方法
- org.postgresql.util.PSQLException: ERROR: column loginid of relation userinfo does not exist
- python爬取京东商品价格走势_用python编写的抓京东商品价格的爬虫
- 【刘晓燕长难句分析】2.并列句
- 2018年支付行业回顾
- 数学在计算机方面的应用论文参考文献,数学论文参考文献
热门文章
- 北京市通州区谷歌卫星地图下载
- JavaScript 学习-42.jQuery 提交表单 submit() 方法
- 基于SSM的企业人事人员管理系统
- 爬虫中国天气网数据并可视化
- Matlab经纬度坐标转换xy坐标,经纬度坐标系转换为UTM坐标系(matlab)
- c#中 utm坐标转换经纬度坐标
- html怎么设置华文行楷,css如何修改字体为华文行楷
- 拜耳2020年10个新植保制剂商业化,3个生物技术性状项目推进至上市阶段
- Python怎么安装jieba库?
- 陕西2020行政区划调整_陕西行政区划调整畅想:西安咸阳合并可行,但成立直辖市不太现实...