目录

  • 1、Keras版本模型保存与加载
  • 2、自定义版本模型保存与加载
  • 3、总结

1、Keras版本模型保存与加载

  • 保存模型权重(model.save_weights)
  • 保存HDF5文件(model.save)
  • 保存pb文件(tf.saved_model)

tf.saved_model和model.save的区别在于,tf.saved_model格式的模型可以直接用来预测,但是tf.saved_model没有保存优化器配置,而model.save保存了优化器配置,所以整体更大。

保存模型权重方法仅仅保存了模型中的权重,而保存模型文件的model.save可以将模型和优化器一起保存,包括权重(weights)、模型配置(architecture)和优化器配置(optimizer configuration)。这样做的好处是,当你恢复模型时,完全不依赖于原来搭建模型的代码。

保存完整的模型有很多应用场景,比如在浏览器中使用TensorFlow.js加载运行,比如在移动设备上使用TensorFlow Lite加载运行。

具体示例:

import numpy as np
import tensorflow as tf
# 训练数据
x_train = np.random.random((1000, 32))
y_train = np.random.randint(10, size=(1000, ))
# 验证数据
x_val = np.random.random((200, 32))
y_val = np.random.randint(10, size=(200, ))
# 测试数据
x_test = np.random.random((200, 32))
y_test = np.random.randint(10, size=(200, ))# 构造模型
def get_uncompiled_model():inputs = tf.keras.Input(shape=(32,), name='digits')x = tf.keras.layers.Dense(64, activation='relu', name='dense_1')(inputs)x = tf.keras.layers.Dense(64, activation='relu', name='dense_2')(x)outputs = tf.keras.layers.Dense(10, name='predictions')(x)model = tf.keras.Model(inputs=inputs, outputs=outputs)return model# 设置优化器和损失函数
def get_compiled_model():model = get_uncompiled_model()model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-3),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['sparse_categorical_accuracy'])return modelmodel = get_compiled_model()
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_val, y_val))
# 可以通过model.summary()查看模型结构

方法一:保存模型权重(model.save_weights)

  • 有两种保存模型权重的方法,一种是保存.h5形式。
model.save_weights("adasd.h5") # 模型权重保存
model.load_weights("adasd.h5") # 模型权重加载,需要有model
model.predict(x_test)  # 模型预测
  • 另一种是保存checkpoint形式。
model.save_weights('./checkpoints/mannul_checkpoint')
model.load_weights('./checkpoints/mannul_checkpoint')
model.predict(x_test)

方法二:保存整个模型(model.save)

  • 有两种保存模型的方法,一种是保存pb形式。
# Export the model to a SavedModel
model.save('keras_model_tf_version', save_format='tf')# Recreate the exact same model
new_model = tf.keras.models.load_model('keras_model_tf_version')
new_model.predict(x_test)
  • 保存.h5格式。
model.save('keras_model_hdf5_version.h5')new_model = tf.keras.models.load_model('keras_model_hdf5_version.h5')
new_model.predict(x_test)

方法三:保存整个模型(tf.saved_model)

tf.saved_model.save(model,'tf_saved_model_version')  # 模型保存
restored_saved_model = tf.saved_model.load('tf_saved_model_version')  # 模型加载
f = restored_saved_model.signatures["serving_default"]
f(digits = tf.constant(x_test.tolist()))  # 模型预测

tf.saved_model.load加载的模型不是keras的模型,所以不能用model.predict()对测试数据进行预测。

我们可以通过!saved_model_cli show --dir tf_saved_model_version --all查看保存的模型,这种模型保存常用于模型的部署。通过上述指令可以看到模型保存的整体形势:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:signature_def['__saved_model_init_op']:The given SavedModel SignatureDef contains the following input(s):The given SavedModel SignatureDef contains the following output(s):outputs['__saved_model_init_op'] tensor_info:dtype: DT_INVALIDshape: unknown_rankname: NoOpMethod name is: signature_def['serving_default']:The given SavedModel SignatureDef contains the following input(s):inputs['digits'] tensor_info:dtype: DT_FLOATshape: (-1, 32)name: serving_default_digits:0The given SavedModel SignatureDef contains the following output(s):outputs['predictions'] tensor_info:dtype: DT_FLOATshape: (-1, 10)name: StatefulPartitionedCall:0Method name is: tensorflow/serving/predict

在模型结构中有一个signature_def['serving_default'],这个结构中有inputsoutputs,所以我们需要通过代码restored_saved_model.signatures["serving_default"]选择signature_def['serving_default'],加载好模型之后该部分就变为一个函数,我们可以通过f(digits = tf.constant(x_test.tolist()))对数据进行预测。

2、自定义版本模型保存与加载

  • 保存模型权重;
  • 保存HDF5文件(model.save);
  • 保存pb文件(tf.save_model)

通过具体示例进行模型保存:

1、模型构建

import tensorflow as tfclass MyModel(tf.keras.Model):def __init__(self, num_classes=10):super(MyModel, self).__init__(name='my_model')self.num_classes = num_classes# 定义自己需要的层self.dense_1 = tf.keras.layers.Dense(32, activation='relu')self.dense_2 = tf.keras.layers.Dense(num_classes)# 将动态图转换为静态图# 在静态图模型中,输入数据的数据维度不对、数据类型不对、数据名称不对都会报错@tf.function(input_signature=[tf.TensorSpec([None,32], tf.float32,name='digits')])def call(self, inputs):#定义前向传播# 使用在 (in `__init__`)定义的层x = self.dense_1(inputs)return self.dense_2(x)

2、数据准备、创建优化器与损失函数

import numpy as np
x_train = np.random.random((1000, 32))
y_train = np.random.random((1000, 10))
x_val = np.random.random((200, 32))
y_val = np.random.random((200, 10))
x_test = np.random.random((200, 32))
y_test = np.random.random((200, 10))# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# 损失函数
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)# 准备metrics函数
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()# 准备训练数据集
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)# 准备测试数据集
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)

3、模型训练

model = MyModel(num_classes=10)
epochs = 3
for epoch in range(epochs):print('Start of epoch %d' % (epoch,))# 遍历数据集的batch_sizefor step, (x_batch_train, y_batch_train) in enumerate(train_dataset):with tf.GradientTape() as tape:logits = model(x_batch_train)loss_value = loss_fn(y_batch_train, logits)grads = tape.gradient(loss_value, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))# 更新训练集的metricstrain_acc_metric(y_batch_train, logits)# 每200 batches打印一次.if step % 200 == 0:print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))print('Seen so far: %s samples' % ((step + 1) * 64))# 在每个epoch结束时显示metrics。train_acc = train_acc_metric.result()print('Training acc over epoch: %s' % (float(train_acc),))# 在每个epoch结束时重置训练指标train_acc_metric.reset_states()# 在每个epoch结束时运行一个验证集。for x_batch_val, y_batch_val in val_dataset:val_logits = model(x_batch_val)# 更新验证集mericsval_acc_metric(y_batch_val, val_logits)val_acc = val_acc_metric.result()val_acc_metric.reset_states()print('Validation acc: %s' % (float(val_acc),))

4、模型保存

  • 保存模型参数,两种方法:
model.save_weights("adasd.h5")
model.load_weights("adasd.h5")
model.predict(x_test)model.save_weights('./checkpoints/mannul_checkpoint')
model.load_weights('./checkpoints/mannul_checkpoint')
model.predict(x_test)
  • 保存模型和优化器参数:
model.save('path_to_my_model',save_format='tf')
new_model = tf.keras.models.load_model('path_to_my_model')
new_model.predict(x_test)
  • 保存模型用于部署:
tf.saved_model.save(model,'my_saved_model')
restored_saved_model = tf.saved_model.load('my_saved_model')
f = restored_saved_model.signatures["serving_default"]f(digits = tf.constant(x_test.tolist()) )

3、总结

Keras版本保存模型:

  • model.save_weights(保存模型权重);
  • model.save(保存模型,本地加载,可以保存为h5或者pb格式文件);
  • tf.saved_model.save(模型部署);

Keras版本加载模型:

  • model.load_weights(加载模型权重);
  • tf.keras.models.load_model(加载h5文件或者pb文件);
  • tf.saved_model.load(加载模型部署文件);

自定义模型版本保存模型:

  • model.save_weights(保存模型权重);
  • tf.saved_model.save(模型部署);
  • model.save(保存模型,本地加载,可保存为pb格式文件);

自定义模型版本加载模型:

  • model.load_weights(加载模型权重);
  • tf.saved_model.load(加载模型部署文件);
  • tf.keras.models.load_model(加载pb文件);

TensorFlow2.0 —— 模型保存与加载相关推荐

  1. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

  2. 飞桨框架2.0RC新增模型保存、加载方案,与用户场景完美匹配,更全面、更易用

    通过一段时间系统的课程学习,算法攻城狮张同学对于飞桨框架的使用越来越顺手,于是他打算在企业内尝试使用飞桨进行AI产业落地. 但是AI产业落地并不是分秒钟的事情,除了专业技能过硬,熟悉飞桨的使用外,在落 ...

  3. Pytorch —— 模型保存与加载

    1.序列化与反序列化 模型的保存与加载就是序列化与反序列化,序列化与反序列化主要将内存与硬盘之间的数据转换关系,模型在内存中以对象的形式存储,在内存中对象不能长久地保存,所以需要将训练好的模型保存到硬 ...

  4. pytorch模型保存与加载总结

    pytorch模型保存与加载总结 模型保存与加载方式 模型保存 方式一 只存储模型中的参数,该方法速度快,占用空间少(官方推荐使用) model = VGGNet() torch.save(model ...

  5. [tensorflow] 模型保存、加载与转换详解

    TensorFlow模型加载与转换详解 本次讲解主要涉及到TensorFlow框架训练时候模型文件的管理以及转换. 首先我们需要明确TensorFlow模型文件的存储格式以及文件个数: model_f ...

  6. 机器学习之模型——保存与加载

    机器学习之模型--保存与加载 知识点 fit() transform() fit_transform() 目的 API 流程 获取数据 划分数据集 标准化 预估器 保存模型 加载模型 得出模型 模型评 ...

  7. PyTorch系列入门到精通——模型保存与加载

    PyTorch系列入门到精通--模型保存与加载

  8. gensim bm25模型保存与加载

    gensim bm25模型保存与加载 1. 模型保存 2. 模型加载 20210719修改: python version:3.6.12 gensim version:3.8.3 使用bm25模型计算 ...

  9. tf第七讲:模型保存与加载(tf.train.Saver()tf.saved_model)及fine_tune(梯度冻结)

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

最新文章

  1. python 100题_python3.0练习100题——001
  2. 【Hankson 的趣味题】
  3. 【联盛德W806上手笔记】七、I2C
  4. php int最大值探究
  5. eclipse的editor does not contain a main type错误
  6. Django rest framework(7)----分页
  7. Error creating bean with name错误,spring-boot报错
  8. revit二次开发 材质类别分析
  9. SAS数据集随机抽样方法
  10. ArcGIS裁剪影像如何保持裁剪完全一致
  11. JavaScript 事件
  12. Android NFC智能卡介绍
  13. Python条件语句和循环
  14. java下载basic_Java-basic(1)
  15. HDFS 的深入了解,深入浅出,面试必备(Hadoop的三部曲——上)
  16. Java面向对象知识点总结(全)
  17. mysql勒索_记一次mysql数据库被勒索(下)
  18. 保暖防风又抗冻 春节出游当然要选头戴式耳机
  19. 使用C4D时,隔一会就卡死的解决办法
  20. Markdown mermaid种草(3)_ 流程图

热门文章

  1. [转]C++的Json解析库:jsoncpp和boost
  2. 面试精讲之面试考点及大厂真题 - 分布式专栏 04 谈谈你对分布式的理解,为什么引入分布式?
  3. Oracle容灾数据库-恢复演练方案
  4. Windows11 安装Docker,安装至D盘(其他非C盘皆可)
  5. 1,2,3……,9组成3个三位数abc,def和ghi,每个数字恰好使用一次,要求abc:def:ghi=1:2:3.输出所有解。
  6. postman使用指南
  7. linux下的磁盘空间使用
  8. 学python分析双色球_我通过使用Python分析了80多个工作拒绝而学到的东西
  9. 亚马逊云服务开通指南_亚马逊弹性容器服务初学者指南
  10. C++中头文件和实现文件分离进行编译