之前训练的网络中有一部分可以用到一个新的网络中,但是不知道存储的参数如何部分恢复到新的网络中,也了解到有许多网络是通过利用一些现有的网络结构,通过finetuning进行改造实现的,因此了解了一下关于模型预训练后部分参数restore和finetuning的内容

更多内容参见:

https://blog.csdn.net/mieleizhi0522/article/details/80535189

https://blog.csdn.net/leo_xu06/article/details/79200634

https://blog.csdn.net/b876144622/article/details/79962727

https://blog.csdn.net/ying86615791/article/details/76215363

首先了解一下变量(tf.Variable),变量是tf框架中用于存储参数的对象,我们这里要恢复的参数也是variable类型的。训练的参数是放在不同名字下的variable中的,checkpoint中存储的变量也是通过不同的名字进行区分的,这里如果要恢复指定的参数可以使用

with tf.variable_scope('', reuse = True):sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))

Saver是用于保存变量的对象。下面是saver对象的创建和调用

saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")

如果仅在session开始时恢复模型变量的一个子集,需要对剩下的变量执行初始化op。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})

对已有checkpoint内容进行查看,可以使用一下代码(来自https://blog.csdn.net/mieleizhi0522/article/details/80535189),然后就可以结合之前的指定变量名的方法对参数进行restore了。注意,在完成部分参数的restore后要记得对没有初始化的变量进行初始化,否则报错。

import tensorflow as tfimport osfrom tensorflow.python import pywrap_tensorflowmodel_dir=r'G:\KeTi\C3D'checkpoint_path = os.path.join(model_dir, "sports1m_finetuning_ucf101.model")# 从checkpoint中读出数据reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法var_to_shape_map = reader.get_variable_to_shape_map()# 输出权重tensor名字和值for key in var_to_shape_map:print("tensor_name: ", key,reader.get_tensor(key).shape)

输出

tensor_name: var_name/wc4a (3, 3, 3, 256, 512)tensor_name: var_name/wc3a (3, 3, 3, 128, 256)tensor_name: var_name/wd1 (8192, 4096)tensor_name: var_name/wc5b (3, 3, 3, 512, 512)tensor_name: var_name/bd1 (4096,)tensor_name: var_name/wd2 (4096, 4096)tensor_name: var_name/wout (4096, 101)tensor_name: var_name/wc1 (3, 3, 3, 3, 64)tensor_name: var_name/bc4b (512,)tensor_name: var_name/wc2 (3, 3, 3, 64, 128)tensor_name: var_name/bc3a (256,)tensor_name: var_name/bd2 (4096,)tensor_name: var_name/bc5a (512,)tensor_name: var_name/bc2 (128,)tensor_name: var_name/bc5b (512,)tensor_name: var_name/bout (101,)tensor_name: var_name/bc4a (512,)tensor_name: var_name/bc3b (256,)tensor_name: var_name/wc4b (3, 3, 3, 512, 512)tensor_name: var_name/bc1 (64,)tensor_name: var_name/wc3b (3, 3, 3, 256, 256)tensor_name: var_name/wc5a (3, 3, 3, 512, 512)

tensorflow 模型预训练后的参数restore finetuning相关推荐

  1. 解密万亿参数M6模型预训练背后的分布式框架Whale

    简介: 最近,阿里云PAI团队和达摩院智能计算实验室一起发布"低碳版"巨模型M6,大幅降低万亿参数超大模型训练能耗.借助我们自研的Whale框架仅使用480卡GPU,即训练出了规模 ...

  2. PTMs:大模型预训练技巧之ZeRO训练优化技术(DeepS库-减少参数的冗余+优化通信)的简介(四大核心(模型分片/梯度累积/内存优化/分布式训练)、两大优化(非精度/冗余消除))、ZeRO3三个版

    PTMs:大模型预训练技巧之ZeRO训练优化技术(DeepSpeed库-减少参数的冗余+优化通信)的简介(四大核心技术(模型分片/梯度累积/内存优化/分布式训练).两大优化技术(ZeRO-Offloa ...

  3. 清华研究登Nature子刊:面向大规模预训练语言模型的参数高效微调

    ©作者 | 机器之心编辑部 来源 | 机器之心 近年来,清华大学计算机系孙茂松团队深入探索语言大模型参数高效微调方法的机理与特性,与校内其他相关团队合作完成的研究成果"面向大规模预训练语言模 ...

  4. Tensorflow模型优化训练思路

    问题现状 随着深度学习模型越来越大,数据集越来越大,模型的训练变得越来越慢.这对于想要快速验证算法的研究人员来说,是个比较麻烦的问题. 那么一般来说,我们会想要优化模型训练,以期更快验证模型效果. 无 ...

  5. TensorFlow 调用预训练好的模型—— Python 实现

    1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如 ...

  6. 预训练后性能反而变差,自训练要取代预训练了吗?

    2020-07-18 13:53:03 编译 | JocelynWang 编辑 | 丛 末 早在2018年底,FAIR的研究人员就发布了一篇名为<Rethinking ImageNet Pre- ...

  7. CVPR 2022 | 清华提出Point-BERT: 基于掩码建模的点云自注意力模型预训练

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:于旭敏   |  已授权转载(源:知乎)编辑:CVer https://zhuanlan.zhihu. ...

  8. [深度学习] 自然语言处理 --- Huggingface-Pytorch中文语言Bert模型预训练

    Hugging face 是一家总部位于纽约的聊天机器人初创服务商,开发的应用在青少年中颇受欢迎,相比于其他公司,Hugging Face更加注重产品带来的情感以及环境因素.官网链接在此 https: ...

  9. 【无标题】tensorflow hub 预训练模型库

    TensorFlow Hub 是一个包含经过训练的机器学习模型的代码库,这些模型稍作调整便可部署到任何设备上.您只需几行代码即可重复使用经过训练的模型,例如 BERT 和 Faster R-CNN. ...

最新文章

  1. 【AI实战】快速掌握TensorFlow(二):计算图、会话
  2. 把一个dataset的表放在另一个dataset里面_视频自监督一. STCR: 一个基于数据增强的简单有效正则项 (降低静态信息的影响)...
  3. 一次OutOfMemoryError: GC overhead limit exceeded
  4. spring-xml实现aop-通知的种类
  5. TypeError: ‘int‘ object is not callable
  6. 静态html js文件上传,js实现动态添加上传文件页面
  7. Web前端开发——BAT面试题汇总及答案01
  8. error: Microsoft Visual C++ 14.0 is required. Get it with “Microsoft Visual C++ Build Tools“:解决方案
  9. 进程+协程 计算操作
  10. 如何使用Java代码获取文件、文件流或字符串的编码方式
  11. STM32标准外设库(标准库)官网下载方法,附带2021最新标准固件库下载链接
  12. 探讨基于球谐函数的全局光照
  13. 方立勋JavaWeb学习地址
  14. 站长必备-伪原创原创度检测软件v1.3 (支持百度/谷歌/360/搜狗/神马/微信)
  15. 冯扬文:基于数据仓库的集装箱运价信息集成研究
  16. win10时间线时间轴(Timeline)如何关闭隐藏?
  17. 线下沙龙:靠谱的区块链应用到底是啥样?
  18. 我的HTML学习------表格的基本使用
  19. 至多包含 K 个不同字符的最长子串
  20. VMware 只能打开一个.vmx,无法打开第二个

热门文章

  1. 网络基础之 Nping 命令
  2. 重构改善既有代码设计--重构手法11:Move Field (搬移字段)
  3. 函数进阶学习之二 声明 定义
  4. 辨异 —— 冠词(定冠词、不定冠词、零冠词)
  5. storm流式大数据处理流行吗
  6. WinForm 实现拖拽功能
  7. 两道与二进制有关的sequence
  8. Linux驱动开发环境配置(内核源码树构造)
  9. 这是时间的推移 不是系统的分类
  10. Ubuntu 下 使用 adb logcat 显示 Android 日志