目录

  • 1 目的
  • 2 tensorflow源码解读
    • 2.1函数的输入和输出
    • 2.2 代码讲解
    • 2.3 缺点
  • 3 正label个数不一致解决方案
    • 3.1 增加一个pad标签作为负label
  • 4 参考

1 目的

在分类任务中,如果分类类别量级较大,几十万甚至上百万的量级,那么最后一层分类层计算将会十分耗时,为了降低模型计算复杂度,每次前向计算,采样部分label参数www参与计算,反向计算梯度的时候也只更新部分参与计算的部分www,而不需要每次更新全部的权重参数www,这样可以大大提高模型的训练速度。

2 tensorflow源码解读

2.1函数的输入和输出

nn_impl.py这是NCE loss的tensorflow源代码,接下来我们对源代码进行一个梳理和讲解,首先我们来看下nce_loss在tensorflow源码中的实现,如下代码所示:

def nce_loss(weights,biases,labels,inputs,num_sampled,num_classes,num_true=1,sampled_values=None,remove_accidental_hits=False,partition_strategy="mod",name="nce_loss"):#计算采样的labels和对应的logits(wx+b)值logits, labels = _compute_sampled_logits(weights=weights,biases=biases,labels=labels,inputs=inputs,num_sampled=num_sampled,num_classes=num_classes,num_true=num_true,sampled_values=sampled_values,subtract_log_q=True,remove_accidental_hits=remove_accidental_hits,partition_strategy=partition_strategy,name=name)# 交叉熵losssampled_losses = sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, name="sampled_losses")# 返回loss求和return _sum_rows(sampled_losses)

函数输入参数

  • weights: [num_classes, dim],最后一层分类层的权重参数www
  • biases: [num_classes],最后一层分类层的偏移bbb
  • labels: [batch_size, num_true],int64类型,每个batch的正label索引idex,要求每个样本的正label数量必须一致,值为num_true (这也是在实际应用中不灵活的部分,后面会有改进的方案)
  • inputs: [batch_size, dim],输入分类层特征向量
  • num_sampled: int类型,每个样本随机采样的负样本个数
  • num_classes: int类型,分类的label总数量
  • num_true: int类型,每个样本的正label数量(一个batch里的所有样本的正label必须一致)
  • sampled_values: 自定义采样的候选集,是个三元组 (采样的候选集,正label数量,采样的label数量),默认是None,采用log_uniform_candidate_sampler采用器
  • remove_accidental_hits: bool类型,是否去除采样到的label有在正label集合里的,设置为“True"则会用负采样loss而不是NCE。
  • partition_strategy: 两种模式“mod"和”div",默认”mod",详情可以参考tf.nn.embedding_lookup
  • name: operation的名称

函数返回值
一维的向量,长度大小为[batch_size],对应每个样本的loss值。

2.2 代码讲解

接下来我们对每个函数的实现做一个深入分析,由上可知,nce_loss函数下主要有三个函数组成,_compute_sampled_logits,sigmoid_cross_entropy_with_logits和_sum_rows。

_compute_sampled_logits函数


def _compute_sampled_logits(weights,biases,labels,inputs,num_sampled,num_classes,num_true=1,sampled_values=None,subtract_log_q=True,div_flag=True,remove_accidental_hits=False,partition_strategy="mod",name=None,seed=None):#数据格式转if isinstance(weights, variables.PartitionedVariable):weights = list(weights)if not isinstance(weights, list):weights = [weights]# 数据格式转换,将label [batch_size, num_ture] 展开,得到一维的sizewith ops.name_scope(name, "compute_sampled_logits",weights + [biases, inputs, labels]):if labels.dtype != dtypes.int64:labels = math_ops.cast(labels, dtypes.int64)labels_flat = array_ops.reshape(labels, [-1])#如果采样label不传入,则默认用log_unifrom_candidate_sampler采样器,生成采用的labelif sampled_values is None:sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(true_classes=labels,num_true=num_true,num_sampled=num_sampled,unique=True,range_max=num_classes,seed=seed)# sampled:[num_sampled],true_expected_count:[batch_size,1]# sampled_expected_count: [num_sampled]# 采样的值不参与梯度更新,所以用stop_gradient标明sampled, true_expected_count, sampled_expected_count = (array_ops.stop_gradient(s) for s in sampled_values)sampled = math_ops.cast(sampled, dtypes.int64)# labels_flat:[batch_size * num_true],sampled: [num_sampled]#将正label和负label对应的索引合并到一起all_ids = array_ops.concat([labels_flat, sampled], 0)#通过索引all_ids从权重参数矩阵weights:[num_classes, dim],取出对应的权重参数,得到all_wall_w = embedding_ops.embedding_lookup(weights, all_ids,partition_strategy=partition_strategy)if all_w.dtype != inputs.dtype:all_w = math_ops.cast(all_w, inputs.dtype)# 抽离出正label w权重参数 true_w,和负label权重参数sampled_w#true_w :[batch_size * num_true, dim]# sampled_w: [num_sampled, dim], 一个batch里,每个样本的负label都是一样的true_w = array_ops.slice(all_w, [0, 0],array_ops.stack([array_ops.shape(labels_flat)[0], -1]))sampled_w = array_ops.slice(all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])#在对应的负label上,计算wx+b,inputs: [batch_size, dim]# sampled_logits: [batch_size, num_sampled]sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)# 与计算all_w一样,抽取偏移all_ball_b = embedding_ops.embedding_lookup(biases, all_ids, partition_strategy=partition_strategy)if all_b.dtype != inputs.dtype:all_b = math_ops.cast(all_b, inputs.dtype)# 抽离出正,负偏移btrue_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])# inputs: [batch_size, dim]# true_w: [batch_size * num_true, dim]# 计算wx+b,得到true_logits:[ batch_size, num_true]dim = array_ops.shape(true_w)[1:2]new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)# 做点乘,得到row_wise_dots: [batch_size, num_true, dim]row_wise_dots = math_ops.multiply(array_ops.expand_dims(inputs, 1),array_ops.reshape(true_w, new_true_w_shape))#reshapedots_as_matrix = array_ops.reshape(row_wise_dots,array_ops.concat([[-1], dim], 0))# 得到正label对应的logits值 [batch_size, num_true]true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])# +btrue_b = array_ops.reshape(true_b, [-1, num_true])true_logits += true_bsampled_logits += sampled_b
################## 此段代码去掉采样的label在正label里  ###########if remove_accidental_hits:acc_hits = candidate_sampling_ops.compute_accidental_hits(labels, sampled, num_true=num_true)acc_indices, acc_ids, acc_weights = acc_hits# This is how SparseToDense expects the indices.acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(acc_ids, dtypes.int32), [-1, 1])sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,"sparse_indices")# Create sampled_logits_shape = [batch_size, num_sampled]sampled_logits_shape = array_ops.concat([array_ops.shape(labels)[:1],array_ops.expand_dims(num_sampled, 0)], 0)if sampled_logits.dtype != acc_weights.dtype:acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)sampled_logits += gen_sparse_ops.sparse_to_dense(sparse_indices,sampled_logits_shape,acc_weights,default_value=0.0,validate_indices=False)
############# 此段代码表示是logits否减去-log(true_expected_count) #######if subtract_log_q:# Subtract log of Q(l), prior probability that l appears in sampled.true_logits -= math_ops.log(true_expected_count)sampled_logits -= math_ops.log(sampled_expected_count)#将正负logits concat到一起,得到out_logits: [batch_size, num_true+num_sampled]out_logits = array_ops.concat([true_logits, sampled_logits], 1)# 标签labels,正label为1/num_true,保证总和为1,负label标签为0# out_labels: [batch_size, num_ture+num_sampled]out_labels = array_ops.concat([array_ops.ones_like(true_logits) / num_true,array_ops.zeros_like(sampled_logits) ], 1)return out_logits, out_labels

sigmoid_cross_entropy_with_logits函数
这就不多说了,交叉熵loss,因为label有多个,是multi-label分类,所以用sigmoid,要注意的一点是,函数参数logits传入的值是原始的wx+b值,sigmoid计算在函数里面操作。

_sum_rows函数

def _sum_rows(x):#该函数的类似tf.reduce_sum(x,1)操作#官方给出用这样计算的理由是,计算梯度效率更高cols = array_ops.shape(x)[1]ones_shape = array_ops.stack([cols, 1])ones = array_ops.ones(ones_shape, x.dtype)# x:[batch_size, num_true+num_sampled]# ones: [num_true+num_sampled, 1]#x和ones两个矩阵相乘,得到[batch_size,1],再reshape [batch_size]return array_ops.reshape(math_ops.matmul(x, ones), [-1])

2.3 缺点

从tensorflow源代码知道,要求每个输入的batch的正label个数必须一致,个数为num_true,所以正常训练模型的时候,必须每一个batch的样本正label一样,但是在实际应用中,特别是multi-label分类,每个样本的正label个数很多是不一致的,在multi task任务中,更不能保证一个batch在多个任务的label标签上都是一致的。

3 正label个数不一致解决方案

针对上述缺陷,尝试如下方案 ,已试验可行。

3.1 增加一个pad标签作为负label

核心思想生成样本的时候,将每个样本的label统一长度为num_true,不足的,用索引为0 (代表pad) 的标签填充,在计算loss的时候,让pad类别对应为负label
修改代码主要如下:

#修改前函数
def _compute_sampled_logits(...):...out_logits = array_ops.concat([true_logits, sampled_logits], 1)# 对应的源代码生成label过程out_labels = array_ops.concat([array_ops.ones_like(true_logits) / num_true,array_ops.zeros_like(sampled_logits)], 1)return out_logits, out_labels#修改后函数
def _compute_sampled_logits(...):...out_logits = array_ops.concat([true_logits, sampled_logits], 1)# 生成mask矩阵,其中真实的正label元素为1, 填充pad label为0mask = tf.cast(tf.not_equal(labels, 0), tf.float32)# 将pad的label都为负label 0true_y = array_ops.ones_like(true_logits) * mask # 然后用div_flag控制是否需要对每个样本的label除以每个样本的个数# 这里动态的计算每个样本的真实label数量,因为每个样本pad的个数不一致if div_flag:dynamic_num_true = tf.reduce_sum(tf.sign(labels), reduction_indices=1)dynamic_num_true = tf.cast(tf.expand_dims(dynamic_num_true, -1), tf.float32)true_y = tf.div(true_y, dynamic_num_true)# 将正label和负label组合,得到out_labels返回out_labels = array_ops.concat([true_y,array_ops.zeros_like(sampled_logits)], 1)return out_logits, out_labels

4 参考

Noise-contrastive estimation: A new estimation principle for
unnormalized statistical models

噪声对比估计NCE (Noise-contrastive estimation)采样方法,提高训练速度,解决源码中正label个数必须相等问题相关推荐

  1. NCE(Noise Contrastive Estimation) 与negative sampling

    NCE Noise Contrastive Estimation与negative sampling负例采样 背景 NCE(Noise Contrastive Estimation) Negative ...

  2. Noise Contrastive Estimation 前世今生——从 NCE 到 InfoNCE

    文章首发于:https://zhuanlan.zhihu.com/p/334772391 0 前言 作为刚入门自监督学习的小白,在阅读其中 Contrastive Based 方法的论文时,经常会看到 ...

  3. “噪声对比估计”杂谈:曲径通幽之妙

    作者丨苏剑林 单位丨广州火焰信息科技有限公司 研究方向丨NLP,神经网络 个人主页丨kexue.fm 说到噪声对比估计,或者"负采样",大家可能立马就想到了 Word2Vec.事实 ...

  4. Noise Contrastive Estimation (NCE) 、负采样(NEG)和InfoNCE

    总结一下自己的理解: NCE需要根据频次分布等进行采样,NEG不考虑,InfoNCE是一种通过无监督任务来学习(编码)高维数据的特征表示(representation),而通常采取的无监督策略就是根据 ...

  5. 2020-4-12 深度学习笔记18 - 直面配分函数 5 ( 去噪得分匹配,噪声对比估计NCE--绕开配分函数,估计配分函数)

    第十八章 直面配分函数 Confronting the Partition Function 中文 英文 2020-4-8 深度学习笔记18 - 直面配分函数 1 ( 配分函数概念,对数似然梯度) 2 ...

  6. Noise Contrastive Estimation

    熵 统计机器学习中经常遇到熵的概念,在介绍NCE和InfoNCE之前,对熵以及相关的概念做简单的梳理.信息量用于度量不确定性的大小,熵可以看作信息量的期望,香农信息熵的定义:对于随机遍历 X X X, ...

  7. [转] Noise Contrastive Estimation 噪声对比估计 资料

    有个视频讲的不错,mark一下 https://vimeo.com/306156327 转载于:https://www.cnblogs.com/Arborday/p/10903065.html

  8. 简单对比四台电脑对相同模型的训练速度

    型号为 i5-7200 CPU @ 2.50GHz 的 CPU 进行训练,训练了 25 次,速度大约为 791ms/step,每 epoch 平均 710s,训练缓慢. 第二次使用型号为 i7-77- ...

  9. 噪音对比估计(NCE)

    噪音对比估计(NCE, Noise Contrastive Estimation)是一种新的统计模型估计方法,由Gutmann和Hyv¨arinen提出来,能够用来解决神经网络的复杂计算问题,因此在图 ...

最新文章

  1. Input type=“file“上传文件change事件只触发一次解决方案
  2. 2016.7.14最新cocoapods最新安装教程
  3. ux设计中的各种地图_如何在UX设计中使用颜色
  4. C++服务器设计(七):聊天系统服务端实现
  5. 计算机网络原理关于实验中几个指令使用的复习——网络层
  6. 6、Django模板语法
  7. android webview网页显示不完整,【报Bug】webview页面内容显示不全
  8. linux x64下安装oracle 11g
  9. Attention is all you need注意力机制代码解析
  10. iphone小圆点在哪儿设置_iPhone终于自带长截屏了?苹果手机这些截图方式,你用过几种?...
  11. Linux fstab文件详解
  12. Android设置拍照或者上传本地图片
  13. 给学计算机男生起外号,如何给男生起外号
  14. matlab锯齿交换,MATLAB折线消除锯齿平滑
  15. 笔记本无线和有线的MAC地址修改
  16. 虚拟机将ip地址修改成静态的
  17. 【Pytorch】复现FCN for Left Ventricle(LV) segmentation记录
  18. 我看国内地理信息产业
  19. 【pytorch】Rosenbrock 函数的 梯度下降法 和 牛顿法 求解
  20. 东南大学计算机考研经验

热门文章

  1. Ultimate Retouch 3.7.59汉化版|影楼终极人像精修磨皮扩展支持2019
  2. 开关电源设计中电感的选择
  3. iOS 歌词解析(lrc, 非谓词, 仿QQ音乐, 仿卡拉ok模式)
  4. 安卓解决小米,魅族状态栏全白的问题
  5. GIS与地质灾害评价——坡度分析
  6. 新书推荐 | Flutter技术入门与实战(第2版)
  7. 354. 俄罗斯套娃信封问题(良心注释)
  8. 搜索引擎优化提示.对关键词的选择应该学会对它的取舍
  9. axios get请求下载后端文件流xlsx文件
  10. ExpressGridPack 21,PivotGrid 控件