神经网络介绍

神经网络即多层感知机

如果不知道感知机的可以看博主之前的文章感知机及Python实现

神经网络实现及手写字识别

关于数据集

  • http://yann.lecun.com/exdb/mnist/下载,下载后将文件解压
  • main函数中的path改为下载文件的存储路径即可

如果对数据集有问题,可以私信博主

关于实现

  • 基于pytorch实现,包括神经网络的构建,激活函数的选择
  • 归一化使用了像素值/255的方式实现,可以尝试用别的方式进行归一化处理
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
class Data:'''this class is about data module'''def __init__(self):self.start_loc_image=16 # the start location of image dataself.start_loc_label=8 # the start location of label dataself.num_pixel=28*28 # the number of pixelsself.choice={'train-image':'train-images.idx3-ubyte','train-label':'train-labels.idx1-ubyte','test-image':'t10k-images.idx3-ubyte','test-label':'t10k-labels.idx1-ubyte'} # the specific file namedef get(self,path,train_test='',image_label=''):'''get the data from the file whose path is "path":param path: the saving path of given files, default is "./file/":param train_test: the data category("train" or "test"):param image_label: the data information("image" or "label"):return: the data you want'''if train_test not in ['train','test'] or image_label not in ['image','label']:raise NameError('please check you spelling,"train_test" can be "train/test", "image_label" can be "image/label"')ch=train_test+'-'+image_labeldata=[]if image_label=='image':print('loading images ...')with open(path+self.choice[ch],'rb',) as f:file=f.read()for i in range(self.start_loc_image,file.__len__(),self.num_pixel):item=[]pixel=file[i:i+self.num_pixel].hex()for p in range(0,pixel.__len__(),2):item.append(int(pixel[p:p+2],16)) # decode -> get the pixel information from original filedata.append(self.transform2image(item))f.close()elif image_label=='label':print('load labels ...')with open(path+self.choice[ch],'rb',) as f:file=f.read()for i in range(self.start_loc_label,file.__len__()):data.append(file[i]) # decode -> get the label from original filef.close()return datadef transform2image(self,data:list):'''transform pixel point to image:param data: the original 1D pixel points:return: transformed image(28*28)'''assert data.__len__()==784import numpy as npreturn np.reshape(data,(28,-1))def transfer_tensor(self,data):'''transfer data to tensor format:param data: the original input data:return: transferred data'''return torch.tensor(data)def normalize(self,data,maximum=255):'''normalize the data with maximum:param data: the input data:param maximum: the maximum of pixel(is 255):return: normalized data'''return torch.div(data,maximum)
class Network(nn.Module):'''this class is about neural network'''def __init__(self,in_dim,n_hidden,out_dim):'''define the network:param in_dim: the input dimension:param n_hidden: the hidden layer dimension:param out_dim: the output dimension'''### about networksuper(Network, self).__init__()self.layer1 = nn.Sequential(nn.Linear(in_dim,n_hidden))self.layer2 = nn.Sequential(nn.Linear(n_hidden,out_dim))### other parameterself.size_pixel=28*28 # the size of pixed of each picture### learning rateself.learning_rate=0.001### optimizerself.optimizer=optim.SGD(self.parameters(),lr=self.learning_rate)def forward(self,data):'''forward the input data:param data: the data you want to train:return: the output or predicted value'''hidden=F.relu(self.layer1(data))out=F.sigmoid(self.layer2(hidden))return outdef accuracy(self,act,pre):'''calculate the accuracy:param act: actual value:param pre: predicted value:return: accuracy'''assert act.__len__()==pre.__len__()return round((act==pre).sum().item()/act.__len__(),3)def pre_process(self,feature,label):'''pre processing:param feature: feature:param label: label:return: preprocessed feature and label'''### transform to the format of tensorfeature=dat.transfer_tensor(feature)feature=dat.normalize(feature)feature=feature.view(-1,self.size_pixel)label=dat.transfer_tensor(label)return feature,torch.tensor(label,dtype=torch.int64)def per_train(self,epoch,feature,label,validation=0.2,batch=50,verbose=True,num_view=50):'''train neural network:param epoch: training times:param feature: feature:param label: label:param validation: for using evaluation:param batch: batch size:param verbose: whether view the training process or not:param num_view: view via training "num_view" times:return: none'''assert feature.__len__()==label.__len__()print('training neural network ...')fea,lab=self.pre_process(feature,label)len_train=int(feature.__len__()*(1-validation))data_train,label_train=fea[:len_train+1],lab[:len_train+1]data_train=[data_train[i:i+batch]for i in range(0,len_train,batch)]label_train=[label_train[i:i+batch]for i in range(0,len_train,batch)]data_val,label_val=fea[len_train:],lab[len_train:]for e in range(epoch+1):self.train()loss_tmp=[]for img,lab in zip(data_train,label_train):pre=self(img)loss_train=nn.CrossEntropyLoss()(pre,lab)loss_tmp.append(loss_train)loss_train.backward()self.optimizer.step()self.optimizer.zero_grad()if verbose and e>0 and e%num_view==0:self.eval()pre=self(data_val)loss_val=nn.CrossEntropyLoss()(pre,label_val)_,pre_view=pre.max(1)acc=self.accuracy(label_val,pre_view)print('epoch: '+str(e)+'/'+str(epoch)+' --> training loss:',loss_train.item(),'validation loss:',loss_val.item(),'validation accuracy:',acc)
def set_seed(seed):'''set random seed in order that result can be replayed:param seed: random seed:return: none'''import randomtorch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.backends.cudnn.deterministic = Truenp.random.seed(seed)random.seed(seed)
Seed=0
Size_pixel= 28*28
Hidden=200
Output=10
set_seed(Seed)
if __name__ == '__main__':### set some necessary parameterspath='./file/' # the path of saved file### initialize necessary classdat=Data() # for data relatednet=Network(Size_pixel,Hidden,Output) # build network### load image and label of training data and testing data# train_image=dat.get(path,'train','image')# train_label=dat.get(path,'train','label')test_image=dat.get(path,'test','image')test_label=dat.get(path,'test','label')### train networkepoch=1000 # training timesnet.per_train(epoch,test_image,test_label)

结果展示

可以看到随着训练次数的增加,预测手写字的准确率也越来越高
运行结果

loading images ...
load labels ...
training neural network ...
epoch: 50/1000 --> training loss: 2.1074414253234863 validation loss: 2.1338484287261963 validation accuracy: 0.716
epoch: 100/1000 --> training loss: 1.8810529708862305 validation loss: 1.9262853860855103 validation accuracy: 0.814
epoch: 150/1000 --> training loss: 1.7580647468566895 validation loss: 1.811558723449707 validation accuracy: 0.844
epoch: 200/1000 --> training loss: 1.6905272006988525 validation loss: 1.7511157989501953 validation accuracy: 0.858
epoch: 250/1000 --> training loss: 1.6461492776870728 validation loss: 1.7131295204162598 validation accuracy: 0.87
epoch: 300/1000 --> training loss: 1.6150131225585938 validation loss: 1.6869707107543945 validation accuracy: 0.877
epoch: 350/1000 --> training loss: 1.5923327207565308 validation loss: 1.6678192615509033 validation accuracy: 0.886
epoch: 400/1000 --> training loss: 1.5752110481262207 validation loss: 1.653057336807251 validation accuracy: 0.892
epoch: 450/1000 --> training loss: 1.5620046854019165 validation loss: 1.6411659717559814 validation accuracy: 0.896
epoch: 500/1000 --> training loss: 1.5515371561050415 validation loss: 1.6312130689620972 validation accuracy: 0.897
epoch: 550/1000 --> training loss: 1.5430898666381836 validation loss: 1.6226950883865356 validation accuracy: 0.903
epoch: 600/1000 --> training loss: 1.5360981225967407 validation loss: 1.615297794342041 validation accuracy: 0.905
epoch: 650/1000 --> training loss: 1.5302610397338867 validation loss: 1.608830451965332 validation accuracy: 0.907
epoch: 700/1000 --> training loss: 1.5252583026885986 validation loss: 1.603135108947754 validation accuracy: 0.908
epoch: 750/1000 --> training loss: 1.5209182500839233 validation loss: 1.5980955362319946 validation accuracy: 0.91
epoch: 800/1000 --> training loss: 1.5170769691467285 validation loss: 1.5936044454574585 validation accuracy: 0.911
epoch: 850/1000 --> training loss: 1.513597846031189 validation loss: 1.5895581245422363 validation accuracy: 0.915
epoch: 900/1000 --> training loss: 1.5103877782821655 validation loss: 1.5858769416809082 validation accuracy: 0.917
epoch: 950/1000 --> training loss: 1.5073943138122559 validation loss: 1.5824891328811646 validation accuracy: 0.918
epoch: 1000/1000 --> training loss: 1.5045477151870728 validation loss: 1.5793448686599731 validation accuracy: 0.92

效果展示

创作不易,觉得不错就微信扫码奖励一下吧!

利用神经网络实现手写字识别相关推荐

  1. 利用卷积神经网络实现手写字识别

    本文我们介绍一下卷积神经网络,然后基于pytorch实现一个卷积神经网络,并实现手写字识别 卷积神经网络介绍 传统神经网络处理图片问题的不足 让我们先复习一下神经网络的工作流程: 搭建一个神经网络 将 ...

  2. Pytorch入门练习2-kaggle手写字识别神经网络(CNN)实现

    目录 数据预处理 自定义数据集 构建网络结构 对卷积神经网络进行训练和评估 对数据进行预测 保存预测数据,提交代码 SNN由于无法考虑到图片数据的维度关系,在预测精度上会被限制,本章我们采用CNN卷积 ...

  3. Pytorch入门练习-kaggle手写字识别神经网络(SNN)实现

    采用pytorch搭建神经网络,解决kaggle平台手写字识别问题. 数据来源:https://www.kaggle.com/competitions/digit-recognizer/data 参考 ...

  4. matlab深度学习——【卷积神经网络】手写字的识别

    > 本文所使用的数据集在文章最后,不需要积分就可以下载! 数据集下载 这里主要是基于卷积神经网络的手写字的识别,我是用matlab做的,如果有对卷积神经网络不太熟悉的伙伴可以搜下,网上资源比较多 ...

  5. 基于tensorflow的MNIST手写字识别

    一.卷积神经网络模型知识要点卷积卷积 1.卷积 2.池化 3.全连接 4.梯度下降法 5.softmax 本次就是用最简单的方法给大家讲解这些概念,因为具体的各种论文网上都有,连推导都有,所以本文主要 ...

  6. TensorFlow基于minist数据集实现手写字识别实战的三个模型

    手写字识别 model1:输入层→全连接→输出层softmax model2:输入层→全连接→隐含层→全连接→输出层softmax model3:输入层→卷积层1→卷积层2→全连接→dropout层→ ...

  7. python手写汉字识别_TensorFlow 2.0实践之中文手写字识别

    问题导读: 1.相比于简单minist识别,汉字识别具有哪些难点? 2.如何快速的构建一个OCR网络模型? 3.读取的时候有哪些点需要注意? 4.如何让模型更简单的收敛? 还在玩minist?fash ...

  8. .net 数字转汉字_TensorFlow 2.0 中文手写字识别(汉字OCR)

    TensorFlow 2.0 中文手写字识别(汉字OCR) 在开始之前,必须要说明的是,本教程完全基于TensorFlow2.0 接口编写,请误与其他古老的教程混为一谈,本教程除了手把手教大家完成这个 ...

  9. 在Windows上调试TensorFlow 2.0 中文手写字识别(汉字OCR)

    在Windows上调试TensorFlow 2.0 中文手写字识别(汉字OCR) 一.环境的搭建 Windows+1080Ti+Cuda10.1 Tsorflow2.0.0 Numpy1.16.4 注 ...

最新文章

  1. scanf可不可以输入浮点型_数据的输入和输出
  2. 微软自拍:让黑科技拯救不会拍照的你
  3. IPO与上市的关系?
  4. 网络原理 | TCP/IP中的连接管理机制 重要协议与核心机制
  5. 入门React第二天(函数式组件传值)
  6. mysql的dql_Mysql-DQL
  7. rhadoop之mapreduce函数
  8. sofa与springboot的入门案例
  9. KafkaController机制(六):Zookeeper Listener之TopicDeletionManager与DeleteTopicsListener
  10. (转)逃脱者可获生机(中)
  11. vue项目路由 Navigating to current location (/xxxx) is not allowed
  12. Carsim2016及2019 轨迹跟踪过程中跑直线的解决方案
  13. Godot4补间动画Tween
  14. su:密码正确,但权限被拒绝
  15. BZOJ4833: [Lydsy1704月赛]最小公倍佩尔数
  16. 驻留内存 虚拟内存 共享内存
  17. 12_JavaScript数据结构与算法(十二)二叉树
  18. docker导入MySQL数据库
  19. SQL Server实现split函数分割字符串功能及用法示例
  20. C语言基础——sizeof的用法总结

热门文章

  1. 第六届蓝桥杯大赛个人赛省赛(软件类本科B组)做题笔记
  2. ps教程之淘宝图片怎么加水印方法案例
  3. 利用js实现论坛发帖小案例
  4. Spring底层核心原理
  5. 1.animation 判断动画结束,结束后执行另一事件;2.用css3写竖条纹背景
  6. Elasticsearch聚合性能优化:深度优先和广度优先
  7. mysql判断是否存在索引并删除_mysql判断索引存在时删除索引的方法_MySQL
  8. 小米8 android Q gsi,小米8 SE已基于Android Q系统进行测试
  9. unity怎么在UI面板上显示出3D立体模型
  10. RabbitMQ-高级