keras提取模型中的某一层_keras K.function获取某层的输出操作
如下所示:
from keras import backend as K
from keras.models import load_model
models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image, (224, 224, 3))
image_arr = np.expand_dims(image_arr, axis=0)
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]
加载训练好并保存的网络模型
加载数据(图像),并将数据处理成array形式
指定输出层
将处理后的数据输入,然后获取输出
其中,K.function有两种不同的写法:
1. 获取名为layer_name的层的输出
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
#指定输出层的名称
2. 获取第n层的输出
layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output])
#指定输出层的序号(层号从0开始)
另外,需要注意的是,书写不规范会导致报错:
报错:
TypeError: inputs to a TensorFlow backend function should be a list or tuple
将该句:
f1 = layer_1(image_arr)[0]
修改为:
f1 = layer_1([image_arr])[0]
补充知识:keras.backend.function()
如下所示:
def function(inputs, outputs, updates=None, **kwargs):
"""Instantiates a Keras function.
Arguments:
inputs: List of placeholder tensors.
outputs: List of output tensors.
updates: List of update ops.
**kwargs: Passed to `tf.Session.run`.
Returns:
Output values as Numpy arrays.
Raises:
ValueError: if invalid kwargs are passed in.
"""
if kwargs:
for key in kwargs:
if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
key not in tf_inspect.getargspec(Function.__init__)[0]):
msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
'backend') % key
raise ValueError(msg)
return Function(inputs, outputs, updates=updates, **kwargs)
这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。
我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。
class Function(object):
"""Runs a computation graph.
Arguments:
inputs: Feed placeholders to the computation graph.
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: a name to help users identify what this function does.
"""
def __init__(self, inputs, outputs, updates=None, name=None,
**session_kwargs):
updates = updates or []
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` to a TensorFlow backend function '
'should be a list or tuple.')
if not isinstance(outputs, (list, tuple)):
raise TypeError('`outputs` of a TensorFlow backend function '
'should be a list or tuple.')
if not isinstance(updates, (list, tuple)):
raise TypeError('`updates` in a TensorFlow backend function '
'should be a list or tuple.')
self.inputs = list(inputs)
self.outputs = list(outputs)
with ops.control_dependencies(self.outputs):
updates_ops = []
for update in updates:
if isinstance(update, tuple):
p, new_p = update
updates_ops.append(state_ops.assign(p, new_p))
else:
# assumed already an op
updates_ops.append(update)
self.updates_op = control_flow_ops.group(*updates_ops)
self.name = name
self.session_kwargs = session_kwargs
def __call__(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
feed_dict = {}
for tensor, value in zip(self.inputs, inputs):
if is_sparse(tensor):
sparse_coo = value.tocoo()
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
np.expand_dims(sparse_coo.col, 1)), 1)
value = (indices, sparse_coo.data, sparse_coo.shape)
feed_dict[tensor] = value
session = get_session()
updated = session.run(
self.outputs + [self.updates_op],
feed_dict=feed_dict,
**self.session_kwargs)
return updated[:len(self.outputs)]
keras提取模型中的某一层_keras K.function获取某层的输出操作相关推荐
- keras提取模型中的某一层_Keras做图片分类(四):迁移学习--猫狗大战实战
本项目数据集来自kaggle竞赛,地址: https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data 数据的训练集放在train文 ...
- keras提取模型中的某一层_keras获得某一层或者某层权重的输出实例
一个例子: print("Loading vgg19 weights...") vgg_model = VGG19(include_top=False, weights='imag ...
- keras提取模型中的某一层_Tensorflow笔记:高级封装——Keras
前言 之前在<Tensorflow笔记:高级封装--tf.Estimator>中介绍了Tensorflow的一种高级封装,本文介绍另一种高级封装Keras.Keras的特点就是两个字--简 ...
- python模型保存save_浅谈keras保存模型中的save()和save_weights()区别
今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别. 我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5.同样是h5文件用save ...
- SSL协议工作在OSI模型中的哪一层?
首先我们来看看什么是SSL协议(引申出TLS): SSL(Secure Sockets Layer 安全套接层),及其继任者传输层安全(Transport Layer Security,TLS)是为网 ...
- OSI的七层模型,网线,网卡,集线器,交换机,路由器分别工作在七层模型中的哪一层?
OSI七层网络模型由下至上为1至7层,分别为物理层(Physical layer),数据链路层(Data link layer),网络层(Network layer),传输层(Transport la ...
- python输入数据的维度_keras分类模型中的输入数据与标签的维度实例
在<python深度学习>这本书中. 一.21页mnist十分类 导入数据集 from keras.datasets import mnist (train_images, train_l ...
- python提取文件中特定字符串
目录 1. Python3文件操作 1.1 打开和关闭文件 1.1.1 open( ) 函数 1.1.2 close( ) 函数 1.2 读写文件 1.2.1 write( ) 函数 1.2.2 r ...
- Java 提取 PPT 中 SmartArt 图形的文本内容
(使用工具: Free Spire.Presentation for Java) JAR包导入 方法一:下载Free Spire.Presentation for Java包并解压缩,然后将lib文件 ...
最新文章
- 带你看看获得鲁班奖的数据中心工程建设的有多完美!!
- Linux内存初始化(C语言部分)
- WWW 2021有哪些值得读的图机器学习相关论文?
- 汇编语言之标志寄存器
- 如何使用txt文件实现JMeter参数化
- 程序员被沦陷!国内程序员真的饱和了?
- 计算机记录乐器声音的文件是,一、用计算机录音的过程.pptx
- mysql sql exists_数据库sql语句的exists总结
- 下载python流程-Python编写win程序的操作流程
- 汽车常识全面介绍 - 刹车系统
- 无线服务器软件,无线局域网AAA服务器的软件设计与实现
- 慕课网仿去哪儿项目笔记--(一)-初始化准备
- android widget包说明与应用
- 华为数通HCIA笔记(OSI七层)
- 好架构师都是写代码写出来的
- NetCore EF 使用scaffold-dbcontext导致deps.json] does not exist的解决办法
- 如何在Idea一个窗口打开多个项目
- dedecms调用友情链接代码
- 使用Python 绘制双Y轴和误差棒柱状图
- 我的jQuery之路(笔记)--6
热门文章
- 华为5G模块MH5000_AT命令手册
- 电脑里文件名称怎么快速重命名
- 设备管理系统的功能通常有哪些?
- java什么时候要重写equse,【音频技术】何时使用均衡(When to Use EQ)【EQ均衡器】
- 20180507记事
- java setlayout_Java布局管理器setLayout()
- Android反编译apk修改版本号重新打包签名详细教程(超详细)
- mysql查询一百万_mysql procedure-MySQL超过一百万条数据查询要用到什么技术
- iOS端内嵌H5页面 点击a标签无反应
- C语言系列必读技术书单推荐从入门到进阶+技术书阅读方法论