文章目录

  • 1.概述
  • 2.Embedding
    • 2.1 nn.Linear
    • 2.2 nn.Embedding
  • 对比
  • 初始化第一层

1.概述

torch.nn.Embedding是用来将一个数字变成一个指定维度的向量的,比如数字1变成一个128维的向量,数字2变成另外一个128维的向量。不过,这128维的向量并不是永恒不变的,这些128维的向量是模型真正的输入(也就是模型的第1层)(数字1和2并不是,可以算作模型第0层),然后这128维的向量会参与模型训练并且得到更新,从而数字1会有一个更好的128维向量的表示。

显然,这非常像全连接层,所以很多人说,Embedding层是全连接层的特例。

2.Embedding

import numpy as np
import torch.nn as nn
import torch

比如我们有2个字,

vocab={"我":0,"你":1}

我们要把这两个字变成向量,有两种做法:

2.1 nn.Linear

vocab={"我":0,"你":1}
vocab_vec=torch.eye(2)#构造one-hot向量,所以需要用两个2维的向量表示这两个字。
print(vocab_vec)

tensor(
[[1., 0.],
[0., 1.]]
)

但是众所周知one-hot向量没有任何的语义信息,而且在这个one-hot中,空间庞大。我们需要一个低维稠密的向量来代替one-hot向量。很简单,一个线性层即可。

fc=nn.Linear(2,2)#表示一个字原来向量维度是2,现在还是线性变换成2.
fc(vocab_vec)


上面才是我们想要的这两个字的表示,然后将上述输入到模型的后续层中,进行训练,这样,由于线性层fc的参数会不断变化,所以上面这个数值fc(vocab_vec)当然也会随之而变化喽。

2.2 nn.Embedding

使用这个会更简单,更方便一些。我们不需要构造one-hot向量,我们开头说过了,Embedding层将一个数字直接转化为你想要的维度的向量,这是好事!比如你内存不够的时候,就不用像上面这种做法那样需要多存储一个one-hot的矩阵。

vocab={"我":0,"你":1}
embedding=torch.nn.Embedding(vocab_size=2,emb_size=2)
#vocab_size:表示一共有多少个字需要embedding,
#emb_size:表示我们希望一个字向量的维度是多少。

然后,我们想要得到我们那两个字(“我”,“你”)的向量,只需要将"我"和"你"的编号输入即可,而不需要那个one-hot向量,比如"我"的编号是0,"你"的编号是1:

me=torch.tensor([0],dtype=torch.int64)
you=torch.tensor([1],dtype=torch.int64)

然后将上述作为参数,传入embedding层。

print(embedding(me))
print(embedding(you))

tensor([[0.6188, 1.5322]], grad_fn=<EmbeddingBackward>)
tensor([[-0.8198, -0.9139]], grad_fn=<EmbeddingBackward>)

就得到了上述两个字的向量。当然了,我们也可以合并起来得到:

meyou=torch.tensor([0,1],dtype=torch.int64)
print(embedding(meyou))


同样,上述得到的向量输入到模型的后续中,训练后,Embedding层的参数会随之改变,从而我们得到了更好的字向量。

对比

可以看到,Embedding和Linear几乎是一样的,区别就在于:输入不同,一个是输入数字,后者是输入one-hot向量。习惯上,我们在模型的第一层使用的是Embedding,而不是Linear。模型的后续不会再使用Embedding,而是使用Linear。

初始化第一层

补充:上述我们在做定义的时候,里面的参数是初始化的。

fc=nn.Linear(2,2)
embedding=torch.nn.Embedding(vocab_size,emb_size)

有的时候,我们可能觉得其初始化得不好,想要按照我们的来指定,怎么办?

1.使用nn.Parameter直接进行赋值。

myemb=nn.Parameter(torch.rand(2,2))#(0,1)均匀分布的参数
print(fc.weight)#原来的
fc.weight=myemb#修改
print(embedding.weight)#原来的
embedding.weight=myemb#修改


我们再查看修改后的:

2.使用.data:

embedding.weight.data.uniform_(-1,1)#又改成(-1,1)的均匀分布。
embedding.weight#查看一下

3.nn.Embedding.from_pretrained

以上两种方法对于Embedding和Linear都适用。接下来的一个方法只适用于Embedding。

a=torch.tensor([[1,2],[2,2]],dtype=torch.float32)
embedding=nn.Embedding.from_pretrained(a)
embedding.weight

模型的第一层:详解torch.nn.Embedding和torch.nn.Linear相关推荐

  1. 隐马尔可夫模型之Baum-Welch算法详解

    隐马尔可夫模型之Baum-Welch算法详解 前言 在上篇博文中,我们学习了隐马尔可夫模型的概率计算问题和预测问题,但正当要准备理解学习问题时,发现学习问题中需要EM算法的相关知识,因此,上一周转而学 ...

  2. 一文读懂NLP之隐马尔科夫模型(HMM)详解加python实现

    一文读懂NLP之隐马尔科夫模型(HMM)详解加python实现 1 隐马尔科夫模型 1.1 HMM解决的问题 1.2 HMM模型的定义 1.2.1HMM的两个假设 1.2.2 HMM模型 1.3 HM ...

  3. Java内存模型(JMM)详解-可见性volatile

    这里写自定义目录标题 Java内存模型(JMM)详解-可见性 什么是JMM JMM存在的意义 为什么示例demo中不会打印 i 的值 如何解决可见性问题 **深入理解JMM内存模型** JAVA内存模 ...

  4. 从零开始学前端 - 7. CSS盒模型 margin和padding详解

    作者: 她不美却常驻我心 博客地址: https://blog.csdn.net/qq_39506551 微信公众号:老王的前端分享 每篇文章纯属个人经验观点,如有错误疏漏欢迎指正.转载请附带作者信息 ...

  5. 数据库系统模式(schema)和模型(model)详解

     数据库系统模式(schema)和模型(model)详解 数据(data)是描述事物的符号记录. 模型(Model)是现实世界的抽象. 数据模型(Data Model)是数据特征的抽象,是数据库管 ...

  6. 机器学习——时间序列ARIMA模型(一):差分法详解

    机器学习--时间序列ARIMA模型(一):差分法详解 一.所需数据的性质 平稳性 样本数据需随着时间序列而发生变化,且序列的均值和方差不发生明显变化. 预测出在未来的一段期间内数据顺着现有的" ...

  7. Diffusion 扩散模型(DDPM)详解及torch复现

    文章目录 torch复现 第1步:正向过程=噪声调度器 Step 2: 反向传播 = U-Net Step 3: 损失函数 采样 Training 我公众号文章目录综述: https://wanggu ...

  8. PointNet模型的Pytorch代码详解

    前言 关于PointNet模型的构成.原理.效果等等论文部分内容,我在之前一篇论文中写到过,可以参考这个链接:PointNet论文笔记    下边我就直接放一张网络组成图,并对代码进行解释,我以一种比 ...

  9. BilSTM 实体识别_NLP-入门实体命名识别(NER)+Bilstm-CRF模型原理Pytorch代码详解——最全攻略

    最近在系统地接触学习NER,但是发现这方面的小帖子还比较零散.所以我把学习的记录放出来给大家作参考,其中汇聚了很多其他博主的知识,在本文中也放出了他们的原链.希望能够以这篇文章为载体,帮助其他跟我一样 ...

最新文章

  1. 【分治】P1228 地毯填补问题(多联骨牌覆盖棋盘问题)(递归,分治)难度⭐⭐⭐
  2. javaCountDownLatch闭锁
  3. Python之IO编程
  4. matlab窗函数带通滤波器,Matlab结合窗函数法设计数字带通FIR滤波器
  5. 今天研究 Client本来是关联的Expression接口,笔记记录一下。
  6. linux core文件乱码,.net core在linux下图片中文乱码
  7. plsql创建中文表头_不安装oracle连接plsql,Oracle instantclient安装详解
  8. mybatis简明教程
  9. md5修改器v1.0
  10. 家长进课堂 计算机ppt,家长进课堂之中华传统美德 成品ppt 三井小学一10班出品.ppt...
  11. 用Python分析了我的微信好友,原来我身边都是这样的人……绝了
  12. 全国社会消费品零售总额ARIMA建模分析
  13. unity 实验演示 教程_Unity的演示团队– Unity最出色的视觉效果背后的创造者
  14. 熊猫烧香被恶搞,网友爆笑诗词句大集合
  15. 云文件共享服务器,云文件共享服务器
  16. 切换windows系统输入法的中英文,可以忽视是哪种打字法
  17. WINCE系统调用的本质
  18. 石榴算法1.0——打击买卖超链
  19. IBM带库加磁带操作
  20. html语言中i i,html元素 i 标签的使用方法及作用

热门文章

  1. ubuntu 16.04 多个python版本切换
  2. 全球首个可繁殖活体机器人问世:AI参与设计,已自我繁殖4代
  3. 旗帜鲜明地反对“码而优则仕”
  4. 深度解析工业软件:研究框架(140页)
  5. 115页Slides带你领略深度生成模型全貌(附PPT)
  6. 如何看待「TensorFlow就是一颗定时炸弹」的说法?
  7. 国内研究生不小心跟了一个水货导师是什么样的体验?
  8. 一高校公示拟聘用人员信息,多为大龄“双非”土博,好像也没那么卷……
  9. 一张清华大学教授工资单曝光!想象与现实天壤之别……
  10. 【目标检测基础积累】常用的评价指标