实验环境:

python3.6.3  pip 9.0.1  tensorflow 1.10.0  window 10  oracle vm virtualbox  ubuntu 16.0.1

1.基于tensorflow对mnist预测,需要连接外网

下面代码可以直接复制去调试,识别率高达98%,最低也在91%。python对代码格式有非常高的要求。行头不能同时存在tab和空格。函数内行头对齐。大概有3/40分钟左右.不过我在8月20号训练结果不是这样,最高也就97%,最低89%。不清楚其中原因

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("/dataset",one_hot=True)
#每一批数据大小
batch_size = 100
#计算多少批数据
n_batch = mnist.train.num_examples
#定义两个placeholder,None=100,28*28=784,即100行784列
x = tf.placeholder(tf.float32,[None,784])
#0-9个输出标签
y = tf.placeholder(tf.float32,[None,10])
#创建一个简单的神经网络
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([1,10]))
#softmax函数转化为概率值
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#tf.equal()比较函数大小是否相同,相同为True,不同为false;tf.argmax():求y=1在哪个位置,求概率最#大在哪个位置
#argmax返回一维张量中最大的值所在的位置,结果存放在一个布尔型列表中
correct_prediction= tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#cast转化类型,将布尔型转化为32位浮点型,true=1.0,false=0.0再求平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:sess.run(init)#将所有图片训练21次for epoch in range(21):#每次训练所有图片for batch in range(n_batch):batch_xs,batch_ys = mnist.train.next_batch(batch_size)#feed_dict传入训练集的图片和标签sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})#传入测试集的图片和标签acc = sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys})print("Iter"+str(epoch)+",Testing Accuracy:"+str(acc))

2.数据模型对mnist预测,预测结果在代码后面截图,几分钟就可看到结果

import struct
from sklearn import cross_validation,svm,metrics
#将mnist 数据集转换成csv格式
def to_csv(name,maxdata):#打开标签数据集lbl_f = open("./dataset/"+name+"-labels.idx1-ubyte","rb")#打开图像数据集img_f = open("./dataset/"+name+"-images.idx3-ubyte","rb")#写入csv文件csv_f = open("./dataset/"+name+",csv","w",encoding="utf-8")#将字节流转换成python数据类型复制给标签mag,lbl_count=struct.unpack(">II",lbl_f.read(8))#将字节流转换成python数据类型复制给图像mag,img_count=struct.unpack(">II",img_f.read(8))#将字节流转换成python数据类型复制给行列rows,cols=struct.unpack(">II",img_f.read(8))#计算数据总量pixels=rows*colsres=[]for idx in range(lbl_count):#设置计数器,大于数据个数总量就跳出循环if idx > maxdata:breaklabel=struct.unpack("B",lbl_f.read(1))[0]bdata=img_f.read(pixels)sdata=list(map(lambda n:str(n),bdata))#写入标签csv_f.write(str(labek)+",")#写入数据csv_f.write(",".join(sdata)+"\r\n")if idx < 10:s="P2 28 28 255\n"s+=" ".join(sdata)iname="./dataset/{0}-{1}-{2}.pgm".format(name.idx,label)with open(iname,"w",encoding="utf-8")as f:f.write(s)#关闭数据流,释放资源csv_f.close()lbl_f.close()img_f.close()
#转换到train.csv 1000个数据
to_csv("train",1000)
#转换到t10k.csv 1000个数据
to_csv("t10k",1000)
#通过sklearn的交叉验证处理数据,svm训练数据预测结果,metrics生成分类报告和准确率
def load_csv(fname):labels=[]images=[]with open(fname,"r")as f:for line in f:cols=line.split(",")if len(cols)<2:continuelabels.append(int(cols.pop(0)))vals=list(map(lambda n:int(n)/256,cols))images.append(vals)return{"labels":labels,"images":images}
data=load_csv("./dataset/train.csv")
test=load_csv("./dataset/t10k.csv")
clf=svm.SVC()
#训练数据集
clf.fit(data["images"],data["labels"])
#预测数据集
predict=clf.predict(test["images"])
#生成测试精度
sore=metrics.accuracy_score(test["labels"],predict)
#生成交叉验证的报告
report=metrics.classification_report(test["labels"],predict)
print(score)
ptrint(report)

该结果是训练1000次的

该结果是训练10000次的

MNIST手写数字识别程序相关推荐

  1. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

  2. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  3. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  4. MNIST 手写数字识别(一)

    MNIST 手写数字识别模型建立与优化 本篇的主要内容有: TensorFlow 处理MNIST数据集的基本操作 建立一个基础的识别模型 介绍 S o f t m a x Softmax Softma ...

  5. 基于K210的MNIST手写数字识别

    基于K210的MNIST手写数字识别 项目已开源链接: Github. 硬件平台 采用Maixduino开发板 在sipeed官方有售 软件平台 使用MaixPy环境进行单片机的编程 官方资源可在这里 ...

  6. Caffe MNIST 手写数字识别(全面流程)

    目录 1.下载MNIST数据集 2.生成MNIST图片训练.验证.测试数据集 3.制作LMDB数据库文件 4.准备LeNet-5网络结构定义模型.prototxt文件 5.准备模型求解配置文件_sol ...

  7. matlab朴素贝叶斯手写数字识别_机器学习系列四:MNIST 手写数字识别

    4. MNIST 手写数字识别 机器学习中另外一个相当经典的例子就是MNIST的手写数字学习.通过海量标定过的手写数字训练,可以让计算机认得0~9的手写数字.相关的实现方法和论文也很多,我们这一篇教程 ...

  8. C语言底层搭建CNN实现MNIST手写数字识别

    手写数字识别 手写数字识别是指使用计算机自动识别手写体阿拉伯数字的技术.作为光学字符识别OCR的一个分支,它可以被广泛应用到手写数据的自动录入场景中.传统的识别方法如最近邻算法k-NN.支持向量机SV ...

  9. ANN原来如此简单!——用Excel实现的MNIST手写数字识别(之一)

    ANN原来如此简单 人工神经网络目前仍然是一个火热的话题,许多人都对它充满了兴趣.然而,对于想了解ANN具体是怎么回事的同学来说,往往缺乏一个足够简单可视化的方法去了解神经网络的内部构造.网络上的各种 ...

最新文章

  1. 机器人抓取汇总|涉及目标检测、分割、姿态识别、抓取点检测、路径规划
  2. android动态创建arraylist,Android:二维ArrayList帮助
  3. (4)打鸡儿教你Vue.js
  4. JAVA基础:Hibernate外键关联与HQL语法
  5. SAP UI5 walkthrough 3 - sapUiBody
  6. 谷歌开源的 GAN 库--TFGAN
  7. quartus仿真27:JK触发器构成的同步十进制可逆计数器(分析)
  8. 通过Spring Boot中的手动Bean定义提高启动性能
  9. 使用uniapp获取当前位置
  10. H5链接调起支付宝APP支付(个人收款)
  11. 7-8 哈利·波特的考试 (20 分)
  12. java 匿名邮件_java开发邮件发送(匿名)
  13. 济南计算机专业职业学校排名,济南计算机专业学校排名
  14. Hybrid App开发模式
  15. 一行一行读取文件的两种方式
  16. 计算机故障代码ff,电脑开机时主板上只显示FF怎么回事?
  17. ElasticSearch DSL语言高级查询+SpringBoot
  18. linux命令的使用:配置静态ip,查看网关,dns服务器ip,关闭防火墙,selinux
  19. 一阶电路中的时间常数_一阶RC电路的时间常数为 ;一阶RL电路的时间常数为
  20. matlab图像处理ppt,数字图像处理(MATLAB版).ppt

热门文章

  1. 二维码生成 API数据接口
  2. Python -- 限流 throttle
  3. python3 中解决\u8bf7\u6c42\u6210\u529f“格式编码问题
  4. ROS2机器人笔记20-11-22
  5. 网页中的png图片无法显示?
  6. Redis 布隆过滤器
  7. FastFlow: Unsupervised Anomaly Detection and Localization via 2D Normalizing Flows
  8. oracle 10092,Oracle诊断事件列表
  9. Apache Thrift 官网学习 一 基本概述与入门
  10. 活体检测论文研读五:Face De-Spoofing: Anti-Spoofing via Noise Modeling