Keras构建分类模型

  • 1. tf.keras简介
  • 2. 利用tf.keras构建神经网络分类模型
    • 2.1 导入相应的库
    • 2.2 数据读取与展示
    • 2.3 数据归一化
    • 2.4 构建模型
    • 2.5 模型的编译与训练
    • 2.6 绘制训练曲线
    • 2.7 增加回调函数

1. tf.keras简介

keras是什么:

  • 基于python的高级神经网络API
  • 以TensorFlow, CNTK或者Theano后端运行,keras必须有后端才可以运行
  • 后端可以切换,现在多用于TensorFlow
  • 非常方便用于快速实验,帮助用户以最少的时间验证自己的想法

TensorFlow-keras是什么:

  • TensorFlow对keras API规范的实现
  • 相对于以TensorFlow为后端的keras, TensorFlow-keras与TensorFlow结合更加紧密
  • 实现在tf.keras空间下

Tf-keras和keras的联系与区别:

  • 联系

    • 二者是基于同一套API,keras程序可以通过更改导入方式轻松转为tf.keras程序,反之可能就不成立
  • 区别
    • tf.keras全面支持eager mode
    • 只用keras.Sequential和keras.Model的时候是没有影响的
    • tf.keras支持基于tf.data的模型训练,支持TPU训练,支持分布式策略

2. 利用tf.keras构建神经网络分类模型

2.1 导入相应的库

首先我们要导入要用到的python库

# matplotlib 用于绘图
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
# 处理数据的库
import numpy as np
import sklearn
import pandas as pd
# 系统库
import os
import sys
import time
# TensorFlow的库
import tensorflow as tf
from tensorflow import keras
# 打印我们导入库的版本号
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:print(module.__name__, module.__version__)

输出结果为:

2.2 数据读取与展示

分类问题的数据集我们采用 fashion_mnist 的数据集,里面是像素为28*28的黑白图像:

# 下载数据集
fashion_mnist = keras.datasets.fashion_mnist
# 拆分训练集与测试集
(x_train_all, y_train_all),(x_test, y_test) = fashion_mnist.load_data()
# 对训练集进行拆分,前5000个数据集作为验证集,其余的作为数据集
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
# 打印出数据维度
print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

运行结果为:

由此可见,验证集的数量为5000, 训练集的数量为55000,测试集的数量为10000,所有数据都是有28*28的图像组成。
我们可以看一下这些数据集的前十五个数据是什么样子的:

def show_imgs(n_rows, n_cols, x_data, y_data, class_names):assert len(x_data) == len(y_data)assert n_rows * n_cols < len(x_data)plt.figure(figsize = (n_cols * 1.4, n_rows * 1.6))for row in range(n_rows):for col in range(n_cols):index = n_cols * row + col plt.subplot(n_rows, n_cols, index+1)plt.imshow(x_data[index], cmap="binary",interpolation = 'nearest')plt.axis('off')plt.title(class_names[y_data[index]])plt.show()class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress','Coat', 'Sandal', 'Shirt', 'Sneaker','Bag', 'Ankle boot']
show_imgs(3, 5, x_train, y_train, class_names)

输出结果为:

2.3 数据归一化

数据归一化可以减少模型的过拟合现象,从而可以提高模型的分类准确率,这里我们使用sklearn中的Standardscaler库对数据进行归一化处理:

from sklearn.preprocessing import StandardScaler
# 初始化scaler对象
scaler = StandardScaler()
# x_train: [None, 28, 28] -> [None, 784]
# 因为数据是int型,但是归一化要做除法,所以先转化为float32型
# 训练集数据使用的是 fit_transform,和验证集与测试集中使用的 transform 是不一样的
# fit_transform 可以计算数据的均值和方差并记录下来
# 验证集和测试集用到的均值和方差都是训练集数据的,所以二者的归一化使用 transform 即可
# 归一化只针对输入数据, 标签不变
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
2.4 构建模型

在归一化数据之后,我们就可以构建我们的神经网络分类模型了。这里我们只是简单的构建了一个4层网络,一层输入层, 两个隐藏层, 以及最后的输出层

# tf.keras.models.Sequential()用于将各个层连接起来
model = keras.models.Sequential()
# flatten层的作用是将28*28维度的输入数据展平成一层
model.add(keras.layers.Flatten(input_shape = [28, 28]))
# 输出为300的全连接层, 激活函数为 “relu”
model.add(keras.layers.Dense(300, activation = "relu"))
# 输出为100的全连接层, 激活函数为 “relu”
model.add(keras.layers.Dense(100, activation = "relu")) # 输出为 10的全连接层, 激活函数为 “softmax”
model.add(keras.layers.Dense(10, activation = "softmax"))
#关于激活函数:
# relu: y = max(0, x)
# softmax: 将向量变成概率分布  Ex. x = [x1, x2, x3],
#         y = [e^x1/sum, e^x2/sum, s^x3/sum], sum = e^x1 + e^x2 +e^x3

利用keras.Sequential构建网络模型可以利用add进行各种层结构的增加,也可以这么写,二者是等价的:

model = keras.models.Sequential([keras.layers.Flatten(input_shape=[28, 28]),keras.layers.Dense(300, activation='relu'),keras.layers.Dense(100, activation='relu'),keras.layers.Dense(10, activation='softmax')
])
2.5 模型的编译与训练

我们在构建好神经网络模型之后需要对模型进行编译:

model.compile(loss = "sparse_categorical_crossentropy",  # 稀疏分类交叉熵损失函数optimizer = keras.optimizers.SGD(0.01),     # 优化函数为随机梯度下降 ,学习率为0.01metrics = ["accuracy"])                     # 优化指标为准确度

然后利用model.fit()进行网络结构的训练:

# 将训练数据与标签,训练周期以及验证集数据传递给model.fit(),其输出为训练时各个指标的变化
# 我们另 history 存放模型的训练过程
history = model.fit(x_train_scaled, y_train,  # 训练数据epochs = 10,     # 训练周期,数据分为10次进行训练validation_data = (x_valid_scaled, y_valid))  # 验证集

训练过程为:

我们可以使用history.history来观察训练时的参数变化:

history.history

输出结果为:

2.6 绘制训练曲线

我们使用pandas的DataFram库对学习曲线进行绘制:

def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize = (8, 5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()plot_learning_curves(history)

输出结果为:

2.7 增加回调函数

回调函数callbacks可以在我们训练的过程中时刻监测与记录, 比如当我们的训练满足一定的条件时(精度达到要求),便停止训练。TensorFlow中常用的回调函数有一下几个:

  • EarlyStopping
    当我们的模型在训练过程中的loss不在降低时,便停止训练,这样可以防止过拟合现象它的参数为:

    • monitor: 需要监视的量,val_loss,val_acc
    • patience: 当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练
    • verbose: 信息展示模式
    • mode: ‘auto’,‘min’,'max’之一,在min模式训练,如果检测值停止下降则终止训练。在max模式下,当检测值不再上升的时候则停止训练。
  • ModelChechpoint
    该回调函数将在每个epoch后保存模型到filepath,其参数为:

    • filename:字符串,保存模型的路径
    • monitor:需要监视的值,通常为:val_acc 或 val_loss 或 acc 或 loss
    • verbose:信息展示模式,0或1。为1表示输出epoch模型保存信息,默认为0表示不输出该信息
    • save_best_only:当设置为True时,将只保存在验证集上性能最好的模型
    • mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
    • save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
    • period:CheckPoint之间的间隔的epoch数
  • TensorBoard
    TensorBoard是一个可视化工具,它可以用来展示网络图、张量的指标变化、张量的分布情况等。特别是在训练网络的时候,我们可以设置不同的参数(比如:权重W、偏置B、卷积层数、全连接层数等),使用TensorBoader可以很直观的帮我们进行参数的选择。

使用回调函数需要在训练之前添加callbacks:

# 创建存放文件
logdir = os.path.join("callbacks")
if not os.path.exists(logdir):os.mkdir(logdir)
output_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
# 创建callbacks
callbacks = [keras.callbacks.TensorBoard(logdir),keras.callbacks.ModelCheckpoint(output_model_file,save_best_only = True),keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]

在训练模型的时候要加上callbacks:

history = model.fit(x_train_scaled, y_train,                     epochs = 10,                          validation_data = (x_valid_scaled, y_valid), callbacks = callbacks)

TensorFlow2.0(二)--Keras构建神经网络分类模型相关推荐

  1. TensorFlow2.0(三)--Keras构建神经网络回归模型

    Keras构建神经网络回归模型 1. 前言 1. 导入相应的库 2. 数据导入与处理 2.1 加载数据集 2.2 划分数据集 2.3 数据归一化 3. 模型构建与训练 3.1 神经网络回归模型的构建 ...

  2. TensorFlow2.0(五)--Keras构建Wide Deep模型

    Keras构建Wide & Deep模型 1. Wide & Deep模型简介 2. Keras实现Wide & Deep模型 2.1 导入相应的库 2.2 数据集加载与处理 ...

  3. TensorFlow2.0(四)--Keras构建深度神经网络(DNN)

    Keras构建深度神经网络(DNN) 1. 深度神经网络简介 2. Kerase搭建DNN模型 2.1 导入相应的库 2.2 数据加载与归一化 2.3 网络模型的构建 2.4 批归一化,dropout ...

  4. tensorflow2.0教程- Keras 快速入门

    tensorflow2.0教程-tensorflow.keras 快速入门 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/artic ...

  5. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization

    <<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...

  6. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...

  7. Keras深度学习实战(2)——使用Keras构建神经网络

    Keras深度学习实战(2)--使用Keras构建神经网络 0 前言 1. Keras 简介与安装 2. Keras 构建神经网络初体验 3. 训练香草神经网络 3.1 香草神经网络与 MNIST 数 ...

  8. TensorFlow2.0教程-keras 函数api

    TensorFlow2.0教程-keras 函数api Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details ...

  9. tensorflow2.0莺尾花iris数据集分类|超详细

    tensorflow2.0莺尾花iris数据集分类 超详细 直接上代码 #导入模块 import tensorflow as tf #导入tensorflow模块from sklearn import ...

最新文章

  1. matlab与acess连接问题
  2. java 判断一个数字是2倍数_如何判断语言发育迟缓的原因|一个2岁半不会说话的案例...
  3. linux c 删除文件,Linux C ftruncate 函数清空文件注意事项(要使用 lseek 重置偏移量)...
  4. 软件测试面试之登录界面
  5. MFC提供的集合类CStringArray类和CPtrArray类
  6. s4-4 以太网概述
  7. Variables多种表达
  8. linux路由内核实现分析(四)---路由缓存机制(2)
  9. 比赛现场打分管理平台的前后台安装配置和使用疑难问题汇编
  10. 小米洗手机拆解自动关机
  11. 地图学:专题地图制作详细步骤
  12. 机器学习:房价预测项目实战
  13. 【MATLAB】绘制矢量场图
  14. 通过mqtt再利用移动oneNet平台的连接与数据收发
  15. EXFS的块分配策略
  16. android下运行时动态链接dlopen()和dlsym()的实现
  17. 2008年5月12日四川汶川8.0级地震4级以上余震目录
  18. JDK17 ReentrantLock 简述 lock()、unLock()
  19. JTAG unlock
  20. GitHub 2020 报告:全球开发者工作与生活平衡情况年度分析

热门文章

  1. centos 搭建Jenkins
  2. linux vim (your system doesn't appear to have the zip pgm)
  3. MacBookPro安装Kali
  4. 软考信息安全工程师考试历年真题汇总及试题分布统计
  5. python中怎么比较两个列表的大小_python中对列表元素大小排序(冒泡排序法,选择排序法和插入排序法)—排序算法...
  6. C++ 梳理(一):跑通简单程序
  7. Cookie 详解
  8. UML和模式应用5:细化阶段(5)---系统顺序图
  9. Oracle性能优化
  10. 详解 height 和 width 属性