第一篇学习了CNN网络的构建以及代码的基础结构,第二篇则是实际项目过程中需要的网络模型的存储

先放上存储的代码:

#tf可以认为是全局变量,从该变量为类,从中取input_data变量
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
import sys
#读取数据集
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
"""
#softmax方法进行训练
#这里是变量的占位符,一般是输入输出使用该部分
x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder("float",[None,10])#定义参数变量
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
y=tf.nn.softmax(tf.matmul(x,W)+b)#评价函数
cross_entropy=-tf.reduce_sum(y_*tf.log(y))
train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#启动模型,Session建立这样一个对象,然后指定某种操作,并实际进行该步
init=tf.initialize_all_variables()
sess=tf.Session()
sess.run(init)#数据读取部分
for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(50)#run第一个参数是fetch,可以是tensor也可以是Operation,第二个feed_dict是替换tensor的值sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})print(batch_xs,batch_ys,i)correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))"""#这里用CNN方法进行训练
#函数定义部分
def weight_variable(shape):initial=tf.truncated_normal(shape,stddev=0.1)#随机权重赋值,不过truncated_normal代表如果是2倍标准差之外的结果重新选取该值return tf.Variable(initial)def bias_variable(shape):initial=tf.constant(0.1,shape=shape)#偏置项return tf.Variable(initial)def conv2d(x,W):return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')#SAME表示输出补边,这里输出与输入尺寸一致def max_pool_2x2(x):return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')#ksize代表池化范围的大小,stride为扫描步长# 这里是变量的占位符,一般是输入输出使用该部分
x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder("float",[None,10])
x_image=tf.reshape(x,[-1,28,28,1])#-1表示自动计算该维度
#建立第一层
W_conv1=weight_variable([5,5,1,32])
b_conv1=bias_variable([32])
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1)
#第二层
W_conv2=weight_variable([5,5,32,64])
b_conv2=bias_variable([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2)#第三层,而且这里是全连接层
W_fc1=weight_variable([7*7*64,1024])
b_fc1=bias_variable([1024])h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
#dropout,注意这里也是有一个输入参数的,和x以及y一样
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)W_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)# 评价函数
cross_entropy=-tf.reduce_sum(y_*tf.log(y_conv))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# 启动模型,Session建立这样一个对象,然后指定某种操作,并实际进行该步
init=tf.initialize_all_variables()
sess=tf.Session()
sess.run(init)#数据读取部分
for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(50)#这里貌似是代表读取50张图像数据#run第一个参数是fetch,可以是tensor也可以是Operation,第二个feed_dict是替换tensor的值'''if i % 10 == 0:train_accuracy = accuracy.eval(feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})print("step:%d,accuracy:%g" % (i, train_accuracy))'''sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})#sess.run第一个参数是想要运行的位置,一般有train,accuracy,initdeng#第二个参数feed_dict,一般是输入参数,该代码里有x,y以及drop的参数if i%20==0 :print(i)print("train accuracy:%g"%sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5}))
print("test accuracy:%g"%sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1}))#保存模型
model_path='MNIST_model/simple_mnist.ckpt'
saver=tf.train.Saver()
saver_path=saver.save(sess,model_path)
print("model saved in file:", saver_path)

保存的代码实际上只有后半部分,前面的代码是第一篇中讲到的。第一篇链接:https://blog.csdn.net/qq_26499769/article/details/82896046

运行代码的结果如下:

读取的代码如下:

import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
import sys
#读取数据集
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)#前半部分主要以网络的构建为主
#这里用CNN方法进行训练
#函数定义部分
def weight_variable(shape):initial=tf.truncated_normal(shape,stddev=0.1)#随机权重赋值,不过truncated_normal代表如果是2倍标准差之外的结果重新选取该值return tf.Variable(initial)def bias_variable(shape):initial=tf.constant(0.1,shape=shape)#偏置项return tf.Variable(initial)def conv2d(x,W):return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')#SAME表示输出补边,这里输出与输入尺寸一致def max_pool_2x2(x):return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')#ksize代表池化范围的大小,stride为扫描步长# 这里是变量的占位符,一般是输入输出使用该部分
x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder("float",[None,10])
x_image=tf.reshape(x,[-1,28,28,1])#-1表示自动计算该维度
#建立第一层
W_conv1=weight_variable([5,5,1,32])
b_conv1=bias_variable([32])
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1)
#第二层
W_conv2=weight_variable([5,5,32,64])
b_conv2=bias_variable([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2)#第三层,而且这里是全连接层
W_fc1=weight_variable([7*7*64,1024])
b_fc1=bias_variable([1024])h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
#dropout,注意这里也是有一个输入参数的,和x以及y一样
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)W_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)# 评价函数
cross_entropy=-tf.reduce_sum(y_*tf.log(y_conv))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# 启动模型,Session建立这样一个对象,然后指定某种操作,并实际进行该步sess=tf.Session()#后半部分,进行参数的下载
#模型下载,(新人,可能理解错误,网络还是需要先定义好,然后进行参数的下载,对于自己的网络这样的方法没有问题,但是他人的网络在不知道具体的网络时,没办法通过下载去复现网络模型)
saver=tf.train.Saver()
saver.restore(sess,"MNIST_model/simple_mnist.ckpt")
print("test accuracy:%g"%sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1}))

读取代码也只有最后的一部分,运行结果如下:

这是之前CNN的训练结果,可以看到完整的下载了参数。

有一个说明的很好的教程:https://www.bilibili.com/video/av16001891/?p=29

tensorflow学习(2.网络模型的存储以及提取)相关推荐

  1. tensorflow 1.0 学习:参数和特征的提取

    tensorflow 1.0 学习:参数和特征的提取 在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如: #取出所有参与训练的参数 params=tf.tra ...

  2. tensorflow学习入门笔记

    <div class="note"><div class="post"><div class="article" ...

  3. Win10:tensorflow学习笔记(4)

    前言 学以致用,以学促用.输出检验,完整闭环. 经过前段时间的努力,已经在电脑上搭好了深度学习系统,接下来就要开始跑程序了,将AI落地了. 安装win10下tensforlow 可以参照之前的例子:w ...

  4. 深度学习框架tensorflow学习与应用——代码笔记11(未完成)

    11-1 第十周作业-验证码识别(未完成) #!/usr/bin/env python # coding: utf-8# In[1]:import os import tensorflow as tf ...

  5. 【tensorflow学习】Ftrl学习

    [tensorflow学习]处理MNISTS数据集 理论 应用 To Do 理论 理论知识 交叉熵理解 应用 #encoding=utf8 import tensorflow as tf import ...

  6. 炼数成金Tensorflow学习笔记之2.2_变量

    炼数成金Tensorflow学习笔记之2.2_变量 代码及分析 代码及分析 import tensorflow as tfx = tf.Variable([1, 2]) a = tf.constant ...

  7. TensorFlow学习笔记02:使用tf.data读取和保存数据文件

    TensorFlow学习笔记02:使用tf.data读取和保存数据文件 使用`tf.data`读取和写入数据文件 读取和写入csv文件 写入csv文件 读取csv文件 读取和保存TFRecord文件 ...

  8. TensorFlow学习之LSTM ---机器翻译(Seq2Seq + Attention模型)

    一.理论知识 Seq2Seq模型的基本思想:使用一个循环神经网络读取输入句子,将这个句子的信息压缩到一个固定维度的编码中:再使用另一个循环神经网络读取这个编码,将其"解压"为目标语 ...

  9. Tensorflow学习五---全连接

    Tensorflow学习五-全连接 什么是全连接 今天开始我们以神经层的形式开始介绍,首先介绍全连接层. 那么什么是全连接? 全连接层就是对n-1层和n层而言,n-1层的任意一个节点,都和第n层所有节 ...

最新文章

  1. vue 按需加载,换存,导航守卫
  2. SAP的标准对话框函数
  3. linux内核启动时间优化
  4. QT的QSGGeometry类的使用
  5. Mysql5.7后的password加密和md5
  6. node 测试生成模拟用户数据
  7. 对话框应用程序的DoModal()源代码
  8. cmake cache变量_反复研究好几遍,我才发现关于 CMake 变量还可以这样理解!
  9. Log-Polar——关于对数极坐标
  10. php-rabbitmq结合rabbitmq_delayed_message_exchange实现延时队列
  11. 数据库课程设计:图书借阅管理系统(控制台交互)
  12. Dreamweaver cs6 网页设计教程笔记
  13. 众里寻他千百度,蓦然回首,那人却在灯火阑珊处
  14. oracle redo查询,ORACLE UNDO REDO查看
  15. win7 计算机 只有硬盘分区,电脑只有一个C盘怎么办?一招教你正确分区!-win7磁盘分区...
  16. JVM类加载、验证、准备、解析、初始化、卸载过程详解
  17. 佳能canon e510打印机驱动 1.0 官E510 series XPS 打印机驱动程序 v. 5.62 (Windows)
  18. 转载:揭秘内容付费的三种商业模式(原作者:小马宋)
  19. Servlet本身的init,service,destory生命周期方法
  20. 打听nofollow标签能力做好网站seo优化

热门文章

  1. 线上 4 台机器同一时间全部 OOM,到底发生了什么?
  2. SpringBoot 项目模板:摆脱步步搭建
  3. 你值得拥有!一个基于 Spring Boot 的API、RESTful API 的项目
  4. 十个经典Java 集合面试题!
  5. org.activiti.engine.ActivitiException: Couldn‘t deserialize object in variable ‘application‘
  6. 04-JDBC学习手册:JDBC中使用transaction(事务)编程和Javabean定义
  7. (Spring)静态/动态代理模式(AOP底层)
  8. (Mybatis)增删改查实现
  9. (JavaWeb)ServletContext对象
  10. Eclipse中spring boot的安装和创建简单的Web应用