1. 基础
    TensorFlow 基础
    TensorFlow 模型建立与训练
    基础示例:多层感知机(MLP)
    卷积神经网络(CNN)
    循环神经网络(RNN)
    深度强化学习(DRL)
    Keras Pipeline
    自定义层、损失函数和评估指标
    常用模块 tf.train.Checkpoint :变量的保存与恢复
    常用模块 TensorBoard:训练过程可视化
    常用模块 tf.data :数据集的构建与预处理
    常用模块 TFRecord :TensorFlow 数据集存储格式
    常用模块 tf.function :图执行模式
    常用模块 tf.TensorArray :TensorFlow 动态数组
    常用模块 tf.config:GPU 的使用与分配

  2. 部署
    TensorFlow 模型导出
    TensorFlow Serving
    TensorFlow Lite

  3. 大规模训练与加速
    TensorFlow 分布式训练
    使用 TPU 训练 TensorFlow 模型

  4. 扩展
    TensorFlow Hub 模型复用
    TensorFlow Datasets 数据集载入

  5. 附录
    强化学习基础简介


目录

  • tf.train.Checkpoint
  • 保存参数
  • 载入之前保存的参数
  • 保存变量+恢复变量
  • `tf.train.Checkpoint` VS `tf.train.Saver`
  • 实例
  • 使用 `tf.train.CheckpointManager` 删除旧的 Checkpoint 以及自定义文件编号

Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),请参考 “部署” 章节中的 SavedModel 。

tf.train.Checkpoint

很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用 Python 的序列化模块 pickle 存储 model.variables。但不幸的是,TensorFlow 的变量类型 ResourceVariable 并不能被序列化。

好在 TensorFlow 提供了 tf.train.Checkpoint 这一强大的变量保存与恢复类,可以使用其 save()restore() 方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizertf.Variabletf.keras.Layer 或者 tf.keras.Model 实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:

checkpoint = tf.train.Checkpoint(model=model)

这里 tf.train.Checkpoint() 接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model 的模型实例 model 和一个继承 tf.train.Optimizer 的优化器 optimizer ,我们可以这样写:

checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)

这里 myAwesomeModel 是我们为待保存的模型 model 所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。

保存参数

接下来,当模型训练完成需要保存的时候,使用:

checkpoint.save(save_path_with_prefix)

就可以。 save_path_with_prefix 是保存文件的目录 + 前缀

  • 例如,在源代码目录建立一个名为 save 的文件夹并调用一次 checkpoint.save('./save/model.ckpt') ,我们就可以在 save 目录下发现名为 checkpointmodel.ckpt-1.indexmodel.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save() 方法可以运行多次,每运行一次都会得到一个 .index 文件和 .data 文件,序号依次累加。

载入之前保存的参数

当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:

model_to_be_restored = MyModel()                                        # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)   # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)

即可恢复模型变量。 save_path_with_prefix_and_index 是之前保存的文件的目录 + 前缀 + 编号

  • 例如,调用 checkpoint.restore('./save/model.ckpt-1') 就可以载入前缀为 model.ckpt ,序号为 1 的文件来恢复模型。

当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) 这个辅助函数返回目录下最近一次 checkpoint 的文件名。

  • 例如如果 save 目录下有 model.ckpt-1.indexmodel.ckpt-10.index 的 10 个保存文件, tf.train.latest_checkpoint('./save') 即返回 ./save/model.ckpt-10

保存变量+恢复变量

总体而言,恢复与保存变量的典型代码框架如下:

# train.py 模型训练阶段model = MyModel()
# 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入)
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型训练代码)
# 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)
checkpoint.save('./save/model.ckpt')
# test.py 模型使用阶段model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model)             # 实例化Checkpoint,指定恢复对象为model
checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数
# 模型使用代码

tf.train.Checkpoint VS tf.train.Saver

tf.train.Checkpoint 与以前版本常用的 tf.train.Saver 相比,强大之处在于其支持在即时执行模式下 “延迟” 恢复变量

具体而言,当调用了 checkpoint.restore() ,但模型中的变量还没有被建立的时候,Checkpoint 可以等到变量被建立的时候再进行数值的恢复。即时执行模式下,模型中各个层的初始化和变量的建立是在模型第一次被调用的时候才进行的(好处在于可以根据输入的张量形状而自动确定变量形状,无需手动指定)。这意味着当模型刚刚被实例化的时候,其实里面还一个变量都没有,这时候使用以往的方式去恢复变量数值是一定会报错的。比如,你可以试试在 train.py 调用 tf.keras.Modelsave_weight() 方法保存 model 的参数,并在 test.py 中实例化 model 后立即调用 load_weight() 方法,就会出错,只有当调用了一遍 model 之后再运行 load_weight() 方法才能得到正确的结果。可见, tf.train.Checkpoint 在这种情况下可以给我们带来相当大的便利。另外, tf.train.Checkpoint 同时也支持图执行模式

实例

最后提供一个实例,以前章的 多层感知机模型 为例展示模型变量的保存和载入:

import tensorflow as tf
import numpy as np
import argparse
from zh.model.mnist.mlp import MLP
from zh.model.utils import MNISTLoaderparser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', default='train', help='train or test')
parser.add_argument('--num_epochs', default=1)
parser.add_argument('--batch_size', default=50)
parser.add_argument('--learning_rate', default=0.001)
args = parser.parse_args()
data_loader = MNISTLoader()def train():model = MLP()optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)checkpoint = tf.train.Checkpoint(myAwesomeModel=model)      # 实例化Checkpoint,设置保存对象为modelfor batch_index in range(1, num_batches+1):                 X, y = data_loader.get_batch(args.batch_size)with tf.GradientTape() as tape:y_pred = model(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch %d: loss %f" % (batch_index, loss.numpy()))grads = tape.gradient(loss, model.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))if batch_index % 100 == 0:                              # 每隔100个Batch保存一次path = checkpoint.save('./save/model.ckpt')         # 保存模型参数到文件print("model saved to %s" % path)def test():model_to_be_restored = MLP()# 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restoredcheckpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)      checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))if __name__ == '__main__':if args.mode == 'train':train()if args.mode == 'test':test()

在代码目录下建立 save 文件夹并运行代码进行训练后,save 文件夹内将会存放每隔 100 个 batch 保存一次的模型变量数据。在命令行参数中加入 --mode=test 并再次运行代码,将直接使用最后一次保存的变量值恢复模型并在测试集上测试模型性能,可以直接获得 95% 左右的准确率。

使用 tf.train.CheckpointManager 删除旧的 Checkpoint 以及自定义文件编号

在模型的训练过程中,我们往往每隔一定步数保存一个 Checkpoint 并进行编号。不过很多时候我们会有这样的需求:

  • 在长时间的训练后,程序会保存大量的 Checkpoint,但我们只想保留最后的几个 Checkpoint;
  • Checkpoint 默认从 1 开始编号,每次累加 1,但我们可能希望使用别的编号方式(例如使用当前 Batch 的编号作为文件编号)。

这时,我们可以使用 TensorFlow 的 tf.train.CheckpointManager 来实现以上需求。具体而言,在定义 Checkpoint 后接着定义一个 CheckpointManager:

checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)

此处, directory 参数为文件保存的路径, checkpoint_name 为文件名前缀(不提供则默认为 ckpt ), max_to_keep 为保留的 Checkpoint 数目。

在需要保存模型的时候,我们直接使用 manager.save() 即可。如果我们希望自行指定保存的 Checkpoint 的编号,则可以在保存时加入 checkpoint_number 参数。例如 manager.save(checkpoint_number=100)

以下提供一个实例,展示使用 CheckpointManager 限制仅保留最后三个 Checkpoint 文件,并使用 batch 的编号作为 Checkpoint 的文件编号。

import tensorflow as tf
import numpy as np
import argparse
from zh.model.mnist.mlp import MLP
from zh.model.utils import MNISTLoaderparser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', default='train', help='train or test')
parser.add_argument('--num_epochs', default=1)
parser.add_argument('--batch_size', default=50)
parser.add_argument('--learning_rate', default=0.001)
args = parser.parse_args()
data_loader = MNISTLoader()def train():model = MLP()optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)checkpoint = tf.train.Checkpoint(myAwesomeModel=model)      # 使用tf.train.CheckpointManager管理Checkpointmanager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)for batch_index in range(1, num_batches):X, y = data_loader.get_batch(args.batch_size)with tf.GradientTape() as tape:y_pred = model(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch %d: loss %f" % (batch_index, loss.numpy()))grads = tape.gradient(loss, model.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))if batch_index % 100 == 0:# 使用CheckpointManager保存模型参数到文件并自定义编号path = manager.save(checkpoint_number=batch_index)         print("model saved to %s" % path)def test():model_to_be_restored = MLP()checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)      checkpoint.restore(tf.train.latest_checkpoint('./save'))y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))if __name__ == '__main__':if args.mode == 'train':train()if args.mode == 'test':test()

【Tensorflow教程笔记】常用模块 tf.train.Checkpoint :变量的保存与恢复相关推荐

  1. 【Tensorflow教程笔记】常用模块 tf.function :图执行模式

    基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DRL) Keras Pipeli ...

  2. 【Tensorflow教程笔记】TensorFlow Datasets 数据集载入

    Tensorflow教程笔记 基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DR ...

  3. 【Tensorflow教程笔记】深度强化学习(DRL)

    基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DRL) Keras Pipeli ...

  4. TensorFlow学习笔记02:使用tf.data读取和保存数据文件

    TensorFlow学习笔记02:使用tf.data读取和保存数据文件 使用`tf.data`读取和写入数据文件 读取和写入csv文件 写入csv文件 读取csv文件 读取和保存TFRecord文件 ...

  5. 【TensorFlow】TensorFlow函数精讲之tf.train.ExponentialMovingAverage()

    tf.train.ExponentialMovingAverage来实现滑动平均模型. 格式: tf.train.ExponentialMovingAverage(decay,num_step) 参数 ...

  6. TensorFlow 实战(二)—— tf.train(优化算法)

    Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...

  7. tensorflow || 滑动平均的理解--tf.train.ExponentialMovingAverage

    1 滑动平均的理解 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以用来估计变 ...

  8. 【TensorFlow】TensorFlow函数精讲之tf.train.exponential_decay()

    tf.train.exponential_decay实现指数衰减率.通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定. tf. ...

  9. Java实用教程笔记 常用实用类

    常用实用类 8.1 String类 判断引用是否为同一变量 常量池vs非常量池(动态区) equals 输出对象为内存地址的方式 "==" 运算结果为ture/false的比较方式 ...

最新文章

  1. 【好程序员笔记分享】——下拉刷新和上拉加载更多
  2. Spring MVC 测试 | 模拟提交表单
  3. 简评黑客利器——中国菜刀
  4. 【计算机网络】物理层传输介质
  5. 【Flink】Flink TimeServer 之 timerService().registerProcessingTimeTimer
  6. javaWeb(入门基础详解)
  7. [转载] 浅析Java OutOfMemoryError
  8. 如何用4K YouTube转换视频为MP3,同时设置成MP3桌面播放器?
  9. linux服务器ftp连接失败的原因,错误:无法与 SFTP 服务器建立 FTP 连接
  10. CPC23-4-K. 喵喵的神数 (数论 Lucas定理)
  11. 水彩画笔效果PS笔刷
  12. python删除表格第一行不动_Excel教程,教你如何设置表格第一行和第一列固定不动,一直显示...
  13. [数据库]-- mysql 获取昨天日期、今天日期、明天日期以及前一个小时和后一个小时的时间
  14. 本人常用软件(工具)
  15. InetAddress.isReachable的超时设置
  16. 如何用Python画奥运五环——circle()
  17. 触摸屏 服务器系统,IP网络触摸屏服务器 SK1606
  18. 【转载】Python遍历pandas数据方法总结
  19. MATLAB代码:储能参与调峰调频联合优化模型
  20. element的el-table列标题添加自定义图标

热门文章

  1. 【苹果iMessage相册推信息推】 重要用于安装背面必要安装的watchman
  2. Python 二分查找(涉及递归思想)
  3. jquery实现的折叠式菜单(手风琴式菜单)
  4. i18n 本地化_国际化与本地化(i18n与l10n)
  5. linux内核 lvs,Linux中 LVS 的介绍
  6. .NET开发框架(四)-服务器IIS安装教程
  7. 你会在终端下快速获取公网 IP 地址吗,学会这些技巧后你就游刃有余了!
  8. 汉诺塔算法python_python实现汉诺塔算法
  9. SQL语句排序、分页、多表查询
  10. javascript-由初速度和仰角求射程