前向传播过程mnist_inference.py

import tensorflow as tf# 定义神经网络相关的参数
INPUT_NODE = 784
OUTPUT_NODE = 10def inference(inputs, dropout_keep_prob):x_image = tf.reshape(inputs, [-1, 28, 28, 1])# 第一层:卷积层conv1_weights = tf.get_variable("conv1_weights", [5, 5, 1, 32], initializer=tf.truncated_normal_initializer(stddev=0.1))  # 过滤器大小为5*5, 当前层深度为1, 过滤器的深度为32conv1 = tf.nn.conv2d(x_image, filter=conv1_weights, strides=[1, 1, 1, 1], padding='SAME')  # 移动步长为1, 使用全0填充conv1_biases = tf.get_variable("conv1_biases", [32], initializer=tf.constant_initializer(0.0))relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))  # 激活函数Relu去线性化# 第二层:最大池化层# 池化层过滤器的大小为2*2, 移动步长为2,使用全0填充pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') #输出14*14*32# 第三层:卷积层conv2_weights = tf.get_variable("conv2_weights", [5, 5, 32, 64], initializer=tf.truncated_normal_initializer(stddev=0.1))  # 过滤器大小为5*5, 当前层深度为32, 过滤器的深度为64conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')  # 移动步长为1, 使用全0填充conv2_biases = tf.get_variable("conv2_biases", [64], initializer=tf.constant_initializer(0.0))relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))# 第四层:最大池化层# 池化层过滤器的大小为2*2, 移动步长为2,使用全0填充pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') #输出7*7*64# 第五层:全连接层pool2_vector = tf.reshape(pool2, [-1, 7 * 7 * 64])fc1_weights = tf.get_variable("fc1_weights", [7 * 7 * 64, 1024], initializer=tf.truncated_normal_initializer(stddev=0.1))  # 7*7*64=3136把前一层的输出变成特征向量fc1_baises = tf.get_variable("fc1_baises", [1024], initializer=tf.constant_initializer(0.1))fc1 = tf.nn.relu(tf.matmul(pool2_vector, fc1_weights) + fc1_baises)# 为了减少过拟合,加入Dropout层fc1_dropout = tf.nn.dropout(fc1, dropout_keep_prob)# 第六层:全连接层fc2_weights = tf.get_variable("fc2_weights", [1024, 10], initializer=tf.truncated_normal_initializer(stddev=0.1))  # 神经元节点数1024, 分类节点10fc2_biases = tf.get_variable("fc2_biases", [10], initializer=tf.constant_initializer(0.1))fc2 = tf.matmul(fc1_dropout, fc2_weights) + fc2_biasesreturn fc2

训练mnist_train.py

import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference#
BATCH_SIZE = 100
#学习率
LEARN_RATE = 0.001
MODEL_SAVE_PATH = "model/"
MODEL_NAME = "model.ckpt"
EPOCH = 2def train(mnist):inputs = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE])labels = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE])dropout_keep_prob = tf.placeholder(tf.float32)logits = mnist_inference.inference(inputs, dropout_keep_prob)global_step = tf.Variable(0, trainable=False)cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)#tf.nn.sparse_softmax_cross_entropy_with_logitscost = tf.reduce_mean(cross_entropy)train_op = tf.train.AdamOptimizer(LEARN_RATE).minimize(cost, global_step=global_step)saver = tf.train.Saver()with tf.Session() as sess:tf.global_variables_initializer().run()print(mnist.train.images.shape)for i in range(20000):batch_inputs, batch_labels = mnist.train.next_batch(BATCH_SIZE)_, cost_value, step = sess.run([train_op, cost, global_step], feed_dict={inputs: batch_inputs, labels: batch_labels, dropout_keep_prob:0.5})if i % 1000 == 0:print("After %d training step(s), loss on training batch is %f." % (step, cost_value))saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)def main(argv=None):mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)train(mnist)if __name__ == '__main__':tf.app.run()

评估mnis_eval.py

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_traindef evaluate(mnist):inputs = tf.placeholder(tf.float32, [None, 784])labels = tf.placeholder(tf.float32, [None, 10])dropout_keep_prob = tf.placeholder(tf.float32)logits = mnist_inference.inference(inputs, dropout_keep_prob)print(logits)correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))saver = tf.train.Saver()with tf.Session() as sess:ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]accuracy_score = sess.run(accuracy, feed_dict={inputs: mnist.test.images, labels: mnist.test.labels, dropout_keep_prob:1.0})print("After %s training step(s), validation accuracy = %f" % (global_step, accuracy_score))else:print("No checkpoint file found")returndef main(argv=None):mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)evaluate(mnist)if __name__ == '__main__':tf.app.run()

TensorFlow MNIST LeNet 模型持久化相关推荐

  1. python实现lenet_吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集...

    importtensorflow as tf tf.reset_default_graph()#配置神经网络的参数 INPUT_NODE = 784OUTPUT_NODE= 10IMAGE_SIZE= ...

  2. TensorFlow模型持久化

    模型持久化的目的在于可以使模型训练后的结果重复使用,节省重复训练模型的时间. 模型保存 train.Saver类是TensorFlow提供的用于保存和还原模型的API,使用非常简单. import t ...

  3. python与机器学习(七)上——PyTorch搭建LeNet模型进行MNIST分类

    任务要求:利用PyTorch框架搭建一个LeNet模型,并针对MNIST数据集进行训练和测试. 数据集:MNIST 导入: import torch from torch import nn, opt ...

  4. tensorflow模型持久化方法

    #测试模型持久化 v1 = tf.Variable(tf.constant(1.,shape=[2,2]),name='v1') v2 = tf.Variable(tf.constant(1.,sha ...

  5. 奔跑吧Caffe(在MNIST手写体数字集上用Caffe框架训练LeNet模型)

    数据集背景: MNIST 是一个大型的手写体数字数据库,广泛应用于机器学习领域的训练和测试,由纽约大学Yann LeCun教授整理. MNIST包括60000个训练集和10000测试集,图片固定尺寸为 ...

  6. tensorflow随笔——LeNet网络

    最近总想写点东西,把以前的这些网络都翻出来自己实现一遍.计划上从经典的分类网络到目标检测再到目标分割的都过一下.这篇从最简单的LeNet写起. 先上一张经典的LeNet模型结果图: 该网络结构包含2个 ...

  7. Tensorflow MNIST for Android

    本篇博客主要介绍如何使用 tensorflow 通过 CNN 实现 MNIST 手写数字识别问题,并将模型持久化在Android端运行. 整体介绍 主要需要通过以下几步: 模型生成过程:使用 tens ...

  8. Apache Spark 2.0预览: 机器学习模型持久化

    在即将发布的Apache Spark 2.0中将会提供机器学习模型持久化能力.机器学习模型持久化(机器学习模型的保存和加载)使得以下三类机器学习场景变得容易: \\ 数据科学家开发ML模型并移交给工程 ...

  9. 训练MNIST数据集模型

    1. 数据集准备 详细信息见: Caffe: LMDB 及其数据转换 mnist是一个手写数字库,由DL大牛Yan LeCun进行维护.mnist最初用于支票上的手写数字识别, 现在成了DL的入门练习 ...

最新文章

  1. AI学习与进阶实践-基于行业价值的AI学习与进阶路径
  2. python自动下载邮件_python实现邮件自动化
  3. 迪美特TVZ8双核智能高清播放器 在电视上编程不是梦
  4. 爱因斯坦耗费近十年的最伟大研究,推导出什么神预言?
  5. 文本以大写字母html,如何强制EditText以大写字母开始文本?
  6. python超级关系_不可阻挡的超级语言--python
  7. 使用 Lightbox 2 和 JavaScript 构建出色的图片库
  8. 知其所以然技术论坛VC++资源下载
  9. Excel数据透视表:查看数据的频率分布
  10. 华为手机自带浏览器的显示问题
  11. linux防火墙reject,linux 防火墙配置与REJECT导致没有生效问题(示例代码)
  12. BAT自动校对时间脚本,让WINDOWS系统自动校对时间
  13. @RunWith(SpringRunner.class)和@RunWith(SpringJUnit4ClassRunner.class)的区别
  14. 【论文笔记】CS会议论文书写注意点
  15. 准标准模式和标准模式之间的差别-1(旧文首发)
  16. appium报错:Original error: socket hang up
  17. 云渲染一张图贵吗?渲染问题详解
  18. Gate联合NFTBomb七大活动,NBP“holder”的福音
  19. 金蝶KIS记账王光盘版 双12五折特惠
  20. 一分钟让你了解什么是CYN

热门文章

  1. lvs dr模式安装
  2. HDU - 7091 重叠的子串(后缀自动机+set启发式合并+树上倍增)
  3. 中石油训练赛 - Watch Later(状压dp)
  4. POJ - 1459 Power Network(网络流-最大流)
  5. linux eclipse报错日志,centos6.8命令行启动eclipse报org.eclipse.swt.SWTError错误
  6. 基于TCP的Socket通讯
  7. 内核隐藏进程(源码)
  8. python操作redis用法详解
  9. 用Python实现冒泡排序
  10. C++学习路线和参考资料