假如我定义了一个网络进行训练:

import tensorflow as tf
import numpy as npclass MNISTLoader():def __init__(self):mnist = tf.keras.datasets.mnist(self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()# MNIST中的图像默认为uint8(0-255的数字)。以下代码将其归一化到0-1之间的浮点数,并在最后增加一维作为颜色通道self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]self.train_label = self.train_label.astype(np.int32)    # [60000]self.test_label = self.test_label.astype(np.int32)      # [10000]self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]def get_batch(self, batch_size):# 从数据集中随机取出batch_size个元素并返回index = np.random.randint(0, self.num_train_data, batch_size)return self.train_data[index, :], self.train_label[index]class MLP(tf.keras.Model):def __init__(self):super().__init__()self.flatten = tf.keras.layers.Flatten()    # Flatten层将除第一维(batch_size)以外的维度展平self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(units=10)def call(self, inputs):         # [batch_size, 28, 28, 1]x = self.flatten(inputs)    # [batch_size, 784]x = self.dense1(x)          # [batch_size, 100]x = self.dense2(x)          # [batch_size, 10]output = tf.nn.softmax(x)return outputnum_epochs = 5
batch_size = 50
learning_rate = 0.001model = MLP()
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
checkpoint = tf.train.Checkpoint(myAwesomeModel=model) # 实例化Checkpoint,取名为myAwesomeModel,设置保存对象为model
for batch_index in range(num_batches):X, y = data_loader.get_batch(batch_size)with tf.GradientTape() as tape:y_pred = model(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch %d: loss %f" % (batch_index, loss.numpy()))grads = tape.gradient(loss, model.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))if batch_index % 100 == 0:                              # 每隔100个Batch保存一次path = checkpoint.save('./save/model.ckpt')         # 保存模型参数到文件print("model saved to %s" % path)
batch 0: loss 2.399996
model saved to ./save/model.ckpt-1
batch 1: loss 2.191273
batch 2: loss 2.172761
batch 3: loss 2.080019
batch 4: loss 1.949049
batch 5: loss 1.927595
batch 6: loss 1.862676
batch 7: loss 1.937338
batch 8: loss 1.783551
batch 9: loss 1.693900
batch 10: loss 1.608254
batch 11: loss 1.595733
batch 12: loss 1.483589
batch 13: loss 1.745229
batch 14: loss 1.605927
batch 15: loss 1.411374
batch 16: loss 1.414417......

这个时候每隔100个batch就报存了一次参数。假如说我们的电脑突然遇到故障了,下一次我不想再重头训练怎么办?这个时候就可以导入原先保存的最新的checkpoint再训练:

import tensorflow as tf
import numpy as npclass MNISTLoader():def __init__(self):mnist = tf.keras.datasets.mnist(self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()# MNIST中的图像默认为uint8(0-255的数字)。以下代码将其归一化到0-1之间的浮点数,并在最后增加一维作为颜色通道self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]self.train_label = self.train_label.astype(np.int32)    # [60000]self.test_label = self.test_label.astype(np.int32)      # [10000]self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]def get_batch(self, batch_size):# 从数据集中随机取出batch_size个元素并返回index = np.random.randint(0, self.num_train_data, batch_size)return self.train_data[index, :], self.train_label[index]class MLP(tf.keras.Model):def __init__(self):super().__init__()self.flatten = tf.keras.layers.Flatten()    # Flatten层将除第一维(batch_size)以外的维度展平self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(units=10)def call(self, inputs):         # [batch_size, 28, 28, 1]x = self.flatten(inputs)    # [batch_size, 784]x = self.dense1(x)          # [batch_size, 100]x = self.dense2(x)          # [batch_size, 10]output = tf.nn.softmax(x)return outputnum_epochs = 3
batch_size = 5
learning_rate = 0.001
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)model = MLP() # 实例化模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model) # myAwesomeModel,这是你原来保存的checkpoint时的model名字
checkpoint.restore(tf.train.latest_checkpoint('./save')) # 恢复最新的checkpointfor batch_index in range(num_batches):X, y = data_loader.get_batch(batch_size)with tf.GradientTape() as tape:y_pred = model(X)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch %d: loss %f" % (batch_index, loss.numpy()))grads = tape.gradient(loss, model.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
WARNING:tensorflow:Unresolved object in checkpoint: (root).myAwesomeModel.dense1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).myAwesomeModel.dense1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).myAwesomeModel.dense2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).myAwesomeModel.dense2.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
batch 0: loss 0.023930
batch 1: loss 0.014747
batch 2: loss 0.005468
batch 3: loss 0.000008
batch 4: loss 0.000106
batch 5: loss 0.000138
batch 6: loss 0.000322
batch 7: loss 0.000636
batch 8: loss 0.000061......

可以看到,它不是从头开始训练,loss的初始值就只有0.02了

当你要在测试集上测试的时候,也可以直接恢复之后使用:

model_to_be_restored = MLP()
# 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restored
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)
checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数
y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)
print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))
test accuracy: 0.975400

好了,就是这样,如果对您有帮助就点个赞吧。

tensorflow2的checkpoint恢复训练相关推荐

  1. 【慕课网】人工智能-语音入门|公开课知识整理

    人工智能-语音入门 该博客是慕课网视频教程的笔者自我小结,原视频传送门 References: 语音增强理论与实践-[美]罗艾洲等 [译]高毅等 WAV和PCM的关系和区别 AudioSet数据集 知 ...

  2. MMEngine理解

    MMEngine理解 1 简介 1.1 架构 1.2 模块介绍 1.2.1 核心模块与相关组件 1.2.1 公共基础模块 2 上手示例 2.1 构建模型 2.2 构建数据集和数据加载器 2.3 构建评 ...

  3. TensorFlow2 -官方教程 :保存和恢复模型

    文章目录 准备工作:安装,导入,获取数据集,定义model 在训练期间保存模型(以 checkpoints 形式保存) Checkpoint 回调用法 checkpoint 回调选项 这些文件是什么? ...

  4. 【NLP】NLP实战篇之tensorflow2.0快速入门

    修改上版代码格式问题.Tensorflow2.0跟Keras紧密结合,相比于1.0版本,2.0可以更快上手,并且能更方便找到需要的api.本文中以IMDB文本分类为例,简单介绍了从数据下载.预处理.建 ...

  5. 深度学习-Tensorflow2.2-模型保存与恢复{9}-保存与恢复-21

    模型保存(tf.keras保存模型) 保存 Tf.Keras 模型保存为 HDF5 文件 Keras 使用了 h5py Python 包. h5py 是 Keras 的依赖项,应默认被安装 保存/加载 ...

  6. tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)

    最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...

  7. Tensorflow2.0

    Tensorflow2.0 Tensorflow 简介 Tensorflow是什么 Google开源软件库 采用数据流图,用于数值计算 支持多平台 GPU CPU 移动设备 最初用于深度学习,变得通用 ...

  8. TensorFlow2.0教程-使用keras训练模型

    TensorFlow2.0教程-使用keras训练模型 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details ...

  9. Tensorflow2.0的简单GCN代码(使用cora数据集)

    废话不多说先放代码. 本文的代码需要两个部分组成--自定义的GCN层的GCN_layer和训练代码train. 首先是自定义的GCN层的GCN_layer.py: import tensorflow ...

最新文章

  1. 检查传递给Bash脚本的参数数量
  2. Java对象的实例化
  3. Lua require 相对路径
  4. 开源Java反编译工具
  5. 吉林大学计算机与科学专业排名,吉林大学专业排名 哪些王牌专业推荐就读
  6. BZOJ1298:[SCOI2009]骰子的学问
  7. 精通SpringBoot---整合RabbitMQ消息队列
  8. php爆数据库,php+MySql注入非暴力爆数据库表段
  9. wubiuefi-支持新版本ubuntu的wubi
  10. eclipse 中 project facet 的作用
  11. bim 水利枢纽 运维_BIM技术与现代化建筑运维管理
  12. ASO优化方法_获取ASO关键词指数接口
  13. 096: 复习习题 求导题型 Case4:变积分限函数导数;Case5:高阶导数;Case6:分段函数求导
  14. php微信公众号报修系统,微信公众号如何实现在线报修系统?
  15. 2022年劳务员-通用基础(劳务员)考试题库及答案
  16. ubuntu18.04+cuda9.0+lenovo y430p(GTX850M)亲测可用
  17. 如何用 canvas 画出分形图
  18. 音频可视化图形引擎—Specinker
  19. 被误解的tinyint(1)
  20. Golang 函数定义 不定数目参数定义 多个输入参数函数 (...) 不定参数 可变参数 定义

热门文章

  1. Led智慧照明系统功能
  2. 详解计算机内存及基于内存理解的几种数据结构
  3. python编程代码-python编程代码
  4. AGC012B Splatter Painting
  5. 新手也能每天挣300,今日头条的这5个隐藏玩法,你知道吗?
  6. 可以用购买的专利做高新技术企业申请吗?
  7. 请求https接口时报错:Caused by SSLError(SSLError(1, u‘[SSL: CERTIFICATE_VERIFY_FAILED] certificat,安装certifi
  8. RabbitMQ学习笔记
  9. roc_auc_score()、auc()和roc_curve()
  10. 求职经历--ThoughtWorks