文章目录

  • 1. Checkpoint 保存变量
  • 2. TensorBoard 训练过程可视化

学习于:简单粗暴 TensorFlow 2

1. Checkpoint 保存变量

  • tf.train.Checkpoint 可以保存 tf.keras.optimizertf.Variabletf.keras.Layertf.keras.Model
path = "./checkp.ckpt"
# 建立一个 checkpoint
mycheckpoint = tf.train.Checkpoint(mybestmodel=mymodel) # 接受 **kwargs 键值对
mycheckpoint.save(path)

  • 恢复指定模型变量
# 待恢复参数的模型
restored_model = LinearModel()
# mybestmodel 名字任意写,跟下面恢复时保持一致
mycheckpoint = tf.train.Checkpoint(mybestmodel=restored_model)
# 恢复指定的变量
path = "./checkp.ckpt-1"
mycheckpoint.restore(path)X_test = tf.constant([[5.1], [6.1]])
res = restored_model.predict(X_test)
print(res)
# [[10.182168] 前一节的线性回归模型
#  [12.176777]]
  • 恢复最近的模型,自动选定目录下最新的存档(后缀数字最大的)
mycheckpoint.restore(tf.train.latest_checkpoint("./"))
  • 管理保存的参数,有时不需要保存太多,占空间
mycheckpoint = tf.train.Checkpoint(mybestmodel=mymodel)  # 接受 **kwargs 键值对
manager = tf.train.CheckpointManager(mycheckpoint, directory="./",checkpoint_name='checkp.ckpt',max_to_keep=2) # 最多保存k个最新的for loop:manager.save() # 自动递增编号manager.save(checkpoint_number=idx) # 指定编号

2. TensorBoard 训练过程可视化

  • summary_writer = tf.summary.create_file_writer(logdir=log_dir)
  • tf.summary.scalar(name='loss', data=loss, step=idx)
  • tf.summary.trace_on(profiler=True)
for loop:with summary_writer.as_default():tf.summary.scalar(name='loss', data=loss, step=idx)
with summary_writer.as_default():tf.summary.trace_export(name='model_trace', step=0,profiler_outdir=log_dir)
  • 示例
import tensorflow as tf
import numpy as npclass MNistLoader():def __init__(self):data = tf.keras.datasets.mnist# 加载数据(self.train_data, self.train_label), (self.test_data, self.test_label) = data.load_data()# 扩展维度,灰度图1通道 [batch_size, 28, 28, chanels=1]self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)self.train_label = self.train_label.astype(np.int32)self.test_label = self.test_label.astype(np.int32)# 样本个数self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]def get_batch(self, batch_size):# 从训练集里随机取出 batch_size 个样本idx = np.random.randint(0, self.num_train_data, batch_size)return self.train_data[idx, :], self.train_label[idx]# 自定义多层感知机模型
class MLPmodel(tf.keras.Model):def __init__(self):super().__init__()# 除第一维以外的维度展平self.flatten = tf.keras.layers.Flatten()self.dense1 = tf.keras.layers.Dense(units=100, activation='relu')self.dense2 = tf.keras.layers.Dense(units=10)def call(self, input):x = self.flatten(input)x = self.dense1(x)x = self.dense2(x)output = tf.nn.softmax(x)return output# %%num_epochs = 5
batch_size = 50
learning_rate = 1e-4
log_dir = './log' # 日志目录
mymodel = MLPmodel()# %%
data_loader = MNistLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)# 实例化记录器
summary_writer = tf.summary.create_file_writer(logdir=log_dir)
# 开启 trace,(可选),记录训练时的大量信息(图的结构,耗时等)
tf.summary.trace_on(profiler=True)for idx in range(num_batches):X, y = data_loader.get_batch(batch_size)with tf.GradientTape() as tape:y_pred = mymodel(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch {}, loss {}".format(idx, loss.numpy()))# 记录器记录losswith summary_writer.as_default():tf.summary.scalar(name='loss', data=loss, step=idx)grads = tape.gradient(loss, mymodel.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, mymodel.variables))with summary_writer.as_default():tf.summary.trace_export(name='model_trace', step=0,profiler_outdir=log_dir)
  • 开始训练,命令行进入 可视化界面 tensorboard --logdir=./log
  • 点击命令行中的链接,打开浏览器,查看训练曲线
  • 若重新训练,请删除 log 文件,或设置别的 log 路径,重新 cmd 开启 浏览器

TensorFlow 2.0 - Checkpoint 保存变量、TensorBoard 训练可视化相关推荐

  1. 使用TensorFlow 2.0+和Keras实现AlexNet CNN架构

    技术 (Technical) 介绍 (Introduction) The main content of this article will present how the AlexNet Convo ...

  2. Keras与Tensorflow2.0入门(6)模型可视化与tensorboard的使用

    文章目录 1. 前言 1.1 Plot_model 1.2 History 1.3 自定义评估函数 PRF值的计算方法 AUC的计算方法 2. tensorboard 2.1 tensorboard是 ...

  3. Tensorflow 2.0的这些新设计,你适应好了吗?

    [新智元导读]几天前,Tensorflow刚度过自己的3岁生日,作为当前最受欢迎的机器学习框架,Tensorflow在这个宝座上已经盘踞了近三年.无论是成熟的Keras,还是风头正盛的pytorch, ...

  4. Tensorflow 2.0的新特性

    Tensorflow 2.0的新特性 几天前,Tensorflow刚度过自己的3岁生日,作为当前最受欢迎的机器学习框架,Tensorflow在这个宝座上已经盘踞了近三年.无论是成熟的Keras,还是风 ...

  5. TensorFlow 2.0中的tf.keras和Keras有何区别?为什么以后一定要用tf.keras?

    选自pyimagesearch 作者:Adrian Rosebrock 参与:王子嘉.张倩 本文经机器之心授权转载,禁止二次转载 随着 TensorFlow 2.0 的发布,不少开发者产生了一些疑惑: ...

  6. 【Tensorflow教程笔记】常用模块 tf.train.Checkpoint :变量的保存与恢复

    基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DRL) Keras Pipeli ...

  7. TensorFlow 2.0 —— 模型训练

    目录 1.Keras版本模型训练 1.1 构造模型(顺序模型.函数式模型.子类模型) 1.2 模型训练:model.fit() 1.3 模型验证:model.evaluate() 1.4 模型预测:m ...

  8. Tensorflow详解保存模型(进阶版一):如何有选择的保存变量

    当然掌握了基础版还不够,我们来看一下进阶版一:如何有选择的保存变量: 这里还要另外涉及两个函数: tf.variable_scope("xxx") 和 tf.get_variabl ...

  9. TensorFlow中查看checkpoint文件中的变量名和对应值

    在加载模型时, 需要知道checkpoint中变量名称,一下代码可以查看TensorFlow中checkpoint文件中的变量名: #!/usr/bin/env python # -*- coding ...

最新文章

  1. 【C++】Google C++编码规范(三):智能指针
  2. 用java调用oracle存储过程总结
  3. 2021聊城二中高考成绩查询,聊城高中成绩排名2021,聊城中考分数线排行榜
  4. Caffe学习系列(12):训练和测试自己的图片
  5. iOS-UICollectionView
  6. cpucores_CPUCores怎么用 CPUCores使用方法指南_3DM单机
  7. 字符串(AC自动机(fail tree))
  8. 计算机虚拟现实技术论文好写吗,虚拟现实技术的论文
  9. 随机数插入排序c 语言,C语言程序设计100例之(22):插入排序
  10. 计算机组成原理关键路径,2020年834数据结构、计算机组成原理大纲(13页)-原创力文档...
  11. 无法识别的配置节“connectionStrings”的解决方法
  12. 迅捷pdf文档转换器注册码
  13. docker安装prestodb大数据查询引擎
  14. 大数元科技牵手中央财经大学 助力财税金融体制改革
  15. Excel表格中重要的数据如何隐藏不显示
  16. 缺少编译器要求的成员“System.Runtime.CompilerServices.ExtensionAttribute..ctor” 解决方案
  17. Unity3D Texture2D转换成Sprite格式
  18. c语言fgetc函数_C语言中的fgetc()函数与示例
  19. 量化选股策略模型大全
  20. Exception in thread “main“ java.lang.ArrayStoreException解决方案(记录一下)

热门文章

  1. python 管道队列_关于python:Multiprocessing-管道与队列
  2. python 爬虫 包_python爬虫学习之路-抓包分析
  3. 【R】语言第四课----读取文件
  4. c语言迷宫游戏怎么存放坐标,求解迷宫问题(c语言,很详细哦
  5. android pss内存,如何释放android系统中pss cache住的内存
  6. linux软件依赖库,【Linux】ubuntu系统安装及软件依赖库
  7. 第一季6:海思方案中uboot、kernel和rootfs的烧写方法
  8. linux pmap was java,jvm 《九》pmap linux 命令介绍 及使用
  9. hdu-1877(大数+进制转换)
  10. composer不成功的原因