谷歌16年出的论文《Deep Neural Networks for Youtube Recommendation》中提到文章采用了负采样的思想来进行extreme multiclass分类任务
我的tensorflow实现已上传CSDN资源https://download.csdn.net/download/weixin_41864878/11107472
Tensorflow提供了两种负采样,分别是NCE loss 和Sampled softmax loss,两者最大的区别就是针对的任务不同,代码实现上两者也只有最后的loss函数不同,两者用的采样函数及算logits方法都相同
NCE loss

  sampled_losses = sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, name="sampled_losses")

Sampled softmax loss

  labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)

显然,NCE能针对多标签,在NLP任务上应用比较多, 而后者针对单个标签分类任务
实验中我采用了后者
sampled softmax 原文:On Using Very Large Target Vocabulary for Neural Machine Translation
在这里强调TF中的采样方式,有4种:

function 采样方式
1 log_uniform_candidate_sampler 只能用于标签顺序和出现频率成反比的情况,因此需要清洗数据,重新进行标签映射
2 learned_unigram_candidate_sampler 适用于不知道标签分布的任何情况
3 uniform_candidate_sampler 均匀采样
4 fixed_unigram_candidate_sampler 允许用户指定概率

代码中默认的是第一种,如果要更改需要自己重写
根据官方的指导代码,我构造了如下代码:
(具体参数请移步官方代码~)

def net_factory():input = inputnet = net() #网络结构,注意这里是softmax输出前的embeddingoutput = tf.nn.softmax()##weights = tf.Variable(tf.truncated_normal([num_class, embedding_size],stddev=1.0 / math.sqrt(embedding_size)))biases = tf.Variable(tf.zeros([num_class]))loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(weights, biases, train_labels, net,num_sampled, num_class))

首先出现的第一个错误是device的错误,原因是这个计算不能在GPU上算,必须在CPU中计算(反正在GPU跑不起来,如果理解有误欢迎指正)
因此在session的config中写入:

config = tf.ConfigProto(allow_soft_placement = True)

或者网络定义之中,在loss计算之前加上with tf.device(’/cpu:0’)
但是出现了不能在tf.train.GradientDescentOptimizer(learning_rate).compute_gradients(loss)中进行计算,报错是梯度名字为None,但是能用tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)计算,但是由于我还有正则损失需要计算,必须解决这个问题,于是在查阅资料之后,修改代码如下:
(里面self我复制过来懒得修改了。。)

    self.weights = tf.get_variable('soft_weight',[self.item_classes, self.embedding_size], initializer=tf.variance_scaling_initializer())biases = tf.get_variable('soft_biases', initializer=tf.zeros([self.item_classes]), trainable=False)loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(weights, biases, train_labels, net,num_sampled, num_class))

于是就可以愉快的和别的loss相加计算梯度了,我看到github上开源项目基本上都是第一种方式写的,如果程序中是单一loss基本就用通用形式就可以了

负样本个数的影响:目前来看对loss计算来说基本没有影响,设置为100和10的时候loss值有显著变化,个数为10时loss下降较快

参考:
github代码段:https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/nn_impl.py
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/candidate_sampling_ops.py
官方指导手册demo:https://docs.pythontab.com/tensorflow/tutorials/word2vec/#_3
http://www.algorithmdog.com/tf-candidate-sampling
https://blog.csdn.net/u010223750/article/details/69948463
https://blog.csdn.net/wuzqChom/article/details/77073246

Tensorflow的负采样函数Sampled softmax loss踩坑之旅相关推荐

  1. Tensorflow之负采样函数Sampled softmax loss

    Tensorflow之负采样函数Sampled softmax loss 谷歌16年出的论文<Deep Neural Networks for Youtube Recommendation> ...

  2. Tensorflow的负采样函数Sampled softmax loss学习笔记

    最近阅读了YouTube的推荐系统论文,在代码实现中用到的负采样方法我比较疑惑,于是查了大量资料,总算能够读懂关于负采样的一些皮毛. 本文主要针对tf.nn.sampled_softmax_loss这 ...

  3. 【机器学习】sampled softmax loss

    目录 1.前置知识softmax loss 2.sampled softmax 1.1.问题引入 1.2.如何通俗理解sampled softmax机制? 3.sampled softmax loss ...

  4. 一文讲懂召回中的 NCE NEG sampled softmax loss

    深度学习中与分类相关的问题都会涉及到softmax的计算.当目标类别较少时,直接用标准的softmax公式进行计算没问题,当目标类别特别多时,则需采用估算近似的方法简化softmax中归一化的计算. ...

  5. [Tensorflow] Ubuntu下NVIDIA Driver+CUDA+cuDNN 安装踩坑总结

    最近安装了3台workstation, 显卡分布是Quadro P2000, Quadro K220和Quadro 2000.其中第一台工作站是去年新入的,另外两台都是3-5年历史的旧机器了. 第一台 ...

  6. TensorFlow Object Detection API 超详细教程和踩坑过程

    安装Anacond:https://blog.csdn.net/ITLearnHall/article/details/81708148 安装Pycharm:https://blog.csdn.net ...

  7. Windows10下Tensorflow启用GPU加速,显卡GTX1060,踩坑记录

    因为需要用到tensorflow学习深度学习,所以有N卡就想开启GPU加速,结果各种坑 1.安装VS和Python环境  (不用VS的可以不安装,使用其他工具也是一样的) 这里使用VS2019作为开发 ...

  8. negative sampling负采样和nce loss

    negative sampling负采样和nce loss 一.Noise contrastive estimation(NCE) 语言模型中,在最后一层往往需要:根据上下文c,在整个语料库V中预测某 ...

  9. NLP-词向量(Word Embedding)-2013:Word2vec模型(CBOW、Skip-Gram)【对NNLM的简化】【层次Softmax、负采样、重采样】【静态表示;无法解决一词多义】

    一.文本的表示方法 (Representation) 文本是一种非结构化的数据信息,是不可以直接被计算的.因为文本不能够直接被模型计算,所以需要将其转化为向量. 文本表示的作用就是将这些非结构化的信息 ...

最新文章

  1. 网易云音乐:基于分布式图学习的推荐系统优化之路
  2. how to setup a Kubernetes cluster on GCP
  3. OpenCL的安装与配置
  4. SQL Server中的TempDB管理——TempDB基本知识(为什么需要版本存储区)
  5. 还不会用typedef?C语言typedef的详细用法总结,一篇解决你的困惑。(学习笔记2--typedef设置别名)
  6. c web mysql数据库_C语言操作MySQL数据库
  7. 大数据与BI的区别在于哪里
  8. 腾跃英语计算机学院微信公众号,英语四级报名_微信还能这么玩:Geek大学生搭建英语课堂互动系统_沪江英语...
  9. kafka和flink的动态扩容
  10. web前端需要学MySQL吗_HTML是web前端工程师必须要学的
  11. 【欧拉计划第 1 题】3 或 5 的倍数 Multiples of 3 or 5
  12. unity地图路径编辑器
  13. 2016书单总结--看透SpringMvc源代码分析与实践-概述
  14. 15、Gantt 修改样式部分
  15. xmanager无法连接Linux服务器,解决xmanager连接linux出错问题
  16. Flutter系列之在 macOS 上安装和配置 Flutter 开发环境
  17. 32位64位Office 2010 beta 简体中文版下载
  18. UIWebView使用app内自定义字体
  19. 朱嘉明:产业周期、科技周期与金融周期的失衡
  20. CSS快速学习(2021.2.7-15)

热门文章

  1. 护网行动(防守方)linux服务器通用安全加固指南(1)
  2. python 爬取boss直聘招聘信息实现
  3. 阿里钉钉亮相重庆智博会,七大资本逾10亿资金赋能钉钉生态
  4. Python多线程入门指南
  5. mysql truncate delete 释放磁盘空间
  6. FPGA:计算滑动求和----信号检测计算信号功率
  7. android官方的wifi direct demo.....,Android WIFI Direct开发实例演示
  8. 超火的人生重开模拟器小程序源码
  9. 揭秘|智慧树千万师生“停课不停学”背后的URTC技术实践之路
  10. html+css美图卡片小练习(未用浮动)