前言

在TensorFlow中提供了挺多损失函数的,这里主要测试一下均方差与交叉熵相关的几个函数的计算流程。主要是测试来自于tf.nntf.lossesmean_square_errorsigmoid_cross_entrysoftmax_cross_entrysparse_softmax_cross_entry

国际惯例,参考博客:

官方文档

一文搞懂交叉熵在机器学习中的使用,透彻理解交叉熵背后的直觉

TensorFlow中多标签分类

预备

单热度编码one-hot

先复习一下one_hot编码,就是将真实标签转换为01标签,需要注意的是tf的one_hot编码中标签0代表的是1,0,0...而非0,0,0...

labels_n=np.array([0,1,2])
labels_oh=tf.one_hot(labels_n,depth=3)
with tf.Session() as sess:print(sess.run(labels_oh))'''[[1. 0. 0.][0. 1. 0.][0. 0. 1.]]'''

softmax

通常将最后的输出规整到和为1的形式:

softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)

设输出为z=(z1,z2,⋯ ,zn)z=(z_1,z_2,\cdots,z_n)z=(z1​,z2​,⋯,zn​),则
σ(z)j=ezj∑i=1nezk\sigma(z)_j=\frac{e^{z_j}}{\sum_{i=1}^n e^{z_k}} σ(z)j​=∑i=1n​ezk​ezj​​

sigmoid

激活函数:
f(x)=11+e−xf(x)=\frac{1}{1+e^{-x}} f(x)=1+e−x1​

交叉熵

多标签分类(每个样本可能属于多个标签),最后一层使用sigmoid激活:

−ylog⁡(P(y))−(1−y)log⁡(1−P(y))-y\log(P(y))-(1-y)\log(1-P(y)) −ylog(P(y))−(1−y)log(1−P(y))

单标签分类(每个样本只可能属于一个标签),最后一层使用softmax激活:
−∑i=1nyilog⁡(P(yi))-\sum_{i=1}^n y_i\log(P(y_i)) −i=1∑n​yi​log(P(yi​))

准备测试

进入测试之前,需要先引入相关的包

import numpy as np
import tensorflow as tf

交叉熵相关函数的测试,使用的变量是

labels=np.array([[1,0,0],[0,1,0],[0,0,1]],dtype='float32')
preds=np.array([[5,6,3],[7,5,1],[1,2,8]],dtype='float32')

均方差损失-MSE

原理

对应项相减的平方和的均值,通常用来做回归,计算预测值与真实值的误差

代码测试

定义相关变量:

ori_labels=np.array([[1,2,3]],dtype='float32')
pred_labels=np.array([[5,3,3]],dtype='float32')

调用原本函数测试:

mse_op=tf.losses.mean_squared_error(labels=ori_labels,predictions=pred_labels)
with tf.Session() as sess:print(sess.run(mse_op))'''5.6666665'''

手动实现过程:

with tf.Session() as sess:print(sess.run(tf.reduce_mean(tf.square(ori_labels-pred_labels))))
'''
5.6666665
'''

总结

原理就是求原标签与预测标签的平方和损失的均值。

sigmoid_cross_entry

原理

使用sigmoid激活的交叉熵,毫无疑问,玩得多标签分类,流程是:

  • 将输出用sigmoid激活
  • 使用多标签分类的交叉熵计算损失

代码测试

使用tf.losses中的交叉熵损失

tf_sce=tf.losses.sigmoid_cross_entropy(labels,preds)
with tf.Session() as sess:print(sess.run(tf_sce))
#2.3132434

使用tf.nn中的交叉熵损失:

tf_sce1=tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=preds)
with tf.Session() as sess:print(sess.run(tf_sce1))
'''[[6.7153485e-03 6.0024757e+00 3.0485873e+00][7.0009112e+00 6.7153485e-03 1.3132617e+00][1.3132617e+00 2.1269281e+00 3.3540637e-04]]
'''

使用流程实现:

#先计算sigmoid,再计算交叉熵
preds_sigmoid=tf.sigmoid(preds)
ce=-labels*tf.log(preds_sigmoid)-(1-labels)*(tf.log(1-preds_sigmoid))
# ce= - tf.reduce_sum(labels*tf.log(preds_sigmoid),-1)
with tf.Session() as sess:print(sess.run(ce))print(sess.run(tf.reduce_mean(ce)))
'''
[[6.7153242e-03 6.0024934e+00 3.0485876e+00][7.0009704e+00 6.7153242e-03 1.3132617e+00][1.3132617e+00 2.1269276e+00 3.3539196e-04]]
2.3132522
'''

总结

  • 多标签分类,输入是原始和预测标签的编码

  • tf.losses中的计算结果是tf.nn中计算结果的均值

softmax_cross_entry

原理

使用softmax激活,显然就是单标签分类的情况,流程是:

  • 将输出用softmax激活
  • 计算单标签分类的交叉熵损失

代码测试

使用tf.losses中的函数:

tf_sce=tf.losses.softmax_cross_entropy(labels,preds)
with tf.Session() as sess:print(sess.run(tf_sce))
#1.160502

使用tf.nn中的函数:

tf_sce1=tf.nn.softmax_cross_entropy_with_logits(labels=labels,logits=preds)
with tf.Session() as sess:print(sess.run(tf_sce1))
#[1.3490121  2.129109   0.00338493]

使用流程计算:

#先计算softmax,再计算交叉熵
preds_sigmoid=tf.nn.softmax(preds)
ce= - tf.reduce_sum(labels*tf.log(preds_sigmoid),-1)
# ce=-labels*tf.log(preds_sigmoid)-(1-labels)*(tf.log(1-preds_sigmoid))
with tf.Session() as sess:print(sess.run(ce))print(sess.run(tf.reduce_mean(ce)))
'''
[1.3490121  2.129109   0.00338495]
1.1605021
'''

总结

  • 用于单标签分类,输入是真实和预测标签的单热度编码
  • tf.losses中的计算结果是tf.nn中计算结果的均值

sparse_softmax_cross_entry

原理

还是看到softmax,依旧是单标签分类,但是多了个sparse,代表输入标签可以是非单热度标签,流程:

  • 将原标签转为单热度编码
  • 将输出用softmax激活
  • 计算单标签分类的交叉熵

代码测试

假设原始标签的非单热度编码是:

labels_n=np.array([0,1,2])

利用tf.losses中的损失函数:

tf_scen=tf.losses.sparse_softmax_cross_entropy(labels=labels_n,logits=preds)
with tf.Session() as sess:print(sess.run(tf_sce))
#1.160502

利用tf.nn中的损失函数:

tf_sce1=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels_n,logits=preds)
with tf.Session() as sess:print(sess.run(tf_sce1))print(sess.run(tf.reduce_mean(tf_sce1)))
'''
[1.3490121  2.129109   0.00338493]
1.160502
'''

利用流程实现:

labels_onehot=tf.one_hot(labels_n,depth=3)
preds_sigmoid=tf.nn.softmax(preds)
ce= - tf.reduce_sum(labels_onehot*tf.log(preds_sigmoid),-1)
# ce=-labels*tf.log(preds_sigmoid)-(1-labels)*(tf.log(1-preds_sigmoid))
with tf.Session() as sess:print(sess.run(labels_onehot))print(sess.run(ce))print(sess.run(tf.reduce_mean(ce)))  '''
[[1. 0. 0.][0. 1. 0.][0. 0. 1.]]
[1.3490121  2.129109   0.00338495]
1.1605021
'''

总结

  • sparse代表原始标签不用转成单热度编码
  • 适用于单标签分类
  • tf.lossestf.nn中函数的均值

总结

本文主要对比了:

  • tf.nntf.losses中同一类损失函数的使用方法与区别
  • 分析计算流程,并实现验证
  • 了解TensorFlow中回归、单标签分类、多标签分类的损失函数的选择

博客代码:

链接:https://pan.baidu.com/s/1b40rNxjdOIIE2g7_Afctiw
提取码:0sb0

【TensorFlow-windows】部分损失函数测试相关推荐

  1. 8.2 TensorFlow实现KNN与TensorFlow中的损失函数,优化函数

    前言 8.1 mnist_soft,TensorFlow构建回归模型中对主要对计算图的概念与公式与计算图的转化进行了介绍,8.2则主要介绍一下TensorFlow中自带的几个算子,与优化函数,损失函数 ...

  2. xmpp 服务器配置 open fire for windows 及 spark 测试

    xmpp 服务器配置 open fire for windows 此文章为 XMPP windows服务器配置,使用的是 open fire 3.9.1.exe 1: 下载 open fire ope ...

  3. 在cmd指令看计算机位数,在.cmd中使用Windows命令来测试32位或64位并运行命令

    我正在编写一个脚本,用于查找注册表值并将该值返回给Windows命令提示符屏幕,并将其添加到.txt文件中.我到了需要测试的位置,看看机器是32位还是64位,这样我才知道使用哪个命令来查找我在注册表中 ...

  4. Windows Mobile logo测试介绍

    首先声明本文转自:http://softtest.chinaitlab.com/sji/744369.html 一.Windows Mobile简介 Windows Mobile是微软主要针对手机市场 ...

  5. tensorflow windows

    conda create -n py35 python=3.5 activate py35 pip install --ignore-installed --upgrade https://stora ...

  6. TensorFlow windows之Tensorboard使用

    TensorBoard是什么 TensorBoard:Tensorflow自带的可视化工具,TensorBoard 来展现 TensorFlow 图像,绘制图像生成的定量指标图以及附加数据 关于如何在 ...

  7. 江西省省赛中职网络安全-Windows操作系统渗透测试

    9:Windows操作系统渗透测试 具体渗透方法观看视频即可 任务环境说明:  服务器场景:Server05  服务器场景操作系统:Windows(版本不详)(封闭靶机) 通过本地PC中渗透测试平 ...

  8. Windows端 USBIP测试

    Windows端 USBIP测试 注:-- 写在前面: 其实USBIP最开始的时候只是用在Linux端的, 具体是如下: Windows端只可以作为客户端, 不可以做为服务端 Linux端既可以作为客 ...

  9. wap2.0有关windows mobile模拟器测试环境的搭建

    wap2.0有关windows mobile模拟器测试环境的搭建 2009年08月01日 星期六 19:38 理论上测试只需要支持wap2.0的模拟器即可,但是各款模拟器不尽相同,起初我用openwa ...

  10. DMU在windows下安装测试--外篇1

    DMU在windows下安装测试–外篇1 1. 下载地址 下载地址:http://dmu.agrsci.dk/ 64为电脑安装DMUv6-R5-2-EM64T.msi, 32为电脑安装DMUv6-R5 ...

最新文章

  1. 中间画一条短竖线_许愿孔明灯怎么画,简约好看的孔明灯简笔画教程
  2. 1.6-1.7配置IP1.8网络问题排查
  3. Leetcode 950. Reveal Cards In Increasing Order
  4. 【差分隐私组合定理,直方图,列联表代码实现】差分隐私代码实现系列(五)
  5. Flask 源代码阅读笔记
  6. Hash函数加密算法(一)
  7. 文件系统功能 os模块 子模块os.path pickle
  8. 记号, 函数空间及不等式
  9. mobi 直接转化为 html,MobiCreator--pdf文档转化为kindle可阅读的格式
  10. ICP备案线下注销 网站域名备案注销
  11. Ping32文档透明加密软件基础概念
  12. 知识图谱导论----相关笔记
  13. ad如何计算电路板的pin数量_各类EDA软件统计pin数方法
  14. C++游戏——小胎大乱斗
  15. Win11电脑怎么让两个屏幕任务栏都显示时间?
  16. 基于云效Flow配置 Jenkins 源
  17. 系统之家xp服务器系统怎么安装,windowsxp系统之家系统详细安装步骤
  18. LeetCode:322. 零钱兑换(python)
  19. 宋宝华:Linux设备与驱动的手动解绑与手动绑定
  20. python的print输出居中对齐

热门文章

  1. gihosoft android 教程,Gihosoft Free Android Data Recovery
  2. 怎么提取html的数据,如何提取网页数据
  3. Greg and Array CodeForces - 296C(差分数组+线段树)
  4. CCF之地铁修建(100分)
  5. node 加密解密模块_NODE.JS加密模块CRYPTO常用方法介绍
  6. 远程电脑桌面控制怎么看计算机,计算机如何通过远程控制,可以查看他人电脑屏幕...
  7. 图神经网络(一)图信号处理与图卷积神经网络(3)图傅里叶变换
  8. oracle+标记要,oracle ORA-00031:session marked for kill(标记要终止的会话)解决方法
  9. 什么是宇宙安全声明_《三体》三体人是否知道如何向宇宙发表安全声明?
  10. Leetcode 1. 两数之和 (Python版)