正文共3565个字,预计阅读时间10分钟。

上海站 | 高性能计算之GPU CUDA培训

4月13-15日

三天密集式训练 带你快速晋级
阅读全文
>

训练一个神经网络的目的是啥?不就是有朝一日让它有用武之地吗?可是,在别处使用训练好的网络,得先把网络的参数(就是那些variables)保存下来,怎么保存呢?其实,tensorflow已经给我们提供了很方便的API,来帮助我们实现训练参数的存储与读取,如果想了解详情,请看晦涩难懂的官方API,接下来我简单介绍一下我的理解。

保存与读取数据全靠下面这个类实现:

class tf.train.Saver

当我们需要存储数据时,下面2条指令就够了

saver = tf.train.Saver() save_path = saver.save(sess, model_path)

解释一下,首先创建一个saver类,然后调用saver的save方法(函数),save需要传递两个参数,一个是你的训练session,另一个是文件存储路径,例如“/tmp/superNet.ckpt”,这个存储路径是可以包含文件名的。save方法会返回一个存储路径。当然,save方法还有别的参数可以传递,这里不再介绍。

然后怎么读取数据呢?看下面

saver = tf.train.Saver() load_path = saver.restore(sess, model_path)

和存储数据神似啊!不再赘述。

下面是重点!关于tf.train.Saver()使用的几点小心得!

1、save方法在实现数据读取时,它仅仅读数据,关键是得有一些提前声明好的variables来接受这些数据,因此,当save读取数据到sess时,需要提前声明与数据匹配的variables,否则程序就报错了。

2、save读取的数据不需要initialize。

3、目前想到的就这么多,随时补充。

为了对数据存储和读取有更直观的认识,我自己写了两个实验小程序,下面是第一个,训练网络并存储数据,用的MNIST数据集

import tensorflow as tf

import sys

# load MNIST data

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data', one_hot=True)

# 一些 hyper parameters

activation = tf.nn.relu batch_size = 100

iteration = 20000

hidden1_units = 30

# 注意!这里是存储路径!

model_path = sys.path[0] + '/simple_mnist.ckpt'

X = tf.placeholder(tf.float32, [None, 784]) y_ = tf.placeholder(tf.float32, [None, 10]) W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2)) b_fc1 = tf.Variable(tf.zeros([hidden1_units])) W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2)) b_fc2 = tf.Variable(tf.zeros([10]))

def inference(img):    fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))    logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)

return logits

def loss(logits, labels):    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels)    loss = tf.reduce_mean(cross_entropy)

return loss

def evaluation(logits, labels):

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

return accuracy logits = inference(X) loss = loss(logits, y_) train_op = tf.train.AdamOptimizer(1e-4).minimize(loss) accuracy = evaluation(logits, y_)

# 先实例化一个Saver()类saver = tf.train.Saver() init = tf.initialize_all_variables()

with tf.Session() as sess:    sess.run(init)    for i in xrange(iteration):        batch = mnist.train.next_batch(batch_size)

if i%1000 == 0 and i:            train_accuracy = sess.run(accuracy, feed_dict={X: batch[0], y_: batch[1]})

print "step %d, train accuracy %g" %(i, train_accuracy)

sess.run(train_op, feed_dict={X: batch[0], y_: batch[1]})

print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})

# 存储训练好的variables

save_path = saver.save(sess, model_path)

print "[+] Model saved in file: %s" % save_path

接下来是读取数据并做测试!

import tensorflow as tf

import sys

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data', one_hot=True)

activation = tf.nn.relu hidden1_units = 30

model_path = sys.path[0] + '/simple_mnist.ckpt'

X = tf.placeholder(tf.float32, [None, 784]) y_ = tf.placeholder(tf.float32, [None, 10]) W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2)) b_fc1 = tf.Variable(tf.zeros([hidden1_units])) W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2)) b_fc2 = tf.Variable(tf.zeros([10]))

def inference(img):    fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))    logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)

return logits

def evaluation(logits, labels):    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

return accuracy logits = inference(X) accuracy = evaluation(logits, y_) saver = tf.train.Saver()with tf.Session() as sess:

# 读取之前训练好的数据    load_path = saver.restore(sess, model_path)

print "[+] Model restored from %s" % load_path

print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.

原文链接:https://www.jianshu.com/p/83fa3aa2d0e9

查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:

www.leadai.org

请关注人工智能LeadAI公众号,查看更多专业文章

大家都在看


LSTM模型在问答系统中的应用

基于TensorFlow的神经网络解决用户流失概览问题

最全常见算法工程师面试题目整理(一)

最全常见算法工程师面试题目整理(二)

TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络

装饰器 | Python高级编程

今天不如来复习下Python基础

存储Tensorflow训练网络的参数相关推荐

  1. tensorflow 模型预训练后的参数restore finetuning

    之前训练的网络中有一部分可以用到一个新的网络中,但是不知道存储的参数如何部分恢复到新的网络中,也了解到有许多网络是通过利用一些现有的网络结构,通过finetuning进行改造实现的,因此了解了一下关于 ...

  2. 深度网络的训练经验总结(参数篇)

      (续我的上一篇博客)最早训练神经网络的时候完全什么技巧都不懂,能成功运行开源代码,并且看到loss下降就放心跑着了.随着对网络越来越多的接触,发现从数据集(train/val/test)的准备到训 ...

  3. 使用预训练网络训练的两种方式:Keras Applications、TensorFlow Hub

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) tensorflow 2.0 画出model网络模型的拓扑图 ...

  4. 基于FPGA的一维卷积神经网络CNN的实现(三)训练网络搭建及参数导出(附代码)

    训练网络搭建 环境:Pytorch,Pycham,Matlab. 说明:该网络反向传播是通过软件方式生成,FPGA内部不进行反向传播计算. 该节通过Python获取训练数据集,并通过Pytorch框架 ...

  5. caffe框架训练网络参数详解

    建立train_sh脚本文件 #!/sur/bin/env sh set -e /home/yourname/caffe/build/tools/caffe train --solver=/home/ ...

  6. 【深度学习】Weight Normalization: 一种简单的加速深度网络训练的重参数方法

    前言:为什么要Normalization 深度学习是一种在给定数据的情况下,学习求解目标函数最小化或者最大化的模型.在深度网络中,模型参数往往包含了大量的weights和biases.在求解优化模型的 ...

  7. CVPR 2023 点云系列 | Point-NN无需训练的非参数、即插即用网络

    CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 我们提出了一种用于 3D 点云分析的非参数网络 Point-NN,它由纯不可学习的组件组成:最远点采样 ...

  8. 使用PaddleFluid和TensorFlow训练RNN语言模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  9. 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

    将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...

最新文章

  1. 《结对-贪吃蛇-设计文档》
  2. 求字典key的和python_python怎么将字典key相同的value值, 合并
  3. Hive 元数据库表信息
  4. ZZULIOJ 1128: 课程平均分
  5. TypeScript入坑
  6. 中国条码解码器市场趋势报告、技术动态创新及市场预测
  7. 线性回归的简洁实现(pytorch框架)
  8. 白鸦:我印象中的Keso
  9. 1.MATLAB简要介绍
  10. html关机命令,shutdown关机命令不起作用
  11. 昨天面了一位,见识到了Spring的天花板~
  12. 抖音自媒体是如何赚钱的,怎么做才能挣到更多的钱?
  13. 2020软件构造实验三
  14. 将ipad作为Windows10系统的的扩展显示屏
  15. Android 应用程序开发
  16. 深克隆和浅克隆的区别
  17. vivo手机便签如何快速彻底一键换机使用?
  18. css3实现向一个方向无缝连接滚动
  19. 2010-2019考研英语二 阅读真题+答案
  20. ROS机器人实践---小乌龟画圆

热门文章

  1. 北航 计算机学院 2011级学生会,北航学生会主席在2011级新生开学典礼发言稿.doc...
  2. mysql数据库搜索引擎要先进入_Mysql搜索引擎都有哪些区别
  3. 各纬度气候分布图_印度和中国都是季风气候显著的国家,但冬夏季风的强弱却完全不同...
  4. 批量修改linux换行格式,linux中sed命令批量修改
  5. poi生成word不可以修改_操作不懂技术就可以做小程序无限生成平台的创业项目实操教程...
  6. java.nio教程_Java NIO系列教程(三) Buffer
  7. python增量爬虫_python增量爬虫pyspider
  8. 大数据:技术与应用实践指南_大数据技术与应用社团 社会实践总结篇
  9. tshark mysql_使用tshark抓包分析http请求
  10. 中采购订单批导的bapi_五:认识SAP SD销售模式之第三方销售和单独采购