tf.keras 是 tensorflow2 引入的高封装度的框架,可以用于快速搭建神经网络模型,keras 为支持快速实验而生,能够把想法迅速转换为结果,是深度学习框架之中最终易上手的一个,它提供了一致而简洁的 API,能够极大地减少一般 应用下的工作量,提高代码地封装程度和复用性

本文以FasionMNIST/加州房价数据集为例,介绍KerasAPI进行分类问题/回归问题模型训练的方法

Tensorflow版本

Tensorflow和keara都需要2.0及以上版本

import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(keras.__version__)

分类MLP构建

数据集

fashion MNIST dataset,数据集有60000张衣服鞋子的图片,大小为28X28。

Keras可以通过keras.datasets方法来下载主流的数据集,fashion MNIST dataset已经区分了训练集(50000张)和测试集(10000张), 但最好从训练集中划出一部分作为验证集(Validation set)

import tensorflow as tf
from tensorflow import keras
fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()# 由于数据是最大值为255的灰度图像,除于255进行归一化
X_valid, X_train = X_train_full[:5000] / 255., X_train_full[5000:] / 255.
y_valid, y_train = y_train_full[:5000], y_train_full[5000:]
X_test = X_test / 255.

模型训练

定义网络层

Sequential API方法

Sequential网络层定义有两种方法,一种是add layers方法,有点笨拙

model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28, 28]))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))

另外一种整体括号内定义网络层,推荐这种方法

model = keras.models.Sequential([keras.layers.Flatten(input_shape=[28, 28]),keras.layers.Dense(300, activation="relu"),keras.layers.Dense(100, activation="relu"),keras.layers.Dense(10, activation="softmax")
])

解释:

  1. Sequential是Keras最简单的网络结构定义方式,按照从上到下(代码文本)将单层的网络连接起来

  1. Flatten称为“拉直层”,作用是将28X28的数组转化为1D的格式

  1. 第3/4行分别定义了两个全联接层(隐藏层),大小为300/100个神经元,激活函数为relu,其格式通常为:

tf.keras.layers.Dense( 神经元个数,activation=”激活函数”, kernel_regularizer=”正则化方式”)    
  1. 第5行定义输出层,由于是多分类问题,使用了“softmax”作为激活函数。

  • 使用model_summary()可以查看模型的结构

  • 使用model.get_layer(), layer.get_weight()可以查看模型的参数

定义损失函数、优化器等

model.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])

Metrics选择:

  • ‘accuracy’:y_和 y 都是数值,如 y_=[1] y=[1]

  • ‘categorical_accuracy’:y_和 y 都是以独热码和概率分布表示。如y_=[0, 1, 0], y=[0.256, 0.695, 0.048]。

  • ‘sparse_ categorical_accuracy’:y_是以数值形式给出,y 是以独热码形式 给出。如 y_=[1],y=[0.256, 0.695, 0.048]。

模型训练

和sk-learn一样调用fit()方法

history = model.fit(X_train, y_train, epochs=30,validation_data=(X_valid, y_valid))'''
model.fit(训练集的输入特征, 训练集的标签, batch_size, epochs, validation_data = (测试集的输入特征,测试集的标签),validataion_split = 从测试集划分多少比例给训练集,validation_freq = 测试的 epoch 间隔次数)
'''

可以方便的运用history数据进行训练过程的可视化

import pandas as pd
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)

测试集评估

model.evaluate(X_test, y_test)

模型保存

测试集上能达到要求就可以进行模型保存了,也可以将其部署到生产环境。

# kears模型的格式为h5
model.save("my_keras_FasionMNist_model.h5")

模型预测

# 载入模型
model = keras.models.load_model("my_keras_FasionMNist_model.h5")X_new = X_test[:1] # 以一个样本为例子,输入是28X28的数组
y_proba = model.predict(X_new) # 由于是10分类问题,softmax输出结果是一个在10分类上概率和为1的10维数组
y_proba.round(2)y_pred = np.argmax(model.predict(X_new), axis=-1)# 求概率和最大,axis=-1为在每个10维度数组内求
y_pred

回归MLP构建

以加州房价预测数据集为例,模型开发和分类问题类似,就不分模块一一介绍了。

主要区别:

  1. 由于输出房价是1维数据,只需要一个神经元,且训练label也是房价,因此不需要激活函数

  1. loss function为MSE

  1. 由于数据集比较简单,为防止过拟合,仅适用一个hidden layer,且神经元的个数设置为更低的值

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# load数据集
housing = fetch_california_housing()X_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target, random_state=42)
X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full, random_state=42)scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_valid = scaler.transform(X_valid)
X_test = scaler.transform(X_test)# 模型定义和训练
model = keras.models.Sequential([keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),keras.layers.Dense(1)
])
model.compile(loss="mean_squared_error", optimizer=keras.optimizers.SGD(lr=1e-3))
history = model.fit(X_train, y_train, epochs=20, validation_data=(X_valid, y_valid))#模型训练epoch过程MSE可视化
plt.plot(pd.DataFrame(history.history))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()mse_test = model.evaluate(X_test, y_test)X_new = X_test[:3]
y_pred = model.predict(X_new)

深度学习-Tensorflow使用Keras进行模型训练相关推荐

  1. 【深度学习】基于PyTorch的模型训练实用教程之数据处理

    [深度学习]基于PyTorch的模型训练实用教程之数据处理 文章目录 1 transforms 的二十二个方法 2 数据加载和预处理教程 3 torchvision 4 如何用Pytorch进行文本预 ...

  2. 文本深度表示模型Word2Vec 简介 Word2vec 是 Google 在 2013 年年中开源的一款将词表征为实数值向量的高效工具, 其利用深度学习的思想,可以通过训练,把对文本内容的处理简

    文本深度表示模型Word2Vec 简介 Word2vec 是 Google 在 2013 年年中开源的一款将词表征为实数值向量的高效工具, 其利用深度学习的思想,可以通过训练,把对文本内容的处理简化为 ...

  3. 深度学习工程实践 6. 使用pytorch训练自己的眼球分割模型

    深度学习工程实践 6. 使用pytorch训练自己的眼球分割模型 1. 概述 2. 目标 3. 工程实践 3.1 数据寻找,数据标注 3.2 训练 3.3 部署应用到桌面程序 4. 总结 1. 概述 ...

  4. 【神经网络与深度学习-TensorFlow实践】-中国大学MOOC课程(八)(TensorFlow基础))

    [神经网络与深度学习-TensorFlow实践]-中国大学MOOC课程(八)(TensorFlow基础)) 8 TensorFlow基础 8.1 TensorFlow2.0特性 8.1.1 Tenso ...

  5. 【神经网络与深度学习-TensorFlow实践】-中国大学MOOC课程(十二)(人工神经网络(1)))

    [神经网络与深度学习-TensorFlow实践]-中国大学MOOC课程(十二)(人工神经网络(1))) 12 人工神经网络(1) 12.1 神经元与感知机 12.1.1 感知机 12.1.2 Delt ...

  6. 3月13日云栖精选夜读:通过阿里云容器服务深度学习解决方案上手Caffe+多GPU训练

    阿里云容器服务提供的深度学习解决方案内置了对Tensorflow, Keras, MXnet框架的环境,并支持基于它们的深度学习模型开发.模型训练和模型预测.同时,对于模型训练和预测,用户还可以通过指 ...

  7. 3月13日云栖精选夜读:通过阿里云容器服务深度学习解决方案上手Caffe+多GPU训练...

    阿里云容器服务提供的深度学习解决方案内置了对Tensorflow, Keras, MXnet框架的环境,并支持基于它们的深度学习模型开发.模型训练和模型预测.同时,对于模型训练和预测,用户还可以通过指 ...

  8. 【神经网络与深度学习-TensorFlow实践】-中国大学MOOC课程(十四)(卷积神经网络))

    [神经网络与深度学习-TensorFlow实践]-中国大学MOOC课程(十四)(卷积神经网络)) 14 卷积神经网络 14.1 深度学习基础 14.1.1 深度学习的基本思想 14.1.2 深度学习三 ...

  9. 使用深度学习TensorFlow框架进行图片识别

    Apsara Clouder大数据专项技能认证:使用深度学习TensorFlow框架进行图片识别 本认证系统的介绍了深度学习的一些基础知识,以及Tensorflow的工作原理.通过阿里云机器学习PAI ...

最新文章

  1. springboot 集成mybatis时日志输出
  2. 立足GitHub学编程:13个不容错过的Java项目
  3. 计算机基础知识作业,第一章计算机基础知识作业
  4. 网易云 6 亿用户音乐推荐算法
  5. mysql封装 javabean,利用Java针对MySql封装的jdbc框架类JdbcUtils完整实现(包含增删改查、JavaBean反射原理,附源码)...
  6. python 返回列表中的偶数
  7. Spark DataFrames DataSet
  8. 计算机等级考试c 试题及答案,3月计算机等级考试级C笔试试题及答案解析.doc
  9. 计算机软件怎么共享使用,局域网共享软件,教您局域网共享软件如何使用
  10. SpringBoot AOP切面实现
  11. spring mvc 配置使用定时任务
  12. CSS 字体大小 font-size属性
  13. iNFTnews | Yuga Labs收购Meebits,NFT IP市场操纵存在担忧
  14. 2023第八届少儿模特明星盛典 福州赛区 初赛圆满收官
  15. 《Cocos Creator游戏实战》滚动数字
  16. c语言小数除于整数怎么运算,C#:将int除以100
  17. 游戏开发中如何设计一个撤销重做系统DoUnDo
  18. linux 跳过overwrite确认
  19. 利用python操作word文档
  20. JavaWeb - 仿小米商城网(2) 用户注册

热门文章

  1. 2022年武汉市助理工程师职称评定流程和评定条件是什么?甘建二
  2. 酷开系统,打造更多可能
  3. 【达内课程】ListView使用
  4. Python的异或(‘^’)运算和程序控制流程题
  5. java设置text字体颜色_java itext添加中文文字和设置文字颜色
  6. win7右键计算机管理参数错误,win7 64位旗舰版系统右键无法打开属性窗口的解决方法...
  7. 三层交换与路由器之间连接配置
  8. [NLP]如何安装繁简转换工具:opencc
  9. python3内置集成开发工具_python应用(3):启用集成开发工具pycharm
  10. 2019最新迅为-i.MX6Q开发板资料目录