use_one_hot_embeddings解析

为什么使用onehot或者tf.gather方法

BERT源码中构造模型时关于embedding的参数有一个为use_one_hot_embeddings
这个参数值为Boolean类型的值,且默认为False
具体的深层代码如下:

def embedding_lookup(input_ids,vocab_size,embedding_size=128,initializer_range=0.02,word_embedding_name="word_embeddings",use_one_hot_embeddings=False):"""Looks up words embeddings for id tensor.Args:input_ids: int32 Tensor of shape [batch_size, seq_length] containing word ids.vocab_size: int. Size of the embedding vocabulary.embedding_size: int. Width of the word embeddings.initializer_range: float. Embedding initialization range.word_embedding_name: string. Name of the embedding table.use_one_hot_embeddings: bool. If True, use one-hot method for word embeddings. If False, use `tf.gather()`.Returns:float Tensor of shape [batch_size, seq_length, embedding_size]."""# This function assumes that the input is of shape [batch_size, seq_length, num_inputs].## If the input is a 2D tensor of shape [batch_size, seq_length], we reshape to [batch_size, seq_length, 1].if input_ids.shape.ndims == 2:input_ids = tf.expand_dims(input_ids, axis=[-1])embedding_table = tf.get_variable(name=word_embedding_name,shape=[vocab_size, embedding_size],initializer=create_initializer(initializer_range))flat_input_ids = tf.reshape(input_ids, [-1])if use_one_hot_embeddings:one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)output = tf.matmul(one_hot_input_ids, embedding_table)else:output = tf.gather(embedding_table, flat_input_ids)input_shape = get_shape_list(input_ids)output = tf.reshape(output,input_shape[0:-1] + [input_shape[-1] * embedding_size])return (output, embedding_table)

use_one_hot_embeddings参数设置为True时,则会在生成embedding时先生成一个对应的onehot张量,并用onehot 张量与embedding table相乘最终获得对应的embedding张量。

而当use_one_hot_embeddings参数设置为False时,则会直接利用tf.gather()方法在embedding table中将对应的embedding提取出来,组合成为对应的embedding张量。

其实两种方法殊途同归,为何要采用onehot的形式呢?

猜想的原因就在于使用TPU加速时,矩阵运算相较于切片运算更加快速,所以采用了onehot的方法。

BERT源码embedding_lookup解析相关推荐

  1. iOS开发之Masonry框架源码深度解析

    Masonry是iOS在控件布局中经常使用的一个轻量级框架,Masonry让NSLayoutConstraint使用起来更为简洁.Masonry简化了NSLayoutConstraint的使用方式,让 ...

  2. 【TarsosDSP】TarsosDSP 简介 ( TarsosDSP 功能 | 相关链接 | 源码和相关资源收集 | TarsosDSP 示例应用 | TarsosDSP 源码路径解析 )

    文章目录 I . TarsosDSP 函数库简介 II . TarsosDSP 功能 III . TarsosDSP 相关资源链接 ( 官方资料 ) IV . TarsosDSP 源码和相关资源收集 ...

  3. 机器学习算法源码全解析(三)-范数规则化之核范数与规则项参数选择

    前言 参见上一篇博文,我们聊到了L0,L1和L2范数,这篇我们絮叨絮叨下核范数和规则项参数选择.知识有限,以下都是我一些浅显的看法,如果理解存在错误,希望大家不吝指正.谢谢. 机器学习算法源码全解析( ...

  4. 【NLP】NLP实战篇之bert源码阅读(run_classifier)

    本文主要会阅读bert源码 (https://github.com/google-research/bert )中run_classifier.py文件,已完成modeling.py.optimiza ...

  5. BERT源码分析(PART III)

    写在前面 继续之前没有介绍完的 Pre-training 部分,在上一篇中(BERT源码分析(PART II))我们已经完成了对输入数据的处理,接下来看看 BERT 是怎么完成「Masked LM」和 ...

  6. 从源码角度解析Android中APK安装过程

    从源码角度解析Android中APK的安装过程 1. Android中APK简介 Android应用Apk的安装有如下四种方式: 1.1 系统应用安装 没有安装界面,在开机时自动完成 1.2 网络下载 ...

  7. Huggingface BERT源码详解:应用模型与训练优化

    ©PaperWeekly 原创 · 作者|李泺秋 学校|浙江大学硕士生 研究方向|自然语言处理.知识图谱 接上篇,记录一下对 HuggingFace 开源的 Transformers 项目代码的理解. ...

  8. dubbo源码深度解析_Spring源码深度解析:手把手教你搭建Spring开发环境

    Spring环境搭建流程,如果是第一次接触spring源码的环境搭建,确实还是比较麻烦的. 作者使用的编译器为目前流行的lntelliJ IDEA,版本为2018旗舰版.Eclipse用户还需要自己揣 ...

  9. Jdk1.8 JUC源码增量解析(1)-atomic-Striped64

    转载自  Jdk1.8 JUC源码增量解析(1)-atomic-Striped64 功能简介: Striped64是jdk1.8提供的用于支持如Long累加器,Double累加器这样机制的基础类. S ...

  10. Jdk1.8 JUC源码增量解析(2)-atomic-LongAdder和LongAccumulator

    转载自 Jdk1.8 JUC源码增量解析(2)-atomic-LongAdder和LongAccumulator 功能简介: LongAdder是jdk1.8提供的累加器,基于Striped64实现. ...

最新文章

  1. 算法面试的理想与现实
  2. Android之用adb命令快速获取手机IP方法总结
  3. 浅谈多线程——NSThread
  4. 并发编程(十六)——java7 深入并发包 ConcurrentHashMap 源码解析
  5. “反应快”的程序猿更优秀吗?
  6. 2017.4.20 比例简化 思考记录
  7. c语言1076素数,大学C语言考试题库(答案)-20210412093908.docx-原创力文档
  8. dts数据库迁移工具_5分钟学会如何玩转云数据库组件(迁移,审计,订阅)
  9. Bootstrap validation
  10. html 图片滑动验证码,html+jQuery实现拖动滑块图片拼图验证码插件【移动端适用】...
  11. 信息系统项目管理师自学笔记(二十二)—— 网络应用与管理
  12. 椰子树和平等 文:王小波
  13. Virtual Box 网络静态IP配置
  14. c++输入10个数/输入n个数,求其平均值
  15. c语言 do while 素数,c语言题目:用while语句求2000以内所有质数(素数)
  16. ST_Intersects
  17. 什么是CISP-PTE证书?考什么?
  18. 使用Arthas排查问题
  19. Docker容器已正式支持苹果M1Mac电脑
  20. 卫星伪距定位matlab,GPS卫星运动及定位matlab仿真.doc

热门文章

  1. Shell /dev/null 文件的含义
  2. CSDN日报20170602 ——《程序员、技术主管和架构师》
  3. 关于计算机的好处的英语作文,关于电脑好处的英语作文
  4. 我的世界服务器显示伤害指令,我的世界指令代码大全
  5. Ubuntu20.04 虚拟机 联网
  6. 吴恩达机器学习系列内容汇总
  7. iOS 系统分享功能
  8. axure能做剪切蒙版吗_二手车销售好做吗?没经验能做二手车销售吗?
  9. 量化金融笔记1-股票量化基础
  10. js 实现当有省略号时,显示title,无省略号不显示title