TensorFlow2.0: keras.metrics的使用
keras.metrics中有两个api函数可以简化准确率acc和损失值loss的计算。其分别是metrics.Accuracy( )和metrics.Mean( )。
一、建立测量尺
#建立测量尺
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()
二、更新数据
loss_meter.update_state(loss)
acc_meter.update_state(y,pred)
三、获取数据并清除buffer
print('epoch:',epoch,' step:',step,' loss:',loss_meter.result().numpy())
loss_meter.reset_states()
print('epoch:',epoch,' acc:',acc_meter.result().numpy())
acc_meter.reset_states()
以mnist数据集训练为例:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,metrics,datasets,Sequential
import datetime
import io#建立监听例子
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)def preprocess(x,y):x = tf.cast(x,dtype=tf.float32)/255y = tf.cast(y,dtype=tf.int32)return x,y(x_train,y_train),(x_test,y_test) = datasets.mnist.load_data()
print('x_train.shape = ', x_train.shape,', y_train.shape = ',y_train.shape)
print('x_test.shape = ', x_test.shape,', y_test.shape = ',y_test.shape)batchsz = 128
epoch_num = 20
lr = 1e-3db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db = db.map(preprocess).shuffle(10000).batch(batchsz)
db_val = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_val = db_val.map(preprocess).batch(batchsz)model = Sequential([layers.Dense(512,activation=tf.nn.relu),layers.Dense(256,activation=tf.nn.relu),layers.Dense(128,activation=tf.nn.relu),layers.Dense(64,activation=tf.nn.relu),layers.Dense(10)
])
model.build(input_shape=[None,28*28])
model.summary()
op = optimizers.Adam(lr)#建立测量尺
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()for epoch in range(epoch_num):for step,(x,y) in enumerate(db):with tf.GradientTape() as tape:x = tf.reshape(x,[-1,28*28])y_onehot = tf.one_hot(y,depth=10)logits = model(x)loss = tf.losses.categorical_crossentropy(y_true=y_onehot,y_pred=logits,from_logits=True)loss_meter.update_state(loss)grads = tape.gradient(loss,model.trainable_variables)op.apply_gradients(zip(grads,model.trainable_variables))if step%100 ==0:print('epoch:',epoch,' step:',step,' loss:',loss_meter.result().numpy())with summary_writer.as_default():tf.summary.scalar(name='loss',data=loss_meter.result().numpy(),step=epoch*600+step)loss_meter.reset_states()for step,(x,y) in enumerate(db_val):x = tf.reshape(x,[-1,28*28])out = model(x)#[b,10]pred = tf.cast(tf.argmax(out,axis=1),dtype=tf.int32)#[b,1]acc_meter.update_state(y,pred)print('epoch:',epoch,' acc:',acc_meter.result().numpy())with summary_writer.as_default():tf.summary.scalar('acc',data=acc_meter.result().numpy(),step=epoch)acc_meter.reset_states()
TensorFlow2.0: keras.metrics的使用相关推荐
- Tensorflow2.0(Keras)转换TFlite
Tensorflow 2.0(Keras)转换TFlite 目录 Tensorflow 2.0(Keras)转换TFlite 1. TensorFlow Lite 指南 (1)TensorFlow L ...
- TensorFlow2.0 Keras多层感知器模型imdb情感分类
# 下载 import urllib.request import os import tarfileurl = 'http://ai.stanford.edu/~amaas/data/sentime ...
- tensorflow2.0 Keras VGG16 VGG19 系列 代码实现
模型介绍参看:博文 VGG16 迁移模型 先看看标准答案 import tensorflow as tf from tensorflow import kerasbase_model = keras. ...
- 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战
基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕 前言 运行python环境 ...
- 基于TensorFlow2.0的摄像头数字识别
import numpy as np import cv2 from skimage import data, segmentation, measure, morphology, color imp ...
- 【深度学习】(6) tensorflow2.0使用keras高层API
各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量.主要内容有: 1. metrics指标:2. compile 模型配置:3. fit 模型训练 ...
- 【TensorFlow2.0】以后我们再也离不开Keras了?
TensorFlow2.0 Alpha版已经发布,在2.0中最重要的API或者说到处都出现的API是谁,那无疑是Keras.因此用过2.0的人都会吐槽全世界都是Keras.今天我们就来说说Keras这 ...
- 官方钦定TensorFlow2.0要改这个API,用户吐槽:全世界都是keras
郭一璞 发自 凹非寺 量子位 报道 | 公众号 QbitAI 前不久,Keras的爸爸François Chollet在GitHub上发起了一个提议: 咱们把tf.train和tf.keras.op ...
- Tensorflow2.0:使用Keras自定义网络实战
tensorflow2.0建议使用tf.keras作为构建神经网络的高级API 接下来我就使用tensorflow实现VGG16去训练数据 背景介绍: 2012年 AlexNet 在 ImageNet ...
最新文章
- QT小例子 ---文件查找
- 模拟网络通信中存储转发的分组交换算法
- 深度学习的150多篇文章和10多个专栏推荐
- 【机器学习基础】数学推导+纯Python实现机器学习算法10:线性不可分支持向量机...
- ubuntu16.04下安装opencv出现libgtk2.0-dev配置失败问题解决方法
- 重新开始Java的原始字符串文字讨论
- linux文件名快速键入,linux修改文件名【使用模式】
- PHP-dede学习:common.ini.php文件
- 最简单的WIN7内核PE系统
- 微信小程序自定义拍照和H5调用摄像头拍照
- android 日历折叠,可折叠的日历控件Calendar
- 为什么蓝鸽的听力下载完还是听不了_推荐这款练习英语听力的神器级免费App
- 【笔记】Opencv 绘制朱利亚(Julia)集合图形
- R语言抽样并验证总体分别为正态分布、均匀分布、指数分布时样本均值的抽样分布
- 读《我喜欢生命本来的样子》记(二)
- 体验一个人自驾游思考人生
- 使用python切割图片
- JavaScript实现鼠标点击监听---弹出社会主义核心价值观(面向对象小练习)
- Pr 入门教程如何个性化“时间轴”面板?
- Semantic Object Segmentation in Weakly Labelled Videos via A Self-Paced Fine-Tuning Network