前言

从2019年9月底到现在,TF 2.0 已经正式发布三个多月了。但其实很多和 2.0 相关的特性,比如说 eager 模式,@tf.function 装饰器和 AutoGraph, 以及 keras 风格的模型,在 TF 1.x 的后期版本中已经可以使用了。在过去的一年里,我和我们组的成员也一直在断断续续的把我们的代码框架往 TF 2.x 的风格开始迁移。

回想我最早用 Tensorflow 是在2016年底,那时候 TF 1.0 还没有正式登场。在那之前涉及到所有神经网络的工作用的一直是 Caffe 和 MatConvNet, 更早一些是 Theano 。这三年来对 TF 的一个感受就是:模型训练越来越快,部署越来越方便,但也实现起来也越来越复杂。虽然对于 TF 也有些迷思 (比如说很长一段时间里 TF 的卷积 (tf.nn.conv2d) 都不支持 Group Convolution,最后还是来自 FB 的研究人员提交的 PR,而且直到现在 TF 的文档里也压根没提这事。。。),但整体来说我觉得它还是很好的一个框架。

我近几年的工作主要是人脸检测与识别,以及短视频分析。用 TF 主要是来训练卷积神经网络,也涉及到一些 LSTM 的工作。我会试着从模型搭建,模型训练,模型导出,分布式训练以及如何写自定义的层 (Layer) / 优化器 (Optimizer)/ 学习率 (Learing Rate Schedule) / 损失函数 Loss Function 等各个方面来看使用 TF 训练卷积神经网络的一路发展过来的变化。聊一聊在这个过程中踩过的一些坑,试着总结一下经验。

题图是去年参加 ICCV 时在会场旁边拍的一个图书馆,与内容无关。

1. Tensorflow的模型有几种写法? 从画图,到搭积木,再到写对象

1.1 计算图

在最早期的 Tensorflow 里,写模型的过程其实就是构造计算图 (Computational Graph) 的过程。计算图中的每一个节点 (Node) 定义了一个操作 (Operation),而计算图中的边 (Edge) 就是计算的数据 (Tensor) 。(严格来说还有另一种定义节点依赖关系 (Control Dependencies) 的边,它们不承载数据,但是控制计算图执行时节点的顺序关系。)

图1 计算图示例

计算图中的节点大致可以分为两类,一类是数据节点,一类是计算节点。假设我们想要构建一个两层的分类器,我们需要定义四个数据节点表示变量 [W1, b1, W2, b2],以及一个数据节点表示输入数据 x,然后通过调用 tf.nn 下的各类计算节点来完成最后计算图的构建,代码如下所示:

def nn(x, num_classes):W1 = tf.Variable(tf.random_uniform([32, 64],-1,1), name='W1')b1 = tf.Variable(tf.zeros([64]), name='b1')W2 = tf.Variable(tf.random_uniform([64, num_classes],-1,1), name='W2')b2 = tf.Variable(tf.zeros([num_classes]), name='b2')dense_1 = tf.nn.relu(tf.matmul(x, W1) + b1)dense_2 = tf.nn.relu(tf.matmul(dense_1, W2) + b2)probs = tf.nn.softmax(dense_2)return probs  x_tensor = tf.placeholder(tf.float32, shape=[None, 32], name='input')
probs = nn(x_tensor, num_classes=10)      

TF 2.x (以及后期的 TF 1.x) 的一个重大的变化是 Variable 不再是一个与计算图相关的全局变量, 而是与对象相关的一个局部变量。具体来说,在TF 1.x (早期的) 里,每当你新建一个 Variable 时,计算图中会添加一个节点,我们可以通过 get_collection 的方法来收集图中所有的 Variable。而在 TF2.x 里,每个 Variable是与创建它的对象相关的,对象可以是某个卷积层,也可以是某个模块,也可以是整个模型。更多内容可以参考 TF community 的 RFC 文档。

以上的代码完成了计算图的构建。TF 1.x 对于很多的操作提供一个更高层的 API tf.layers (TF 1.x 比较后期的版本需要访问 tf.compat.v1.layers)。layers 封装了计算和生成相关变量的过程,用 layers 来构建计算图的代码如下:

def nn(x, num_classes):dense1 = tf.compat.v1.layers.dense(x, 64, activation=tf.nn.relu, use_bias=True, name='fc1')dense2 = tf.compat.v1.layers.dense(dense1, num_classes, activation=tf.nn.relu, use_bias=True, name='fc2')probs = tf.nn.softmax(dense2)return probsx_tensor = tf.placeholder(tf.float32, shape=[None, 32], name='input')
probs = nn(x_tensor, num_classes=10)

这种方法使用起来更方便,但本质和之前的方法是一样的。

1.2 搭积木

关于 TF 2.0 的吐槽中,有一部分是关于 Keras 的。事实上从 TF 1.4 开始, Keras 就已经出现在了 TF 的 API 中。而且从 TF 1.9 开始,所有的 tf.layers 里的类,其实都是继承自 Keras.layer 相对应的类。所以不管你喜欢不喜欢,只要你用了 tf.layers, 你就是在用 Keras 相关的设计 lol

用 Keras 搭建模型主要有两种方式,分别是 Symbolic 式和 Imperative 式。(我实在不知道怎么翻译这两个词。。。) 第一种就像是搭积木,第二种是直接写一个 python 风格的对象 (pythonic) 。

搭积木的方法有两种,第一种是序列化 (Sequential) 的方法。我们从输入到输出,将定义好的层一个个的连接起来,合并到一个 Sequential 的对象中去。

from tf.keras import layers
model = tf.keras.Sequential([layers.Dense(64, activation='relu', input_shape=(32,), name='fc1'),layers.Dense(10, activation='relu', name='fc2'),layers.Activation(activation='softmax', name='sm')], name='nn') 

这种方法有一个缺陷是对于每一层,我们只能使用一个输入一个输出,所以这种方法对于大部分的现代卷积神经网络 (ResNet, DenseNet) 是没有用的。TF 提供了另一种搭积木的方法,就是使用 Functional API。在 Functional API 下,我们不仅定义层,同时直接搭建出计算的过程。

from tf.keras import layers
def get_model(num_classes):inputs = keras.Input(shape=(32,), name='input')x = layers.Dense(64, activation='relu')(inputs)x = layers.Dense(10, activation='relu')(x) outputs = layers.Activation(activation='softmax')(x)model = keras.Model(inputs=inputs, outputs=outputs, name='nn')return model

如果说 Sequential 的方式是在搭 Jenga, 那么 Function API 的方式更像是搭乐高,它更加的灵活,当我们需要使用多个输入或者多个输出的时候,只需要将将它们拼接起来即可,比如下面的这一个 Skip Connection 结构:

   x1 = layers.Dense(64, activation='relu')(inputs)x2 = layers.Dense(64, activation='relu')(x1)x3 = x2 + x1   

但是这两种方法依然会有一些问题:

  1. 训练时和推理时的不一致: 现代卷积网络中有些操作,比如说 BatchNormalization 还有 DropOut,它们在训练和推理时是不一致的。但是从 DropOut 的 API 设置上大家可以看到,在它的接口里并没有设置 是否训练时 (training=True) 的地方。这个设置往往是在 Keras 后端设置的,所以这就限制了只能使用 Build-in 的训练方法来训练模型 ( model.fit() ) 或者需要很仔细的来手动控制训练状态。
  2. 只能构建 DAG (Directed Acyclic Graph) 计算图,不能有循环的部分。这个对于 RNN/LSTM 来说是一个致命的限制。

1.3 最灵活的方法:Subclassing

为了得到最大的灵活性,TF 提供一种面对对象的方式来构造模型。具体来说,我们继承一 个 keras.model ,自己定义模型初始化以及前向传播的过程 (后向传播的过程会由 TF 自动生成)。以下是一个 Subclassing 的例子:

from tf.keras import layersclass NN(tf.keras.Model):def __init__(self, units=64, num_classes=10):super(NN, self).__init__()self.units = unitsself.num_classes = num_classesself.dense1 = layers.Dense(self.units, activation='relu')self.dense2 = layers.Dense(self.num_classes, activation='relu')self.sm = layers.Activation(activation='softmax')def call(inputs, training=True):net = self.dense1(inputs)net = self.dense2(net) probs = self.sm(net)return probs

在示例中我们可以看到,在 call 方法里我们有一个 argument 是 training (在这个例子中其实并没有用,但是当我们使用 BN 还有 DropOut 时会非常有用),当调用模型时我们可以通过该参数来控制模型的运行模态:

# model initialization
nn = NN(units=64, num_classes=10)# training mode
probs = nn(x, training=True)
# testing mode
probs = nn(x, training=False)

Keras.Model 的具体实现:keras.Model 继承于 keras.Network , keras.Network 继承于 keras.Layer 。具体来说,Layer 定义了基本的运算单元,以及 Variable 的初始化过程;Network 定义了网络的结构,这里的结构同时包含了前向传播和后向传播 (也就说 Optimizer 也包含在了 Network 里) 的网络,同时 Network 中也定义了 Save 方法,也就是说当我们导出模型时其实是从 Network 这个层次发起的;Model 在 Network 的基础上添加了训练 (training),验证 (evaluation) 以及测试 (testing) 的实现。

虽然在 TF 中较晚才引入,但其实 Subclassing 并不是一个多新的概念,事实上在其它的框架中 (pyTroch, mxNet 等)很早有就了。 keras 的作者 François Chollet 曾经发过这么一个推来比较不同框架下写一个 RNN 的代码,看上去是不是很像?

图2 不同框架 (TensorFlow, MXNet, Chainer, PyTorch) 下写出的 RNN Image Credits: François Chollet

1.4 如何选择

以上就是我使用过的一些方法,但是在实际应用中,我们到底该如何选择呢?我觉得这个问题并没有一个固定的答案,它是由你的任务,你的预期以及生产环境决定的。

  • 如果你的模型和数据量都偏小,单卡训练就可以,那么我觉得 Sequential 和 Function API 就足够了;
  • 如果你模型和数据量都比较大,训练时间对你来说最关键,需要多卡 MirroredStrategy 甚至用到 TPU,或者你有很多自定义的层,那么 Subclassing 会更灵活;
  • 如果部署问题对你来说是最关键的,后期需要做 TensorRT 或者 OpenVino 优化,那么用最底层的方式来写模型可能会更合适。

代码风格变化的背后是底层框架的变化。TF1.x 的本质是构建和执行计算图。但是从 TF1.x 的后期开始,TF 开始逐渐淡化用户与计算图的交互进而转到用户与对象图 (ObjectGraph) 的交互。Layer / Model / Optimizer 等等对象,他们的基类都是一个 Trackable 类。TF2.x 的风格提倡我们从对象的层面来收集模型的损失项(losses,比如说正则项) 和更新项 (updates,比如说 BN 里的 Moving Average),以及控制梯度的更新。同时 TF2.x 倡导使用 @tf.function 装饰器来编译计算图而不是使用 tf.session() 来执行计算图。所有的这些,都是全局化风格向局部化风格的一个演进。

---------------------------------------------------------------------------------------

下一篇会从模型训练的角度来分析一下 TF 代码风格的演变,主要会介绍和分析一下 TF 是如何从自定义训练 (Session-based Custom Training), 到 Estimator,再到 Build-in 训练,最后又回到自定义训练 (GradientTape-based Custom Training) 的。

参考文献:

  1. TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems
  2. What are Symbolic and Imperative APIs in TensorFlow 2.0?
  3. Inside Tensorflow 来自 TF team 内部的分享,非常值得一看
  4. TensorFlow Offical Document

gpxclear寄存器写0和写1_画图,搭积木,写对象 [TF 笔记 0]相关推荐

  1. 惊呆!编程就像写文档!开发神似搭积木!

    演讲台上,搭搭云创始人万斌正在指挥工作人员在Word中设计一个采访备忘录表格. "对,就在这里最后增加一个补充采访字段,好!就是这样." 大屏幕上即时的展现出了一张在Word中创建 ...

  2. lstm keras 权重 理解_Keras 作者:TF 2.0+Keras 深度学习研究你需要了解的 12 件事...

    [新智元导读]Keras 作者 François Chollet 今天发表了一系列推文,如果你使用 TensorFlow 2.0 + Keras 做深度学习研究,这里有你需要知道的一切. Tensor ...

  3. 可能是第二好的 Spring OAuth 2.0 文章,艿艿端午在家写了 3 天~

    " 本文在提供完整代码示例,可见 https://github.com/YunaiV/SpringBoot-Labs 的 lab-68-spring-security-oauth 目录. 原 ...

  4. 基于TensorFlow1.4.0的FNN全连接网络识别MNIST手写数据集

    MNIST手写数据集是所有新手入门必经的数据集,数据集比较简单,训练集为50000张手写图片,测试集为张手写图片10000,大小都为28*28,不用自己下载,直接从TensorFlow导入即可 后续随 ...

  5. 转载一个病毒programguide,过后给大家写下感想,学8086时,写过com病毒

    Billy Belceb病毒编写教程---DOS篇(基础理论)                                     翻译:onlyu[FCG][DFCG] [译者声明]       ...

  6. linux手写数字识别opencv,opencv实现KNN手写数字的识别

    人工智能是当下很热门的话题,手写识别是一个典型的应用.为了进一步了解这个领域,我阅读了大量的论文,并借助opencv完成了对28x28的数字图片(预处理后的二值图像)的识别任务. 预处理一张图片: 首 ...

  7. python写界面进度条程序_Python中如何写控制台进度条的整理

    进度条和一般的print区别在哪里呢? 答案就是print会输出一个\n,也就是换行符,这样光标移动到了下一行行首,接着输出,之前已经通过stdout输出的东西依旧保留,而且保证我们在下面看到最新的输 ...

  8. java jxl 写 excel_Java 操作Excel(jxl读和写)

    一.读操作: package com.jxl.opr; import java.io.FileInputStream; import java.io.FileNotFoundException; im ...

  9. wx.checkjsapi是写在config里面吗_用Python写一个程序,解密游戏内抽奖的秘密

    前言 本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理. 作者: 极客挖掘机 PS:如有需要Python学习资料的小伙伴可以加点击下 ...

最新文章

  1. 曙光计算机系统,曙光1000大规模并行计算机系统
  2. asp.net 后台任务作业框架收集
  3. 【MM模块】 Blanket PO 框架订单
  4. config中自定义配置
  5. 内核对象管理:Slab,Slub
  6. linux权限介绍,Linux的权限介绍
  7. BZOJ 2969 期望
  8. Java 泛型List clone
  9. _软件园三期西片区F地块举行招商推介会 超300家企业意向落户 - 本网原创
  10. requests模块中使用代理proxy发送请求
  11. 要毕业了,兄弟也签了工作。。。
  12. Neuralog.v2013.06 1CD (测井曲线智能矢量化软件)
  13. 科研工作者要会的技能----查找顶刊会议或期刊的方法
  14. Linux虚拟网络基础 — Bridge
  15. Vue打包出现Browserslist: caniuse-lite is outdated
  16. opencv cvhog详解
  17. 计算机等级考试考几级才能成为数据库工程师?
  18. sun服务器多磁盘配置信息,配置 Solaris iSCSI initiator
  19. 算法(Algorithm)
  20. 打造个人版微信小程序(1)——本地开发api接口调用

热门文章

  1. c#子线程中打开系统文件操作对话框
  2. 一维傅里叶变换后的复数怎样理解?
  3. Java连接Mysql数据库增删改查实现
  4. play2框架 jpa mysql_单元测试 – Playframework 2.2.x Java JPA – 用于单元测试和生产的独立数据库...
  5. 友盟分享成功之后没有提示信息的解决
  6. 计算机应用基础第2版在线作业1,计算机应用基础(第2版)在线作业(1)
  7. python的自带数据集_盘点 | Python自带的那些数据集
  8. Qt Example各例子演示功能说明
  9. java.lang.UnsatisfiedLinkError: No implementation found for void com.mchsdk.paysdk.net.MCHKeyTools.n
  10. Android开发之设置DialogFragment的窗体背景色的方法亲测可用