目录

前言

1. 用tensorflow自带的工具

2. 用tensorflow.contrib.slim。

3. 从保存的model中提取var_list

4. 其他


前言

在加载预训练的网络模型时,有时仅需要利用保存的网络模型中的部分变量,因此我们需要提取这些变量的变量列表,用于变量加载。如以下示例所示,当用tensorflow.train.Saver(var_list)加载模型参数时,需要传入的var_list需要一种特定的数据格式,而直接使用str组成的变量名列表是无效的。

var_list = saver._var_list
print(type(var_list[0]))
#output : <class 'tensorflow.python.ops.variables.RefVariable'>
print(var_list[0])
#output : <tf.Variable 'g_b1:0' shape=(32768,) dtype=float32_ref>

那么,如何获取指定格式的变量列表呢?主要有以下方式:

1. 用tensorflow自带的工具

  • (1) 获取图(graph)中所有变量,包括可训练的与不可训练(training=False)的变量
all_vars = tf.global_variables()
  • (2)获取图中可训练的变量。
train_vars = tf.trainable_variables()
  • (3)根据变量名在已有变量列表中筛选变量。在(1)或(2)中获取的变量列表中进行筛选
# 从all_vars列表中筛选以”generator“开头的变量,组成一个新的变量列表。
g_vars =[var for var in all_vars for var.name.startswith("generator")]# 从all_vars列表中筛选name中包含”batch_normalization“的变量
bn_var_list = [v for v in all_vars if "batch_normalization" in v.name ] 
  • (4)提取某个特定变量空间中的变量列表
# 提取variable_scope或name_scope ”generator“下的所有变量的变量列表。
g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="generator")# 提取variable_scope或name_scope ”generator“下的所有可训练变量的变量列表。
g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="generator")

2. 用tensorflow.contrib.slim。

tensorflow.contrib.slim是一个强大的工具,也可以用来方便的构建变量列表

  • (1)返回常规变量,应该是返回所有变量(没测试)
import tensorflow.contrib.slim as slim
#下面这行代码返回常规变量
#常规变量是slim里面与model变量对应的一个类型
regular_variables = slim.get_variables()
#你也可以直接
vars = slim.get_variables_to_restore()
  • (2)根据条件,筛选变量,构建变量列表
#通过name或前缀筛选
variables = slim.get_variables_by_name("d_")#通过name后缀筛选
variables = slim.get_variables_by_suffix("_b")#通过namespace筛选
variables = slim.get_variables(scope="layer1")#通过include和exclude特定字符串筛选
variables_to_restore = slim.get_variables_to_restore(include=["d_"])
variables_to_restore = slim.get_variables_to_restore(exclude=["_w"])

3. 从保存的model中提取var_list

方法,将离线文件载入当前环境(加载保存的model的图结构),然后使用以上的方法。

注意:在加载model中的图时,要先清空现有的图,否则,import_meta_graph会把原model里面的数据追加到现有的model中,导致一片混乱

# 清空图
tf.reset_default_graph()with tf.Session(graph=tf.get_default_graph()) as sess:# 从model.meta文件中加载保存的图结构new_saver = tf.train.import_meta_graph('./model.meta')# 从model中加载保存的变量数据new_saver.restore(sess, './model')# 然后从加载得到的图中获取变量列表,如var_list=tf.global_variables()print(var_list)

下面是一个从保存的模型中加载图结构,加载变量数值,并打印特定变量的变量值的示例。

import tensorflow as tfmodel_path = "/media/***/model.ckpt-18000"tf.reset_default_graph() # 清空图,防止图上存在干扰节点
with tf.Session(graph=tf.get_default_graph()) as sess:saver = tf.train.import_meta_graph(model_path+".meta")# 从model.meta文件中加载保存的图结构saver.restore(sess, model_path)# 从model中加载保存的变量数据var_list = tf.global_variables() # 获取图上的所有变量bn_var_list = [v for v in var_list if "batch_normalization" in v.name ] # 筛选与BatchNormalization相关的变量for v in bn_var_list: print(v,sess.run(v)[1]) # 打印变量名,与对应的变量值。为了便于观察,打印变量张量的第一个数值

4. 其他

  • (1)用pywrap_tensorflow 直接读取离线文件中的变量信息,但是,这种方法获取到的var_list与要求的格式不一样。

我还不知道怎么转换成saver要求的类型,等一个有缘人相告!

    import tensorflow as tffrom tensorflow.python import pywrap_tensorflow#文件夹地址改成自己的model_dir="./model"ckpt = tf.train.get_checkpoint_state(model_dir)reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)#返回一个dict= {'name':[shape] }#例如 'd_w2/Adam':[4, 4, 32, 64]var_to_shape_map = reader.get_variable_to_shape_map()#我们可以用遍历的方式,取出字典里所有的keyfor key in var_to_shape_map:print(key)        #key是str类型的#再用key去找这个tensor的值a=reader.get_tensor(key)print(type(a))    #输出: <class 'numpy.ndarray'>
  • (2)把获得的list输出到一个外部txt文件中,以后使用只需要读取txt即可

PS:目前这种方法只支持str类型保存变量name!不支持saver!

代码示例:

 with open("var_list.txt","w") as f:f.write( ','.join(str(var_list)))   #write只能打印str类型,需要强制转换

效果如图:

要使用的时候,只需要按如下代码读取

 with open("var_list.txt","r") as f:str1 = f.readlines()[0]var_list=str1.split(',')print(var_list)

Tensorflow 获取model中的变量列表,用于模型加载等相关推荐

  1. tensorflow 获取checkpoint中的变量列表

    方式1:静态获取,通过直接解析checkpoint文件获取变量名及变量值 通过 reader = tf.train.NewCheckpointReader(model_path) 或者通过: from ...

  2. 使用gensim.models.Word2Vec.load(‘model.txt‘)报错,导致模型加载不了的解决办法之一

    背景: 想做一个基于Word2Vec的分析标题与标题之间相关性的模型,训练完之后保存模型.再加载发生了如下错误: 在这里插入代码片 Traceback (most recent call last): ...

  3. 【matlab-1】工具箱、窗口、变量的存储与加载、帮助

    1. MatLab的工具箱子(Toolboxes)     (1) 应用数学类     (2) 电子技术类     (3) 图形图象技术     (4) 通讯     (5) 财经与金融     (6 ...

  4. 进程handle获取线程_获取进程中的线程列表

    进程handle获取线程 The System.Diagnostics namespace contains functions that allow you to manage processes, ...

  5. java获取文件目录列表_获取目录中的文件列表

    我正在开发一个C项目,我需要获取目录中的文件列表 . 我正在使用dirent.h但是在使用它时遇到了一些问题,我正在Linux下构建程序 . 当我尝试构建程序时,我收到以下错误 myClass:err ...

  6. 通过js获取Model中数据

    通过js获取Model中数据 前端js获取model 1.获取model的js代码必须写在html中 2.script中添加 th:inline="javascript" < ...

  7. python获取url列表参数_python 获取url中的参数列表实例

    Python的urlparse有对url的解析,从而获得url中的参数列表 import urlparse urldata = "http://en.wikipedia.org/w/api. ...

  8. Thymeleaf-如何获取model中的值

    后台的实现: @RequestMapping("/adds") public String ProtaskAdd(Model model){model.addAttribute(& ...

  9. html读取model的值,Js和Thymeleaf如何获取model中的值

    简述 在大多数的项目架构中,使用SPringBoot发布微服务,前端采用Thymeleaf做为Html模版,使用Jquery做为动态脚本,那么Thymeleaf和Jquery是如何获取Model中的数 ...

最新文章

  1. JavaScript 利用location对象实现跨页面传参
  2. 给热爱学习的同学们推荐一些顶级的c# Blogs链接
  3. 初识jvm-1.Java类的加载机制
  4. linux route命令深入浅出与实战案例精讲
  5. chrome使用技巧
  6. hdu 4279 Number
  7. 输入3个双精度实数,分别求出它们的和,平均值,平方和以及平方和的开方
  8. 鼠标侧键能改为ctrl吗_200元档次又一高竞争力外设 雷柏V30鼠标评测
  9. [React Native Android 安利系列]样式与布局的书写
  10. 自定义srv消息之ros
  11. python爬取学校题库_Python爬虫面试题
  12. printf()语句
  13. MyBatis框架的基本使用
  14. 阶段3 1.Mybatis_09.Mybatis的多表操作_8 mybatis多对多操作-查询角色获取角色下所属用户信息...
  15. seq()函数--R语言
  16. win7网络里计算机登录失败,Win7系统访问网络时提示“登陆失败”的解决方法
  17. centos编译安装vim7.4
  18. 关于GitHub如何转为中文问题——Google举例
  19. Flash游戏开发框架Flixel介绍
  20. 中国大学MOOC 程序设计入门——C语言 翁凯 编程测试题汇总

热门文章

  1. flink读取不到文件_Flink流处理API——Source
  2. python是交互式语言吗_什么是Python交互式解释器
  3. linux脚本ipddr.sh 是什么,MTK DDR调试
  4. mysql 不同的记录_Mysql通过一个限制条件,查出多条不同的记录
  5. xp怎么删除计算机用户,WinXp系统如何删除用户账户?Xp系统删除用户账号的方法...
  6. 系统动力学建模工具_多体动力学:ANSYS Motion 2020R2
  7. 1977年发生事件_大金蛇:千年银蛇,万年金蛇:1977年【蛇蛇人】11月上旬家里有“爆炸性”事件发生!...
  8. jsf如何与数据库连接_JSF数据库示例– MySQL JDBC
  9. 亚马逊CloudFront
  10. 使用Adobe Acrobat为PDF文件添加签名(图片+签名)