各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量。主要内容有:

1. metrics指标;2. compile 模型配置;3. fit 模型训练;4. evaluate 模型评估;5. predict 预测;6. 自定义网络


1. metrics 性能指标

加权平均值: tf.keras.metrics.Mean

预测值和真实值的准确度: tf.keras.metrics.Accuracy


1.1 新建一个metrics指标

准确度指标 metrics.Accuracy() 一般用于训练集,加权平均值 metrics.Mean()  一般用于测试集

# 新建准确度指标
acc_meter = metrics.Accuracy()
# 新建平均值指标
mean_meter = metrics.Mean()

1.2 向metrics添加数据

添加数据:update_state()。每一次迭代,都向准确率指标中添加测试数据的真实值和测试数据的预测值,将准确率保存在缓存区,需要时取出来。向平均损失指标中添加每一次训练产生的损失,每添加进来一个值就计算加权平均值sample_weight指定每一项的权重,将结果保存在缓存区,需要时取出来。

# 计算真实值和预测值之间的准确度
acc_meter.update_state(y_true, predict)
# 计算平均损失
mean_meter = mean_meter.update_state(loss, sample_weight=None)

1.3 从metrics中取出数据

取出数据:result().numpy()。result()返回tensor类型数据,转换成numpy()类型的数据。

# 取出准确率
acc_meter.result().numpy()
# 取出训练集的损失值的均值
mean_meter.result().numpy()

1.4 清空缓存

清空缓存:reset_states()。每一次循环缓存区都会将之前的数据保存,在开始第二次循环之前,应该把缓存区清空,重新读入数据。

# 清空准确率的缓存
acc_meter.reset_states()
# 清空加权均值的缓存
mean_meter.reset_states()

2. compile 模型配置

compile(optimizer, loss, metrics, loss_weights)

参数设置:

optimizer: 用来配置模型的优化器,可以调用tf.keras.optimizers API配置模型所需要的优化器。

loss: 用来配置模型的损失函数,可以通过名称调用tf.losses API中已经定义好的loss函数。

metrics: 用来配置模型评价的方法,模型训练和测试过程中的度量指标,如accuracy、mse等

loss_weights: float类型,损失加权系数,总损失是所有损失的加权和,它的元素个数和模型的输出数量是1比1的关系。

# 选择优化器Adam,loss为交叉熵损失,测试集评价指标accurancy
network.compile(optimizer=optimizers.Adam(lr=0.01), #学习率0.01loss = tf.losses.CategoricalCrossentropy(from_logits=True),metrics = ['accuracy'])

3. fit 模型训练

fit(x, y, batch_size, epochs, validation_split, validation_data, shuffle, validation_freq)

参数:

x: 训练集的输入数据,可以是array或者tensor类型。

y: 训练集的目标数据,可以是array或者tensor类型。

batch_size:每一个batch的大小,默认32

epochs: 迭代次数

validation_split:配置测试集数据占训练数据集的比例,取值范围为0~1。

validation_data: 配置测试集数据(输入特征及目标)。如果已经配置validation_split参数,则可以不配置该参数。如果同时配置validation_split和validation_data参数,那么validation_split参数的配置将会失效。

shuffle:配置是否随机打乱训练数据。当配置steps_per_epoch为None时,本参数的配置失效。

validation_freq: 每多少次循环做一次测试

# ds为包含输入特征及目标的数据集
network.fit(ds, eopchs=20, validation_data=ds_val, validation_freq=2)
# validation_data给定测试集,validation_freq每多少次大循环做一次测试,测试时自动计算准确率

4. evaluate 模型评估

evaluate(x, y, batch_size, sample_weight, steps)

返回模型的损失及准确率等相关指标

参数:

x: 输入测试集特征数据

y:测试集的目标数据

batch_size: 整数或None。每个梯度更新的样本数。如果未指定,batch_size将默认为32。如果数据采用数据集,生成器形式,则不要指定batch_size。

sample_weight: 测试样本的可选Numpy权重数组,用于加权损失函数。

steps: 整数或None。宣布评估阶段结束之前的步骤总数。


5. predict 预测

predict(x, batch_size, steps)

参数:

x: numpy类型,tensor类型。预测所需的特征数据

batch_size: 每个梯度更新的样本数。如果未指定,batch_size将默认为32

steps: 整数或None,宣布预测回合完成之前的步骤总数(样本批次)。

等同于:

sample = next(iter(ds_pred)) # 每次从验证数据中取出一组batch
x = sample[0] # x 保存第0组验证集特征值
pred = network.predict(x)  # 获取每一个分类的预测结果
pred = tf.argmax(pred, axis=1) # 获取值最大的所在的下标即预测分类的结果
print(pred)

6. sequential

Sequential模型适用于简单堆叠网络层,即每一层只有一个输入和一个输出。

# ==1== 设置全连接层
# [b,784]=>[b,256]=>[b,128]=>[b,64]=>[b,32]=>[b,10],中间层一般从大到小降维
network = Sequential([layers.Dense(256, activation='relu'), #第一个连接层,输出256个特征layers.Dense(128, activation='relu'), #第二个连接层layers.Dense(64, activation='relu'), #第三个连接层layers.Dense(32, activation='relu'), #第四个连接层layers.Dense(10), #最后一层不需要激活函数,输出10个分类])
# ==2== 设置输入层维度
network.build(input_shape=[None, 28*28])
# ==3== 查看网络结构
network.summary()
# ==4== 查看网络的所有权重和偏置
network.trainable_variables
# ==5== 自动把x从第一层传到最后一层
network.call()

7. 自定义层构建网络

通过对 tf.keras.Model 进行子类化并定义自己的前向传播模型。在 __init__ 方法中创建层并将它们设置为类实例的属性。在 call 方法中定义前向传播。

# 自定义Dense层
class MyDense(layers.Layer): #必须继承layers.Layer层,放到sequential容器中# 初始化方法def __int__(self, input_dim, output_dim):super(MyDense, self).__init__() # 调用母类初始化,必须# 自己发挥'w''b'指定名字没什么用,创建shape为[input_dim, output_dim的权重# 使用add_variable创建变量self.kernel = self.add_variable('w', [input_dim, output_dim])self.bias = self.add_variable('b', [output_dim])# call方法,training来指示现在是训练还是测试def call(self, inputs, training=None):out = inputs @ self.kernel + self.biasreturn out# 自定义层来创建网络
class MyModel(keras.Model):  # 必须继承keras.Model大类,才能使用complie、fit等功能# def __init__(self):super(MyModel, self).__init__() # 调用父类Mymodel# 使用自定义层创建5层self.fc1 = MyDense(28*28,256) #input_dim=784,output_dim=256self.fc2 = MyDense(256,128)self.fc3 = MyDense(128,64)self.fc4 = MyDense(64,32)self.fc5 = MyDense(32,10)def call(self, inputs, training=None):# x从输入层到输出层x = self.fc1(inputs)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)        x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x) #logits层return x

【深度学习】(6) tensorflow2.0使用keras高层API相关推荐

  1. 吴恩达深度学习之tensorflow2.0 课程

    课链接 吴恩达深度学习之tensorflow2.0入门到实战 2019年最新课程 最佳配合吴恩达实战的教程 代码资料 自己取 链接:https://pan.baidu.com/s/1QrTV3KvKv ...

  2. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization

    <<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...

  3. TensorFlow2.0教程-keras 函数api

    TensorFlow2.0教程-keras 函数api Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details ...

  4. 第1章【深度学习简介】--动手学深度学习【Tensorflow2.0版本】

    项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的<动手学深度学习>开源书一经推出便广受好评.很多开 ...

  5. 第3章(3.11~3.16节)模型细节/Kaggle实战【深度学习基础】--动手学深度学习【Tensorflow2.0版本】

    项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的<动手学深度学习>开源书一经推出便广受好评.很多开 ...

  6. 第0章【序】--动手学深度学习【Tensorflow2.0版本】

    项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 这个项目将<动手学深度学习> 原书中MXNet代码实现改为Tenso ...

  7. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...

  8. 降低深度学习开发门槛,“动态图+高层API”能带来多大的便利?

    "在深度学习技术面前,我感觉到深深的危机感." 已经有无数人曾经吐槽过这一点,因为深度学习趋势不可阻挡,但其入门门槛之高,落地难度之大,又往往会把开发者挡于门外. 如何降低深度学习 ...

  9. 【深度学习与tensorflow2.0实战】(网易云课堂)13-GAN

    本文目录 GAN原理 纳什均衡-D.G EM距离 GAN实战 **gan.py** dataset.py GAN原理 Having Fun ▪ https://reiinakano.github.io ...

最新文章

  1. sql优化ppt_一款跨平台免费的开源 SQL 编辑器和数据库管理器!
  2. java中封装日期加时间_java日期处理简单封装
  3. 鸢尾花识别问题,萼片有什么用?
  4. 04 列表的增删改查 常用方法 元祖 range
  5. Ubuntu将python2.7默认更改为python3.X版本
  6. REVERSE-PRACTICE-BUUCTF-5
  7. 经典排序算法总结与Python实现(下)
  8. adb shell am 的用法
  9. 【工具】开发环境之vagrant
  10. 毕设题目:Matlab图像评价
  11. wps PPT 中提取视频
  12. 2017年5月—信息安全工程师—上午综合知识(11-15)
  13. ae合成设置快捷键_【教程】你不知道的全网最全ae快捷键【基础篇】
  14. 【机器学习理论】换底公式--以e,2,10为底的对数关系转化
  15. python实战演练(二)三级菜单
  16. 适合iPhone13的蓝牙耳机音质比较好有哪些?音质好的蓝牙耳机推荐
  17. win10常用快捷键和常用DOS命令
  18. 未连接到互联网,检查代理服务器地址
  19. Python实现文字转语音功能
  20. 完整的渗透测试靶场通关

热门文章

  1. Android XML: unbound prefix
  2. Android 相对布局别自己快遗忘的属性layout_alignRight,layout_alignBottom,layout_alignTop,layout_alignLeft
  3. Web Service 安全性解决方案(SOAP篇)
  4. SQL*Plus 系统变量之15 - DESC[RIBE]
  5. 博客非100%原创,在学习道路上,我一直站在别人肩上
  6. 深入浅出的webpack构建工具---DllPlugin DllReferencePlugin提高构建速度(七)
  7. 2022-2028年中国反射偏光膜行业市场研究及前瞻分析报告
  8. emacs 搭建racket开发环境
  9. 第一段冲刺_个人总结_5.2
  10. 前端页面——Cookie与Session有什么区别