every blog every motto:

0. 前言

以fashion_mnist 为例,保存模型

1. 代码部分

1. 导入模块

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import kerasprint(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:print(module.__name__,module.__version__)

2. 读取数据

fashion_mnist = keras.datasets.fashion_mnist
# print(fashion_mnist)
(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()
x_valid,x_train = x_train_all[:5000],x_train_all[5000:]
y_valid,y_train = y_train_all[:5000],y_train_all[5000:]
# 打印格式
print(x_valid.shape,y_valid.shape)
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape)

3. 数据归一化

# 数据归一化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
# x_train:[None,28,28] -> [None,784]
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)

4. 构建模型

# tf.keras.models.Sequential()
# 构建模型# 创建对象
"""model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
model.add(keras.layers.Dense(300,activation='sigmoid'))
model.add(keras.layers.Dense(100,activation='sigmoid'))
model.add(keras.layers.Dense(10,activation='softmax'))"""# 另一种写法
model = keras.models.Sequential([keras.layers.Flatten(input_shape=[28,28]),keras.layers.Dense(300,activation='sigmoid'),keras.layers.Dense(100,activation='sigmoid'),keras.layers.Dense(10,activation='softmax')
])#
model.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])

5. 保存模型/参数

5.1 回调函数,保存模型

save_wegiths_only=True 只保存参数
save_weights_only=False 保存模型+参数

  1. 保存模型+参数
# 回调函数 Tensorboard(文件夹)\earylystopping\ModelCheckpoint(文件名)
logdir = os.path.join("graph_def_and_weigths")
print(logdir)
if not os.path.exists(logdir):os.mkdir(logdir)
# 文件名
output_model_file = os.path.join(logdir,"fashion_mnist_model.h5")callbacks = [keras.callbacks.TensorBoard(logdir),# save_weights_only=True 只保存参数,=False 保存参数加模型keras.callbacks.ModelCheckpoint(output_model_file,save_best_only=True,save_weights_only=False),keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3),
]
# 开始训练
history = model.fit(x_train_scaled,y_train,epochs=10,validation_data=(x_valid_scaled,y_valid),callbacks=callbacks)
  1. 只保存参数
# 回调函数 Tensorboard(文件夹)\earylystopping\ModelCheckpoint(文件名)
logdir = os.path.join("graph_def_and_weigths")
print(logdir)
if not os.path.exists(logdir):os.mkdir(logdir)
# 文件名
output_model_file = os.path.join(logdir,"fashion_mnist_weights.h5")callbacks = [keras.callbacks.TensorBoard(logdir),# save_weights_only=True 只保存参数,=False 保存参数加模型keras.callbacks.ModelCheckpoint(output_model_file,save_best_only=True,save_weights_only=True),keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3),
]
# 开始训练
history = model.fit(x_train_scaled,y_train,epochs=10,validation_data=(x_valid_scaled,y_valid),callbacks=callbacks)

5.2 保存参数

另一种方法

model.save_weights(os.path.join(logdir,"fashion_mnist_weights_2.h5"))

6. 学习曲线

# 画图
def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize=(8,5))plt.grid(True)plt.gca().set_ylim(0,1)plt.show()
plot_learning_curves(history)

7. 模型评估

model.evaluate(x_test_scaled,y_test,verbose=0)

8. 载入保存的模型

# 载入模型
loaded_model = keras.models.load_model(output_model_file)
loaded_model.evaluate(x_test_scaled,y_test,verbose=0)

从零基础入门Tensorflow2.0 ----九、44.1 keras 保存模型、参数相关推荐

  1. 视频编码零基础入门(0):零基础,史上最通俗视频编码技术入门

    [来源申明]本文引用了微信公众号"鲜枣课堂"的<视频编码零基础入门>文章内容.为了更好的内容呈现,即时通讯网在引用和收录时内容有改动,转载时请注明原文来源信息,尊重原作 ...

  2. SQL零基础入门学习(九)

    SQL零基础入门学习(八) SQL UNION 操作符 UNION 操作符用于合并两个或多个 SELECT 语句的结果集. 请注意,UNION 内部的每个 SELECT 语句必须拥有相同数量的列.列也 ...

  3. 【天池赛事】零基础入门语义分割-地表建筑物识别 Task5:模型训练与验证

    [天池赛事]零基础入门语义分割-地表建筑物识别 Task1:赛题理解与 baseline(3 天) – 学习主题:理解赛题内容解题流程 – 学习内容:赛题理解.数据读取.比赛 baseline 构建 ...

  4. Apache Flink 零基础入门(十九)Flink windows和Time操作

    Time类型 在Flink中常用的Time类型: 处理时间 摄取时间 事件时间 处理时间 是上图中,最后一步的处理时间,表示服务器中执行相关操作的处理时间.例如一些算子操作时间,在服务器上面的时间. ...

  5. 指针04 - 零基础入门学习C语言44

    第八章:指针04 让编程改变世界 Change the world by program 小结 归纳起来, 如果有一个实参数组, 想在函数中改变此数组中的元素的值, 实参与形参的对应关系有以下4种情况 ...

  6. 零基础入门--中文命名实体识别(BiLSTM+CRF模型,含代码)

    https://github.com/mali19064/LSTM-CRF-pytorch-faster 中文分词 说到命名实体抽取,先要了解一下基于字标注的中文分词. 比如一句话 "我爱北 ...

  7. SQL零基础入门学习(十)

    SQL零基础入门学习(九) SQL CREATE DATABASE 语句 CREATE DATABASE 语句用于创建数据库. SQL CREATE DATABASE 语法 CREATE DATABA ...

  8. 【题解】《算法零基础100讲》(第44讲) 位运算 (位或) 入门

    文章目录 一. 概念定义 1.1 位或定义 1.2 位与定义 二. 推荐专栏 三. 相关练习 3.1 根据数字二进制下 1 的数目排序 3.2 二进制表示中质数个计算置位 3.3 2 的幂 一. 概念 ...

  9. 0基础能学漫画么?漫画零基础入门教程!

    漫画零基础入门教程!很多人都喜欢看动漫,同时也会幻想成为动漫里的主角,与此同时也会诞生学漫画的想法.不论是你真的想学习漫画,又或出于个人爱好,或职业需要,或为了具备一项自己喜欢的看家本领.我们都要先清 ...

  10. 新版思科CCNA认证1.0 零基础入门技术VTP协议解析-ielab网络实验室

     新版思科CCNA认证1.0 零基础入门技术VTP协议解析-ielab网络实验室 VTP(VLAN Trunking Protocol):VLAN中继协议,是Cisco专用协议.也被称为虚拟局域网干道 ...

最新文章

  1. 欧拉公式——真正的宇宙第一公式
  2. 【Rollo的Python之路】Python 同步条件 学习笔记 Event
  3. vs的资源管理器中一次性添加整个文件夹
  4. 【存储过程】MySQL存储过程/存储过程与自定义函数的区别
  5. java线程池之一:创建线程池的方法
  6. 【django】三、常用的模板标签和过滤器
  7. System.Net.Http.Formatting的nuget版本冲突问题
  8. 分享一个好用的网页pdf打印插件
  9. 一文搞懂深度学习所有工具——Anaconda、CUDA、cuDNN
  10. 爬取当当网评论(1)
  11. 当下的力量(解读版)
  12. JAVA基础 网络编程
  13. Motion planning for self-driving cars课程笔记1:应用雷达数据生成占用栅格地图(Occupancy Grid Map)
  14. AI一键图文生成短视频工具,文章AI自动生成视频,傻瓜式操作。
  15. 配一副适合程序员的眼镜
  16. 打印圆周率指定位数之python
  17. vue-pdf实现pdf文件在线预览
  18. 正则表达式不区分大小写以及解决思路的探索
  19. Java - SpringBoot 框架详解(一)
  20. 三菱5uplc伺服电机指令_三菱FX3U PLC如何控制松下伺服

热门文章

  1. linux查看ps进程命令,linux ps查看进程命令
  2. (秒杀项目) 4.10 项目面试项目常见问题
  3. python策略模式包含角色_Python 之策略模式
  4. php中fgetss函数,fgetss-函数用法_PHP教程
  5. python 持续集成 教程_dotnet 部署 github 的 Action 进行持续集成|简明python教程|python入门|python教程...
  6. mysql连网安装和断网安装的区别_Linux 断网安装MySQL5.x操作步骤
  7. java 反射集合_Java反射的理解(六)-- 通过反射了解集合泛型的本质
  8. sql 时间字符串转换
  9. 「学习路线分享」SLAM/深度估计/三维重建/相机标定/传感器融合(目录)
  10. 汇总|实时性语义分割算法(共24篇)