7. keras - 模型的保存与载入
- 需要安装h5py:
pip install h5py
1.模型的保存与加载
- 方法1:model.save()
model.save('model_path.h5')
# 保存模型
model = load_model('model_path.h5')
# 加载模型, 需要from keras.models import load_model
说明:- 优点:保存整个模型(包括结构和权重),加载时只需要直接加载指定路径下的model即可。
- 缺点:①如果有自定义的层(如:MyLayer1),则需要使用代码
model = load_model('model.h5', custom_objects={'MyLayer1': MyLayer1, 'tf': tf} )
来加载model,即还是需要重新定义层; ②相比于save_weights
而言,由于还要保存结构信息,所以保存下载的模型占的存储空间比较大 - 总之,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。
- 方法2:model.save_weights()
model.save_weights('model_path.h5')
# 保存模型权重
model.load_weights('model_path.h5')
# 加载模型权重
说明:需要先把模型定义出来,然后再加载权值。 - 方法3和方法4:model.to_json()和model.to_yaml()
本人使用的不多,可以参考:Keras模型保存的几个方法和它们的区别
或者keras的官网:Model saving & serialization APIs
2.代码案例
保存模型
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')])# 定义优化器
sgd = SGD(lr=0.2)# 定义优化器,loss function,训练过程中计算准确率
model.compile(optimizer = sgd,loss = 'mse',metrics=['accuracy'],
)# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)
print('accuracy',accuracy)# 保存模型
model.save('model.h5') # HDF5文件,pip install h5py
加载模型
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)# 载入模型
model = load_model('model.h5')# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)
print('accuracy',accuracy)
在原有模型的基础上,再迭代训练两个周期,以提高精确度
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)
print('accuracy',accuracy)
- 保存模型的参数或者保存模型的网络结构
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)print(json_string)
参考:
视频: 覃秉丰老师的“Keras入门”:http://www.ai-xlab.com/course/32
博客参考:https://www.cnblogs.com/XUEYEYU/tag/keras%E5%AD%A6%E4%B9%A0/
7. keras - 模型的保存与载入相关推荐
- TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)
TF:利用TF的train.Saver将训练好的W.b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据) 目录 输出结果 代码设计 输出结果 代码设计 import tensorflow as ...
- pytorch多卡并行模型的保存与载入
pytorch多卡并行模型的保存与载入 当模型是在数据并行方式在多卡上进行训练的训练和保存,那么载入的时候也是一样需要是多卡.并且,load_state_dict()函数的调用要放在DataParal ...
- Keras框架训练模型保存及载入继续训练
Keras框架训练模型保存及再载入 实验数据MNIST 初次训练模型并保存 import numpy as np from keras.datasets import mnist from keras ...
- keras的model保存和载入
保存keras的model文件和载入keras文件的方法有很多. keras中的模型主要包括model和weight两个部分. 保存模型结构 保存model部分的主要方法:一是通过json文件,二是通 ...
- Keras模型的保存与调用
一.模型的保存(结构 + 权重 + 优化器状态) 1.model.save('model.h5')#保存名为model的h5文件到程序所在目录 你可以使用 model.save(filepath) 将 ...
- Keras——模型的保存、读取及加载
本文将会介绍如何利用Keras来实现模型的保存.读取以及加载. 本文使用的模型为解决IRIS数据集的多分类问题而设计的深度神经网络(DNN)模型,模型的结构示意图如下: 具体的模型参数可以参考文章 ...
- PyTorch模型的保存加载以及数据的可视化
文章目录 PyTorch模型的保存和加载 模块和张量的序列化和反序列化 模块状态字典的保存和载入 PyTorch数据的可视化 TensorBoard的使用 总结 PyTorch模型的保存和加载 在深度 ...
- 实践教程 | Pytorch 模型的保存与迁移
实践教程 | Pytorch 模型的保存与迁移 在本篇文章中,笔者首先介绍了模型复用的几种典型场景:然后介绍了如何查看Pytorch模型中的相关参数信息:接着介绍了如何载入模型.如何进行追加训练以及进 ...
- Keras保存和载入训练好的模型和参数
1.保存模型 my_model = create_model_function( ...... )my_model.compile( ...... )my_model.fit( ...... )mod ...
- Keras学习笔记---保存model文件和载入model文件
Keras学习笔记---保存model文件和载入model文件 保存keras的model文件和载入keras文件的方法有很多.现在分别列出,以便后面查询. keras中的模型主要包括model和we ...
最新文章
- 2009年总结-爱与快乐着
- 08年美国最值得信赖20大公司排行 谷歌落榜
- Objective-C学习笔记--NSLog用法及例子
- leetcode mysql 排名_(LeetCode:数据库)分数排名
- Android应用程序组件Content Provider的共享数据更新通知机制分析
- java stric_Java中的strictfp关键字
- ZOJ3826 Hierarchical Notation(14牡丹江 H) 树套树
- python批量读取csv并入库pg_如何通读CSV然后在Python中发布批量API调用
- 「镁客早报」传SpaceX计划展开7.5亿美元贷款融资;LG开始为苹果生产OLED面板
- PHP设计模式——单例模式
- UOS设置屏幕缩放后的配置文件研究
- ktv服务器管理系统,小型KTV综合解决方案
- 基坑计算理论m法弹性支点法_建筑基坑支护考题汇总.doc
- [1160]C语言实验——某年某月的天数
- CDN是什么意思 CDN加速服务有什么功能和作用?
- dataframe如何替换某列元素值_dataframe 按条件替换某一列中的值方法
- 斯坦福公布3D街景数据集:2500万张图像,8个城市模型 | 下载
- Visual SLAM 笔记——李群和李代数详解
- 华为云存储空间图库占比太大_华为手机照片太多?放这里既安全又不占内存,瞬间腾出50G空间...
- jquery基础学习记录
热门文章
- RxJS修炼之 用弹珠测试学习RxJS
- 如何从一个USB上安装Windows Vista
- [linux] 查看目录/文件字节数
- 《专业嵌入式软件开发》的样章、建议和勘误
- error: Unexpected trailing comma (comma-dangle) at src\components\Login.vue:99:4:
- 面向对象的超级面试题,涉及封装多态继承等多方面考核,异常烧脑,90%的面试官必问题目,不会这个的,只是会搬砖的码农
- 计算机与科学 研究生考试内容,计算机科学与技术考研考哪些科目 备考技巧有哪些...
- python 进程池阻塞和非阻塞_Python协程还不理解?请收下这份超详细的异步编程教程!还没学会来找我!...
- python图形编程环境环境_Python开发环境Wing IDE matplotlib 2D绘图库代码调试技巧小结...
- 改写反话技巧_2021考研唐迟阅读技巧总结