按以下2部分写:

1 Keras常用的接口函数介绍

2 Keras代码实例


[keras] 模型保存、加载、model类方法、打印各层权重

1.模型保存
model.save_model()可以保存网络结构权重以及优化器的参数
model.save_weights() 仅仅保存权重

2.模型加载
from keras.models import load_model
load_model()只能load 由save_model保存的形式,将模型和weight全load进来

model.load_weights(self, filepath, by_name=False):
在加载权重之前,model必须编译好,即如下先执行以后。load_weights()和save_weights()配套用的

        metrics = ['accuracy']if self.nb_classes >= 10:metrics.append('top_k_categorical_accuracy')# self.input_shape = (seq_length, features_length)self.model,self.original_model = self.zf_model()optimizer = SGD(lr=1e-3)#必须先model.compile(),才能加载权重self.model.compile(loss='categorical_crossentropy', optimizer=optimizer,metrics=metrics) #

3.sequential 和functional

序列式模型只能有单输入单输出,函数式模型可以有多个输入输出

4.model类

因为是继承, model对象有 container和layer的所有方法,可以用model对象访问下面三个类的所有方法

以上的具体区别,可以参考Keras教程:https://keras.io/zh/

Container的类属性

类属性,不是函数nameinputsoutputsinput_layersoutput_layersinput_spec trainable (boolean)input_shapeoutput_shapeinbound_nodes: list of nodesoutbound_nodes: list of nodestrainable_weights (list of variables)non_trainable_weights (list of variables)

layer.get_weights返回的是没有名字的权重array,Model.get_weights() 是他们的拼接,也没有名字,利用layer.weights 可以访问到后台的变量

5.打印各层权重

for layer in model.layers:for weight in layer.weights:print weight.name,weight.shape#打印各层名字,权重的形状
block14_sepconv1/pointwise_kernel:0 (1, 1, 1024, 1536)
block14_sepconv1_bn/gamma:0 (1536,)
block14_sepconv1_bn/beta:0 (1536,)
block14_sepconv1_bn/moving_mean:0 (1536,
conv_att/bias:0 (5,)
linear_1/kernel:0 (2048, 256)
linear_1/bias:0 (256,)
linear_2/kernel:0 (2048, 256)
linear_2/bias:0 (256,)
linear_3/kernel:0 (2048, 256)
linear_3/bias:0 (256,)
linear_4/kernel:0 (2048, 256)
linear_4/bias:0 (256,)
linear_5/kernel:0 (2048, 256)
linear_5/bias:0 (256,)
rgb_softmax/kernel:0 (1280, 60)
rgb_softmax/bias:0 (60,)
from keras.applications.vgg16 import VGG16
# model.layers  ,layer.weights
model = VGG16()
names = [weight.name for layer in model.layers for weight in layer.weights]
weights = model.get_weights()
for name, weight in zip(names, weights):print(name, weight.shape)

--------------------- 案例1------------------
【Keras】保存权重以及载入,Model、Layers函数code

from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, Activation, Input
from keras.optimizers import adam, rmsprop, adadelta
import numpy as np
import matplotlib.pyplot as plt#construct model
data_input = Input((1,),dtype='float32',name='input_data')
x = Dense(100, activation = 'relu', name='layer1')(data_input)
x = Dense(32, activation = 'tanh', name='layer2')(x)
data_output = Dense(1, activation='tanh', name='output_data')(x)model = Model(inputs=data_input, outputs=data_output)
model.compile(optimizer='rmsprop', loss='mse', metrics=['accuracy'])#print model
print('models layers:',model.layers)
print('models config:',model.get_config())
print('models summary:',model.summary())#get layers by name
layer1 = model.get_layer(name='layer1')
layer1_W_pro = layer1.get_weights()
layer2 = model.get_layer(name='layer2')
layer2_W_pro = layer2.get_weights()#train data
dataX = np.linspace(-2 * np.pi,2 * np.pi, 1000)
dataX = np.reshape(dataX, [dataX.__len__(), 1])
noise = np.random.rand(dataX.__len__(), 1) * 0.1
dataY = np.sin(dataX) + noisemodel.fit(dataX, dataY, epochs=10, batch_size=10, shuffle=True, verbose = 1)
predictY = model.predict(dataX, batch_size=1)
score = model.evaluate(dataX, dataY, batch_size=10)
print(score)
#get layers1 wights
layer1_W_end = layer1.get_weights()
#layer1_W_end - layer1_W_prolayer2_W_end = layer2.get_weights()
#layer2_W_end - layer2_W_pro#plot
fig, ax = plt.subplots()
ax.plot(dataX, dataY, 'b-')
ax.plot(dataX, predictY, 'r.')
ax.set(xlabel="x", ylabel="y=f(x)", title="y = sin(x),red:predict data,bule:true data")
ax.grid(True)
plt.savefig('d:\\test.eps', format='eps', dpi=1000)
plt.show()#save weight
model.save_weights('d:\\test.hdf5')#create new model
data_input1 = Input((1,),dtype='float32',name='input_data1')
x1 = Dense(100, activation = 'relu', name='layer11')(data_input1)
x1 = Dense(32, activation = 'tanh', name='layer21')(x1)
data_output1 = Dense(1, activation='tanh', name='output_data')(x1)model1 = Model(inputs=data_input1, outputs=data_output1)
model1.load_weights('d:\\test.hdf5')

-----------------------------案例2:实验数据MNIST---------------------------------

初次训练模型并保存

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')])# 定义优化器
sgd = SGD(lr=0.2)# 定义优化器,loss function,训练过程中计算准确率
model.compile(optimizer = sgd,loss = 'mse',metrics=['accuracy'],
)# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)
print('accuracy',accuracy)# 保存模型
model.save('model.h5')   # HDF5文件,pip install h5py

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)# 载入模型
model = load_model('model.h5')# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)
print('accuracy',accuracy)# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)
print('accuracy',accuracy)# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)print(json_string)

原文:https://blog.csdn.net/u013608336/article/details/82664529

https://blog.csdn.net/cymy001/article/details/78647640

https://blog.csdn.net/xiaoxiao133/article/details/79709524

Keras 入门基础知识+完整实例相关推荐

  1. ***入门基础知识(超全)

    ***入门基础知识(超全) [sell=2]  DOS 常用命令: dir 列文件名 deltree 删除目录树 cls 清屏 cd 改变当前目录 copy 拷贝文件 diskcopy 复制磁盘 de ...

  2. 【PMP考试扫盲】超详细的PMP项目管理入门基础知识,考证必看

    我今年 6 月刚考过 PMP,发现很多小伙伴都对 PMP 还不了解,这篇文章就是对 PMP 基础知识的一个全面解答,文章有点长,先给大家上个目录,方便大家直接查看对应内容~ 目录 [PMP考试扫盲]超 ...

  3. Spark —— 闪电般快速的统一分析引擎 —— 入门基础知识

    Spark 入门基础知识 Spark 的特点 速度快 使用方便 通用 兼容 Spark 基础 下载 独立部署模式(Standalone) 弹性分布式数据集 Scala shell 1. 数组中的最值: ...

  4. 深入了解计算机的知识,电脑入门基础知识之深入理解计算机系统

    电脑入门基础知识之深入理解计算机系统 导语:计算机系统由计算机硬件和软件两部分组成.硬件包括中央处理机.存储器和外部设备等.下面就来看看小编为大家整理的资料,希望对您有所帮助! 简介 按人的要求接收和 ...

  5. C++入门基础知识[5]——判断语句

    C++入门基础知识[5]--判断语句 原创不易,路过的各位大佬请点个赞 C++入门基础知识--判断语句 C++入门基础知识[5]--判断语句 9.判断语句 9.1 判断语句 9.2 判断语句 9.3 ...

  6. PHP简单入门基础知识

    PHP简单入门基础知识 作为一个web前端开发者第一天开始学php,整理的以下笔记,笔记并不完善,只是自我觉得和html,js有差别的地方做了下入门笔记 PHP 变量规则: 变量以 $ 符号开头,其后 ...

  7. 10分钟HTML5入门基础知识(一)

    毫无疑问,对于开发人员而言, HTML5 已是一个热点话题.如果你需要快速了解HTML5的功能的基本原理,阅读本文是你最好的选择. 本文来自The Code Project的付费搜索位置,由Solut ...

  8. Python培训入门基础知识学什么?

    Python培训基础知识主要是针对一些零基础的同学安排的,虽说Python是相对比较简单的一门编程语言,但是没有基础的同学还是要进行系统的学习,那么Python培训入门基础知识学什么呢?来看看下面小编 ...

  9. NLP汉语自然语言处理入门基础知识介绍

    NLP汉语自然语言处理入门基础知识介绍 自然语言处理定义: 自然语言处理是一门计算机科学.人工智能以及语言学的交叉学科.虽然语言只是人工智能的一部分(人工智能还包括计算机视觉等),但它是非常独特的一部 ...

最新文章

  1. python 类 公有属性、私有属性、公有方法、私有方法
  2. 服务器邮箱群发,独立IP独立账号日发万封的邮件群发服务器
  3. XenDesktop5.0 Add Host使用vSphere5.1客户端注意事项
  4. Java IO流及应用(一)
  5. c#操作Xml(四)
  6. 陕西省计算机二级报名流程,计算机二级考试报名流程
  7. 记一次ArrayList产生的线上OOM问题
  8. Python模块学习
  9. 振子天线三维方向图 matlab仿真,1阵列天线方向图的MATLAB实现
  10. 弹出选择文件夹的对话框 BROWSEINFO 的用法【MFC】
  11. 现在转行前端,该怎么学习呢?怎么学好基础html、css、js
  12. SAP License:2021年:传统ERP丧钟响起
  13. 树状数组相关应用之区间更新单点查询问题
  14. C++中INT与BYTE相互转换
  15. html 图片 把绝对路径改为相对路径,html中想把图片绝对路径 改成相对路径怎么操作?...
  16. 计算机组成原理完整学习笔记(八):控制器设计
  17. C++/QT控制通过VISA控制硬件设备,超级容易学会的控制硬件方法
  18. 赴日软件工程师,据说很火
  19. 13个免费下载SVG图标网站
  20. 腾讯云代理商:共青城市与“腾讯云”举行战略合作协议远程签约仪式

热门文章

  1. 2022上海快递物流展,上海快递展,砥砺前行-移师上海新国际博览中心
  2. vue的组件库如何从0开始
  3. 南华大学计算机科学学院,【计算机科学与技术学院】南华大学计算机科学与技术学院召开2016年大运会表彰大会...
  4. ubuntu 软件源的设置
  5. Android 图形密码
  6. ghost here
  7. 欧拉-拉格朗日方程【转】
  8. 内核裁剪和部分选项的意义
  9. OpenWRT 跨网段解析 mDNS 域名
  10. 3D 坐标变换 公式 推导