文章目录

  • 0.函数介绍
  • 1.区别联系
    • 1.1 tf.nn.softmax_cross_entropy_with_logits
    • 1.2 tf.nn.sparse_softmax_cross_entropy_with_logits
    • 1.3 tf.contrib.legacy_seq2seq.sequence_loss_by_example
  • 2.代码呈现
  • 3.References

Author: Cao Shengming
Email: caoshengming@trio.ai ? checkmate.ming@gmail.com
Company: Trio 北京(三角兽)科技有限公司


0.函数介绍

这两个函数是在 model 中非常常用的两个损失函数,不管是序列标注还是语言模型中都会见到他们两个的身影,总的来说tf.nn.softmax_cross_entropy_with_logits 是 tf.contrib.legacy_seq2seq.sequence_loss_by_example 的特殊情况,而且在代码处理中也有一定的技巧。
(注:查东西的时候先看 api 和源码再去翻各种乱七八糟的博客,效率会更高。)

1.区别联系

1.1 tf.nn.softmax_cross_entropy_with_logits

**函数实现:**传统实现不赘述
函数输入:

logits: [batch_size, num_classes]
labels: [batch_size, num_classes]
logits和 labels 拥有相同的shape

代码示例:

import tensorflow as tf
import numpy as np
y = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]])  # onestep vector
logits = np.array([[12, 3, 2], [3, 10, 1], [1, 2, 5], [4, 6.5, 1.2], [3, 6, 1]])
y_ = tf.nn.softmax(logits)
e1 = -np.sum(y * np.log(y_), -1)  # reduce_sum 所有样本 loss 求和sess = tf.Session()
y = np.array(y).astype(np.float64)
e2 = sess.run(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))  # labels 和 logtis shape 相同print("公式计算的结果:\n", e1)
print("tf api 计算的结果:\n", e2)

1.2 tf.nn.sparse_softmax_cross_entropy_with_logits

主要区别:与上边函数不同,输入 labels 不是 one-hot 格式所以会少一维
函数输入:

logits: [batch_size, num_classes]
labels: [batch_size]
logits和 labels 拥有相同的shape

代码示例:

import tensorflow as tf
labels = [0,1,2] #只需给类的编号,从 0 开始logits = [[2,0.5,1],[0.1,1,3],[3.1,4,2]]logits_scaled = tf.nn.softmax(logits)
result = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)with tf.Session() as sess:print(sess.run(result))

1.3 tf.contrib.legacy_seq2seq.sequence_loss_by_example

函数实现:

def sequence_loss_by_example(logits, targets, weights,average_across_timesteps=True,softmax_loss_function=None, name=None):
#logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
#targets: List of 1D batch-sized int32 Tensors of the same length as logits.
#weights: List of 1D batch-sized float-Tensors of the same length as logits.
#return:log_pers 形状是 [batch_size].for logit, target, weight in zip(logits, targets, weights):if softmax_loss_function is None:# TODO(irving,ebrevdo): This reshape is needed because# sequence_loss_by_example is called with scalars sometimes, which# violates our general scalar strictness policy.target = array_ops.reshape(target, [-1])crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(logit, target)else:crossent = softmax_loss_function(logit, target)log_perp_list.append(crossent * weight)log_perps = math_ops.add_n(log_perp_list)if average_across_timesteps:total_size = math_ops.add_n(weights) total_size += 1e-12  # Just to avoid division by 0 for all-0 weights.log_perps /= total_size

函数说明: 可以发现通过 zip 操作对 list 的每个元素执行一次 sparse 操作,其他的都与 sparse 是相同的,所以使用这个函数的关键在于如何进行输入的 list 的构造。最容易想到的是安 sequence_length进行 unstack,但是这样会给输入的构造带来很多额外的工作量。具体代码使用时是由一定技巧的,请参见下一部分代码呈现

2.代码呈现

此处我们将展示在训练 char-level语言模型的时候,两种损失函数的处理,就可以搞清这两个函数到底是怎么使用的。

case1:使用 softmax…的情况

def build_loss(self):with tf.name_scope('loss'):y_one_hot = tf.one_hot(self.targets, self.num_classes)y_reshaped = tf.reshape(y_one_hot, self.logits.get_shape())loss =tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=y_reshaped)self.loss = tf.reduce_mean(loss)

case1:使用 sequece…的情况

with tf.name_scope('loss'):output = tf.reshape(outputs, [-1, args.state_size])self.logits = tf.matmul(output, w) + bself.probs = tf.nn.softmax(self.logits)self.last_state = last_statetargets = tf.reshape(self.target_data, [-1])loss = seq2seq.sequence_loss_by_example([self.logits],[targets],[tf.ones_like(targets, dtype=tf.float32)])self.cost = tf.reduce_sum(loss) / args.batch_sizetf.summary.scalar('loss', self.cost)

上述代码总结:
我们可以清晰地看到两者接受的原始的输入竟然都是一样的,所以这两者在处理时都用到了一定的技巧,前者的技巧是是将 [B*T] 作为 [B] ,后者的技巧是在外边包一层 “[ ]” 留给函数内部的 zip 来使用,当然同样是将 [B*T] 作为 [B] ,这些操作都要考虑清楚,才能搞清楚函数的差异到底在哪。
另外需要注意的一点是 sequence 函数中的 weight 参数,可以替代手工的 loss mask 对不需要的 padding 位置不进行 loss 的计算。

3.References

  • cross_entropy 代码
  • sequence_loss 代码
  • tensorflow中sequence_loss_by_example()函数的计算过程(结合TF的ptb构建语言模型例子)
  • 新老损失函数 api 介绍

tf.nn.softmax_cross_entropy_with_logits 和 tf.contrib.legacy_seq2seq.sequence_loss_by_example 的联系与区别相关推荐

  1. tf.nn.softmax_cross_entropy_with_logits()笔记及交叉熵

    交叉熵 交叉熵可在神经网络(机器学习)中作为损失函数,p表示真实标记的分布,q则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量p与q的相似性.交叉熵作为损失函数还有一个好处是使用sigmoid函 ...

  2. 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法

    [TensorFlow]tf.nn.softmax_cross_entropy_with_logits的用法 from:https://blog.csdn.net/mao_xiao_feng/arti ...

  3. tf.nn.sparse_softmax_cross_entropy_with_logits()与tf.nn.softmax_cross_entropy_with_logits的差别

    这两个函数的用法类似 sparse_softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, name=N ...

  4. 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits中的“logits”到底是个什么意思?

    tf.nn.softmax_cross_entropy_with_logits中的"logits"到底是个什么意思?_玉来愈宏的随笔-CSDN博客 https://blog.csd ...

  5. 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits 函数:求交叉熵损失

    [TensorFlow]tf.nn.softmax_cross_entropy_with_logits的用法_xf__mao的博客-CSDN博客 https://blog.csdn.net/mao_x ...

  6. 【TensorFlow】TensorFlow函数精讲之tf.nn.softmax_cross_entropy_with_logits

    tf.nn.softmax_cross_entropy_with_logits()函数是TensorFlow中计算交叉熵常用的函数. 后续版本中,TensorFlow更新为:tf.nn.softmax ...

  7. TensorFlow基础篇(三)——tf.nn.softmax_cross_entropy_with_logits

    tf.nn.softmax_cross_entropy_with_logits()函数是TensorFlow中计算交叉熵常用的函数. 后续版本中,TensorFlow更新为:tf.nn.softmax ...

  8. tf.nn.dropout和tf.keras.layers.Dropout的区别(TensorFlow2.3)与实验

    这里写目录标题 场景:dropout和Dropout区别 问题描述: 结论: 深层次原因:dropout是底层API,Dropout是高层API 场景:dropout和Dropout区别 全网搜索tf ...

  9. 【TensorFlow】TensorFlow函数精讲之tf.nn.max_pool()和tf.nn.avg_pool()

    tf.nn.max_pool()和tf.nn.avg_pool()是TensorFlow中实现最大池化和平均池化的函数,在卷积神经网络中比较核心的方法. 有些和卷积很相似,可以参考TensorFlow ...

最新文章

  1. 将一个对象拆开拼接成URL
  2. Angular NgRx MemoizedSelector的类型定义学习
  3. 学一下Unix/C啊
  4. php在四线城市待遇如何,月薪5000元在四线城市算什么水平,丢人吗?
  5. 移动端高清适配方案(解决图片模糊问题、1px细线问题)
  6. try catch与异常的说明
  7. 中国ai人工智能发展太快_中国的AI:开放采购和幕后玩家
  8. 201912-3 化学方程式 的一种解法
  9. Oracle DBA日常工作手册
  10. 代码随想录算法训练营第二十二天
  11. Oracle多个数据库备份和还原,oracle 多数据库还原
  12. 步进电机驱动器驱动不了电机的一种情况及解决方案
  13. Linux服务器下Matlab的安装
  14. AntiVir UNIX 在Ubuntu 8.04下的安装
  15. [附源码]Nodejs计算机毕业设计敬老院信息管理系统Express(程序+LW)
  16. 函数的单调性和曲线的凹凸性
  17. 传智播客 HTTP协议详解
  18. python3编程实战_生信编程实战第3题(python)
  19. MOOC编程判定身高是否与预计的相符合。(10分)
  20. 图形数据库之Neo4j学习(一)

热门文章

  1. 提示找不到include/common.h 提示No package 'minigui' found
  2. Ubunt_配置_start
  3. 面向过程与面向对象编程的区别和优缺点
  4. Arcgis10.3在添加XY数据时出现问题
  5. fastRPC的数据库服务
  6. spring security 允许 iframe 嵌套
  7. Java——Java封装
  8. 线程间通信的三种方法 (转)
  9. 关于HtmlParser中Parser【org.htmlparser.Parser】这个类奇怪的地方...求解释【已获得解释】...
  10. 加密和解密.net配置节