用PyCharm实现MNIST手写数字识别。
本人几乎没有编程基础,完全按照《Python神经网络》一书学习,在此过程中困惑很多。希望可以相互学习,下面是我最终代码。
数据链接:
https://pjreddie.com/media/files/mnist_test.csv测试集
https://pjreddie.com/media/files/mnist_train.csv训练集
import numpy # scipy.special for the sigmoid function expit() import scipy.special # library for plotting arrays import matplotlib.pyplot# 初始化函数--设定输入层节点、隐藏节点、和输出层节点数量 # 训练——训练给定样本,优化权重。(使用训练数据集) # 查询函数——给定输入,从输出节点给出答案(最后导入测试数据集,使用查询函数获得输出,与正确结果比较,获得该神经网络的准确度)# neural network class definition class neuralNetwork:# initialise the neural networkdef __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):# set number of nodes in each input, hidden, output playerself.inodes = inputnodesself.hnodes = hiddennodesself.onodes = outputnodes# link weight matrices,wih and who# weight inside the arrays are w_i_j,where link is from node i to node j in the next layer# w11 w21# w12 w22 etcselfwih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))# learning rateself.lr = learningrate# activation function is the sigmoid functionself.activation_function = lambda x: scipy.special.expit(x)pass# train the neural workdef train(self, inputs_list, targets_list):# convert inputs list to 2d arrayinputs = numpy.array(inputs_list, ndmin=2).Ttargets = numpy.array(targets_list, ndmin=2).T# calculate signals into hidden layerhidden_inputs = numpy.dot(self.wih, inputs)# calculate the signals emerging from hidden layerhidden_outputs = self.activation_function(hidden_inputs)# calculate signals into final output layerfinal_inputs = numpy.dot(self.who, hidden_outputs)# calculate the signals emerging from final output layerfinal_outputs = self.activation_function(final_inputs)# output layer error is the (target-actual)output_errors = targets - final_outputs# hidden layer error is the output_error,split by weights,recombined at hidden nodeshidden_errors = numpy.dot(self.who.T, output_errors)# update the weights for the links between the hidden and output layersself.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)),numpy.transpose(hidden_outputs))# update the weights for the links between the input and hidden layersself.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),numpy.transpose(inputs))pass# query the neural networkdef query(self, inputs_list):# convert inputs list to 2d arrayinputs = numpy.array(inputs_list, ndmin=2).T# calculate signals into hidden layerhidden_inputs = numpy.dot(self.wih, inputs)# calculate the signals emerging from hidden layerhidden_outputs = self.activation_function(hidden_inputs)# calculate signals into final output layerfinal_inputs = numpy.dot(self.who, hidden_outputs)# calculate the signals emerging from final output layerfinal_outputs = self.activation_function(final_inputs)return final_outputs# number of nodesinput_nodes = 784 hidden_nodes = 100 output_nodes = 10# learning rate is 0.3 learning_rate = 0.1# create instance of neural network n = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)# 读入手写字体训练数据集 training_data_file = open("训练集数据存储路径", 'r') training_data_list = training_data_file.readlines() training_data_file.close()# train the neural network # epochs is the number of times the training data set is used for training epochs = 6 for e in range(epochs):# go through all record in the training data setfor record in training_data_list:# split the record by the','commasall_values = record.split(',')# scale and shift the inputsinputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01# create the target output values(all 0.01,except the desired label which is 0.99)targets = numpy.zeros(output_nodes) + 0.01# all_values[0] is the target label for this recordtargets[int(all_values[0])] = 0.99n.train(inputs, targets)passpass# 读入手写字体测试数据集test_data_file = open("测试集数据存储路径", 'r')test_data_list = test_data_file.readlines()test_data_file.close()# test the neural networkscorecard = [] # go through all record in the test data set for record in test_data_list:# split the record by the','commasall_values = record.split(',')# correct answer is first valuecorrect_label = int(all_values[0])# scale and shift the inputsinputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01# query the networkoutputs = n.query(inputs)# the index of the highest value corresponds to the labellabel = numpy.argmax(outputs)if label == correct_label:scorecard.append(1)else:scorecard.append(0)passpass scorecard_array = numpy.asfarray(scorecard) print("performance=", scorecard_array.sum() / scorecard_array.size)
用PyCharm实现MNIST手写数字识别。相关推荐
- 基于K210的MNIST手写数字识别
基于K210的MNIST手写数字识别 项目已开源链接: Github. 硬件平台 采用Maixduino开发板 在sipeed官方有售 软件平台 使用MaixPy环境进行单片机的编程 官方资源可在这里 ...
- AI常用框架和工具丨11. 基于TensorFlow(Keras)+Flask部署MNIST手写数字识别至本地web
代码实例,基于TensorFlow+Flask部署MNIST手写数字识别至本地web,希望对您有所帮助. 文章目录 环境说明 文件结构 模型训练 本地web创建 实现效果 环境说明 操作系统:Wind ...
- 用PyTorch实现MNIST手写数字识别(非常详细)
Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...
- 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 (zz)
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 我想写一系列深度学习的简单实战教程,用mxnet做实现平台的实例代码简单讲解深度学习常用的一些技术方向和实战样例.这 ...
- TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类
TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类 目录 设计思路 实现代码 设计思路 更新-- 实现代码 # -*- coding:utf-8 -*- import ten ...
- 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别
一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...
- 使用PYTORCH复现ALEXNET实现MNIST手写数字识别
网络介绍: Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当年夺下了不少比赛的冠军,下面是Alexnet的网络结构: 网络结构较为简单,共有五个卷积层和三个全连接层,原 ...
- 使用tf.keras搭建mnist手写数字识别网络
使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...
- TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络
TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...
最新文章
- Beyond MySQL --Branching the popular database--转载
- android 相机路径,android – 如何获取刚从相机捕获的图像路径
- ALV程序checkbox全选及取消全选
- 需求评审五个维度框架分析及其带来的启示-总起
- 解决spark on yarn报错:File /tmp/hadoop-root/nm-local-dir/filecache does not exist
- Opencv、OpenCV2.x、Opencv3.x个版本的进化,与VS各个版本的匹配问题
- Logstash之Logstash inputs(file和redis插件)、Logstash outputs(elasticsearch 和redis插件)和Filter plugins...
- @Bean+@Component+@Configuration+@Autowired的配合使用与区别(转载+整理+完整实验)
- python api数据接口_python写数据api接口
- 【Android开发艺术探索】RemoteViews
- android sqlite 打包 xe,Delphi XE使用SQLite3
- 东秦数模美赛校赛记录——紧急车辆位置.题目
- 银河麒麟v10sp1桌面安装远程控制工具todesk
- 《iRedMail邮件服务器搭建详细过程》
- 计算机专业建设会议纪要,智能控制教研室会议纪要6号
- Navicat无法导入excel文件的异常处理
- 湖南农业大学有计算机应用,计算机应用基础复习资料–湖南农业大学.doc
- Docker Secret加密
- 【算法】第三届全国大学生算法设计与编程挑战赛(冬季赛)
- 基本Kmeans算法介绍及其实现
热门文章
- 帝吧出征FB:这李毅吧的“爆吧”文化是如何形成的
- rdkit中 logP, mr, TPSA, Labute ASA 讲解
- @Autowire注解的工作原理
- XMind 8 pro / XMind 8 Update 8 软件破解版
- SimLab Composer 9 for Mac(3D场景渲染工具)
- python3半自动爬虫,获取风暴英雄官方壁纸
- PDF如何去除水印?三种方法教你如何去除PDF文件水印
- SecuritySpy for Mac(Mac视频监控软件)
- SwiftUI——界面间的“闪转腾挪”(页面跳转的各种方法)
- 今年找工作为什么这么难?