torch.nn.Embedding

在使用pytorch进行词嵌入使用torch.nn.Embedding()就可以做到

nn.Embedding在pytoch中的解释

class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2, scale_grad_by_freq=False, sparse=False)

模块的输入是一个下标的列表,输出是对应的词嵌入。

参数:

  • num_embeddings (int) - 嵌入字典的大小
  • embedding_dim (int) - 每个嵌入向量的大小
  • padding_idx (int, optional) - 如果提供的话,输出遇到此下标时用零填充
  • max_norm (float, optional) - 如果提供的话,会重新归一化词嵌入,使它们的范数小于提供的值
  • norm_type (float, optional) - 对于max_norm选项计算p范数时的p
  • scale_grad_by_freq (boolean, optional) - 如果提供的话,会根据字典中单词频率缩放梯度

变量:

  • weight (Tensor) -形状为(num_embeddings, embedding_dim)的模块中可学习的权值

形状:

  • 输入: LongTensor (N, W), N = mini-batch, W = 每个mini-batch中提取的下标数
  • 输出: (N, W, embedding_dim)

Embedding的使用

参数的初始化

import torch
import torch.nn as nn
vocab={'a':0,'b':1,'c':2}
#nn.Embedding(1000,50)表示有1000个词,嵌入到50维中
#对比若是one-hot,1000个词将会是1000维
embedd=nn.Embedding(3,5)

Embedd

以简单一维的向量为例

#输入:LongTensor (N, W), N = mini-batch, W = 每个mini-batch中提取的下标数
#输出: (N, W, embedding_dim)
a_idx=torch.LongTensor(vocab['a'])#输入类型为LongTensor  a_idx=0
#a_idx在embedding处理后
a=embedd(a_idx)
print(a)
#输出:
#tensor([[ 0.9256,  1.1268, -0.5181, -0.2635,  1.9282]],grad_fn=<EmbeddingBackward>)#

我们输入二维的 i n p u t input input
输入: L o n g T e n s o r ( n , w ) LongTensor (n, w) LongTensor(n,w), n n n为样本的数量, w w w为每个样本中词数

#两组,每组有三个词0 , 1 ,2
inputs=torch.LongTensor([[0,2,1,1],[1,2,1,0]])
outputs=embedd(inputs)
print(outputs)

输出: ( n , w , e m b e d d i n g − d i m ) (n, w, embedding-dim) (n,w,embeddingdim)

tensor([[[ 0.9256,  1.1268, -0.5181, -0.2635,  1.9282],[-0.1351,  0.7759,  1.4697,  1.2225,  0.0390],[ 1.2778, -1.8050, -0.7962, -1.5764,  0.0664],[ 1.2778, -1.8050, -0.7962, -1.5764,  0.0664]],[[ 1.2778, -1.8050, -0.7962, -1.5764,  0.0664],[-0.1351,  0.7759,  1.4697,  1.2225,  0.0390],[ 1.2778, -1.8050, -0.7962, -1.5764,  0.0664],[ 0.9256,  1.1268, -0.5181, -0.2635,  1.9282]]],grad_fn=<EmbeddingBackward>)

因为初始化时* embedd=nn.Embedding(3,5)*
所以是 0, 1, 2三个词嵌入
当超过时就会报错
当inputs中出现 3

inputs=torch.LongTensor([[0,2,1,3],[1,2,1,3]])
outputs=embedd(inputs)
print(outputs)
#RuntimeError: index out of range: Tried to access index 3 out of table with 2 rows.#

torch.nn.Embedding的使用相关推荐

  1. torch.nn.Embedding理解

    Pytorch官网的解释是:一个保存了固定字典和大小的简单查找表.这个模块常用来保存词嵌入和用下标检索它们.模块的输入是一个下标的列表,输出是对应的词嵌入. torch.nn.Embedding(nu ...

  2. 模型的第一层:详解torch.nn.Embedding和torch.nn.Linear

    文章目录 1.概述 2.Embedding 2.1 nn.Linear 2.2 nn.Embedding 对比 初始化第一层 1.概述 torch.nn.Embedding是用来将一个数字变成一个指定 ...

  3. pytorch torch.nn.Embedding

    词嵌入矩阵,可以加载使用word2vector,glove API CLASS torch.nn.Embedding(num_embeddings: int, embedding_dim: int, ...

  4. torch.nn.Embedding()的固定化

    问题 最近在运行模型时,结果不稳定,所以尝试修改随机初始化的参数,使参数是随机初始化的,但是每次都一样 发现是用了 self.embed_user = nn.Embedding(user_num, f ...

  5. 正态分布初始化 torch.nn.Embedding.weight()与torch.nn.init.normal()的验证对比

    torch.nn.Embedding.weight(num_embeddings, embedding_dim) 随机初始化,生成标准正态分布N(0,1)N(0,1)N(0,1)的张量Tensor t ...

  6. torch.nn.Embedding()中的padding_idx参数解读

    torch.nn.Embedding() Word Embedding 词嵌入,就是把一个词典,随机初始化映射为一个向量矩阵. 列如:有一组词典,有两个词"hello"和" ...

  7. embedding = torch.nn.Embedding(10, 3)

    embedding = torch.nn.Embedding(10, 3) 通过 word embedding,就可以将自然语言所表示的单词或短语转换为计算机能够理解的由实数构成的向量或矩阵形式(比如 ...

  8. 【Pytorch基础教程28】浅谈torch.nn.embedding

    学习总结 文章目录 学习总结 一.nn.Embedding 二.代码栗子 2.1 通过embedding降维 2.2 RNN中用embedding改进 2.3 deepFM模型中embedding R ...

  9. torch.nn.Embedding

    在pytorch里面实现word embedding是通过一个函数来实现的:nn.Embedding 1 2 3 4 5 6 7 8 9 10 11 12 13 # -*- coding: utf-8 ...

最新文章

  1. 网易2017校招编程:计算糖果
  2. linux 内核中一个全局变量引发的性能问题
  3. BOOST使用 proto 转换进行任意类型操作的简单示例
  4. JQuery判断数组中是否包含某个元素$.inArray(js, arr);
  5. 地理必修一三大类岩石_高一地理必修一知识点总结归纳
  6. Linux shell 进制转换
  7. tensor数据类型转换_PyTorch的tensor数据类型及其相关转换
  8. Google十大真理带给中国网络公司的启示
  9. apscheduler Trigger
  10. Vue自定义组件封装及使用Excel
  11. 当当网Python图书数据分析
  12. 打印机服务器不支持1020,HP1020打印机驱动安装不上的解决办法
  13. 关于机械臂仿真的几款软件简介
  14. 软件测试的术语SRS,HLD,LLD,BD,FD,DD意义
  15. pyecharts学习笔记
  16. 莫队算法+带修莫队+回滚莫队
  17. 基于51单片机的智能汽车雨刮器的程序设计proteus仿真
  18. win10蓝屏代码_?联想电脑蓝屏的解决方法教程
  19. 一起来找茬:下面这段代码是让计算机在屏幕上输出“hi”。其中有三个错误,快来改正吧
  20. Node.js之 express写后端接口

热门文章

  1. Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic
  2. 看《神探夏洛克》经典台词
  3. 什么是快速连接器?如何选择合适的快速密封接头?
  4. 创宇技能表_知道创宇研发技能表v3.0 来了!
  5. 自制英语翻译(调用有道翻译接口)
  6. 取得system权限
  7. 跳板机的工作原理和简单的跳板机实现
  8. 让欺诈风险无处遁形的计算机视觉
  9. 基于Android系统的高精度定位SDK方案
  10. 安卓开发设置系统文件夹下图片为控件背景