keras.metrics中有两个api函数可以简化准确率acc和损失值loss的计算。其分别是metrics.Accuracy( )和metrics.Mean( )。
一、建立测量尺

#建立测量尺
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()

二、更新数据

loss_meter.update_state(loss)
acc_meter.update_state(y,pred)

三、获取数据并清除buffer

print('epoch:',epoch,' step:',step,' loss:',loss_meter.result().numpy())
loss_meter.reset_states()
print('epoch:',epoch,' acc:',acc_meter.result().numpy())
acc_meter.reset_states()

以mnist数据集训练为例:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,metrics,datasets,Sequential
import datetime
import io#建立监听例子
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)def preprocess(x,y):x = tf.cast(x,dtype=tf.float32)/255y = tf.cast(y,dtype=tf.int32)return x,y(x_train,y_train),(x_test,y_test) = datasets.mnist.load_data()
print('x_train.shape = ', x_train.shape,', y_train.shape = ',y_train.shape)
print('x_test.shape = ', x_test.shape,', y_test.shape = ',y_test.shape)batchsz = 128
epoch_num = 20
lr = 1e-3db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db = db.map(preprocess).shuffle(10000).batch(batchsz)
db_val = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_val = db_val.map(preprocess).batch(batchsz)model = Sequential([layers.Dense(512,activation=tf.nn.relu),layers.Dense(256,activation=tf.nn.relu),layers.Dense(128,activation=tf.nn.relu),layers.Dense(64,activation=tf.nn.relu),layers.Dense(10)
])
model.build(input_shape=[None,28*28])
model.summary()
op = optimizers.Adam(lr)#建立测量尺
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()for epoch in range(epoch_num):for step,(x,y) in enumerate(db):with tf.GradientTape() as tape:x = tf.reshape(x,[-1,28*28])y_onehot = tf.one_hot(y,depth=10)logits = model(x)loss = tf.losses.categorical_crossentropy(y_true=y_onehot,y_pred=logits,from_logits=True)loss_meter.update_state(loss)grads = tape.gradient(loss,model.trainable_variables)op.apply_gradients(zip(grads,model.trainable_variables))if step%100 ==0:print('epoch:',epoch,' step:',step,' loss:',loss_meter.result().numpy())with summary_writer.as_default():tf.summary.scalar(name='loss',data=loss_meter.result().numpy(),step=epoch*600+step)loss_meter.reset_states()for step,(x,y) in enumerate(db_val):x = tf.reshape(x,[-1,28*28])out = model(x)#[b,10]pred = tf.cast(tf.argmax(out,axis=1),dtype=tf.int32)#[b,1]acc_meter.update_state(y,pred)print('epoch:',epoch,' acc:',acc_meter.result().numpy())with summary_writer.as_default():tf.summary.scalar('acc',data=acc_meter.result().numpy(),step=epoch)acc_meter.reset_states()

TensorFlow2.0: keras.metrics的使用相关推荐

  1. Tensorflow2.0(Keras)转换TFlite

    Tensorflow 2.0(Keras)转换TFlite 目录 Tensorflow 2.0(Keras)转换TFlite 1. TensorFlow Lite 指南 (1)TensorFlow L ...

  2. TensorFlow2.0 Keras多层感知器模型imdb情感分类

    # 下载 import urllib.request import os import tarfileurl = 'http://ai.stanford.edu/~amaas/data/sentime ...

  3. tensorflow2.0 Keras VGG16 VGG19 系列 代码实现

    模型介绍参看:博文 VGG16 迁移模型 先看看标准答案 import tensorflow as tf from tensorflow import kerasbase_model = keras. ...

  4. 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战

    基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕 前言 运行python环境 ...

  5. 基于TensorFlow2.0的摄像头数字识别

    import numpy as np import cv2 from skimage import data, segmentation, measure, morphology, color imp ...

  6. 【深度学习】(6) tensorflow2.0使用keras高层API

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量.主要内容有: 1. metrics指标:2. compile 模型配置:3. fit 模型训练 ...

  7. 【TensorFlow2.0】以后我们再也离不开Keras了?

    TensorFlow2.0 Alpha版已经发布,在2.0中最重要的API或者说到处都出现的API是谁,那无疑是Keras.因此用过2.0的人都会吐槽全世界都是Keras.今天我们就来说说Keras这 ...

  8. 官方钦定TensorFlow2.0要改这个API,用户吐槽:全世界都是keras

    郭一璞 发自 凹非寺  量子位 报道 | 公众号 QbitAI 前不久,Keras的爸爸François Chollet在GitHub上发起了一个提议: 咱们把tf.train和tf.keras.op ...

  9. Tensorflow2.0:使用Keras自定义网络实战

    tensorflow2.0建议使用tf.keras作为构建神经网络的高级API 接下来我就使用tensorflow实现VGG16去训练数据 背景介绍: 2012年 AlexNet 在 ImageNet ...

最新文章

  1. QT小例子 ---文件查找
  2. 模拟网络通信中存储转发的分组交换算法
  3. 深度学习的150多篇文章和10多个专栏推荐
  4. 【机器学习基础】数学推导+纯Python实现机器学习算法10:线性不可分支持向量机...
  5. ubuntu16.04下安装opencv出现libgtk2.0-dev配置失败问题解决方法
  6. 重新开始Java的原始字符串文字讨论
  7. linux文件名快速键入,linux修改文件名【使用模式】
  8. PHP-dede学习:common.ini.php文件
  9. 最简单的WIN7内核PE系统
  10. 微信小程序自定义拍照和H5调用摄像头拍照
  11. android 日历折叠,可折叠的日历控件Calendar
  12. 为什么蓝鸽的听力下载完还是听不了_推荐这款练习英语听力的神器级免费App
  13. 【笔记】Opencv 绘制朱利亚(Julia)集合图形
  14. R语言抽样并验证总体分别为正态分布、均匀分布、指数分布时样本均值的抽样分布
  15. 读《我喜欢生命本来的样子》记(二)
  16. 体验一个人自驾游思考人生
  17. 使用python切割图片
  18. JavaScript实现鼠标点击监听---弹出社会主义核心价值观(面向对象小练习)
  19. Pr 入门教程如何个性化“时间轴”面板?
  20. Semantic Object Segmentation in Weakly Labelled Videos via A Self-Paced Fine-Tuning Network

热门文章

  1. C++ const常量和指针
  2. 启动hadoop遇到的datanode启动不了
  3. 电力项目十--整合文本编辑器
  4. java 16 - 15 集合嵌套存储和遍历元素
  5. ASP.NET MVC学习---(一)ORM框架,EF实体数据模型简介
  6. 输入问题C++字符数组越界问题的一个案例分析
  7. 10 Ways To Suck At Programming
  8. Smart ORM v0.3发布(完全面向对象的轻量级ORM工具)
  9. PHPCMS V9.6.0 SQL注入漏洞EXP
  10. 视频采集以及播放的流程