一、复用流程

如果原始模型使用 TensorFlow 进行训练,则可以简单地将其恢复并在新任务上进行训练:

第一步:

[...] # construct the original model

第二步:

with tf.Session() as sess:
            saver.restore(sess, "./my_model_final.ckpt")

# continue training the model...

二、完整代码

n_inputs = 28 * 28  # MNIST
n_hidden1 = 300
n_hidden2 = 50
n_hidden3 = 50
n_hidden4 = 50
n_outputs = 10

X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")

with tf.name_scope("dnn"):
    hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1")
    hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2")
    hidden3 = tf.layers.dense(hidden2, n_hidden3, activation=tf.nn.relu, name="hidden3")
    hidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4")
    hidden5 = tf.layers.dense(hidden4, n_hidden5, activation=tf.nn.relu, name="hidden5")
    logits = tf.layers.dense(hidden5, n_outputs, name="outputs")

with tf.name_scope("loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss = tf.reduce_mean(xentropy, name="loss")

with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

learning_rate = 0.01
threshold = 1.0

optimizer = tf.train.GradientDescentOptimizer(learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
capped_gvs = [(tf.clip_by_value(grad, -threshold, threshold), var)
              for grad, var in grads_and_vars]
training_op = optimizer.apply_gradients(capped_gvs)

init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, "./my_model_final.ckpt")

for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        accuracy_val = accuracy.eval(feed_dict={X: mnist.test.images,
                                                y: mnist.test.labels})

print(epoch, "Test accuracy:", accuracy_val)

save_path = saver.save(sess, "./my_new_model_final.ckpt")

但是,一般情况下,您只需要重新使用原始模型的一部分(就像我们将要讨论的那样)。 一个简单的解决方案是将Saver配置为仅恢复原始模型中的一部分变量。 例如,下面的代码只恢复隐藏的层1,2和3:

n_inputs = 28 * 28  # MNIST
n_hidden1 = 300 # reused
n_hidden2 = 50  # reused
n_hidden3 = 50  # reused
n_hidden4 = 20  # new!
n_outputs = 10  # new!

X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")

with tf.name_scope("dnn"):
    hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1")       # reused
    hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2") # reused
    hidden3 = tf.layers.dense(hidden2, n_hidden3, activation=tf.nn.relu, name="hidden3") # reused
    hidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4") # new!
    logits = tf.layers.dense(hidden4, n_outputs, name="outputs")                         # new!

with tf.name_scope("loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss = tf.reduce_mean(xentropy, name="loss")

with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

with tf.name_scope("train"):
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    training_op = optimizer.minimize(loss)
[...] # build new model with the same definition as before for hidden layers 1-3 
reuse_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                               scope="hidden[123]") # regular expression

reuse_vars_dict = dict([(var.op.name, var) for var in reuse_vars])
restore_saver = tf.train.Saver(reuse_vars_dict) # to restore layers 1-3

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    init.run()
    restore_saver.restore(sess, "./my_model_final.ckpt")

for epoch in range(n_epochs):                                      # not shown in the book
        for iteration in range(mnist.train.num_examples // batch_size): # not shown
            X_batch, y_batch = mnist.train.next_batch(batch_size)      # not shown
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})  # not shown
        accuracy_val = accuracy.eval(feed_dict={X: mnist.test.images,  # not shown
                                                y: mnist.test.labels}) # not shown
        print(epoch, "Test accuracy:", accuracy_val)                   # not shown

save_path = saver.save(sess, "./my_new_model_final.ckpt")

首先我们建立新的模型,确保复制原始模型的隐藏层 1 到 3。我们还创建一个节点来初始化所有变量。 然后我们得到刚刚用trainable = True(这是默认值)创建的所有变量的列表,我们只保留那些范围与正则表达式hidden [123]相匹配的变量(即,我们得到所有可训练的隐藏层 1 到 3 中的变量)。 接下来,我们创建一个字典,将原始模型中每个变量的名称映射到新模型中的名称(通常需要保持完全相同的名称)。 然后,我们创建一个Saver,它将只恢复这些变量,并且创建另一个Saver来保存整个新模型,而不仅仅是第 1 层到第 3 层。然后,我们开始一个会话并初始化模型中的所有变量,然后从原始模型的层 1 到 3中恢复变量值。最后,我们在新任务上训练模型并保存。

任务越相似,您可以重复使用的层越多(从较低层开始)。 对于非常相似的任务,您可以尝试保留所有隐藏的层,只替换输出层

复用 TensorFlow 模型相关推荐

  1. TensorFlow模型的签名推荐与快速上线\n

    简介 往期文章 我们给你推荐一种TensorFlow模型格式 介绍过, TensorFlow官方推荐SavedModel格式作为在线服务的模型文件格式.近期TensorFlow SavedModel模 ...

  2. 干货 | tensorflow模型导出与OpenCV DNN中使用

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自|OpenCV学堂 OpenCV DNN模块 Deep N ...

  3. keras添加正则化全连接_收藏!改善TensorFlow模型的4种方法你需要了解的关键正则化技术(2)...

    上一篇文章和同学们分享了两种方法,今天我们继续分享另外两种方法. Batch Normalization 批处理规范化背后的主要思想是,在我们的案例中,我们通过使用几种技术(sklearn.prepr ...

  4. 简单完整地讲解tensorflow模型的保存和恢复

    http://blog.csdn.net/liangyihuai/article/details/78515913 在本教程主要讲到: 1. 什么是Tensorflow模型? 2. 如何保存Tenso ...

  5. 打成jar包_keras, tensorflow模型部署通过jar包部署到spark环境攻略

    这是个我想干很久的事情了.之前研究tensorflow on spark, DL4j 都没有成功.所以这里首先讲一下我做这件事情的流程.模型的部署,首先你得有一个模型.这里假设你有了一个keras模型 ...

  6. Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载

    我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存 ...

  7. 移动端目标识别(1)——使用TensorFlow Lite将tensorflow模型部署到移动端(ssd)之TensorFlow Lite简介...

    平时工作就是做深度学习,但是深度学习没有落地就是比较虚,目前在移动端或嵌入式端应用的比较实际,也了解到目前主要有 caffe2,腾讯ncnn,tensorflow,因为工作用tensorflow比较多 ...

  8. 手把手教你使用TF服务将TensorFlow模型部署到生产环境

    2019独角兽企业重金招聘Python工程师标准>>> 介绍 将机器学习(ML)模型应用于生产环境已成为一个火热的的话题,许多框架提供了旨在解决此问题的不同解决方案.为解决这一问题, ...

  9. 实操将TensorFlow模型部署成Docker服务化

    背景 深度学习模型如何服务化是一个机器学习领域工程方面的热点,现在业内一个比较主流的做法是将模型和模型的服务环境做成docker image.这样做的一个好处是屏蔽了模型对环境的依赖,因为深度学习模型 ...

最新文章

  1. 红帽RHEL6.8离线环境下升级到RHEL7.3
  2. golang mysql 超时_golang中mysql建立连接超时时间timeout 测试
  3. [Unity C#教程] 游戏对象和脚本
  4. Android向本地写入一个XML文件和解析XML文件
  5. 渗透测试中dns log的使用
  6. linux那些事之gup_flags
  7. 零基础直接学Python入门IT合适吗?
  8. 遇到代码缺陷不要慌,马上教你快速检测和修复
  9. SDUT 3399 数据结构实验之排序二:交换排序
  10. .NET组件和COM组件之间的相互操作方法
  11. 单链表逆置-java(递归与非递归)
  12. 各省简称 拼音 缩写_近50个拼音/英文缩写合集 (一)
  13. 自己做量化交易软件(26)小白量化事件回测之MetaTrader5自动回测
  14. 实验五 CA的安装和使用
  15. 【C语言】请将1至7中的任意一个数字转化成对应的英文星期几的前三个字母,如1转化为Mon,7转化为Sun等。 个人解答
  16. 统计学中cv表示什么_cv是什么意思
  17. C#范例开发大全.刘丽霞李俊民(奋斗的小鸟)_PDF 电子书
  18. Linux(2)---Crtl+z与Crtl+c
  19. 如何恢复vscode的默认配置_史上最全vscode配置使用教程
  20. 鼠标移动到图片上实现图片的放大缩小

热门文章

  1. python中os.path.isdir()等函数的作用和用法
  2. 服务器异常下电文件系统,SUN服务器Solaris异常情况下恢复操作步骤(8页)-原创力文档...
  3. linux ubuntu pkg-config工具的使用(源代码编译库接口查询工具)
  4. cmake教程(为什么要用cmake?)(cmake编译opencv)(就是个跨平台的编译工具Linux、windows)(很重要,必须得学)(报错解决方案)opencv编译
  5. 图像处理中的“内插”是什么?插值、图像内插值、图像间插值、重取样(用已知数据来估计未知位置的数值的处理)(最近邻内插法、双线性内插)
  6. python PyQt5 QComboBox类(下拉列表框、组合下拉框)
  7. python hashlib模块(提供常见摘要算法)
  8. 图像的亮度和对比度区别
  9. Python 科学计算库 Numpy(一)—— 概述
  10. 计算机应用基础实训任务书,《计算机应用基础》任务书