全连接网络实现手写数字识别

程序分为三个部分,分别是

mnist_forward.py:前向传播。

mnist_backward.py:反向传播。

mnist_test.py:模型测试。

前向传播

这里搭建了全连接网络。我使用了一个三层的网络。输入的是784个神经元(mnist中一张图片的大小);隐藏层的神经元的个数分别是500,200;输出层是10个神经元,是预测的结果。

下面是mnist_forward.py文件

import tensorflow as tf
import osos.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # 忽略tensorflow警告信息def get_weight(shape,regularizer):w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))if regularizer!=None:tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))return wdef get_bias(shape):b = tf.Variable(tf.zeros(shape))return bdef forward(x,regularizer):'''这里定义一个3层的全连接网络第0层:输入层,784个神经元第1层:隐藏层,500个神经元第2层:隐藏层,200个神经元第3层:输出层,10个神经元'''w1 = get_weight([784,500],regularizer)b1 = get_bias([500])y1 = tf.nn.leaky_relu(tf.matmul(x,w1) + b1)w2 = get_weight([500,200],regularizer)b2 = get_bias([200])y2 = tf.nn.leaky_relu(tf.matmul(y1,w2) + b2)w3 = get_weight([200,10],regularizer)b3 = get_bias([10])y = tf.matmul(y2,w3) + b3return y

反向传播

1)参数使用了滑动平均

2)使用指数衰减学习率

3)在训练过程中,进行了模型的保存

下面是mnist_backward.py

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # 忽略tensorflow警告信息BATCH_SIZE = 200
LEARNING_TARE_BASE = 0.1
LEARNING_TARE_DECAY = 0.99
STEPS = 50000
regularizer = 0.0001
moving_average_decay = 0.99
model_saver_path = './model/' #模型保存的路径
model_name = 'mnist_model' #模型的名字def backward(mnist):x = tf.placeholder(tf.float32, [None, 784]) # 输入y_ = tf.placeholder(tf.float32, [None, 10]) # 标签y = mnist_forward.forward(x, regularizer) # 前向传播的输出global_step = tf.Variable(0, trainable=False) # 计数器#加入滑动平均ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))cem = tf.reduce_mean(ce)loss = cem + tf.add_n(tf.get_collection('losses'))#指数衰减学习率learning_rate = tf.train.exponential_decay(LEARNING_TARE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE,LEARNING_TARE_DECAY,staircase=True)#定义训练过程train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)#如果有滑动平均ema = tf.train.ExponentialMovingAverage(moving_average_decay,global_step)ema_op = ema.apply(tf.trainable_variables())with tf.control_dependencies([train_step,ema_op]):train_op = tf.no_op('train')#实例化saversaver = tf.train.Saver()#建立会话with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)for i in range(STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE) #加载BATCH_SIZE个mnist中的图片和标签_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})if i%1000 ==0:print('After %d training step, loss on training batch is %g' %(step,loss_value))# 保存模型到当前会话saver.save(sess,os.path.join(model_saver_path,model_name),global_step=global_step)def main():mnist =input_data.read_data_sets('./data/',one_hot=True)backward(mnist)if __name__ == '__main__':main()

模型测试

因为在训练时使用滑动平均,所以在测试时,需要恢复参数的滑动平均值。

 with tf.Graph().as_default() as g:  # 其内定义的节点在计算图g中

用这种方法,将神经网络复现到计算图中。

下面是:mnist_test.py

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backwardos.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # 忽略tensorflow警告信息def test(mnist):with tf.Graph().as_default() as g:x = tf.placeholder(tf.float32,[None,784])y_ = tf.placeholder(tf.float32,[None,10])y = mnist_forward.forward(x,None) # 前向传播获得的输出值#实例化带滑动平均的saver对象ema = tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)ema_restore = ema.variables_to_restore()saver = tf.train.Saver(ema_restore)# 准确率计算correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) # 将输出结果和标签答案进行比较accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))with tf.Session( ) as sess:ckpt = tf.train.get_checkpoint_state(mnist_backward.model_saver_path)#判断,如果有模型,恢复模型到当前会话if ckpt and ckpt.model_checkpoint_path:saver.restore(sess,ckpt.model_checkpoint_path) # 恢复模型global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] # 获取当前轮数# 喂入的是数据集中的,测试用的图片和标签accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('After %s training step, test accuracy = %g' %(global_step,accuracy_score))else:print('No checkpoint file found')def main():mnist = input_data.read_data_sets('./data/',one_hot=True)test(mnist)if __name__ == '__main__':main()

结果

训练结果

After 1 training step, loss on training batch is 2.7423
After 1001 training step, loss on training batch is 0.377551
After 2001 training step, loss on training batch is 0.243816

......

fter 36001 training step, loss on training batch is 0.144181
After 37001 training step, loss on training batch is 0.143775
After 38001 training step, loss on training batch is 0.14224
After 39001 training step, loss on training batch is 0.144268
After 40001 training step, loss on training batch is 0.142971
After 41001 training step, loss on training batch is 0.14316
After 42001 training step, loss on training batch is 0.142349
After 43001 training step, loss on training batch is 0.142429
After 44001 training step, loss on training batch is 0.140229
After 45001 training step, loss on training batch is 0.140057
After 46001 training step, loss on training batch is 0.138865
After 47001 training step, loss on training batch is 0.137823
After 48001 training step, loss on training batch is 0.138767
After 49001 training step, loss on training batch is 0.139221

测试结果

After 49001 training step, test accuracy = 0.98

如何实现断点续训?

实现断点续训,这样就可以在出现意外的情况下保存训练好的模型,下次训练,在此基础上进行。

上面的案例,若想要实现断点续训的功能,只需要在“反向传播”文件中,添加恢复模型的操作即可。

......
#建立会话with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)#实现断点续训,只需要加入下面三句话ckpt = tf.train.get_checkpoint_state(model_saver_path)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)  # 恢复模型for i in range(STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE) #加载BATCH_SIZE个mnist中的图片和标签_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})if i%1000 ==0:
......

全连接网络:实现第一个全连接网络相关推荐

  1. 第二课.图卷积神经网络

    目录 卷积神经网络 图卷积神经网络 GNN数据集 图的表示 GCN GNN的基准化:Benchmarking Graph Neural Networks 卷积神经网络 在计算机视觉中,卷积网络是一种高 ...

  2. 第七课.简单的图像分类(一)

    第七课目录 图像分类基础 卷积神经网络 Pooling layer BatchNormalization BatchNormalization与归一化 torch.nn.BatchNorm2d MNI ...

  3. Paper:《Spatial Transformer Networks》的翻译与解读

    Paper:<Spatial Transformer Networks>的翻译与解读 目录 <Spatial Transformer Networks>的翻译与解读 Abstr ...

  4. CNN基本步骤以及经典卷积(LeNet、AlexNet、VGGNet、InceptionNet 和 ResNet)网络讲解以及tensorflow代码实现

    课程来源:人工智能实践:Tensorflow笔记2 文章目录 前言 1.卷积神经网络的基本步骤 1.卷积神经网络计算convolution 2.感受野以及卷积核的选取 3.全零填充Padding 4. ...

  5. 【Pytorch神经网络理论篇】 15 过拟合问题的优化技巧(二):Dropout()方法

    1 Dropout方法 2.1 Dropout原理 在训练过程中,每次随机选择一部分节点不去进行学习. 2.1.1 从Dropout原理来看过拟合的原因 任何一个模型不能完全把数据分开,在某一类中一定 ...

  6. 【Pytorch神经网络实战案例】08 识别黑白图中的服装图案(Fashion-MNIST)

    1 Fashion-MNIST简介 FashionMNIST 是一个替代 MNIST 手写数字集 的图像数据集. 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供.其涵盖了来自 10 ...

  7. PyTorch框架学习十七——Batch Normalization

    PyTorch框架学习十七--Batch Normalization 一.BN的概念 二.Internal Covariate Shift(ICS) 三.BN的一个应用案例 四.PyTorch中BN的 ...

  8. 深度学习知识点全面总结

    神经网络与深度学习结构(图片选自<神经网络与深度学习>一邱锡鹏) 目录 常见的分类算法 一.深度学习概念 1.深度学习定义 2.深度学习应用 3.深度学习主要术语 二.神经网络基础 1. ...

  9. 人工智能实践:tensorflow笔记

    tensorflow2.1安装教程,遇到的问题及解决办法 一.神经网络计算过程及模型搭建 (一)人工智能三学派: ​ 我们常说的人工智能,就是让机器具备人的思维和意识.人工智能主要有三个学派,即行为主 ...

最新文章

  1. 图灵奖得主杨立昆:人工智能比你更聪明吗?
  2. opencv 阈值分割 — threshold()
  3. Windows Phone 7 开发 31 日谈——第6日:工具栏
  4. 关于TCP和MQTT之间的转换
  5. NLP-基础知识-001
  6. php输出的数组如何存入表单,jquery:如何在jquery中将数组附加到表单请求并将其发送到php...
  7. 云效故障定位研究论文被ICSE 2021 SEIP track收录
  8. C++中流状态badbit, failbit, eofbit
  9. EntityFramework6.X 之 Fulent
  10. 使用Speedment 3.0.17及更高版本简化了事务
  11. java集合合并_【Java必修课】各种集合类的合并(数组、List、Set、Map)
  12. Emacs一个键绑定多个命令
  13. OpenCV(VS2019)——无法打开“opencv2/opencv.hpp”文件
  14. ubuntu Nvidia 显卡驱动失效问题
  15. php结合phantomjs实现网页截屏、抓取js渲染的页面
  16. 同时删除多个 Word 文档空白行
  17. 论文阅读笔记——野外和非侵入性遗传方法评估棕熊种群规模
  18. 快让你的App分20亿吧!
  19. 华硕无双+2022款笔记本重装系统笔记
  20. 你看到的就是真实的吗?

热门文章

  1. 如何开启红米手机4X的ROOT超级权限
  2. win10配置Sublime Text 3作为latex的编辑器
  3. 邮箱安全再成热点 金笛企业邮件系统保障企业用户通信安全
  4. 服务器保存时提示文档未保存,Word文档保存时常遇到的问题及其解决方法
  5. 彻底卸载Tomcat
  6. (最新整理)国内网页设计网站网址大全(转)
  7. 实现短信验证码有效时间
  8. CTF Crypto简单题学习思路总结(持续更新)
  9. 仿Android端饿了么外卖的效果
  10. 程序设计第二十二题 空心三角形