原文在这里,总结一下有关CBOW模型的原理,代码是搬运的。Skip-gram模型与之类似,就不展开详细说明了。有理解不正确的地方请指正。

模型架构


CBOW模型包含三层:输入层,投影层,输出层。与NNML相比,去掉了隐藏层。

CBOW是根据上下文预测中心词,有点类似于完形填空。上下文的多少是个超参数,可以自己调整。

在构建数据集时,根据CBOW的特点,一般是将上下文当作输入,中心词当作标签。

训练时,首先随机初始一个矩阵 C ∈ R ∣ V ∣ × d i m C\in R^{|V|\times dim} CRV×dim,其中 ∣ V ∣ |V| V表示词汇表的大小, d i m dim dim表示单词的向量维度,自己设定。然后将上下文和中心词通过矩阵 C C C映射成向量。最后使用向量进行训练,优化矩阵 C C C。矩阵 C C C即为单词的向量矩阵。

import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as FCONTEXT_SIZE = 2
raw_text = "We are about to study the idea of a computational process. Computational processes are abstract beings that inhabit computers. As they evolve, processes manipulate other abstract things called data. The evolution of a process is directed by a pattern of rules called a program. People create programs to direct processes. In effect, we conjure the spirits of the computer with our spells.".split(' ')vocab = set(raw_text)
word_to_idx = {word: i for i, word in enumerate(vocab)}data = []
for i in range(CONTEXT_SIZE, len(raw_text)-CONTEXT_SIZE):context = [raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]]target = raw_text[i]data.append((context, target))class CBOW(nn.Module):def __init__(self, n_word, n_dim, context_size):super(CBOW, self).__init__()self.embedding = nn.Embedding(n_word, n_dim)self.linear1 = nn.Linear(2*context_size*n_dim, 128)self.linear2 = nn.Linear(128, n_word)def forward(self, x):x = self.embedding(x)x = x.view(1, -1)x = self.linear1(x)x = F.relu(x, inplace=True)x = self.linear2(x)x = F.log_softmax(x)return xmodel = CBOW(len(word_to_idx), 100, CONTEXT_SIZE)
if torch.cuda.is_available():model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)for epoch in range(40):print('epoch {}'.format(epoch))print('*'*10)running_loss = 0for word in data:context, target = wordcontext = Variable(torch.LongTensor([word_to_idx[i] for i in context]))target = Variable(torch.LongTensor([word_to_idx[target]]))if torch.cuda.is_available():context = context.cuda()target = target.cuda()out = model(context)loss = criterion(out, target)running_loss += loss.dataoptimizer.zero_grad()loss.backward()optimizer.step()print('loss: {:.6f}'.format(running_loss / len(data)))

参考文献

[1] Mikolov T, Chen K, Corrado G, et al. Efficient estimation of word representations in vector space. arXiv preprint arXiv, 2013, 1301(3781).
[2] https://www.cnblogs.com/jfdwd/p/11076977.html

CBOW(Continuous Bag-of-Words)模型原理相关推荐

  1. CBOW(Continous Bag of Words)模型学习(2020-08-19)

    CBOW(Continous Bag of Words)模型学习 大致原理看自己发的第一篇博客:https://blog.csdn.net/fuchengguo666/article/details/ ...

  2. Holt-Winters模型原理分析

    Holt-Winters模型原理分析及代码实现(python) from:https://blog.csdn.net/u010665216/article/details/78051192 引言 最近 ...

  3. Select模型原理

    Select模型原理 利用select函数,推断套接字上是否存在数据,或者是否能向一个套接字写入数据.目的是防止应用程序在套接字处于锁定模式时,调用recv(或send)从没有数据的套接字上接收数据, ...

  4. Bag of Words模型

    这几天忙里偷闲看了一些关于BOW模型的知识,虽然自己做图像检索到目前为止并没有用到过BOW模型,不过了解一下BOW并不是一件毫无意义的事情.网上关于理解BOW模型也很多,而且也很详细,再写一点关于BO ...

  5. Java开发中Netty线程模型原理解析!

    Java开发中Netty线程模型原理解析,Netty是Java领域有名的开源网络库具有高性能和高扩展性的特点,很多流行的框架都是基于它来构建.Netty 线程模型不是一成不变的,取决于用户的启动参数配 ...

  6. [zz]GMM-HMM语音识别模型 原理篇

    GMM-HMM语音识别模型 原理篇 分类: Data Structure Machine Learning Data Mining 2014-05-28 20:52 20662人阅读 评论(34) 收 ...

  7. logistic模型原理与推导过程分析(3)

    附录:迭代公式向量化 θ相关的迭代公式为: ​ 如果按照此公式操作的话,每计算一个θ需要循环m次.为此,我们需要将迭代公式进行向量化. 首先我们将样本矩阵表示如下: 将要求的θ也表示成矩阵的形式: 将 ...

  8. logistic模型原理与推导过程分析(2)

    二项逻辑回归模型 既然logistic回归把结果压缩到连续的区间(0,1),而不是离散的0或者1,然后我们可以取定一个阈值,通常以0.5为阈值,如果计算出来的概率大于0.5,则将结果归为一类(1),如 ...

  9. logistic模型原理与推导过程分析(1)

    从线性分类器谈起 给定一些数据集合,他们分别属于两个不同的类别.例如对于广告数据来说,是典型的二分类问题,一般将被点击的数据称为正样本,没被点击的数据称为负样本.现在我们要找到一个线性分类器,将这些数 ...

  10. 自然语言生成任务,如文本摘要和图像标题的生成。seq2seq的模型原理

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/Irving_zhang/article/details/78889364 </div>& ...

最新文章

  1. c++ float 转string
  2. SAP PP生产订单相关信息的获取
  3. java 算法 排序算法_Java七种排序算法以及实现
  4. 统计建模与r软件_【统计建模与R软件笔记】008 描述统计量(1)
  5. 【模型加速】TensorRT安装、测试及常见问题
  6. JavaWeb文件上传(1)--基础
  7. Docker环境调优
  8. HTML、CSS制作小米商城网页首页源码解析
  9. 一元函数微分学的概念与计算
  10. RocketMQ 实战与原理解析
  11. Oracle导入dmp文件步骤
  12. 打印程序在计算机上的应用程序,什么是“后台打印程序子系统应用程序”(spoolsv.exe),以及为什么它在我的电脑上运行?...
  13. DDNS的NAT穿越问题
  14. 中国移动互联网公司10年战争史
  15. ocv特性_锂离子电池的三大特性分析
  16. Creo 9.0 基准特征:基准坐标系
  17. 远程Debug远端服务器JVM配置
  18. cnc程序加工中心_cnc加工自动可制造性评估的可制造性设计
  19. 兄弟连区块链入门教程分享区块链POW证明代码实现demo
  20. vue项目打包部署在windows或linux服务器上

热门文章

  1. 什么是显热?什么是潜热?
  2. 大理石分割(回溯法)
  3. daocloud mysql_DaoCloud 平台更新汇总
  4. 模拟电话录音系统2.0
  5. 2.23 haas506 2.0开发教程 - KeyPad - 矩阵键盘(仅支持M320开发板)
  6. yolov5 | 移动端部署yolov5s模型
  7. 使用Qt创建一个时钟
  8. 最好的vsftpd配置教程
  9. RTK Query(RTKQ)
  10. Android 百度鹰眼 SDK