TensorFlow 笔记6–迁移学习


参考文档:https://github.com/ageron/handson-ml/blob/master/11_deep_learning.ipynb


一、冻结部分层权重

法一:

with tf.name_scope("train"):                                        optimizer = tf.train.GradientDescentOptimizer(learning_rate)# 指定要训练的那部分层train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="hidden[34]|outputs")training_op = optimizer.minimize(loss, var_list=train_vars)# 恢复冻结层的数据,其实也可以全部恢复
reuse_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="hidden[12]")
restore_saver = tf.train.Saver(reuse_vars) with tf.Session() as sess:restore_saver.restore(sess, "./my_model_final.ckpt")

法二:

with tf.name_scope("dnn"):hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1") # reused frozenhidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2") # reused frozen# 在此之前的层不会进行梯度更新hidden2_stop = tf.stop_gradient(hidden2)# 注意以下的层要相应的修改为hidden2_stophidden3 = tf.layers.dense(hidden2_stop, n_hidden3, activation=tf.nn.relu, name="hidden3") # reused, not frozenhidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4") # new!logits = tf.layers.dense(hidden4, n_outputs, name="outputs") # new!# 剩下的和正常的一样

二、缓存冻结层结果

# 先设置冻结层,再进行以下操作with tf.Session() as sess:init.run()restore_saver.restore(sess, "./my_model_final.ckpt")# 缓存冻结层的结果,即训练期间只计算一次h2_cache = sess.run(hidden2, feed_dict={X: X_train})h2_cache_valid = sess.run(hidden2, feed_dict={X: X_valid}) for epoch in range(n_epochs):shuffled_idx = np.random.permutation(len(X_train))# feed的数据应该相应的改为冻结层的结果hidden2_batches = np.array_split(h2_cache[shuffled_idx], n_batches)y_batches = np.array_split(y_train[shuffled_idx], n_batches)for hidden2_batch, y_batch in zip(hidden2_batches, y_batches):sess.run(training_op, feed_dict={hidden2:hidden2_batch, y:y_batch})accuracy_val = accuracy.eval(feed_dict={hidden2: h2_cache_valid,  y: y_valid})      print(epoch, "Validation accuracy:", accuracy_val)               save_path = saver.save(sess, "./my_new_model_final.ckpt")

TensorFlow 笔记6--迁移学习相关推荐

  1. 吴恩达深度学习笔记(67)-迁移学习(Transfer learning)

    https://www.toutiao.com/a6644868806923518471/ 2019-01-11 07:36:41 迁移学习(Transfer learning) 深度学习中,最强大的 ...

  2. Tensorflow 2.1 迁移学习 基于VGG

    1. 什么是迁移学习 迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相关任务 ...

  3. 【学习笔记】迁移学习分类

    什么是迁移学习 通俗来讲,就是运用已有的知识来学习新的知识,核心是找到已有知识和新知识之间的相似性,用成语来说就是举一反三.由于直接对目标域从头开始学习成本太高,我们故而转向运用已有的相关知识来辅助尽 ...

  4. Tensorflow官网——迁移学习和微调部分解读

    import matplotlib.pyplot as plt import numpy as np import os import tensorflow as tf#数据预处理 #数据下载 fro ...

  5. 绒毛动物探测器:通过TensorFlow.js中的迁移学习识别浏览器中的自定义对象

    目录 起点 MobileNet v1体系结构上的迁移学习 修改模型 训练新模式 运行物体识别 终点线 下一步是什么?我们可以检测到脸部吗? 下载TensorFlowJS-Examples-master ...

  6. 2019年上半年收集到的人工智能迁移学习干货文章

    2019年上半年收集到的人工智能迁移学习干货文章 迁移学习全面指南:概念.项目实战.优势.挑战 迁移学习:该做的和不该做的事 深度学习不得不会的迁移学习Transfer Learning 谷歌最新的P ...

  7. 宠物狗图片分类之迁移学习代码笔记

    五月两场 | NVIDIA DLI 深度学习入门课程 5月19日/5月26日一天密集式学习  快速带你入门阅读全文> 正文共3152个字,预计阅读时间8分钟. 本文主要是总结之前零零散散抽出时间 ...

  8. 迁移学习笔记3: TCA, Finetune, 与Triplet Network(元学习)

    主要想讲的内容有: TCA, Finetune, Triplet Network 迁移学习与元学习有哪几类方法 想讲的目标(但不一定完全能写完, 下一次笔记补充): 分别属于什么方法, 处于什么位置, ...

  9. (转)Tensorflow 实战Google深度学习框架 读书笔记

    本文大致脉络: 读书笔记的自我说明 对读书笔记的摘要 具体章节的摘要: 第一章 深度学习简介 第二章 TensorFlow环境搭建 第三章 TensorFlow入门 第四章 深层神经网络 第五章 MN ...

最新文章

  1. 谷歌AutoML鼻祖Quoc Le新作AutoML-Zero:从零开始构建机器学习算法
  2. python写出的程序如何给别人使用-利用这10个工具,你可以写出更好的Python代码...
  3. (How to)Windows Live Writer插入Latex公式
  4. Python补充01 序列的方法
  5. python利用win32com读取doc和pdf内容,并保存到文件
  6. 大话数据结构:拓扑排序
  7. 如何使用CNN进行物体识别和分类_RCNN物体识别
  8. linux鼠标键盘被禁用了,debian squeeze下鼠标、键盘突然被系统禁用
  9. 留住用户的APP弹窗设计素材模板
  10. java 下载视频文件
  11. oracle11g数据库导入导出方法教程
  12. 解决资源监视器不显示的问题。
  13. 左神讲算法——超级水王问题(详解)
  14. 路由器显示dns服务器异常怎么办,手机显示DNS异常解决方法(图文)
  15. ajax请求后状态码200却无法进入success解决方案
  16. 物联网NB-IoT技术商用正全面铺开 竞争日趋激烈
  17. 蛋白质二级结构预测工具psipred安装使用
  18. 教大家pr如何新建工程文件
  19. VO,PO,BO,QO, DAO ,POJO,的概念
  20. 以太坊中metamask、imtoken等钱包签名的php验证

热门文章

  1. [SDOI2010]外星千足虫 题解 高斯消元+bitset简介
  2. 【数据算法】Java实现二叉树存储以及遍历
  3. [转]MyBatis的foreach语句详解
  4. 29.怎样扩展现有类功能?
  5. How to uninstall git
  6. Response.ContentType所有类型例举
  7. 实训笔记(一) 创建文件夹(SDCard)
  8. 浅析SQL Server数据库中的伪列以及伪列的含义
  9. Java实战之04JavaWeb-02Request和Response
  10. nodejs mysql 创建连接池