every blog every motto: You can do more than you think.

0. 前言

对于构建深度学习网络模型,我们通常有三种方法,分别是:

  • Sequential API
  • Functional API
  • Subclassing API

说明: 推荐使用functional API.


本文主要对有关子类API(tf.keras.Model)构建模型时“两种方法”进行比较分析。
注: 为保持文章的完整性,本文仅就部分问题进行探讨,后续问题见下一篇博文。

1. 正文

  1. 通过继承tf.keras.Mdoel 这个Python类来定义自己的模型。
  2. 在继承类中,我们需要重写__init__()(构造函数,初始化)和call(input)(模型调用)两个方法,同时也可以根据需要增加自定义的 方法。
  3. init方法用于定义/初始化用到的层(如:卷积层、池化层等);call方法用于神经网络的正向传递(自动生成反向传递)

1.1 模板

以下两种方法结果类似,主要区别在于:

  • 一种再init方法中调用已有层
  • 另一种重写Layer,然后再init中对自定义的层进行实例化

1.1.1 方法一:调用已有层

class MyModel(tf.keras.Model):def __init__(self):super().__init__()     # Python 2 下使用 super(MyModel, self).__init__()# 此处添加初始化代码(包含 call 方法中会用到的层),例如# layer1 = tf.keras.layers.BuiltInLayer(...)# layer2 = MyCustomLayer(...)def call(self, input):# 此处添加模型调用的代码(处理输入并返回输出),例如# x = layer1(input)# output = layer2(x)return output# 还可以添加自定义的方法

1.1.2 方法二:自定义层

class DoubleConv(layers.Layer):"""自定义层"""def __init__(self):super().__init__()def call(self, input):passclass MyModel(tf.keras.Model):def __init__(self):super().__init__()  # Python 2 下使用 super(MyModel, self).__init__()# 此处添加初始化代码(包含 call 方法中会用到的层),例如doub_block = DoubleConv()# layer1 = tf.keras.layers.BuiltInLayer(...)# layer2 = MyCustomLayer(...)def call(self, input):# 此处添加模型调用的代码(处理输入并返回输出),例如# x = layer1(input)x = DoubleConv()# output = layer2(x)# return output# 还可以添加自定义的方法

1.2 实例演示

1.2.1 调用已有层

1.2.1.1 常规代码

import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLUclass Models(tf.keras.Model):def __init__(self):super().__init__()self.conv = Conv2D(16, (3, 3), padding='same')self.bn = BatchNormalization()self.ac = ReLU()self.conv2 = Conv2D(32, (3, 3), padding='same')self.bn2 = BatchNormalization()self.ac2 = ReLU()def call(self, x, **kwargs):x = self.conv(x)x = self.bn(x)x = self.ac(x)x = self.conv2(x)x = self.bn2(x)x = self.ac2(x)return xm = Models()
m.build(input_shape=(2, 8, 8, 3))
m.summary()

模型结构:

1.2.1.2 调整后代码

1. 共用批归一化
import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLUclass Models(tf.keras.Model):def __init__(self):super().__init__()self.conv = Conv2D(16, (3, 3), padding='same')self.bn = BatchNormalization()self.ac = ReLU()self.conv2 = Conv2D(32, (3, 3), padding='same')self.bn2 = BatchNormalization()self.ac2 = ReLU()def call(self, x, **kwargs):x = self.conv(x)x = self.bn(x)x = self.ac(x)x = self.conv2(x)# ==========================# 此处共用一个BatchNormalization# ===========================x = self.bn(x)x = self.ac2(x)return xm = Models()
m.build(input_shape=(2, 8, 8, 3))
m.summary()
  • 上面两处的批归一化(BatchNormalization)共用了一个BatchNormalization,出现如下报错。
  • 如果共用一个卷积/激活函数,同样会出现报错。(读者可自行验证)

共用一个批归一化,报错如下:

ValueError: Input 0 of layer batch_normalization is incompatible with the layer: expected axis 3 of input shape to have value 16 but received input with shape [2, 8, 8, 32]

共用一个卷积,报错如下:

ValueError: Input 0 of layer conv2d is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape [2, 8, 8, 16]

共用一个激活函数,报错如下:

ValueError: You tried to call `count_params` on re_lu_1, but the layer isn't built. You can build it manually via: `re_lu_1.build(batch_input_shape)`.

1.2.2 调用自定义层

1.2.2.1 常规代码

说明: 代码较长,此处分开写,打消读者的畏难情绪,便于阅读。
导入模块

import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'import tensorflow as tf
from tensorflow.keras import layers

自定义层

class DoubleConv(layers.Layer):def __init__(self, mid_kernel_numbers, out_kernel_number):"""初始化含有两个卷积的卷积块:param mid_kernel_numbers: 中间特征图的通道数:param out_kernel_number: 输出特征图的通道数"""super().__init__()self.conv1 = layers.Conv2D(mid_kernel_numbers, (3, 3), padding='same')self.conv2 = layers.Conv2D(out_kernel_number, (3, 3), padding='same')self.bn = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()self.ac = layers.ReLU()self.ac2 = layers.ReLU()def call(self, input, **kwargs):"""正向传播"""x = self.conv1(input)x = self.bn(x)x = self.ac(x)x = self.conv2(x)x = self.bn2(x)x = self.ac2(x)return x

模型类

class Model(tf.keras.Model):def __init__(self):"""构建模型的类"""super().__init__()# 初始化卷积块self.block = DoubleConv(16, 32)def call(self, x, **kwargs):x = self.block(x)return x

打印模型结构图

m = Model()
m.build(input_shape=(2, 8, 8, 3))
m.summary()

1.2.2.2 调整后代码

说明: 因代码较长,本部分仅展示调整部分,其余代码同上文(1.2.2.1)

1. 共用卷积
class DoubleConv(layers.Layer):def __init__(self, mid_kernel_numbers, out_kernel_number):"""初始化含有两个卷积的卷积块:param mid_kernel_numbers: 中间特征图的通道数:param out_kernel_number: 输出特征图的通道数"""super().__init__()self.conv1 = layers.Conv2D(mid_kernel_numbers, (3, 3), padding='same')self.conv2 = layers.Conv2D(out_kernel_number, (3, 3), padding='same')self.bn = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()self.ac = layers.ReLU()self.ac2 = layers.ReLU()def call(self, input, **kwargs):"""正向传播"""x = self.conv1(input)x = self.bn(x)x = self.ac(x)# =======================#   此处共用卷积# =======================x = self.conv(x)x = self.bn2(x)x = self.ac2(x)return x

报错:

AttributeError: 'DoubleConv' object has no attribute 'conv'
2. 共用批归一化
class DoubleConv(layers.Layer):def __init__(self, mid_kernel_numbers, out_kernel_number):"""初始化含有两个卷积的卷积块:param mid_kernel_numbers: 中间特征图的通道数:param out_kernel_number: 输出特征图的通道数"""super().__init__()self.conv1 = layers.Conv2D(mid_kernel_numbers, (3, 3), padding='same')self.conv2 = layers.Conv2D(out_kernel_number, (3, 3), padding='same')self.bn = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()self.ac = layers.ReLU()self.ac2 = layers.ReLU()def call(self, input, **kwargs):"""正向传播"""x = self.conv1(input)x = self.bn(x)x = self.ac(x)x = self.conv2(x)# =======================#   此处公用批归一化# =======================x = self.bn(x)x = self.ac2(x)return x

报错:

ValueError: Input 0 of layer batch_normalization is incompatible with the layer: expected axis 3 of input shape to have value 16 but received input with shape [2, 8, 8, 32]
3. 共用激活函数
class DoubleConv(layers.Layer):def __init__(self, mid_kernel_numbers, out_kernel_number):"""初始化含有两个卷积的卷积块:param mid_kernel_numbers: 中间特征图的通道数:param out_kernel_number: 输出特征图的通道数"""super().__init__()self.conv1 = layers.Conv2D(mid_kernel_numbers, (3, 3), padding='same')self.conv2 = layers.Conv2D(out_kernel_number, (3, 3), padding='same')self.bn = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()self.ac = layers.ReLU()self.ac2 = layers.ReLU()def call(self, input, **kwargs):"""正向传播"""x = self.conv1(input)x = self.bn(x)x = self.ac(x)x = self.conv2(x)x = self.bn2(x)# =======================#   此处公用批归一化# =======================x = self.ac(x)return x

4. 自定义层的多次调用(附)
import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'import tensorflow as tf
from tensorflow.keras import layersclass DoubleConv(layers.Layer):def __init__(self, mid_kernel_numbers, out_kernel_number):"""初始化含有两个卷积的卷积块:param mid_kernel_numbers: 中间特征图的通道数:param out_kernel_number: 输出特征图的通道数"""super().__init__()self.conv1 = layers.Conv2D(mid_kernel_numbers, (3, 3), padding='same')self.conv2 = layers.Conv2D(out_kernel_number, (3, 3), padding='same')self.bn = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()self.ac = layers.ReLU()self.ac2 = layers.ReLU()def call(self, input, **kwargs):"""正向传播"""x = self.conv1(input)x = self.bn(x)x = self.ac(x)x = self.conv2(x)x = self.bn2(x)x = self.ac2(x)return xclass Model(tf.keras.Model):def __init__(self):"""构建模型的类"""super().__init__()# 初始化卷积块self.block = DoubleConv(16, 32)self.block2 = DoubleConv(32, 64)def call(self, x, **kwargs):x = self.block(x)x = self.block2(x)return xm = Model()
m.build(input_shape=(2, 8, 8, 3))
m.summary()

1.3 总结

1.3.1 一般性总结

  1. 两种方法归根到底是一种方法,即对tf.keras.Model的继承,即,我们所说的子类API(Subclassing API)
  2. 对于三种构建模型方法( Sequential API / Functional API / Subclassing API),入门难度和灵活性依次增大
  3. 推荐使用函数式(Functional API),一般够用,且较为灵活。
  4. 对于子类API
    • 类的init方法,初始化要用到的层,如:卷积、池化、批归一化、激活函数等
    • 类的call方法,定义正向传递过程,即模型的图,其中反向传递自动完成。
  5. 模型结构图:
    • 调用已有层,打印模型结构时,我们能看到其中的每一层的信息(如:特征图大小)
    • 调用自定义层会将自定义层当做一个整体,打印模型结构时,我们看不到内部信息(具体见1.2.2)

1.3.2 (针对)错误性总结

  1. 调用已有层: 模型内各层不能重复使用!
  2. 调用自定义层: 模型内卷积、批归一化不能重复使用,激活函数可以重复使用

针对二者的区别,笔者有两种猜想:

  • 卷积和批归一化均需要参数,所以不能重复使用,否则,前一个用到的参数无法保留;激活函数不需要参数,所以可以重复使用。
  • 每一个层都有自已的名称,所以不能重复使用。(tf1.x 好像要对用到的层的名称进行指定才可使用,tf2.x并无此要求,笔者对tf1.x并不熟悉,此猜想不牢靠)

小结: 无论哪种猜想都不能解释二者为何会有区别,此点待解!

参考文献

[1] https://blog.csdn.net/weixin_39190382/article/details/104130782
[2] https://blog.csdn.net/weixin_39190382/article/details/104130995
[3] https://blog.csdn.net/weixin_42264234/article/details/103946960
[4] https://www.cnblogs.com/xiximayou/p/12690353.html#_label2
[5] https://tf.wiki/zh_hans/basic/models.html
[6] https://stackoverflow.com/questions/55908188/this-model-has-not-yet-been-built-error-on-model-summary#comment104868791_55909624
[7] https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/layers/Conv2D

【tf.keras.Model】构建模型小结(部分问题未解决)相关推荐

  1. tf.Keras.Model类总结

    文章目录 tf.keras.Model类 1. 创建一个tf.keras.Model类实例的方法 1.1 通过指定输入输出进行实例化 1.2 通过继承Model类进行实例化 2. tf.Keras.M ...

  2. tf.keras.Model之model.compile

    目录 model.compile的作用 model.compile的示例 tf.keras.Model类可能属于tf中拥有最多方法的类了,也最为常用.为啥?tensorflow就是一种机器学习框架,用 ...

  3. 解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题

    错误描述: 1.保存模型:model.save_weights('./model.h5') 2.脚本重启 3.加载模型:model.load_weights('./model.h5') 4.模型报错: ...

  4. tensorflow从入门到精通100讲(七)-TensorFlow房价预估使用Keras快速构建模型

    前言 这篇文章承接上一篇tensorflow从入门到精通100讲(二)-IRIS数据集应用实战 https://wenyusuran.blog.csdn.net/article/details/107 ...

  5. tf.keras.Model之model.fit

    目录 model.fit的作用 model.fit的示例 model.fit的作用 model.fit可用于以指定的迭代次数训练模型.可以设置的参数很多,重点理解黄色标注的参数,这些比较常用. x=N ...

  6. 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...

  7. 【tf.keras】官方教程一 Keras overview

    目录 Sequential Model:(the simplest type of model) Getting started with the Keras Sequential model Spe ...

  8. Tensorflow学习之tf.keras(一) tf.keras.layers.Model(另附compile,fit)

    模型将层分组为具有训练和推理特征的对象. 继承自:Layer, Module tf.keras.Model(*args, **kwargs ) 参数 inputs 模型的输入:keras.Input ...

  9. 使用估算器、tf.keras 和 tf.data 进行多 GPU 训练

    文 / Zalando Research 研究科学家 Kashif Rasul 来源 | TensorFlow 公众号 与大多数 AI 研究部门一样,Zalando Research 也意识到了对创意 ...

  10. keras多输出模型

    Keras多输入多输出模型构建 1. 多输出模型构建 多输出模型构建 自定义loss函数 批量训练 调试 2. 多输入多输出模型(上) 多输入多输出模型 (关键)定义这个具有两个输入和输出的模型: 编 ...

最新文章

  1. 转:linux svn常用命令
  2. arXiv热文解读 | 不懂Photoshop如何P图?交给深度学习吧
  3. FireBug实用指南
  4. Java中利用MessageFormat对象实现类似C# string.Format方法格式化
  5. 初来乍到!各位博客朋友多多支持!
  6. 编程题【System类】计算一千万个数添加到集合的时间
  7. 做自适应网站专业乐云seo_什么叫网站优化-网站建设-SEO优化
  8. jQuery操作radio、checkbox、select 集合
  9. 史上最失败系统!微软正式终止对Vista支持
  10. 【在线分享】考研数学思维导图+高数思维导图+汤家凤重点笔记+武忠祥重点笔记以及高数Xmind思维导图
  11. Msfconsole爆破ssh
  12. 摄像机标定的简单理解与纪要
  13. Oracle导出部分表 par,Oracle使用par文件进行全库导入导出
  14. ubuntu网页邮箱服务器设置,ubuntu配置邮件服务器
  15. erlang ets写入mysql_Erlang 进程字典 VS ETS
  16. 2012年09月12日-13日
  17. linux红帽子认证费用RHCT,关于RHCE和RHCT认证
  18. 50本永不过时的经典计算机书籍
  19. BRL_CAD 教程
  20. 【HAL库】HAL库STM32cubemx快速使用

热门文章

  1. Centos7---1708 Linux上从零开始安装mysql
  2. eureka默认端口号是多少_微服务技术系列教程 - SpringCloud- 服务治理Eureka(集群搭建)...
  3. stn专线和otn有什么区别_专线网络和家庭宽带有什么区别?
  4. 像素生存者2为什么显示服务器不可用,像素生存者2为什么更新了玩不了 | 手游网游页游攻略大全...
  5. mysql in 临时表_什么时候会用到临时表?MySQL临时表的使用总结
  6. qt 日历类 不可输入当前日期之后的日期_UI设计组件时间选择器,日历设计从未如此简单!...
  7. (day 52 - DFS) 剑指 Offer 68 - II. 二叉树的最近公共祖先
  8. (day 33 - 位运算 )剑指 Offer 56 - II. 数组中数字出现的次数 II
  9. 小程序素材抓取软件_小程序上新丨2020冬季产品图库更新,海量素材随你用!...
  10. Vagrant:将装在C盘的虚拟机移动到别的目录