一、nn.Embedding.weight初始化分布

nn.Embedding.weight随机初始化方式是标准正态分布  ,即均值$\mu=0$,方差$\sigma=1$的正态分布。

论据1——查看源代码

## class Embedding具体实现(在此只展示部分代码)
import torch
from torch.nn.parameter import Parameterfrom .module import Module
from .. import functional as Fclass Embedding(Module):def __init__(self, num_embeddings, embedding_dim, padding_idx=None,max_norm=None, norm_type=2, scale_grad_by_freq=False,sparse=False, _weight=None):if _weight is None:self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))self.reset_parameters()else:assert list(_weight.shape) == [num_embeddings, embedding_dim], \'Shape of weight does not match num_embeddings and embedding_dim'self.weight = Parameter(_weight)def reset_parameters(self):self.weight.data.normal_(0, 1)if self.padding_idx is not None:self.weight.data[self.padding_idx].fill_(0)

Embedding这个类有个属性weight,它是torch.nn.parameter.Parameter类型的,作用就是存储真正的word embeddings。如果不给weight赋值,Embedding类会自动给他初始化,看上述代码第6~8行,如果属性weight没有手动赋值,则会定义一个torch.nn.parameter.Parameter对象,然后对该对象进行reset_parameters(),看第21行,对self.weight先转为Tensor在对其进行normal_(0, 1)(调整为$N(0, 1)$正态分布)。所以nn.Embeddig.weight默认初始化方式就是N(0, 1)分布,即均值$\mu=0$,方差$\sigma=1$的标准正态分布。

论据2——简单验证nn.Embeddig.weight的分布

下面将做的是验证nn.Embeddig.weight某一行词向量的均值和方差,以便验证是否为标准正态分布。
注意:验证一行数字的均值为0,方差为1,显然不能说明该分布就是标准正态分布,只能是其必要条件,而不是充分条件,要想真正检测这行数字是不是正态分布,在概率论上有专门的较为复杂的方法,请查看概率论之假设检验。

import torch.nn as nn# dim越大,均值、方差越接近0和1
dim = 800000
# 定义了一个(5, dim)的二维embdding
# 对于NLP来说,相当于是5个词,每个词的词向量维数是dim
# 每个词向量初始化为正态分布 N(0,1)(待验证)
embd = nn.Embedding(5, dim)
# type(embd.weight) is Parameter
# type(embd.weight.data) is Tensor
# embd.weight.data[0]是指(5, dim)的word embeddings中取第1个词的词向量,是dim维行向量
weight = embd.weight.data[0].numpy()
print("weight: {}".format(weight))weight_sum = 0
for w in weight:weight_sum += w
mean = weight_sum / dim
print("均值: {}".format(mean))square_sum = 0
for w in weight:square_sum += (mean - w) ** 2
print("方差: {}".format(square_sum / dim))

代码输出:

weight: [-0.65507996  0.11627434 -1.6705967  ...  0.78397447  ...  -0.13477565]
均值: 0.0006973597864689242
方差: 1.0019535550544454

可见,均值接近0,方差接近1,从这里也可以反映出nn.Embeddig.weight是标准正态分布$N(0, 1)$。

二、torch.Tensortorch.tensortorch.randn初始化分布

1、torch.rand

返回$[0,1)$上的均匀分布(uniform distribution)。

2、torch.randn

返回$N(0, 1)$,即标准正态分布(standard normal distribution)。

3、torch.Tensor

torch.Tensor是Tensor class,torch.Tensor(2, 3)是调用Tensor的构造函数,构造了$2\times3$矩阵,但是没有分配空间,未初始化。
不推荐使用torch.Tensor创建Tensor,应使用torch.tenstortorch.onestorch.zerostorch.randtorch.randn等,原因:

t = torch.Tensor(2,3)
# 容易出现下述错误,因为t中的值取决当前内存中的随机值
# 如果当前内存中随机值特别大会溢出
RuntimeError: Overflow when unpacking long

Pytorch的默认初始化分布 nn.Embedding.weight初始化分布相关推荐

  1. 正态分布初始化 torch.nn.Embedding.weight()与torch.nn.init.normal()的验证对比

    torch.nn.Embedding.weight(num_embeddings, embedding_dim) 随机初始化,生成标准正态分布N(0,1)N(0,1)N(0,1)的张量Tensor t ...

  2. pytorch tensor 初始化_Pytorch - nn.init 参数初始化方法

    Pytorch 的参数初始化 - 给定非线性函数的推荐增益值(gain value):nonlinearity 非线性函数gain 增益 Linear / Identity1 Conv{1,2,3}D ...

  3. pytorch的词嵌入函数nn.Embedding

    看下面的代码,注释及输出就可以理解了. import torch import torch.nn as nn''' 回想一下词嵌入的理论知识,我深度学习的系列文章也可以 num_embeddings= ...

  4. 模型的第一层:详解torch.nn.Embedding和torch.nn.Linear

    文章目录 1.概述 2.Embedding 2.1 nn.Linear 2.2 nn.Embedding 对比 初始化第一层 1.概述 torch.nn.Embedding是用来将一个数字变成一个指定 ...

  5. 【Pytorch基础教程28】浅谈torch.nn.embedding

    学习总结 文章目录 学习总结 一.nn.Embedding 二.代码栗子 2.1 通过embedding降维 2.2 RNN中用embedding改进 2.3 deepFM模型中embedding R ...

  6. pytorch中nn.Embedding和nn.LSTM和nn.Linear

    使用pytorch实现一个LSTM网络很简单,最基本的有三个要素:nn.Embedding, nn.LSTM, nn.Linear 基本框架为: class LSTMModel(nn.Module): ...

  7. 什么是embedding(把物体编码为一个低维稠密向量),pytorch中nn.Embedding原理及使用

    文章目录 使embedding空前流行的word2vec 句子的表达 训练样本 损失函数 输入向量表达和输出向量表达vwv_{w}vw​ 从word2vec到item2vec 讨论环节 pytorch ...

  8. pytorch nn.Embedding

    pytorch nn.Embedding class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_n ...

  9. pytorch torch.nn.Embedding

    词嵌入矩阵,可以加载使用word2vector,glove API CLASS torch.nn.Embedding(num_embeddings: int, embedding_dim: int, ...

最新文章

  1. ABP理论学习之日志记录
  2. 【洛谷】NOIP2018原创模拟赛DAY1解题报告
  3. ubuntu无法设置亮度,触摸板失效,声音无法调节
  4. OrCAD Library Builder使用
  5. 项目管理甘特图模板_甘特图简易制作流程
  6. PHP发卡自动源码,PHP自动化售货发卡网源码
  7. 原生JavaScript开发高级课程 |智能S
  8. linux过滤端口抓包_linux抓包命令tcpdump
  9. 基于LED的室内可见光通信系统
  10. 谷歌浏览器配置微信浏览器_微信网页版 - Chrome社交与通讯插件 - 画夹插件网
  11. Pigsty是什么?
  12. HTML5系列代码:使用空格符号
  13. NodeJS:Express 框架实战解析视频教程
  14. android scrollview滚动条初始位置,ScrollView 设置滚动条的位置
  15. 共模干扰和差模干扰的理解
  16. 没错,Linux需要更多的憎恨者
  17. 软件工程学科对人类社会和生活的重要意义_2019-2020全国软件工程专业大学排名,高考生志愿填报看过来...
  18. Jetty篇教程 之Jetty 嵌入式服务器
  19. Ubuntu Windows双系统切换最简方法!!!
  20. 百度API加载离线百度电子地图和卫星切片

热门文章

  1. 2022-2028年现代农业背景下中国家庭农场深度调研及投资前景预测报告
  2. Go 学习笔记(65)— Go 中函数参数是传值还是传引用
  3. NumPy — 创建全零、全1、空、arange 数组,array 对象类型,astype 转换数据类型,数组和标量以及数组之间的运算,NumPy 数组共享内存
  4. python学习之pip常用命令
  5. Kali2021.2 VMware最新版安装步骤
  6. 【Sql Server】DateBase-SQL安全
  7. Google Colab 免费GPU服务器使用教程 挂载云端硬盘
  8. LeetCode简单题之学生分数的最小差值
  9. LeetCode简单题之猜数字大小
  10. ALD对照CVD淀积技术的优势