文章目录

  • 1 代码实现
  • 2 输出:
  • 3 过程讲解
    • 3.1 训练模型
    • 3.2 保存模型
    • 3.3 导入模型并应用

1 代码实现

import numpy as np
np.random.seed(1337)  # for reproducibilityfrom keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model# create some data
X = np.linspace(-1, 1, 200)
np.random.shuffle(X)    # randomize the data
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, ))
X_train, Y_train = X[:160], Y[:160]     # first 160 data points
X_test, Y_test = X[160:], Y[160:]       # last 40 data points
model = Sequential()
model.add(Dense(output_dim=1, input_dim=1))
model.compile(loss='mse', optimizer='sgd')
for step in range(301):cost = model.train_on_batch(X_train, Y_train)# save
print('test before save: ', model.predict(X_test[0:2]))
model.save('my_model.h5')   # HDF5 file, you have to pip3 install h5py if don't have it
del model  # deletes the existing model# load
model = load_model('my_model.h5')
print('test after load: ', model.predict(X_test[0:2]))# save and load weights
#model.save_weights('my_model_weights.h5')
#model.load_weights('my_model_weights.h5')# save and load fresh network without trained weights
#from keras.models import model_from_json
#json_string = model.to_json()
#model = model_from_json(json_string)

2 输出:

3 过程讲解

3.1 训练模型

下面的导入数据和训练模型用的是之前讲过的回归模型的例子,今天要做的是如何保存这个模型。

import numpy as np
np.random.seed(1337)  # for reproducibilityfrom keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model# create some data
X = np.linspace(-1, 1, 200)
np.random.shuffle(X)    # randomize the data
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, ))
X_train, Y_train = X[:160], Y[:160]     # first 160 data points
X_test, Y_test = X[160:], Y[160:]       # last 40 data points
model = Sequential()
model.add(Dense(output_dim=1, input_dim=1))
model.compile(loss='mse', optimizer='sgd')
for step in range(301):cost = model.train_on_batch(X_train, Y_train)

3.2 保存模型

训练完模型之后,可以打印一下预测的结果,接下来就保存模型。

保存的时候只需要一行代码 model.save,再给它加一个名字就可以用 h5 的格式保存起来。

这里注意,需要已经安装了 HDF5 这个模块。

保存完模型之后,删掉它,后面可以来比较是否成功的保存。

# save
print('test before save: ', model.predict(X_test[0:2]))
model.save('my_model.h5')   # HDF5 file, you have to pip3 install h5py if don't have it
del model  # deletes the existing model"""
test before save:  [[ 1.87243938] [ 2.20500779]]
"""

3.3 导入模型并应用

导入保存好的模型,再执行一遍预测,与之前预测的结果比较,可以发现结果是一样的。#load
model = load_model('my_model.h5')
print('test after load: ', model.predict(X_test[0:2]))"""
test after load:  [[ 1.87243938] [ 2.20500779]]
"""
另外还有其他保存模型并调用的方式,第一种是只保存权重而不保存模型的结构。#save and load weights
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
第二种是用 model.to_json 保存完结构之后,然后再去加载这个json_string。#save and load fresh network without trained weights
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)

Keras【Deep Learning With Python】Save reload 保存提取模型相关推荐

  1. Keras Tutorial: Deep Learning in Python

    This Keras tutorial introduces you to deep learning in Python: learn to preprocess your data, model, ...

  2. Deep learning with Python 学习笔记(6)

    本节介绍循环神经网络及其优化 循环神经网络(RNN,recurrent neural network)处理序列的方式是,遍历所有序列元素,并保存一个状态(state),其中包含与已查看内容相关的信息. ...

  3. Deep learning with Python 学习笔记(9)

    神经网络模型的优化 使用 Keras 回调函数 使用 model.fit()或 model.fit_generator() 在一个大型数据集上启动数十轮的训练,有点类似于扔一架纸飞机,一开始给它一点推 ...

  4. python思想读后感_《Deep Learning with Python》读后感精选

    <Deep Learning with Python>是一本由Francois Chollet著作,Manning Publications出版的Paperback图书,本书定价:USD ...

  5. Python深度学习:基于PyTorch [Deep Learning with Python and PyTorch]

    作者:吴茂贵,郁明敏,杨本法,李涛,张粤磊 著 出版社:机械工业出版社 品牌:机工出版 出版时间:2019-11-01 Python深度学习:基于PyTorch [Deep Learning with ...

  6. Deep Learning with Python

    1.学习地址 Deep Learning with Python(wang@123) 2.大神的twitter 大神的twitter

  7. Keras【Deep Learning With Python】MNIST数据集识别优化

    文章目录 前言 1 线性回归预测 2 手写数字识别 3 模型优化 前言 本文分为三部分: a.线性回归 b.手写数字识别 c.手写数字识别模型优化. 1 线性回归预测 import keras Usi ...

  8. Keras【Deep Learning With Python】—Keras基础

    文章目录 1.关于Keras 2.Keras的模块结构 3.使用Keras搭建一个神经网络 4. 主要概念 5.第一个示例 下载网站数据注意 1.关于Keras 1)简介 Keras是由纯python ...

  9. 《Deep Learning with Python》(中文版)—读书笔记

    深度学习 人工智能>机器学习>深度学习 人工智能:努力将通常由人类完成的智力任务自动化.(人们输入的是规则(即程序)和需要根据这些规则进行处理的数据,系统输出的是答案) 机器学习:机器学习 ...

最新文章

  1. anaconda安装scrapy_Scrapy框架的安装
  2. 使用log4j监视和筛选应用程序日志到邮件
  3. Facebook 开源 M2M-100,不依赖英语互译百种语言
  4. 平安银行支付接口 PHP ECSHOP
  5. 气象netCDF数据可视化分析
  6. vs离线安装Qt开发插件vsix
  7. 陈进: 创业维艰吗? 换个皮肤就能获批一亿经费!
  8. Pr学习笔记——添加字幕流
  9. C语言函数指针的几种用法【转】+gyy修改
  10. 【优化求解】基于收敛因子和黄金正弦指引机制的蝴蝶优化算法求解单目标优化问题matlab代码(AGSABOA)
  11. VS2019下编译x264.dll
  12. 玫琳凯携手联合国机构推出女性创业加速器计划
  13. 无创脑刺激对不同神经和神经精神疾病睡眠障碍的影响
  14. Ceph配置——5.Ceph-MON设置
  15. 小米物联网世界第一_雷军:小米智能设备连接数世界第一 AI+IoT是核心战略
  16. Android 10.0 蓝牙去掉传输文件的功能
  17. 使用图神经网络预测药物-药物相互作用
  18. 腾讯手游助手android文件夹,腾讯手游助手安装的apk在哪个文件夹?腾讯手游助手游戏安装目录介绍...
  19. 非标准包 game.rgss3a 的打开方法 | 2023 年实测
  20. 批量字符替换(利用好压的文本替换工具)

热门文章

  1. 任意角度人脸检测pcn
  2. Run-Time Check Failure #3
  3. 9行Python代码搭建神经网络
  4. it 脑裂_脑裂是什么?Zookeeper是如何解决的?
  5. html+li标签+高度,有时在使用Jquery插入LI元素时,JavaScript不会调整UL元素的高度
  6. linux mint输入法托盘,linux mint12安装ibus之后,语言栏不跟随光标和系统托盘输入法图标不能显示问题解决...
  7. etl常用的三种工具介绍_Adobe Photoshop常用修图插件+屏幕模式+内容感知移动工具介绍...
  8. 阿里nacos安装及使用指南
  9. java中synchronized修饰静态方法和非静态方法有什么区别?
  10. maven中servlet报错:不识别此servlet问题的解决办法