作者丨天雨粟

知乎专栏丨机器不学习

地址丨https://zhuanlan.zhihu.com/p/27296712

前言

本篇文章将利用TensorFlow来完成Skip-Gram模型。还不是很了解Skip-Gram思想的小伙伴可以先看一下上一篇的专栏内容。

本篇实战代码的目的主要是加深对Skip-Gram模型中一些思想和trick的理解。由于受限于语料规模、语料质量、算法细节以及训练成本的原因,训练出的结果显然是无法跟gensim封装的Word2Vec相比的,本代码适合新手去理解与练习Skip-Gram模型的思想。

工具介绍

  • 语言:Python 3

  • 包:TensorFlow(1.0版本)及其它数据处理包(见代码中)

  • 编辑器:jupyter notebook

  • 线上GPU:floyd

  • 数据集:经过预处理后的维基百科文章(英文)

正文部分

文章主要包括以下四个部分进行代码构造:

- 数据预处理
- 训练样本构建
- 模型构建
- 模型验证

1 数据预处理

关于导入包和加载数据在这里就不写了,比较简单,请参考git上的代码。

数据预处理部分主要包括:

  • 替换文本中特殊符号并去除低频词

  • 对文本分词

  • 构建语料

  • 单词映射表

首先我们定义一个函数来完成前两步,即对文本的清洗和分词操作。

上面的函数实现了替换标点及删除低频词操作,返回分词后的文本。

下面让我们来看看经过清洗后的数据:

有了分词后的文本,就可以构建我们的映射表,代码就不再赘述,大家应该都比较熟悉。

我们还可以看一下文本和词典的规模大小:

整个文本中单词大约为1660万的规模,词典大小为6万左右,这个规模对于训练好的词向量其实是不够的,但可以训练出一个稍微还可以的模型。

2 训练样本构建

我们知道skip-gram中,训练样本的形式是(input word, output word),其中output word是input word的上下文。为了减少模型噪音并加速训练速度,我们在构造batch之前要对样本进行采样,剔除停用词等噪音因素。

采样

在建模过程中,训练文本中会出现很多“the”、“a”之类的常用词(也叫停用词),这些词对于我们的训练会带来很多噪音。在上一篇Word2Vec中提过对样本进行抽样,剔除高频的停用词来减少模型的噪音,并加速训练。

我们采用以下公式来计算每个单词被删除的概率大小:

其中代表单词的出现频次。为一个阈值,一般介于1e-3到1e-5之间。

上面的代码计算了样本中每个单词被删除的概率,并基于概率进行了采样,现在我们手里就拿到了采样过的单词列表。

构造batch

我们先来分析一下skip-gram的样本格式。skip-gram不同于CBOW,CBOW是基于上下文预测当前input word。而skip-gram则是基于一个input word来预测上下文,因此一个input word会对应多个上下文。我们来举个栗子“The quick brown fox jumps over lazy dog”,如果我们固定skip_window=2的话,那么fox的上下文就是[quick, brown, jumps, over],如果我们的batch_size=1的话,那么实际上一个batch中有四个训练样本。

上面的分析转换为代码就是两个步骤,第一个是找到每个input word的上下文,第二个就是基于上下文构建batch。

首先是找到input word的上下文单词列表:

我们定义了一个get_targets函数,接收一个单词索引号,基于这个索引号去查找单词表中对应的上下文(默认window_size=5)。请注意这里有一个小trick,我在实际选择input word上下文时,使用的窗口大小是一个介于[1, window_size]区间的随机数。这里的目的是让模型更多地去关注离input word更近词。

我们有了上面的函数后,就能够轻松地通过input word找到它的上下文单词。有了这些单词我们就可以构建我们的batch来进行训练:

注意上面的代码对batch的处理。我们知道对于每个input word来说,有多个output word(上下文)。例如我们的输入是“fox”,上下文是[quick, brown, jumps, over],那么fox这一个batch中就有四个训练样本[fox, quick], [fox, brown], [fox, jumps], [fox, over]。

3 模型构建

数据预处理结束后,就需要来构建我们的模型。在模型中为了加速训练并提高词向量的质量,我们采用负采样方式进行权重更新。

输入层到嵌入层

输入层到隐层的权重矩阵作为嵌入层要给定其维度,一般embeding_size设置为50-300之间。

嵌入层的lookup通过TensorFlow中的embedding_lookup实现。

嵌入层到输出层

在skip-gram中,每个input word的多个上下文单词实际上是共享一个权重矩阵,我们将每个(input word, output word)训练样本来作为我们的输入。为了加速训练并且提高词向量的质量,我们采用negative sampling的方法来进行权重更新。

TensorFlow中的sampled_softmax_loss,由于进行了negative sampling,所以实际上我们会低估模型的训练loss。

请注意代码中的softmax_w的维度是vocab_size x embedding_size,这是因为TensorFlow中的sampled_softmax_loss中参数weights的size是[num_classes, dim]。

4 模型验证

在上面的步骤中,我们已经将模型的框架搭建出来,下面就让我们来训练训练一下模型。为了能够更加直观地观察训练每个阶段的情况。我们来挑选几个词,看看在训练过程中它们的相似词是怎么变化的。

训练模型:

在这里注意一下,尽量不要经常去让代码打印验证集相似的词,因为这里会多了一步计算步骤,就是计算相似度,会非常消耗计算资源,计算过程也很慢。所以代码中我设置1000轮打印一次结果。

从最后的训练结果来看,模型还是学到了一些常见词的语义,比如one等计数词以及gold之类的金属词,animals中的相似词也相对准确。

为了能够更全面地观察我们训练结果,我们采用sklearn中的TSNE来对高维词向量进行可视化。(具体代码见git)

上面的图中通过TSNE将高维的词向量按照距离远近显示在二维坐标系中,该图已经在git库中,想看原图的小伙伴去git看~

我们来看一下细节:

上面是显示了整张大图的局部区域,可以看到效果还不错。

关于提升效果的技巧:

增大训练样本,语料库越大,模型学习的可学习的信息会越多。

增加window size,可以获得更多的上下文信息。

增加embedding size可以减少信息的维度损失,但也不宜过大,我一般常用的规模为50-300。

附录:

git代码中还提供了中文的词向量计算代码。同时提供了中文的一个训练语料,语料是我从某招聘网站上爬取的招聘数据,做了分词和去除停用词的操作(可从git获取),但语料规模太小,训练效果并不好。

上面是我用模型训练的中文数据,可以看到有一部分语义被挖掘出来,比如word和excel、office很接近,ppt和project、文字处理等,以及逻辑思维与语言表达等,但整体上效果还是很差。一方面是由于语料的规模太小(只有70兆的语料),另一方面是模型也没有去调参。如果有兴趣的同学可以自己试下会不会有更好的效果。

备注:公众号菜单包含了整理了一本AI小抄非常适合在通勤路上用学习

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习在线手册深度学习在线手册AI基础下载(pdf更新到25集)备注:加入本站微信群或者qq群,请回复“加群”获取一折本站知识星球优惠券,请回复“知识星球”

喜欢文章,点个在看

NLP深度学习:基于TensorFlow实现Skip-Gram模型相关推荐

  1. [深度学习]-基于tensorflow的CNN和RNN-LSTM文本情感分析对比

    基于tensorflow的CNN和LSTM文本情感分析对比 1. 背景介绍 2. 数据集介绍 2.0 wordsList.npy 2.1 wordVectors.npy 2.2 idsMatrix.n ...

  2. python做神经网络有什么框架_神经网络与深度学习——基于TensorFlow框架和Python技术实现...

    目 录 第1章 绪论1 1.1 人工智能2 1.2 机器学习3 1.2.1 监督学习3 1.2.2 非监督学习3 1.2.3 半监督学习4 1.3 深度学习4 1.3.1 卷积神经网络4 1.3.2 ...

  3. TensorFlow 2.X中的动手NLP深度学习模型准备

    简介:为什么我写这篇文章 (Intro: why I wrote this post) Many state-of-the-art results in NLP problems are achiev ...

  4. 使用Pytorch实现NLP深度学习

    原文链接:https://pytorch.org/tutorials/beginner/deep_learning_nlp_tutorial.html 本文将会帮助你了解使用Pytorch进行深度学习 ...

  5. 视频教程-深度学习与TensorFlow 2入门实战-深度学习

    深度学习与TensorFlow 2入门实战 新加坡国立大学研究员 龙良曲 ¥399.00 立即订阅 扫码下载「CSDN程序员学院APP」,1000+技术好课免费看 APP订阅课程,领取优惠,最少立减5 ...

  6. 深度学习必备书籍——《Python深度学习 基于Pytorch》

    作为一名机器学习|深度学习的博主,想和大家分享几本深度学习的书籍,让大家更快的入手深度学习,成为AI达人!今天给大家介绍的是:<Python深度学习 基于Pytorch> 文章目录 一.背 ...

  7. 深度学习与TensorFlow

    深度学习与TensorFlow DNN(深度神经网络算法)现在是AI社区的流行词.最近,DNN 在许多数据科学竞赛/Kaggle 竞赛中获得了多次冠军. 自从 1962 年 Rosenblat 提出感 ...

  8. 深度学习调用TensorFlow、PyTorch等框架

    深度学习调用TensorFlow.PyTorch等框架 一.开发目标目标 提供统一接口的库,它可以从C++和Python中的多个框架中运行深度学习模型.欧米诺使研究人员能够在自己选择的框架内轻松建立模 ...

  9. NLP深度学习:近期趋势概述

    NLP&深度学习:近期趋势概述 https://www.cnblogs.com/DicksonJYL/p/9686204.html 摘要:当NLP遇上深度学习,到底发生了什么样的变化呢? 在最 ...

  10. NLP深度学习:近期趋势概述 1

    摘要:当NLP遇上深度学习,到底发生了什么样的变化呢? 在最近发表的论文中,Young及其同事汇总了基于深度学习的自然语言处理(NLP)系统和应用程序的一些最新趋势.本文的重点介绍是对各种NLP任务( ...

最新文章

  1. Python编程比较好的机构怎么选择
  2. Linux下Apache+Tomcat 负载均衡
  3. 设计模式3:装饰模式
  4. Microsoft Windows Workflow Foundation 入门:开发人员演练
  5. 电脑无线网络与服务器共享,图文详解win7笔记本如何实现内置无线局域网卡共享...
  6. (STL,set,priority_queue)丑数
  7. 能使Oracle索引失效的六大限制条件
  8. Flutter中Contrainer 组件的宽高限制分析
  9. C++编程基础二 13-函数与string对象
  10. android怎样判断插入数据是否成功_Android 端 V1/V2/V3 签名的原理
  11. 如何在GO语言中使用Kubernetes API?
  12. linux文件描述符、软硬连接、输入输出重定向
  13. Java 服务器端支持断点续传的源代码
  14. VC++中如何让RadioButton分组,并且互斥
  15. Atititi ui之道 attilax著 v3 s11.docx 1. 概览 2 1.1. 软件设计可分为两个部分:编码设计与UI设计 2 2. 用户界面设计的三大原则是:置界面于用户的控制之下;
  16. 系统安全中主要风险有哪些,弱密码怎么解决?
  17. fw300r虚拟服务器设置,迅捷(FAST)FW300R无线路由器怎么设置
  18. 创建一个三维空间形状,算立方体,球体,正三棱锥表面积体积
  19. 【模拟】桐桐的新闻系统
  20. 如何平衡工作与生活?真相在此

热门文章

  1. python面向编程:类继承、继承案例、单继承下属性查找、super方法
  2. 找出1000以内的所有完数。
  3. Yarn框架和工作流程研究
  4. 查询反模式 - 隐式的列
  5. iOS音频的后台播放 锁屏
  6. 不同的jar里边相同的包名类名怎么区别导入
  7. WCF热带鱼书学习手记 - Service Contract Overload
  8. 代谢组学在疾病诊断如何应用?
  9. 2021-04-09 linux的shell脚本简单教程
  10. 共轭梯度法求解线性方程组