简介


可以在训练过程中和训练完成后保存模型,这样就可以很方便地恢复和重用模型,节省模型训练时间。

这样也便于别人使用你的模型,一般有两种方式共享模型:

  • 创建模型的源码
  • 训练好的模型(包括权重、参数等)

这里主要使用第二种方式。

使用的框架是TensorFlow2.4的高阶API:Keras进行模型训练。

验证环境


假设你已经安装好了TensorFlow2.4的运行环境。

如未安装,请稳步 install

安装依赖:

pip install -q pyyaml h5py # Required to save models in HDF5 format

运行以下代码验证:

import osimport tensorflow as tf
from tensorflow import kerasprint(tf.version.VERSION)

输出:

2.4.1

训练模型


使用mnist数据集进行数字分类,代码如下:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()train_labels = train_labels[:1000]
test_labels = test_labels[:1000]train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0# Define a simple sequential model
def create_model():model = tf.keras.models.Sequential([keras.layers.Dense(512, activation='relu', input_shape=(784,)),keras.layers.Dropout(0.2),keras.layers.Dense(10)])model.compile(optimizer='adam',loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[tf.metrics.SparseCategoricalAccuracy()])return model# Create a basic model instance
model = create_model()# Display the model's architecture
model.summary()

以上代码即可完成模型的训练(为了说明问题,只使用了前1000个元素,节省时间),可以从输出中看到训练过程和结果。

训练中保存快照


可以在训练过程中保存模型,以便后续继续执行。

这需要使用回调函数:tf.keras.callbacks.ModelCheckpoint。

创建回调的代码如下:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1)# Train the model with the new callback
model.fit(train_images, train_labels,  epochs=10,validation_data=(test_images, test_labels),callbacks=[cp_callback])  # Pass callback to training# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.

这会创建一个TensorFlow节点文件,并在每轮训练后更新。

只要两个模型共用相同的网络结构,它们就可以共用权重。所以仅从权重恢复模型的时候,需要以原始模型相同的网络结构创建这个模型,再设置权重。

重建一个新的、未训练的模型,它的精度约10%:

# Create a basic model instance
model = create_model()# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))

加载刚才已经训练的权重并重新评估精确度,精度可恢复到原来的水平:

# Loads the weights
model.load_weights(checkpoint_path)# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

也可以手动保存权重:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')# Create a new model instance
model = create_model()# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

保存整个模型


这个是使用的比较多的方式。

一般会使用Keras训练好模型,保存为文件,再使用c++等方式加载模型,用于生产环境。

使用 model.save 就可以把模型的结构、权重和训练配置保存到单个文件或文件夹。

完整的模型可以保存为两种格式:

  • SavedModel,这是TF2.x的默认存储格式
  • HDF5
SavedModel格式

保存代码如下:

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)# Save the entire model as a SavedModel.
os.system('mkdir -p saved_model')
model.save('saved_model/my_model')

SavedModel格式生成的文件夹包括pb文件和TensorFlow节点文件。

加载也很简单:

new_model = tf.keras.models.load_model('saved_model/my_model')# Check its architecture
new_model.summary()# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))print(new_model.predict(test_images).shape)
HDF5格式

这是Keras的基础格式,保存代码如下:

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')

加载:

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')# Show the model architecture
new_model.summary()loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

Keras会检查网络结构,并保存模型相关的所有内容:

  • 权重值
  • 模型结构
  • 模型训练配置,也就是传入compile的参数
  • 优化器和状态

SavedModel和HDF5格式的关键区别在于:

  • HDF5使用对象配置保存模型框架
  • SavedModel保存的是可执行图

因此,SavedModel不用查询源码即可保存诸如子类模型和定制层等定制化的对象。而HDF5就要复杂一些。

具体的不再详述,可参考相关资料。

参考资料

https://tensorflow.google.cn/tutorials/keras/save_and_load

保存与加载Keras训练好的模型相关推荐

  1. C++ 和 OpenCV 实现卷积神经网络并加载 Keras 训练好的参数进行预测

    C++ 和 OpenCV 实现卷积神经网络并加载 Keras 训练好的参数进行预测 一. 背景 二. Keras 定义神经网络结构 channels_first 与 channels_last cha ...

  2. java加载tensorflow训练的PB模型记录

    java加载tensorflow训练的PB模型记录 python训练 1. 模型的输入输出定义 2. 训练时保存模型的方法 java加载模型 1.maven依赖 2. Java代码实例 tensor注 ...

  3. OpenCvSharp (C# OpenCV) DNN模块加载自己训练的TensorFlow模型做目标检测(含手势识别、骰子识别、菜品识别)(附源码)

    本文作者Color Space,文章未经作者允许禁止转载! 本文将介绍OpenCVSharp DNN模块加载自己训练的TensorFlow模型做目标检测(含手势识别.骰子识别.菜品识别)! 前言: 下 ...

  4. Keras如何保存、加载Keras模型

    链接 Keras中文文档 一.如何保存 Keras 模型? 1.保存/加载整个模型(结构 + 权重 + 优化器状态) 不建议使用 pickle 或 cPickle 来保存 Keras 模型. 你可以使 ...

  5. tensorflow加载预训练好的模型图(.pb文件)

    千万不要试图在jupyter notebook中打开.pb模型文件,否则你会得到: 这时候我以为shi编码的问题,开始转换编码,转换完成后发现shi乱ma. 后来网上查了,.pb文件里面存储的shi模 ...

  6. Java加载sklearn训练好的模型进行预测(无法搞定)

    参考文献主要是[1][2] [2]中代码各种类函数都是自定义的,放弃吧 转攻向[1] --------------------------------------------------------- ...

  7. python如何保存训练好的模型_Python机器学习7:如何保存、加载训练好的机器学习模型...

    本文将介绍如何使用scikit-learn机器学习库保存Python机器学习模型.加载已经训练好的模型.学会了这个,你才能够用已有的模型做预测,而不需要每次都重新训练模型. 本文将使用两种方法来实现模 ...

  8. Tensorflow 2.x(keras)源码详解之第十章:keras中的模型保存与加载(详解Checkpointmd5模型序列化)

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

  9. TensorFlow2.0 —— 模型保存与加载

    目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...

最新文章

  1. spring 事务 对象保存之后,修改的时候,自动更新提交?
  2. 建立linux两用户之间的信任关系
  3. C指针原理(39)-GLIB
  4. springMVC-配置Bean
  5. v-slot vue2.6新增指令使用指南
  6. c++十进制转二进制_二进制与十进制如何互相转换?
  7. 在运行时打开GC日志记录
  8. 在XML文件中定义动画(1)
  9. Unity3D(五)渲染管线
  10. 四种超实用的超级记忆法-数字定桩法,借助语句定桩法,标题定桩法,记忆宫殿法
  11. 抖音特效转场模板预设 Premiere调色预设 PR光效转场等900个特效包(含教程及转场音效)
  12. gitgub常用按钮说明
  13. java打造手机远程控制电脑之详细教程
  14. java编写投票功能需求分析
  15. 用Python自制随机点名程序,压迫感来了~
  16. 极限编程-拥抱变化阅读感想(二)
  17. 记录下我磕磕碰碰的三个月找工作经历,offer拿到手软
  18. git did not exit cleanly (exit code 1)
  19. UCOS III 任务堆栈理解
  20. 从F型网页浏览看用户对网页的浏览习惯

热门文章

  1. FIFO(二):FIFO工作原理
  2. 图像风格迁移与快速风格迁移的对比(感知损失)
  3. Win7启动无桌面(explorer.exe)
  4. 新思路计算机等级考试50套,新思路计算机一级选择题50套(含答案)解析.doc
  5. 2020A证(安全员)实操考试视频及A证(安全员)操作证考试
  6. 获取汉字串的拼音助记码
  7. 【论文研读】-DiPETrans: A Framework for Distributed Parallel Execution of Transactions of Blocks in BLC
  8. 长虹变频空调整机不工作维修案例分析
  9. MyBatis学习(1)
  10. 飓鼎玩笑傲江湖服务器维护,6月25日《问道》笑傲、华山更新维护