作者丨苏剑林

单位丨追一科技

研究方向丨NLP,神经网络

个人主页丨kexue.fm

今天我们继续来深挖 Keras,再次体验 Keras 那无与伦比的优雅设计。这一次我们的焦点是“重用”,主要是层与模型的重复使用。

所谓重用,一般就是奔着两个目标去:一是为了共享权重,也就是说要两个层不仅作用一样,还要共享权重,同步更新;二是避免重写代码,比如我们已经搭建好了一个模型,然后我们想拆解这个模型,构建一些子模型等。

基础

事实上,Keras 已经为我们考虑好了很多,所以很多情况下,掌握好基本用法,就已经能满足我们很多需求了。

层的重用

层的重用是最简单的,将层初始化好,存起来,然后反复调用即可:

x_in = Input(shape=(784,))
x = x_inlayer = Dense(784, activation='relu') # 初始化一个层,并存起来x = layer(x) # 第一次调用
x = layer(x) # 再次调用
x = layer(x) # 再次调用

要注意的是,必须先初始化好一个层,存为一个变量好再调用,才能保证重复调用的层是共享权重的。反之,如果是下述形式的代码,则是非共享权重的:

x = Dense(784, activation='relu')(x)
x = Dense(784, activation='relu')(x) # 跟前面的不共享权重
x = Dense(784, activation='relu')(x) # 跟前面的不共享权重

模型重用

Keras 的模型有着类似层的表现,在调用时可以用跟层一样的方式,比如:

x_in = Input(shape=(784,))
x = x_inx = Dense(10, activation='softmax')(x)model = Model(x_in, x) # 建立模型x_in = Input(shape=(100,))
x = x_inx = Dense(784, activation='relu')(x)
x = model(x) # 将模型当层一样用model2 = Model(x_in, x)

读过 Keras 源码的朋友就会明白,之所以可以将模型当层那样用,是因为 Model 本身就是继承 Layer 类来写的,所以模型自然也包含了层的一些相同特性。

模型克隆

模型克隆跟模型重用类似,只不过得到的新模型跟原模型不共享权重了,也就是说,仅仅保留完全一样的模型结构,两个模型之间的更新是独立的。Keras 提供了模型可用专用的函数,直接调用即可:

from keras.models import clone_modelmodel2 = clone_model(model1)

注意,clone_model 完全复制了原模型模型的结构,并重新构建了一个模型,但没有复制原模型的权重的值。也就是说,对于同样的输入,model1.predict 和 model2.predict 的结果是不一样的。

如果要把权重也搬过来,需要手动 set_weights 一下:

model2.set_weights(K.batch_get_value(model1.weights))

进阶

上述谈到的是原封不等的调用原来的层或模型,所以比较简单,Keras 都准备好了。下面介绍一些复杂一些的例子。

交叉引用

这里的交叉引用是指在定义一个新层的时候,沿用已有的某个层的权重,注意这个自定义层可能跟旧层的功能完全不一样,它们之间纯粹是共享了某个权重而已。比如,Bert 在训练 MLM 的时候,最后预测字词概率的全连接层,权重就是跟 Embedding 层共享的。

参考写法如下:

class EmbeddingDense(Layer):"""运算跟Dense一致,只不过kernel用Embedding层的embedding矩阵"""def __init__(self, embedding_layer, activation='softmax', **kwargs):super(EmbeddingDense, self).__init__(**kwargs)self.kernel = K.transpose(embedding_layer.embeddings)self.activation = activationself.units = K.int_shape(self.kernel)[1]def build(self, input_shape):super(EmbeddingDense, self).build(input_shape)self.bias = self.add_weight(name='bias',shape=(self.units,),initializer='zeros')def call(self, inputs):outputs = K.dot(inputs, self.kernel)outputs = K.bias_add(outputs, self.bias)outputs = Activation(self.activation).call(outputs)return outputsdef compute_output_shape(self, input_shape):return input_shape[:-1] + (self.units,)# 用法
embedding_layer = Embedding(10000, 128)
x = embedding_layer(x) # 调用Embedding层
x = EmbeddingDense(embedding_layer)(x) # 调用EmbeddingDense层

提取中间层

有时候我们需要从搭建好的模型中提取中间层的特征,并且构建一个新模型,在 Keras 中这同样是很简单的操作:

from keras.applications.resnet50 import ResNet50
model = ResNet50(weights='imagenet')Model(inputs=model.input,outputs=[model.get_layer('res5a_branch1').output,model.get_layer('activation_47').output,]
)

从中间拆开

最后,来到本文最有难度的地方了,我们要将模型从中间拆开,搞懂之后也可以实现往已有模型插入或替换新层的操作。这个需求看上去比较奇葩,但是还别说,stackoverflow 上面还有人提问过,说明这确实是有价值的。

https://stackoverflow.com/questions/49492255/how-to-replace-or-insert-intermediate-layer-in-keras-model

假设我们有一个现成的模型,它可以分解为:

那可能我们需要将 h2 替换成一个新的输入,然后接上后面的层,来构建一个新模型,即新模型的功能是:

如果是 Sequential 类模型,那比较简单,直接把 model.layers 都遍历一边,就可以构建新模型了:

x_in = Input(shape=(100,))
x = x_infor layer in model.layers[2:]:x = layer(x)model2 = Model(x_in, x)

但是,如果模型是比较复杂的结构,比如残差结构这种不是一条路走到底的,就没有这么简单了。事实上,这个需求本来没什么难度,该写的 Keras 本身已经写好了,只不过没有提供现成的接口罢了。为什么这么说,因为我们通过 model(x) 这样的代码调用已有模型的时候,

实际上 Keras 就相当于把这个已有的这个 model 从头到尾重新搭建了一遍,既然可以重建整个模型,那搭建“半个”模型原则上也是没有任技术难度的,只不过没有现成的接口。具体可以参考 Keras 源码的 keras/engine/network.py 的 run_internal_graph 函数:

https://github.com/keras-team/keras/blob/master/keras/engine/network.py

完整重建一个模型的逻辑在 run_internal_graph 函数里边,并且可以看到它还不算简单,所以如无必要我们最好不要重写这个代码。但如果不重写这个代码,又想调用这个代码,实现从中间层拆解模型的功能,唯一的办法是“移花接木”了:通过修改已有模型的一些属性,欺骗一下 run_internal_graph 函数,使得它以为模型的输入层是中间层,而不是原始的输入层。有了这个思想,再认真读读 run_internal_graph 函数的代码,就不难得到下述参考代码:

def get_outputs_of(model, start_tensors, input_layers=None):"""start_tensors为开始拆开的位置"""# 为此操作建立新模型model = Model(inputs=model.input,outputs=model.output,name='outputs_of_' + model.name)# 适配工作,方便使用if not isinstance(start_tensors, list):start_tensors = [start_tensors]if input_layers is None:input_layers = [Input(shape=K.int_shape(x)[1:], dtype=K.dtype(x))for x in start_tensors]elif not isinstance(input_layers, list):input_layers = [input_layers]# 核心:覆盖模型的输入model.inputs = start_tensorsmodel._input_layers = [x._keras_history[0] for x in input_layers]# 适配工作,方便使用if len(input_layers) == 1:input_layers = input_layers[0]# 整理层,参考自 Model 的 run_internal_graph 函数layers, tensor_map = [], set()for x in model.inputs:tensor_map.add(str(id(x)))depth_keys = list(model._nodes_by_depth.keys())depth_keys.sort(reverse=True)for depth in depth_keys:nodes = model._nodes_by_depth[depth]for node in nodes:n = 0for x in node.input_tensors:if str(id(x)) in tensor_map:n += 1if n == len(node.input_tensors):if node.outbound_layer not in layers:layers.append(node.outbound_layer)for x in node.output_tensors:tensor_map.add(str(id(x)))model._layers = layers # 只保留用到的层# 计算输出outputs = model(input_layers)return input_layers, outputs

用法:

from keras.applications.resnet50 import ResNet50
model = ResNet50(weights='imagenet')x, y = get_outputs_of(model,model.get_layer('add_15').output
)model2 = Model(x, y)

代码有点长,但其实逻辑很简单,真正核心的代码只有三行:

model.inputs = start_tensors
model._input_layers = [x._keras_history[0] for x in input_layers]
outputs = model(input_layers)

也就是覆盖模型的 model.inputs 和 model._input_layers 就可以实现欺骗模型从中间层开始构建的效果了,其余的多数是适配工作,不是技术上的,而 model._layers = layers 这一句是只保留了从中间层开始所用到的层,只是为了统计模型参数量的准确性,如果去掉这一部分,模型的参数量依然是原来整个 model 那么多。

小结

Keras 是最让人赏心悦目的深度学习框架,至少到目前为止,就模型代码的可读性而言,没有之一。可能读者会提到 PyTorch,诚然 PyTorch 也有不少可取之处,但就可读性而言,我认为是比不上 Keras 的。

在深究 Keras 的过程中,我不仅惊叹于 Keras 作者们的深厚而优雅的编程功底,甚至感觉自己的编程技能也提高了不少。不错,我的很多 Python 编程技巧,都是从读 Keras 源码中学习到的。

点击以下标题查看作者其他文章:

  • 基于DGCNN和概率图的轻量级信息抽取模型

#投 稿 通 道#

 让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。

来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

? 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site

• 所有文章配图,请单独在附件中发送

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

?

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

▽ 点击 | 阅读原文 | 查看作者博客

“让Keras更酷一些!”:层与模型的重用技巧相关推荐

  1. “让Keras更酷一些!”:层中层与mask

    这一篇"让 Keras 更酷一些!"将和读者分享两部分内容:第一部分是"层中层",顾名思义,是在 Keras 中自定义层的时候,重用已有的层,这将大大减少自定义 ...

  2. 让Keras更酷一些:中间变量、权重滑动和安全生成器

    作者丨苏剑林 单位丨追一科技 研究方向丨NLP,神经网络 个人主页丨kexue.fm 继续"让Keras更酷一些"之旅. 今天我们会用 Keras 实现灵活地输出任意中间变量,还有 ...

  3. “让Keras更酷一些!”:分层的学习率和自由的梯度

    作者丨苏剑林 单位丨广州火焰信息科技有限公司 研究方向丨NLP,神经网络 个人主页丨kexue.fm 高举"让 Keras 更酷一些!"大旗,让 Keras 无限可能. 今天我们会 ...

  4. 比Keras更好用的机器学习“模型包”:无需预处理,0代码上手做模型

    萧箫 发自 凹非寺 量子位 报道 | 公众号 QbitAI 做机器学习模型时,只是融合各种算法,就已经用光了脑细胞? 又或者觉得,数据预处理就是在"浪费时间"? 一位毕业于哥廷根大 ...

  5. 比Keras更好用的机器学习“模型包”:0代码上手做模型

    做机器学习模型时,只是融合各种算法,就已经用光了脑细胞? 又或者觉得,数据预处理就是在"浪费时间"? 一位毕业于哥廷根大学.做机器学习的小哥也发现了这个问题:原本只是想设计个模型, ...

  6. keras dense sigmoid_tf.keras一个存在自定义层时加载模型时的小坑

    前言 Tensorflow在现在的doc里强推Keras,用过之后感觉真的很爽,搭模型简单,模型结构可打印,瞬间就能train起来不用自己写get_batch和evaluate啥的,跟用原生tenso ...

  7. GitChat · 人工智能 | 如何零基础用 Keras 快速搭建实用深度学习模型

    GitChat 作者:谢梁 原文: 如何零基础用 Keras 快速搭建实用深度学习模型 关注微信公众号:GitChat 技术杂谈 ,一本正经的讲技术 [不要错过文末活动] 前言 在这篇小文章中,我们将 ...

  8. 如何使用Keras和TensorFlow建立深度学习模型以预测员工留任率

    The author selected Girls Who Code to receive a donation as part of the Write for DOnations program. ...

  9. Keras读书笔记----卷积层、池化层

    1. 卷积层 1.1. Convolution1D层 一维卷积层,用以在一维输入信号上进行邻域滤波.当使用该层作为首层时,需要提供关键字参数 input_dim 或 input_shape . ker ...

最新文章

  1. Python3学习笔记----环境安装及文本编辑器的选择
  2. mysql 5.7临时表空间_深度解析MySQL 5.7之临时表空间
  3. oracle数据源的报表sql计算慢解决
  4. .NET Framework 4.5 五个很棒的特性
  5. Weblogic EJB 学习笔记(3)精
  6. Python脚本做接口测试,抛弃接口测试工具是否可行?(二)
  7. word转pdf出现空白页||去除PDF中的指定页
  8. 2020 安装 nacos
  9. Android入门项目(校园软件)
  10. Facebook全新数字货币Libra引发关注 数字货币国际化逐渐发展
  11. 计算机二级python刷题软件排行榜_计算机二级office刷题软件求推荐?
  12. 回声状态网络(ESN)对MNIST手写数字集识别
  13. 剑指offer题目记录
  14. WIFI 6有哪些新特征
  15. springmvc 升级到5.2.15版本,前台时间显示时间戳全局处理
  16. 算法题——判断四边形是否为凸四边形
  17. pytorch手写VGG16网络,两种写法,低阶基础写法
  18. Fork/Join框架之双端队列
  19. 操作无法完成(0x000006ba)。本地后台打印程序服务没有运行。请重新启动后台打印程序或重新启动计算机。
  20. PNAS:人类头皮记录电位的时间尺度

热门文章

  1. e语言html显示框,html marguee标签
  2. 安卓访问mysql的源码_【原创源码】安卓数据库简单操作demo
  3. python与数据库交互的模块pymysql
  4. CodeForces509F Progress Monitoring
  5. js-ajax-04
  6. HPU-- 1190 確率
  7. 前端开发中通过js设置cookie的一组方法
  8. 封装各种生成唯一性ID算法的工具类
  9. [Spring Framework]学习笔记--Dependency injection(DI)
  10. python pip安装指定版本unittest_你们想要的unittest用例失败重运行,解决方案来啦!...