今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))

x = Dense(64, activation='relu')(inputs)

x = Dense(64, activation='relu')(x)

y = Dense(10, activation='softmax')(x)

model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model

from keras.layers import Input, Dense

from keras.datasets import mnist

from keras.utils import np_utils

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train=x_train.reshape(x_train.shape[0],-1)/255.0

x_test=x_test.reshape(x_test.shape[0],-1)/255.0

y_train=np_utils.to_categorical(y_train,num_classes=10)

y_test=np_utils.to_categorical(y_test,num_classes=10)

inputs = Input(shape=(784, ))

x = Dense(64, activation='relu')(inputs)

x = Dense(64, activation='relu')(x)

y = Dense(10, activation='softmax')(x)

model = Model(inputs=inputs, outputs=y)

model.save('m1.h5')

model.summary()

model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=10)

#loss,accuracy=model.evaluate(x_test,y_test)

model.save('m2.h5')

model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model

model = load_model('m1.h5')

#model = load_model('m2.h5')

#model = load_model('m3.h5')

model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model

from keras.layers import Input, Dense

inputs = Input(shape=(784, ))

x = Dense(64, activation='relu')(inputs)

x = Dense(64, activation='relu')(x)

y = Dense(10, activation='softmax')(x)

model = Model(inputs=inputs, outputs=y)

model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

python模型保存save_浅谈keras保存模型中的save()和save_weights()区别相关推荐

  1. java内存模型浅析_浅谈java内存模型

    不同的平台,内存模型是不一样的,但是jvm的内存模型规范是统一的.其实java的多线程并发问题最终都会反映在java的内存模型上,所谓线程安全无非是要控制多个线程对某个资源的有序访问或修改.总结jav ...

  2. 浅谈在基本数据包装类中使用'=='与equals的区别

    当"=="两边的数据都是包装类型时,比较的是两对象是否为同一对象;当等式一边有逻辑运算时,会触发自动拆箱,则比较的是数值. 而equals则不会进行类型转换. 输出结果为:true ...

  3. 浅谈List保存的数据是引用数据类型的地址

    浅谈List保存的数据是引用数据类型的地址 今天一个初学javaweb的朋友问我一个bug,经过和别人 讨论分析了解到List对象细节上的一些问题,我将代码重新简化构造了一下做成了一个例子.上代码: ...

  4. python的matmul_浅谈keras中的batch_dot,dot方法和TensorFlow的matmul

    概述 在使用keras中的keras.backend.batch_dot和tf.matmul实现功能其实是一样的智能矩阵乘法,比如A,B,C,D,E,F,G,H,I,J,K,L都是二维矩阵,中间点表示 ...

  5. 浅谈linux线程模型和线程切换

    本文从linux中的进程.线程实现原理开始,扩展到linux线程模型,最后简单解释线程切换的成本. 刚开始学习,不一定对,好心人们快来指正我啊啊啊!!! linux中的进程与线程 首先明确进程与进程的 ...

  6. 浅谈DirectX的模型加载

    浅谈DirectX的模型加载 xanxus - 2010年10月3日 - DirectX - 0 Comments 喜欢这篇文章吗?分享给你的朋友吧~  基于DirectX的游戏开发中,人物和模型由针 ...

  7. python语法中infile语句_浅谈pymysql查询语句中带有in时传递参数的问题

    直接给出例子说明: cs = conn.cursor() img_ids = [1,2,3] sql = "select img_url from img_url_table where i ...

  8. python查询数据库带逗号_浅谈pymysql查询语句中带有in时传递参数的问题

    直接给出例子说明: cs = conn.cursor() img_ids = [1,2,3] sql = "select img_url from img_url_table where i ...

  9. 浅谈Java内存模型、并发、多线程

    浅谈Java内存模型.并发.多线程 Java内存模型(Java Memory Model)是围绕着在并发编程中如何处理原子性,可见性,有序性三个特性而建立的模型. 下面我简单描述一下这三个特性: 原子 ...

最新文章

  1. 解决Linux中java.net.UnknownHostException: oracledb.sys.iflashbuy.com问题
  2. feign post 传递空值_HTTP中GET与POST的区别,99 %的人都理解错了
  3. css如何设置图转30度,使用CSS实现左右30度的摆钟
  4. jQuery Vue的CDN
  5. 查看某个端口的进程 lsof -i:端口号
  6. arrayrand php,php中array_rand函数的功能起什么作用呢?
  7. 终于有人把可解释机器学习讲明白了
  8. Python_argparse
  9. 接口测试(apipost、jmeter和python脚本)
  10. Activity之间的数据传递—实现Parcelable接口
  11. 让Office无处不在——Office Web App初体验
  12. python窗口大小动态变化_python – 如何让tkinter画布动态调整窗口宽度?
  13. 国内机场代码(IATA)
  14. noip2017提高组初赛(答案+选择题题目+个人分析)
  15. 《Go程序设计语言》- 第3章:基本数据
  16. 土地利用转移矩阵图怎么做_如何用Arcgis做土地利用转移矩阵?求教各位..._土地估价师_帮考网...
  17. 线性回归中一次性实现所有自变量的单因素分析
  18. 计算机软件系统两大类,详解计算机软件系统包括哪两大类
  19. python版本的flapy bird_python实现简单flappy bird
  20. 克里斯蒂安贝尔_克里斯蒂安贝尔解释为何只演3次蝙蝠侠

热门文章

  1. java音频实时传输_会议室智能系统建设方案,实时远程视频协作
  2. python和c混合编程 gil,如何在python中使用C扩展来解决GIL
  3. PHP 与go 通讯,Golang和php通信
  4. c语言什么时候需要加分号,归纳一下html中什么时候需要分号什么时候需要冒
  5. Oracle关联查询-数据类型不一致问题 ORA-01722: 无效数字
  6. rsync服务扩展应用
  7. MySQL Index Condition Pushdown
  8. 模拟UIWebView
  9. jQuery的ajaxFileUpload上传文件插件刷新一次才能再次调用触发change
  10. 适用响应式 Web UI 框架