转自:http://www.jianshu.com/p/fab82fa53e16

1.tf中的nce_loss的API

defnce_loss(weights, biases, inputs, labels, num_sampled, num_classes,num_true=1,sampled_values=None,remove_accidental_hits=False,partition_strategy="mod",name="nce_loss")

假设nce_loss之前的输入数据是K维的,一共有N个类,那么

  • weight.shape = (N, K)
  • bias.shape = (N)
  • inputs.shape = (batch_size, K)
  • labels.shape = (batch_size, num_true)
  • num_true : 实际的正样本个数
  • num_sampled: 采样出多少个负样本
  • num_classes = N
  • sampled_values: 采样出的负样本,如果是None,就会用不同的sampler去采样。待会儿说sampler是什么。
  • remove_accidental_hits: 如果采样时不小心采样到的负样本刚好是正样本,要不要干掉
  • partition_strategy:对weights进行embedding_lookup时并行查表时的策略。TF的embeding_lookup是在CPU里实现的,这里需要考虑多线程查表时的锁的问题。

nce_loss的实现逻辑如下:

  • _compute_sampled_logits: 通过这个函数计算出正样本和采样出的负样本对应的output和label
  • sigmoid_cross_entropy_with_logits: 通过 sigmoid cross entropy来计算output和label的loss,从而进行反向传播。这个函数把最后的问题转化为了num_sampled+num_real个两类分类问题,然后每个分类问题用了交叉熵的损伤函数,也就是logistic regression常用的损失函数。TF里还提供了一个softmax_cross_entropy_with_logits的函数,和这个有所区别。

2.tf中word2vec实现

loss =tf.reduce_mean(tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels,num_sampled, vocabulary_size))

它这里并没有传sampled_values,那么它的负样本是怎么得到的呢?继续看nce_loss的实现,可以看到里面处理sampled_values=None的代码如下:

if sampled_values isNone:sampled_values=candidate_sampling_ops.log_uniform_candidate_sampler(true_classes=labels,num_true=num_true,num_sampled=num_sampled,unique=True,range_max=num_classes)

所以,默认情况下,他会用log_uniform_candidate_sampler去采样。那么log_uniform_candidate_sampler是怎么采样的呢?他的实现在这里:

  • 他会在[0, range_max)中采样出一个整数k
  • P(k) = (log(k + 2) - log(k + 1)) / log(range_max + 1)

可以看到,k越大,被采样到的概率越小。那么在TF的word2vec里,类别的编号有什么含义吗?看下面的代码:

defbuild_dataset(words):count= [['UNK', -1]]count.extend(collections.Counter(words).most_common(vocabulary_size- 1))dictionary=dict()for word, _ incount:dictionary[word]=len(dictionary)data=list()unk_count=0for word inwords:if word indictionary:index=dictionary[word]else:index= 0  #dictionary['UNK']unk_count += 1data.append(index)count[0][1] =unk_countreverse_dictionary=dict(zip(dictionary.values(), dictionary.keys()))return data, count, dictionary, reverse_dictionary

可以看到,TF的word2vec实现里,词频越大,词的类别编号也就越小。因此,在TF的word2vec里,负采样的过程其实就是优先采词频高的词作为负样本。

在提出负采样的原始论文中, 包括word2vec的原始C++实现中。是按照热门度的0.75次方采样的,这个和TF的实现有所区别。但大概的意思差不多,就是越热门,越有可能成为负样本。

转载于:https://www.cnblogs.com/BlueBlueSea/p/10615766.html

Tf中的NCE-loss实现学习【转载】相关推荐

  1. 一文讲懂召回中的 NCE NEG sampled softmax loss

    深度学习中与分类相关的问题都会涉及到softmax的计算.当目标类别较少时,直接用标准的softmax公式进行计算没问题,当目标类别特别多时,则需采用估算近似的方法简化softmax中归一化的计算. ...

  2. Debug深度学习中的NAN Loss

    深度学习中遇到NAN loss 什么都不改,重新训练一下,有时也能解决问题 学习率减小 检查输入数据(x和y),如果是正常突然变为NAN,有可能是学习率策略导致,也可能是脏数据导致 If using ...

  3. 【对比学习】CUT模型论文解读与NCE loss代码解析

    标题:Contrastive Learning for Unpaired Image-to-Image Translation(基于对比学习的非配对图像转换) 作者:Taesung Park, Ale ...

  4. 工作中如何做好技术积累『转载-保持学习的空杯心态』

    引言 古人云:"活到老,学到老."互联网算是最辛苦的行业之一,"加班"对工程师来说已是"家常便饭",同时互联网技术又日新月异,很多工程师都疲 ...

  5. Candidate sampling:NCE loss和negative sample

    在工作中用到了类似于negative sample的方法,才发现我其实并不了解candidate sampling.于是看了一些相关资料,在此简单总结一些相关内容. 主要内容来自tensorflow的 ...

  6. Tensorflow中的多层感知器学习

    Tensorflow中的多层感知器学习 在这篇文章中,我们将了解多层感知器的概念和它在Python中使用TensorFlow库的实现. 多层感知 多层感知也被称为MLP.它是完全连接的密集层,可以将任 ...

  7. negative sampling负采样和nce loss

    negative sampling负采样和nce loss 一.Noise contrastive estimation(NCE) 语言模型中,在最后一层往往需要:根据上下文c,在整个语料库V中预测某 ...

  8. 从NCE loss到InfoNCE loss

    关于NCE loss:知乎上的一些介绍的文字 Noise Contrastive Estimation 学习 - 知乎 github上的介绍文字:Lei Mao's Log Book – Noise ...

  9. 神经网络中,设计loss function有哪些技巧?

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:视学算法 神经网络中,设计loss function有哪 ...

最新文章

  1. 拒绝躺平,Redis选择实现了自己的VM
  2. HDU7059-Counting Stars 线段树 (区间加最低位置,区间减最高位)
  3. 谷歌自动驾驶之父疯狂打Call, 无人车连续5小时不接管,又快又稳
  4. 查找表包含的页和页所在的表
  5. Centos设置静态IP及修改Centos配置文件
  6. C++基本序列式容器 vector (一)
  7. 算法竞赛入门经典(第二版) | 程序3-10 生成元 (UVa1584,Circular Sequence)
  8. rxjs pipe和filter组合的一个实际例子的单步调试
  9. Linux 命令 ——less命令
  10. 干货 · UI设计|APP引导页面可临摹素材
  11. 投简历没回音?你没写到点子上,HR当然不看
  12. 对check list理解
  13. (1)Zabbix2.4.5 安装配置
  14. 爬取世界各国历年的GDP数据
  15. 共轭梯度法及其matlab程序
  16. 宝峰c1对讲机写频软件_宝峰888s写频软件
  17. 废旧 Android 手机如何改造成 Linux 服务器
  18. dlp技术(dlp技术和单片lcd的区别)
  19. 增加客流量的方法_如何增加博客流量-简单的方法(27条可靠的技巧)
  20. android控件属性大全

热门文章

  1. java 实现打印条形码_激光打印机与条码打印机打印不干胶标签哪个好?
  2. php 管道,PHP 进程间通信---管道篇
  3. 免安装mysql配置图解_mysql免安装版配置步骤详解分享
  4. python 命令行运行 多进程_Python初学——多进程Multiprocessing
  5. 2013-2017蓝桥杯省赛C++A组真题总结(题型及解法)
  6. Unity3D基础33:物理射线
  7. 树链剖分(bzoj 1036: [ZJOI2008]树的统计Count)
  8. 吴恩达神经网络和深度学习-学习笔记-43-Bounding box 预测 + YOLO算法
  9. k8s优先级priority的使用
  10. js中DOM, DOCUMENT, BOM, WINDOW 区别