【Tensorflow教程笔记】常用模块 tf.train.Checkpoint :变量的保存与恢复
基础
TensorFlow 基础
TensorFlow 模型建立与训练
基础示例:多层感知机(MLP)
卷积神经网络(CNN)
循环神经网络(RNN)
深度强化学习(DRL)
Keras Pipeline
自定义层、损失函数和评估指标
常用模块 tf.train.Checkpoint :变量的保存与恢复
常用模块 TensorBoard:训练过程可视化
常用模块 tf.data :数据集的构建与预处理
常用模块 TFRecord :TensorFlow 数据集存储格式
常用模块 tf.function :图执行模式
常用模块 tf.TensorArray :TensorFlow 动态数组
常用模块 tf.config:GPU 的使用与分配部署
TensorFlow 模型导出
TensorFlow Serving
TensorFlow Lite大规模训练与加速
TensorFlow 分布式训练
使用 TPU 训练 TensorFlow 模型扩展
TensorFlow Hub 模型复用
TensorFlow Datasets 数据集载入附录
强化学习基础简介
目录
- tf.train.Checkpoint
- 保存参数
- 载入之前保存的参数
- 保存变量+恢复变量
- `tf.train.Checkpoint` VS `tf.train.Saver`
- 实例
- 使用 `tf.train.CheckpointManager` 删除旧的 Checkpoint 以及自定义文件编号
Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),请参考 “部署” 章节中的 SavedModel 。
tf.train.Checkpoint
很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用 Python 的序列化模块 pickle
存储 model.variables
。但不幸的是,TensorFlow 的变量类型 ResourceVariable
并不能被序列化。
好在 TensorFlow 提供了 tf.train.Checkpoint
这一强大的变量保存与恢复类,可以使用其 save()
和 restore()
方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizer
、 tf.Variable
、 tf.keras.Layer
或者 tf.keras.Model
实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:
checkpoint = tf.train.Checkpoint(model=model)
这里 tf.train.Checkpoint()
接受的初始化参数比较特殊,是一个 **kwargs
。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model
的模型实例 model
和一个继承 tf.train.Optimizer
的优化器 optimizer
,我们可以这样写:
checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)
这里 myAwesomeModel
是我们为待保存的模型 model
所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。
保存参数
接下来,当模型训练完成需要保存的时候,使用:
checkpoint.save(save_path_with_prefix)
就可以。 save_path_with_prefix
是保存文件的目录 + 前缀。
- 例如,在源代码目录建立一个名为 save 的文件夹并调用一次
checkpoint.save('./save/model.ckpt')
,我们就可以在 save 目录下发现名为checkpoint
、model.ckpt-1.index
、model.ckpt-1.data-00000-of-00001
的三个文件,这些文件就记录了变量信息。checkpoint.save()
方法可以运行多次,每运行一次都会得到一个 .index 文件和 .data 文件,序号依次累加。
载入之前保存的参数
当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:
model_to_be_restored = MyModel() # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)
即可恢复模型变量。 save_path_with_prefix_and_index
是之前保存的文件的目录 + 前缀 + 编号。
- 例如,调用
checkpoint.restore('./save/model.ckpt-1')
就可以载入前缀为 model.ckpt ,序号为 1 的文件来恢复模型。
当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path)
这个辅助函数返回目录下最近一次 checkpoint 的文件名。
- 例如如果 save 目录下有
model.ckpt-1.index
到model.ckpt-10.index
的 10 个保存文件,tf.train.latest_checkpoint('./save')
即返回./save/model.ckpt-10
。
保存变量+恢复变量
总体而言,恢复与保存变量的典型代码框架如下:
# train.py 模型训练阶段model = MyModel()
# 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入)
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型训练代码)
# 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)
checkpoint.save('./save/model.ckpt')
# test.py 模型使用阶段model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model) # 实例化Checkpoint,指定恢复对象为model
checkpoint.restore(tf.train.latest_checkpoint('./save')) # 从文件恢复模型参数
# 模型使用代码
tf.train.Checkpoint
VS tf.train.Saver
tf.train.Checkpoint
与以前版本常用的 tf.train.Saver
相比,强大之处在于其支持在即时执行模式下 “延迟” 恢复变量。
具体而言,当调用了 checkpoint.restore()
,但模型中的变量还没有被建立的时候,Checkpoint 可以等到变量被建立的时候再进行数值的恢复。即时执行模式下,模型中各个层的初始化和变量的建立是在模型第一次被调用的时候才进行的(好处在于可以根据输入的张量形状而自动确定变量形状,无需手动指定)。这意味着当模型刚刚被实例化的时候,其实里面还一个变量都没有,这时候使用以往的方式去恢复变量数值是一定会报错的。比如,你可以试试在 train.py 调用 tf.keras.Model
的 save_weight()
方法保存 model 的参数,并在 test.py 中实例化 model 后立即调用 load_weight()
方法,就会出错,只有当调用了一遍 model 之后再运行 load_weight()
方法才能得到正确的结果。可见, tf.train.Checkpoint
在这种情况下可以给我们带来相当大的便利。另外, tf.train.Checkpoint
同时也支持图执行模式。
实例
最后提供一个实例,以前章的 多层感知机模型 为例展示模型变量的保存和载入:
import tensorflow as tf
import numpy as np
import argparse
from zh.model.mnist.mlp import MLP
from zh.model.utils import MNISTLoaderparser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', default='train', help='train or test')
parser.add_argument('--num_epochs', default=1)
parser.add_argument('--batch_size', default=50)
parser.add_argument('--learning_rate', default=0.001)
args = parser.parse_args()
data_loader = MNISTLoader()def train():model = MLP()optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)checkpoint = tf.train.Checkpoint(myAwesomeModel=model) # 实例化Checkpoint,设置保存对象为modelfor batch_index in range(1, num_batches+1): X, y = data_loader.get_batch(args.batch_size)with tf.GradientTape() as tape:y_pred = model(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch %d: loss %f" % (batch_index, loss.numpy()))grads = tape.gradient(loss, model.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))if batch_index % 100 == 0: # 每隔100个Batch保存一次path = checkpoint.save('./save/model.ckpt') # 保存模型参数到文件print("model saved to %s" % path)def test():model_to_be_restored = MLP()# 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restoredcheckpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) checkpoint.restore(tf.train.latest_checkpoint('./save')) # 从文件恢复模型参数y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))if __name__ == '__main__':if args.mode == 'train':train()if args.mode == 'test':test()
在代码目录下建立 save 文件夹并运行代码进行训练后,save 文件夹内将会存放每隔 100 个 batch 保存一次的模型变量数据。在命令行参数中加入 --mode=test
并再次运行代码,将直接使用最后一次保存的变量值恢复模型并在测试集上测试模型性能,可以直接获得 95% 左右的准确率。
使用 tf.train.CheckpointManager
删除旧的 Checkpoint 以及自定义文件编号
在模型的训练过程中,我们往往每隔一定步数保存一个 Checkpoint 并进行编号。不过很多时候我们会有这样的需求:
- 在长时间的训练后,程序会保存大量的 Checkpoint,但我们只想保留最后的几个 Checkpoint;
- Checkpoint 默认从 1 开始编号,每次累加 1,但我们可能希望使用别的编号方式(例如使用当前 Batch 的编号作为文件编号)。
这时,我们可以使用 TensorFlow 的 tf.train.CheckpointManager
来实现以上需求。具体而言,在定义 Checkpoint 后接着定义一个 CheckpointManager:
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)
此处, directory
参数为文件保存的路径, checkpoint_name
为文件名前缀(不提供则默认为 ckpt ), max_to_keep
为保留的 Checkpoint 数目。
在需要保存模型的时候,我们直接使用 manager.save()
即可。如果我们希望自行指定保存的 Checkpoint 的编号,则可以在保存时加入 checkpoint_number
参数。例如 manager.save(checkpoint_number=100)
。
以下提供一个实例,展示使用 CheckpointManager 限制仅保留最后三个 Checkpoint 文件,并使用 batch 的编号作为 Checkpoint 的文件编号。
import tensorflow as tf
import numpy as np
import argparse
from zh.model.mnist.mlp import MLP
from zh.model.utils import MNISTLoaderparser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', default='train', help='train or test')
parser.add_argument('--num_epochs', default=1)
parser.add_argument('--batch_size', default=50)
parser.add_argument('--learning_rate', default=0.001)
args = parser.parse_args()
data_loader = MNISTLoader()def train():model = MLP()optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)checkpoint = tf.train.Checkpoint(myAwesomeModel=model) # 使用tf.train.CheckpointManager管理Checkpointmanager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)for batch_index in range(1, num_batches):X, y = data_loader.get_batch(args.batch_size)with tf.GradientTape() as tape:y_pred = model(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch %d: loss %f" % (batch_index, loss.numpy()))grads = tape.gradient(loss, model.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))if batch_index % 100 == 0:# 使用CheckpointManager保存模型参数到文件并自定义编号path = manager.save(checkpoint_number=batch_index) print("model saved to %s" % path)def test():model_to_be_restored = MLP()checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) checkpoint.restore(tf.train.latest_checkpoint('./save'))y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))if __name__ == '__main__':if args.mode == 'train':train()if args.mode == 'test':test()
【Tensorflow教程笔记】常用模块 tf.train.Checkpoint :变量的保存与恢复相关推荐
- 【Tensorflow教程笔记】常用模块 tf.function :图执行模式
基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DRL) Keras Pipeli ...
- 【Tensorflow教程笔记】TensorFlow Datasets 数据集载入
Tensorflow教程笔记 基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DR ...
- 【Tensorflow教程笔记】深度强化学习(DRL)
基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DRL) Keras Pipeli ...
- TensorFlow学习笔记02:使用tf.data读取和保存数据文件
TensorFlow学习笔记02:使用tf.data读取和保存数据文件 使用`tf.data`读取和写入数据文件 读取和写入csv文件 写入csv文件 读取csv文件 读取和保存TFRecord文件 ...
- 【TensorFlow】TensorFlow函数精讲之tf.train.ExponentialMovingAverage()
tf.train.ExponentialMovingAverage来实现滑动平均模型. 格式: tf.train.ExponentialMovingAverage(decay,num_step) 参数 ...
- TensorFlow 实战(二)—— tf.train(优化算法)
Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...
- tensorflow || 滑动平均的理解--tf.train.ExponentialMovingAverage
1 滑动平均的理解 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以用来估计变 ...
- 【TensorFlow】TensorFlow函数精讲之tf.train.exponential_decay()
tf.train.exponential_decay实现指数衰减率.通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定. tf. ...
- Java实用教程笔记 常用实用类
常用实用类 8.1 String类 判断引用是否为同一变量 常量池vs非常量池(动态区) equals 输出对象为内存地址的方式 "==" 运算结果为ture/false的比较方式 ...
最新文章
- 【好程序员笔记分享】——下拉刷新和上拉加载更多
- Spring MVC 测试 | 模拟提交表单
- 简评黑客利器——中国菜刀
- 【计算机网络】物理层传输介质
- 【Flink】Flink TimeServer 之 timerService().registerProcessingTimeTimer
- javaWeb(入门基础详解)
- [转载] 浅析Java OutOfMemoryError
- 如何用4K YouTube转换视频为MP3,同时设置成MP3桌面播放器?
- linux服务器ftp连接失败的原因,错误:无法与 SFTP 服务器建立 FTP 连接
- CPC23-4-K. 喵喵的神数 (数论 Lucas定理)
- 水彩画笔效果PS笔刷
- python删除表格第一行不动_Excel教程,教你如何设置表格第一行和第一列固定不动,一直显示...
- [数据库]-- mysql 获取昨天日期、今天日期、明天日期以及前一个小时和后一个小时的时间
- 本人常用软件(工具)
- InetAddress.isReachable的超时设置
- 如何用Python画奥运五环——circle()
- 触摸屏 服务器系统,IP网络触摸屏服务器 SK1606
- 【转载】Python遍历pandas数据方法总结
- MATLAB代码:储能参与调峰调频联合优化模型
- element的el-table列标题添加自定义图标
热门文章
- 【苹果iMessage相册推信息推】 重要用于安装背面必要安装的watchman
- Python 二分查找(涉及递归思想)
- jquery实现的折叠式菜单(手风琴式菜单)
- i18n 本地化_国际化与本地化(i18n与l10n)
- linux内核 lvs,Linux中 LVS 的介绍
- .NET开发框架(四)-服务器IIS安装教程
- 你会在终端下快速获取公网 IP 地址吗,学会这些技巧后你就游刃有余了!
- 汉诺塔算法python_python实现汉诺塔算法
- SQL语句排序、分页、多表查询
- javascript-由初速度和仰角求射程