本文主要复现的是word2vec中基于negative sampling的skip-gram模型,原文是《Distributed Representations of Words and Phrases and their Compositionality》。word2vec原理部分请戳这里(虽然很长,但是真的干货满满,强推!)
      什么是word_embedding呢?简单来说,就是用一个向量来表示一个word。
      比如我们现在有10个word,我们想用数字来表示这些word,使得我们一看到数字就能知道它代表的是什么单词,那么最简单的操作就是one-hot。但是当有1k个、1w个单词时,使用one-hot就会使得维数很大,并且one-hot还不能显示出单词之间的相关性。
      如果使用word-embedding,假设embedding-size设置为3,那么每个单词就用3个数字来表示。此时,我们可认为这10个单词存在于一个3维空间,3个数字代表了各自的x\y\z坐标,那么只要看到数字,我们就能在这3维空间中定位到这个word;并且根据比较两个单词之间的x\y\z坐标具体数值,我们还能计算它们之间的相关性。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tu
from  collections import Counter
import numpy as np
import random
import math
import scipy
from sklearn.metrics.pairwise import cosine_similarityrandom.seed(1)
np.random.seed(1)
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
if use_cuda:torch.cuda.manual_seed(1)

设置一些超参数

c = 2 #context window
k = 100 #负采样的个数 number of negative samples
num_eopchs =10
max_vocab_size = 30000
batch_size = 128
lr = 0.2
embedding_size = 100

载入文件,进行数据预处理

with open('../第二课资料/text8.train.txt','r') as f:text = f.read()
print(text[:148])
text = text.split()
#原始文件很干净,全部是英文单词,没有标点符号,并且使用空格隔开,所以这里只使用split就可以。
anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans
vocab = dict(Counter(text).most_common(max_vocab_size-1)) # 留一个位置给unkvocab['<unk>'] = len(text) - np.sum(list(vocab.values()))
# len(text)是所有单词出现次数加总
# vocab.values()记录的是最常见的3000个单词每个词出现的次数
# 因此,np.sum(list(vocab.values()))就是这些常见词出现的总次数
# 那么len(text)-np.sum(list(vocab.values()))就是不常见单词出现的总次数idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:i for i,word in enumerate(idx_to_word,)}
# idx_to_word 是一个list 包含了30000个单词
# word_to_idx 是一个dictionary key是单词,value是该单词在idx_to_word中的索引。它的作用是将word进行编码,code就是该词在idx_to_word中的索引
# ps: vocab也是一个dictionary,key是单词,value是该单词出现的次数
# 每个单词出现的频率
word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
# 然后文章中对频率做了一个处理,这里我们也照做
word_freqs = word_freqs**(3./4.)
word_freqs = word_freqs/np.sum(word_freqs)
class embedding_dataset(tud.Dataset):def __init__(self,text,word_to_idx,idx_to_word,word_freqs):super(embedding_dataset,self).__init__()self.text = textself.text_encoded = torch.LongTensor([word_to_idx.get(word,word_to_idx['<unk>']) for word in text])# 将text中的每一个单词进行编码,编码规则就是word_to_idxself.word_to_idx = word_to_idxself.idx_to_word = idx_to_wordself.word_freqs = torch.Tensor(word_freqs)def __len__(self):return len(self.text_encoded)# len(text_encoded) = len(text) = 15313011(有重复)def __getitem__(self,idx):center_word_code = self.text_encoded[idx]pos_indices = list(range(idx-c,idx)) + list(range(idx+1,idx+c+1))pos_indices = [i%len(self.text_encoded) for i in pos_indices]# 以上这行代码是防止idx小于c,以及idx+1到idx+c超过范围# 以环形取索引pos_words_code = self.text_encoded[pos_indices]neg_words_code = torch.multinomial(self.word_freqs,k*pos_words_code.shape[0],True)# 对input(就是第一个参数)的每一行做num_samples次取值(num_samples不能超过每一行的元素个数),输出的张量是每一次取值时该元素在该行的索引。True表示有放回抽样# input的值必须大于等于0,值越大,在该行中越容易被抽中。若为0,并且最后一个参数为True,那么0永远不会被抽中return center_word_code,pos_words_code,neg_words_code
dataset = embedding_dataset(text,word_to_idx,idx_to_word,word_freqs)
dataloader = tud.DataLoader(dataset,batch_size=batch_size,shuffle=True)center,pos,neg = next(iter(dataloader))
print(center.size(),pos.size(),neg.size())
# 每个batch中,有128个中心词,每个中心词对应4个postive_words(窗口大小2)、每个postive_words对应100个negative_words
print(center[:10],'\n',pos[:5],'\n',neg[:5])
torch.Size([128]) torch.Size([128, 4]) torch.Size([128, 400])
tensor([   26,     1,   375,  1634, 29999,     1,  3170,   594,    24,    23]) tensor([[  282,    99,    10,     5],[  801,   523,   145,     2],[  296, 15248,    56,     1],[ 1949,  2429,  9211,     2],[   15,     7,     7,    16]]) tensor([[  127,  1764,    16,  ...,  1424,   862,  4194],[ 1527,    56, 19949,  ...,   747,   995,   272],[    4, 26812,    91,  ...,     7, 13702,    30],[ 1952, 22953, 27444,  ...,  2814,   596,  2603],[ 2007, 29999,  5844,  ..., 29999,   960,  3189]])

定义模型

class embeddingmodel(nn.Module):def __init__(self,vocab_size,embed_size):super(embeddingmodel,self).__init__()self.vocab_size = vocab_size #3000self.embed_size = embed_size #100initrange = 0.5/self.embed_sizeself.in_embed = nn.Embedding(self.vocab_size,self.embed_size,sparse=False)self.in_embed.weight.data.uniform_(-initrange,initrange)self.out_embed = nn.Embedding(self.vocab_size,self.embed_size,sparse=False)self.out_embed.weight.data.uniform_(-initrange,initrange)def forward(self,center_word_code,pos_words_code,neg_words_code):batch_size = center_word_code.size(0)input_embedding = self.in_embed(center_word_code) # b*embed_sizepos_embedding = self.out_embed(pos_words_code) # (b)*(2c)*(embed_size)neg_embedding = self.out_embed(neg_words_code) # (b)*(2c*k)*embed_sizelog_pos = torch.bmm(pos_embedding,input_embedding.unsqueeze(2)).squeeze()# bmm之后的size是 b*2c*1 然后squeeze 变成 b*2c# log_pos表示的center_word和自己周围的单词之间的相关性,如果embedding效果好的话,那么相关性越高,log_pos越大。log_neg = torch.bmm(neg_embedding,-input_embedding.unsqueeze(2)).squeeze()# bmm之后的size是b*2ck*1 squeeze之后变成b*2ck# log_neg表示center_word与负采样得到的word之间的相关性,因为是负采样,所以embedding效果好的话,那么它们之间的相关性应该较低,log_neg越大(因为input_embedding前面加了一个负号)log_pos = F.logsigmoid(log_pos).sum(1) #logsigmoid后size不变,sum(1)之后size变成b*1log_neg = F.logsigmoid(log_neg).sum(1) #...,...loss = log_pos+log_neg # loss的size为b*1return -loss# 不断迭代的结果是-loss越来越小,即loss越来越大,loss越来越大,说明pos_embedding与input_embedding越来越接近、neg_embedding与input_embedding越来越不接近def input_embeddings(self):# 按照论文作者的意思,in_embed的效果较好,所以最后我们取出in_embedreturn self.in_embed.weight.data.cpu().numpy()

初始化模型,如果有GPU,就把模型部署到GPU上

model = embeddingmodel(max_vocab_size,embedding_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
embeddingmodel((in_embed): Embedding(30000, 100)(out_embed): Embedding(30000, 100)
)
optimizer = torch.optim.SGD(model.parameters(),lr=0.2)# 开始训练
for epoch in range(2):for batch_idx,(center_word_code,pos_words_code,neg_words_code) in enumerate(dataloader):center_word_code = center_word_code.long()pos_words_code = pos_words_code.long()neg_words_code = neg_words_code.long()if use_cuda:center_word_code = center_word_code.to(device)pos_words_code = pos_words_code.to(device)neg_words_code = neg_words_code.to(device)optimizer.zero_grad()loss = model(center_word_code,pos_words_code,neg_words_code).mean()loss.backward()optimizer.step()if batch_idx % 100 == 0 :print('epoch:{},batch_idx:{},loss:{}'.format(epoch,batch_idx,loss))
epoch:0,batch_idx:0,loss:280.031494140625
epoch:0,batch_idx:100,loss:223.23663330078125
epoch:0,batch_idx:200,loss:175.07229614257812
epoch:0,batch_idx:300,loss:141.5162353515625
epoch:0,batch_idx:400,loss:120.5655517578125
epoch:0,batch_idx:500,loss:116.16273498535156
... ...
epoch:0,batch_idx:119400,loss:19.99425506591797
epoch:0,batch_idx:119500,loss:19.939586639404297
epoch:0,batch_idx:119600,loss:19.87261199951172
... ...
embedding_weights = model.in_embed.weight.data.cpu().numpy()
# 原文中说in_embed的效果要比out_embed的效果要好,所以我们这里也按照原文来
# embedding_weights的size为(vocab_size,embedding_size) 即(30000,100),每行代表vocab中的一个单词def find_nearest(word):index = word_to_idx[word] # 拿到指定单词在词典中的索引embedding = embedding_weights[index] # 根据索引得到该词的embedding表示cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights]) # 计算该词与vocab中所有单词的余弦相似度return [idx_to_word[i] for i in cos_dis.argsort()[:10]] #选出余弦相似度最高的10个单词# 展示与所选word含义最相近的10个word
for word in ["good", "fresh", "monster", "green", "like", "america", "chicago", "work", "computer", "language"]:print(word, find_nearest(word))print('~'*80)
good ['good', 'bad', 'perfect', 'practical', 'poor', 'false', 'hard', 'unique', 'evil', 'truth']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
fresh ['fresh', 'warm', 'thermal', 'drinking', 'smooth', 'static', 'minimal', 'grain', 'clean', 'mild']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
monster ['monster', 'cow', 'clown', 'hammer', 'golem', 'triangle', 'demon', 'pig', 'giant', 'serpent']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
green ['green', 'blue', 'yellow', 'white', 'cross', 'orange', 'mountain', 'black', 'salt', 'crescent']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
like ['like', 'whereas', 'etc', 'unlike', 'similarly', 'amongst', 'involving', 'rich', 'plant', 'containing']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
america ['america', 'korea', 'africa', 'carolina', 'india', 'korean', 'australia', 'pakistan', 'japan', 'indonesia']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
chicago ['chicago', 'boston', 'texas', 'massachusetts', 'illinois', 'london', 'harvard', 'florida', 'cambridge', 'michigan']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
work ['work', 'writing', 'songs', 'job', 'experiments', 'dylan', 'experience', 'experiment', 'marx', 'speech']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
computer ['computer', 'digital', 'electronic', 'video', 'graphics', 'software', 'hardware', 'audio', 'programs', 'computers']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
language ['language', 'languages', 'alphabet', 'grammar', 'spelling', 'dialect', 'arabic', 'pronunciation', 'dialects', 'translation']
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

pytorch实现word_embedding(negative_sampling的skip-gram模型)相关推荐

  1. 《自然语言处理学习之路》02 词向量模型Word2Vec,CBOW,Skip Gram

    本文主要是学习参考莫烦老师的教学,对老师课程的学习,记忆笔记. 原文链接 文章目录 书山有路勤为径,学海无涯苦作舟. 零.吃水不忘挖井人 一.计算机如何实现对于词语的理解 1.1 万物数字化 1.2 ...

  2. 【Pytorch基础教程34】EGES召回模型

    note 文章目录 note 一.EGES图算法 1.0 回顾GNN 1.1 基本定义和数据预处理 1.2 GES: GNN with side info 1.3 EGES: enhanced版本 二 ...

  3. 目标检测-基于Pytorch实现Yolov3(1)- 搭建模型

    原文地址:https://www.cnblogs.com/jacklu/p/9853599.html 本人前段时间在T厂做了目标检测的项目,对一些目标检测框架也有了一定理解.其中Yolov3速度非常快 ...

  4. 【Pytorch神经网络理论篇】 35 GaitSet模型:步态识别思路+水平金字塔池化+三元损失

    代码: [Pytorch神经网络实战案例]28 GitSet模型进行步态与身份识别(CASIA-B数据集)_LiBiGor的博客-CSDN博客1 CASIA-B数据集本例使用的是预处理后的CASIA- ...

  5. (pytorch-深度学习系列)pytorch实现多层感知机(自动定义模型)对Fashion-MNIST数据集进行分类-学习笔记

    pytorch实现多层感知机(自动定义模型)对Fashion-MNIST数据集进行分类 导入模块: import torch from torch import nn from torch.nn im ...

  6. Facebook 发布 PyTorch Hub:一行代码实现经典模型调用!

    作者 | Team PyTorch 译者 | Monanfei 责编 | 夕颜 出品 | AI科技大本营(ID: rgznai100) 6月11日,Facebook PyTorch 团队推出了全新 A ...

  7. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  8. Rasa课程、Rasa培训、Rasa面试、Rasa实战系列之Understanding Word Embeddings CBOW and Skip Gram

    Rasa课程.Rasa培训.Rasa面试.Rasa实战系列之Understanding Word Embeddings CBOW and Skip Gram 字嵌入 从第i个字符,第i+1个字符预测第 ...

  9. pytorch笔记: 搭建Skip—gram

    skip-gram 理论部分见:NLP 笔记:Skip-gram_刘文巾的博客-CSDN博客 1 导入库 import numpy as np import torch from torch impo ...

最新文章

  1. 谷歌兄弟公司Wing将于10月开始试点无人机配送
  2. 电视机当计算机屏幕,怎么实现电视机当电脑的显示器和音箱用?
  3. python3.7和3.8的区别-python3.8.0与3.7.0哪个好?
  4. 深入理解DOM节点类型第一篇——12种DOM节点类型概述
  5. C语言程序设计的常用算法
  6. Android 系统(264)---android进阶——自定义View
  7. 《图形学》实验一:钻石图案
  8. ios手机怎么连接adb命令_没有 mac 的福音,windows 下对 ios 进行操作 (类似 android 的 adb 操作)...
  9. 声源测向: TDOA-GCC-PATH方法
  10. 笨方法学Python-1
  11. matlab双闭环绘图,matlab双闭环直流调速系统设计及仿真+电路图
  12. 如何解决控件附件上传时超大附件无法上传的问题
  13. EOF到底是什么意思?
  14. 阿基里斯与乌龟的悖论
  15. Kanzi for Android Demo
  16. 这几天来的第一篇日志
  17. 网页自动关机代码HTML,电脑如何自动关机
  18. 电商平台-安全设计与架构
  19. Drop Shipment PO以及Replenishment PO有何异同?
  20. 单页活动页面html,优秀HTML5活动页面

热门文章

  1. 2022年机动车新规,外地人上京牌不需要居住证啦
  2. word操作:单独调整英文字体
  3. 利用Matlab解决线性规划问题并绘制特定形状的空间曲面(约束区域的绘图)
  4. Photoshop 2023 Mac(PS 2023)v24.0.0中英文已发布,新功能详细介绍,支持M1/M2/intel
  5. [stanford NLP] 原理小结
  6. 科普:飞针测试机探针分类概要
  7. SharePoint站点图片轮转器imageRotator
  8. Int相乘为负数问题
  9. 【2019-TGRS】Aerial LaneNet: Lane-Marking Semantic Segmentation in Aerial Imagery Using Wavelet-Enhanc
  10. 我的机器学习支线「模型复杂度」