Tf中的NCE-loss实现学习【转载】
转自:http://www.jianshu.com/p/fab82fa53e16
1.tf中的nce_loss的API
defnce_loss(weights, biases, inputs, labels, num_sampled, num_classes,num_true=1,sampled_values=None,remove_accidental_hits=False,partition_strategy="mod",name="nce_loss")
假设nce_loss之前的输入数据是K维的,一共有N个类,那么
- weight.shape = (N, K)
- bias.shape = (N)
- inputs.shape = (batch_size, K)
- labels.shape = (batch_size, num_true)
- num_true : 实际的正样本个数
- num_sampled: 采样出多少个负样本
- num_classes = N
- sampled_values: 采样出的负样本,如果是None,就会用不同的sampler去采样。待会儿说sampler是什么。
- remove_accidental_hits: 如果采样时不小心采样到的负样本刚好是正样本,要不要干掉
- partition_strategy:对weights进行embedding_lookup时并行查表时的策略。TF的embeding_lookup是在CPU里实现的,这里需要考虑多线程查表时的锁的问题。
nce_loss的实现逻辑如下:
- _compute_sampled_logits: 通过这个函数计算出正样本和采样出的负样本对应的output和label
- sigmoid_cross_entropy_with_logits: 通过 sigmoid cross entropy来计算output和label的loss,从而进行反向传播。这个函数把最后的问题转化为了num_sampled+num_real个两类分类问题,然后每个分类问题用了交叉熵的损伤函数,也就是logistic regression常用的损失函数。TF里还提供了一个softmax_cross_entropy_with_logits的函数,和这个有所区别。
2.tf中word2vec实现
loss =tf.reduce_mean(tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels,num_sampled, vocabulary_size))
它这里并没有传sampled_values,那么它的负样本是怎么得到的呢?继续看nce_loss的实现,可以看到里面处理sampled_values=None的代码如下:
if sampled_values isNone:sampled_values=candidate_sampling_ops.log_uniform_candidate_sampler(true_classes=labels,num_true=num_true,num_sampled=num_sampled,unique=True,range_max=num_classes)
所以,默认情况下,他会用log_uniform_candidate_sampler去采样。那么log_uniform_candidate_sampler是怎么采样的呢?他的实现在这里:
- 他会在[0, range_max)中采样出一个整数k
- P(k) = (log(k + 2) - log(k + 1)) / log(range_max + 1)
可以看到,k越大,被采样到的概率越小。那么在TF的word2vec里,类别的编号有什么含义吗?看下面的代码:
defbuild_dataset(words):count= [['UNK', -1]]count.extend(collections.Counter(words).most_common(vocabulary_size- 1))dictionary=dict()for word, _ incount:dictionary[word]=len(dictionary)data=list()unk_count=0for word inwords:if word indictionary:index=dictionary[word]else:index= 0 #dictionary['UNK']unk_count += 1data.append(index)count[0][1] =unk_countreverse_dictionary=dict(zip(dictionary.values(), dictionary.keys()))return data, count, dictionary, reverse_dictionary
可以看到,TF的word2vec实现里,词频越大,词的类别编号也就越小。因此,在TF的word2vec里,负采样的过程其实就是优先采词频高的词作为负样本。
在提出负采样的原始论文中, 包括word2vec的原始C++实现中。是按照热门度的0.75次方采样的,这个和TF的实现有所区别。但大概的意思差不多,就是越热门,越有可能成为负样本。
转载于:https://www.cnblogs.com/BlueBlueSea/p/10615766.html
Tf中的NCE-loss实现学习【转载】相关推荐
- 一文讲懂召回中的 NCE NEG sampled softmax loss
深度学习中与分类相关的问题都会涉及到softmax的计算.当目标类别较少时,直接用标准的softmax公式进行计算没问题,当目标类别特别多时,则需采用估算近似的方法简化softmax中归一化的计算. ...
- Debug深度学习中的NAN Loss
深度学习中遇到NAN loss 什么都不改,重新训练一下,有时也能解决问题 学习率减小 检查输入数据(x和y),如果是正常突然变为NAN,有可能是学习率策略导致,也可能是脏数据导致 If using ...
- 【对比学习】CUT模型论文解读与NCE loss代码解析
标题:Contrastive Learning for Unpaired Image-to-Image Translation(基于对比学习的非配对图像转换) 作者:Taesung Park, Ale ...
- 工作中如何做好技术积累『转载-保持学习的空杯心态』
引言 古人云:"活到老,学到老."互联网算是最辛苦的行业之一,"加班"对工程师来说已是"家常便饭",同时互联网技术又日新月异,很多工程师都疲 ...
- Candidate sampling:NCE loss和negative sample
在工作中用到了类似于negative sample的方法,才发现我其实并不了解candidate sampling.于是看了一些相关资料,在此简单总结一些相关内容. 主要内容来自tensorflow的 ...
- Tensorflow中的多层感知器学习
Tensorflow中的多层感知器学习 在这篇文章中,我们将了解多层感知器的概念和它在Python中使用TensorFlow库的实现. 多层感知 多层感知也被称为MLP.它是完全连接的密集层,可以将任 ...
- negative sampling负采样和nce loss
negative sampling负采样和nce loss 一.Noise contrastive estimation(NCE) 语言模型中,在最后一层往往需要:根据上下文c,在整个语料库V中预测某 ...
- 从NCE loss到InfoNCE loss
关于NCE loss:知乎上的一些介绍的文字 Noise Contrastive Estimation 学习 - 知乎 github上的介绍文字:Lei Mao's Log Book – Noise ...
- 神经网络中,设计loss function有哪些技巧?
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:视学算法 神经网络中,设计loss function有哪 ...
最新文章
- 拒绝躺平,Redis选择实现了自己的VM
- HDU7059-Counting Stars 线段树 (区间加最低位置,区间减最高位)
- 谷歌自动驾驶之父疯狂打Call, 无人车连续5小时不接管,又快又稳
- 查找表包含的页和页所在的表
- Centos设置静态IP及修改Centos配置文件
- C++基本序列式容器 vector (一)
- 算法竞赛入门经典(第二版) | 程序3-10 生成元 (UVa1584,Circular Sequence)
- rxjs pipe和filter组合的一个实际例子的单步调试
- Linux 命令 ——less命令
- 干货 · UI设计|APP引导页面可临摹素材
- 投简历没回音?你没写到点子上,HR当然不看
- 对check list理解
- (1)Zabbix2.4.5 安装配置
- 爬取世界各国历年的GDP数据
- 共轭梯度法及其matlab程序
- 宝峰c1对讲机写频软件_宝峰888s写频软件
- 废旧 Android 手机如何改造成 Linux 服务器
- dlp技术(dlp技术和单片lcd的区别)
- 增加客流量的方法_如何增加博客流量-简单的方法(27条可靠的技巧)
- android控件属性大全
热门文章
- java 实现打印条形码_激光打印机与条码打印机打印不干胶标签哪个好?
- php 管道,PHP 进程间通信---管道篇
- 免安装mysql配置图解_mysql免安装版配置步骤详解分享
- python 命令行运行 多进程_Python初学——多进程Multiprocessing
- 2013-2017蓝桥杯省赛C++A组真题总结(题型及解法)
- Unity3D基础33:物理射线
- 树链剖分(bzoj 1036: [ZJOI2008]树的统计Count)
- 吴恩达神经网络和深度学习-学习笔记-43-Bounding box 预测 + YOLO算法
- k8s优先级priority的使用
- js中DOM, DOCUMENT, BOM, WINDOW 区别