nn.Embedding使用
nn.Embedding是一种词嵌入的方式,跟one-hot相似但又不同,会生成低维稠密向量,但是初始是随机化的,需要根据模型训练时进行调节,若使用预训练词向量模型会比较好。
1. one-hot
one-hot是给定每个单词一个索引,然后根据输入文本是否含有这个单词来决定向量。
单词 | 索引 |
---|---|
we | 0 |
have | 1 |
are | 2 |
any | 3 |
all | 4 |
excellent | 5 |
people | 6 |
… | … |
给定“We are all excellent people”,生成one-hot向量[1,0,1,0,1,1,1,0,0,…]
2. nn.Embedding
nn.Embedding参数设置:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None)
- num_embeddings:字典的大小,如上表中字典大小是5000,就写5000。
- embedding_dim:嵌入向量的维度,表示将单词编码为多少维的向量。
- padding_idx:填充索引,意思是填充的索引,向量值默认为全0(可以自定义),相当于unknown,对未知的词编码为零向量。
其他参数并不常用,详情可以参考官网
Note, embedding只接受LongTensor类型的数据,且数据不能大于等于词典大小。
3. 应用
3.1 示例一
import torch
from torch import nn
embedding = nn.Embedding(10, 3) #设置字典大小为10,维度为3
input = torch.LongTensor([[0,2,0,5]]) #设置为LongTensor
vector = embedding(input)
print(vector)
3.2 示例2
import torch
from torch import nn
embedding = nn.Embedding(10, 3, padding_idx=2) #设置字典大小为10,维度为3
input = torch.LongTensor([[0,2,0,5]])
vector = embedding(input)
print(vector)
print("查看词典:",embedding.weight) #weight可以查看全部词的向量
通过结果可以看到,词向量跟字典中的向量一一对应。
3.3 示例三
import torch
from torch import nn
padding_idx = 2
embedding = nn.Embedding(10, 3, padding_idx=2) #设置字典大小为10,维度为3
input = torch.LongTensor([[0,2,0,5]])
vector = embedding(input)
with torch.no_grad(): embedding.weight[padding_idx] = torch.ones(3) #设置填充向量
print(vector)
print("查看词典:",embedding.weight) #weight可以查看全部词的向量
目前在自然语言中,主要使用词向量模型生成词向量,如word2vec,glove之类的静态词向量模型,BERT、RoBERT之类的动态词向量模型,或者使用超大模型bloom等,效果都比之前要好,具体使用那种模型根据具体情况而定。
nn.Embedding使用相关推荐
- Pytorch的默认初始化分布 nn.Embedding.weight初始化分布
一.nn.Embedding.weight初始化分布 nn.Embedding.weight随机初始化方式是标准正态分布 ,即均值$\mu=0$,方差$\sigma=1$的正态分布. 论据1--查看 ...
- pytorch nn.Embedding
pytorch nn.Embedding class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_n ...
- torch.nn.Embedding理解
Pytorch官网的解释是:一个保存了固定字典和大小的简单查找表.这个模块常用来保存词嵌入和用下标检索它们.模块的输入是一个下标的列表,输出是对应的词嵌入. torch.nn.Embedding(nu ...
- torch.nn.Embedding
在pytorch里面实现word embedding是通过一个函数来实现的:nn.Embedding 1 2 3 4 5 6 7 8 9 10 11 12 13 # -*- coding: utf-8 ...
- 深入理解padding_idx(nn.Embedding、nn.Embedding.from_pretrained)
这个参数出现在一些地方,例如: nn.Embedding.nn.Embedding.from_pretrained. import torch import torch.nn as nn import ...
- 模型的第一层:详解torch.nn.Embedding和torch.nn.Linear
文章目录 1.概述 2.Embedding 2.1 nn.Linear 2.2 nn.Embedding 对比 初始化第一层 1.概述 torch.nn.Embedding是用来将一个数字变成一个指定 ...
- pytorch 笔记: torch.nn.Embedding
pytorch中,针对词向量有一个专门的层nn.Embedding,用来实现词与词向量的映射. nn.Embedding具有一个权重,形状是(num_embeddings,embedding_dime ...
- pytorch torch.nn.Embedding
词嵌入矩阵,可以加载使用word2vector,glove API CLASS torch.nn.Embedding(num_embeddings: int, embedding_dim: int, ...
- torch.nn.Embedding()的固定化
问题 最近在运行模型时,结果不稳定,所以尝试修改随机初始化的参数,使参数是随机初始化的,但是每次都一样 发现是用了 self.embed_user = nn.Embedding(user_num, f ...
- 正态分布初始化 torch.nn.Embedding.weight()与torch.nn.init.normal()的验证对比
torch.nn.Embedding.weight(num_embeddings, embedding_dim) 随机初始化,生成标准正态分布N(0,1)N(0,1)N(0,1)的张量Tensor t ...
最新文章
- 【测试】ESP32天线信号强度比较,小龟小车A2天线esp32cam板载外置天线测试数据...
- 13.4.虚拟化工具--jmap详解
- 工控随笔_01_西门子_安装西门子软件提示重启解决方法。
- [UE4]自动旋转组件
- PHP中问号?和冒号: 的作用
- css 关闭按钮实现,CSS做的关闭按钮动效
- 给plt.axvline设置图例(label)
- php mysql 链表_浅谈PHP链表数据结构(单链表)
- python cmath模块_python中math模块常用的方法整理
- 4 计数器verilog与Systemverilog编码
- RFIC4463_F3CD
- UVA12169 Disgruntled Judge
- Proteus软件的安装与使用方法(超详细)
- java档案管理系统_基于JAVA的简单档案管理系统
- 2021考研数学二汤家凤接力题典1800【解答册】
- Python字符串逆序输出六种方法
- 0xl c语言中003是整形常量,整型常量
- 2021上海第34届创业连锁加盟展会
- 寄居蟹与海葵是一对合作互助的共栖伙伴。海葵是寄居蟹最称职的门卫。它用有毒的触角去蜇那些敢来靠近它们的所有动物,保护寄居蟹。 而寄居蟹则背着行动困难的海葵,四出觅食,有福同享。但并不是所有寄居蟹和海
- Windows10下安装Centos7系统及常见问题
热门文章
- Project ‘cv_bridge‘ specifies ‘/usr/include/opencv‘ as an include dir, which is not found的解决方法
- 兰海说成长|让孩子端正态度,你也许用错了方法
- 应用安全加上游戏盾,为您业务保驾护航
- git权威指南总结五:git克隆
- 无线接入点与无线路由器有什么区别?
- ELINK离线编程器版本说明
- 【 js中通过键盘上下左右移动图片】
- surfacei5用matlab,良心爆料微软Surface Pro 6怎么样呢?评测好不好?老司机指教诉说...
- 云服务器、虚拟主机、VPS有什么具体区别?
- 【通知】Mr.张小白