本文使用tensorflow实现在mnist数据集上的图片训练和测试过程,使用了简单的两层神经网络,代码中涉及到的内容,均以备注的形式标出。
关于文中的数据集,大家如果没有下载下来,可以到我的网盘去下载,链接如下:
https://pan.baidu.com/s/1KU_YZhouwk0h9MK0xVZ_QQ
下载下来后解压到F盘mnist文件夹下,或者自己选择文件存储位置,然后在下面代码的相应位置改过来即可。

直接上代码:

    import tensorflow as tfimport numpy as np#引入input_mnistfrom tensorflow.examples.tutorials.mnist import input_data#加载mnist信息,获得训练和测试图片以及对应标签mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)trainimg = mnist.train.imagestrainlabel = mnist.train.labelstestimg = mnist.test.imagestestlabel = mnist.test.labelsprint("MNIST LOAD READY")#输入图片尺寸28*28n_input = 784#输出类别数n_output = 10#初始化权重weights = {#卷积层参数,采用高斯初始化'wc1':tf.Variable(tf.random_normal([3,3,1,64],stddev = 0.1)),'wc2':tf.Variable(tf.random_normal([3,3,64,128],stddev=0.1)),#全连接层参数'wd1':tf.Variable(tf.random_normal([7*7*128,1024],stddev=0.1)),       'wd2':tf.Variable(tf.random_normal([1024,n_output],stddev=0.1))        }#初始化偏置biases = {'bc1':tf.Variable(tf.random_normal([64],stddev = 0.1)),'bc2':tf.Variable(tf.random_normal([128],stddev=0.1)),'bd1':tf.Variable(tf.random_normal([1024],stddev=0.1)),       'bd2':tf.Variable(tf.random_normal([n_output],stddev=0.1))        }#定义前向传播函数def conv_basic(_input,_w,_b,_keepratio):#输入#reshape()中的-1表示不用我们指定,让函数自己计算_input_r = tf.reshape(_input,shape = [-1,28,28,1])#CONV1_conv1 = tf.nn.conv2d(_input_r,_w['wc1'],strides=[1,1,1,1],padding='SAME')_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1,_b['bc1']))_pool1 = tf.nn.max_pool(_conv1,ksize = [1,2,2,1],strides = [1,2,2,1],padding='SAME')#Dropout层既可以使用在全连接层之后,也可以使用在每层之后,这里在每层之后都加了Dropout_pool_dr1 = tf.nn.dropout(_pool1,_keepratio)#CONV2#conv2d计算二维卷积_conv2 = tf.nn.conv2d(_pool_dr1,_w['wc2'],strides=[1,1,1,1],padding='SAME')_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2,_b['bc2']))_pool2 = tf.nn.max_pool(_conv2,ksize = [1,2,2,1],strides = [1,2,2,1],padding='SAME')_pool_dr2 = tf.nn.dropout(_pool2,_keepratio)#向量化 全连接层输入 得到wd1层的7*7*128的shape 然后转化为向量_dense1 = tf.reshape(_pool_dr2,[-1,_w['wd1'].get_shape().as_list()[0]])#FULL CONNECTION1_fc1 = tf.nn.relu(tf.add(tf.matmul(_dense1,_w['wd1']),_b['bd1']))_fc_dr1 = tf.nn.dropout(_fc1,_keepratio)#FULL CONNECTION2_out = tf.add(tf.matmul(_fc_dr1,_w['wd2']),_b['bd2'])#输出字典out = {'input_r':_input_r,'conv1':_conv1,'pool1':_pool1,'pool1_dr1':_pool_dr1,'conv2':_conv2,'pool2':_pool2,'pool_dr2':_pool_dr2,'dense1':_dense1,'fc1':_fc1,'fc_dr1':_fc_dr1,'out':_out}return outprint("CNN READY")a = tf.Variable(tf.random_normal([3,3,1,64],stddev=0.1))print(a)a = tf.Print(a,[a],"a: ")init = tf.global_variables_initializer()sess = tf.Session()sess.run(init)#填充x = tf.placeholder(tf.float32,[None,n_input])y = tf.placeholder(tf.float32,[None,n_output])keepratio = tf.placeholder(tf.float32)#进行一次前向传播_pred = conv_basic(x,weights,biases,keepratio)['out']#计算损失cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = _pred,labels=y))#定义优化器optm = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(cost)#比较预测的标签和真实标签是否一致,一致返回True,不一致返回False#argmax找到给定的张量tensor中在指定轴axis上的最大值/最小值的位置,0为每一列,1为每一行_corr = tf.equal(tf.argmax(_pred,1),tf.argmax(y,1))#True转化为1 False为0accr = tf.reduce_mean(tf.cast(_corr,tf.float32)) #每1个epoch保存一次save_step = 1#max_to_keep最终只保留三组模型,即(12 13 14)saver = tf.train.Saver(max_to_keep=3)#控制训练还是测试do_train=1init = tf.global_variables_initializer()  sess = tf.Session()sess.run(init)      #训练15个epochtraining_epochs = 15batch_size = 16display_step = 1#训练过程if do_train == 1:for epoch in range(training_epochs):avg_cost=0.total_batch = 10#迭代优化for i in range(total_batch):batch_xs,batch_ys = mnist.train.next_batch(batch_size)       sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys,keepratio:0.7})     avg_cost += sess.run(cost,feed_dict={x:batch_xs,y:batch_ys,keepratio:1.})/total_batch#打印信息if (epoch+1) % display_step ==0:print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))train_acc = sess.run(accr,feed_dict = {x:batch_xs,y:batch_ys,keepratio:1.})print("Train accuracy:%.3f"%(train_acc))#保存模型if epoch % save_step == 0:saver.save(sess,"F:/mnist/data/model.ckpt-"+str(epoch))#测试(cpu版本慢的要死 电脑都快要被卡死了...)if do_train == 0:#epoch = 15 减1之后即加载第14个模型epoch = training_epochs-1#读取模型saver.restore(sess,"F:/mnist/data/model.ckpt-"+str(epoch))#打印测试精度test_acc = sess.run(accr,feed_dict={x:testimg,y:testlabel,keepratio:1.})print("test accr is:%.3f"%(test_acc))print("Optimization Finished")

训练的部分过程如下:

测试过程如下:

测试时只需修改do_train==0 即可。如果使用Anaconda的spyder的话,记得测试之前先restart kennel一下。

tensorflow(七)实现mnist数据集上图片的训练和测试相关推荐

  1. keras笔记-mnist数据集上的简单训练

    学习了keras已经好几天了,之前一直拒绝使用keras,但是现在感觉keras是真的好用啊,可以去尝试一下啊. 首先展示的第一个代码,还是mnist数据集的训练和测试,以下是代码: from ker ...

  2. autoencoder自编码器原理以及在mnist数据集上的实现

    Autoencoder是常见的一种非监督学习的神经网络.它实际由一组相对应的神经网络组成(可以是普通的全连接层,或者是卷积层,亦或者是LSTMRNN等等,取决于项目目的),其目的是将输入数据降维成一个 ...

  3. 使用mnist数据集_使用MNIST数据集上的t分布随机邻居嵌入(t-SNE)进行降维

    使用mnist数据集 It is easy for us to visualize two or three dimensional data, but once it goes beyond thr ...

  4. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  5. MNIST数据集转为图片形式输出

    前期工作 1.请注意运行代码存入的文件夹的名称,要与代码中的path路径对应一致: 2.下载MNIST数据集(四个压缩包),并将四个压缩包的内容解压出来,如下图①: 3.在运行代码目录下,建立data ...

  6. DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本

    DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本 目录 输出结果 设计思路 实现部分代码 说明:所有图片文件丢失 输出结果 更新-- 设计思路 更新-- 实现部分代码 更 ...

  7. 在MNIST数据集上训练一个手写数字识别模型

    使用Pytorch在MNIST数据集上训练一个手写数字识别模型, 代码和参数文件 可下载 1.1 数据下载 import torchvision as tvtraining_sets = tv.dat ...

  8. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

  9. NLP之词向量:利用word2vec对20类新闻文本数据集进行词向量训练、测试(某个单词的相关词汇)

    NLP之词向量:利用word2vec对20类新闻文本数据集进行词向量训练.测试(某个单词的相关词汇) 目录 输出结果 设计思路 核心代码 输出结果 寻找训练文本中与morning最相关的10个词汇: ...

最新文章

  1. 中gcd函数_欧拉函数φ(n)的计算及欧拉定理
  2. Android移动开发之【Android实战项目】activity生命周期与Java中@Override的作用
  3. MySQL 讨厌哪种类型的查询
  4. python-第二块:time模块和datatime模块
  5. linux系统牵引程序设置,Linux上安装Wine运行AutoCAD实例[多图]
  6. netty系列之:netty中的Channel详解
  7. 怎么复制远程服务器上的文件夹,Linux系统复制文件/文件夹到远程服务器
  8. 如何在CentOS 5.5上安装Kippo蜜罐
  9. [设计模式]单例模式(懒汉式,饿汉式)
  10. python从入门到实践答案博客园_《Python从入门到实践》--第八章 函数 课后练习4...
  11. 全数字FM接收机 --(1)结构
  12. Python 算法交易实验30 退而结网7-交易策略思考
  13. 无感支付及相应技术规范
  14. 2022联想创新科技大会--智能为变革赋能
  15. 腾讯服务器微信分身,腾讯禁止微信双开应用,为什么还有那么多人“冒死”双开微信呢?...
  16. 上dnf一直连接服务器中,Win7系统下玩dnf提示正在连接服务器如何解决
  17. 大数据薪水大概多少_大数据就业岗位有哪些?薪资多少?
  18. 如何在PS中将一张图片一次性裁剪保存成多张图片
  19. 二十个经典管理学定律
  20. 国产高性能车载应用DCDC电源芯片SCT2432、SCT2432Q

热门文章

  1. 【前端初级项目】学成在线网站首页,HTML+CSS,附PSD设计稿!!
  2. Apollo学习笔记(4)坐标系
  3. Relu激活函数、sigmoid
  4. 编译原理学习笔记·语法分析(LL(1)分析法/算符优先分析法OPG)及例子详解
  5. 从“小霸王”到“Kinect”
  6. proxmox ve win7/windows7安装过程分享
  7. css设置宽高相等,高度自适应
  8. IE6浏览器的bug问题及相关解决的方法
  9. hive桌游中文规则_桌面游戏-中文规则-Burn Rate 烧钱
  10. OpenCV图像明度