源码地址: https://github.com/AlbertBJ/word2vecpy.git

这也是 我 fork别人的,觉得写得很棒,所以拜读了大神的代码,先对 关键点 进行说明:

主要是 针对 train_process这个方法中 针对 负采样 计算方法:

# Randomize window size, where win is the max window size 
            # 下面4行代码,主要是 获得 目标词 的上下文词(滑动窗口大小为win,即获取 目标词 的 左右各 win各词)            
            current_win = np.random.randint(low=1, high=win+1)# 主要是利用随机的思想,每次都产生 上下文词 的 数量 为[1,win]
            context_start = max(sent_pos - current_win, 0)# 这一步骤 主要是 针对 刚开始 的目标词 左边(以字典 索引,小于目标词索引的为左边,大于的                                             # 为右边)不足以 产生 current_win个 上下文,即 当不足时,上下文索引 从 0开始计算
            context_end = min(sent_pos + current_win + 1, len(sent)) # 这一步 和上一步 目的一致,主要是 针对 目标词 右侧 不足以 产生current_win个上                                                    #  下文 
            context = sent[context_start:sent_pos] + sent[sent_pos+1:context_end] # Turn into an iterator? 这一步 主要是 产生 上下文列表

# CBOW
            if cbow:
                # Compute neu1
                neu1 = np.mean(np.array([syn0[c] for c in context]), axis=0) # 获得词嵌入向量,在此处 体现 cbow和skip-gram不同点,cbow是用周边上下文词求平均后,
                                                         # 再进行 与 目标词的 dot(此时的 目标词 包括 正样本以及负样本)                
                assert len(neu1) == dim, 'neu1 and dim do not agree'

# Init neu1e with zeros
                neu1e = np.zeros(dim)

# Compute neu1e and update syn1 syn1为辅助向量
                if neg > 0:
                    classifiers = [(token, 1)] + [(target, 0) for target in table.sample(neg)]
                else:
                    classifiers = zip(vocab[token].path, vocab[token].code)
                for target, label in classifiers:
                    z = np.dot(neu1, syn1[target]) # 利用平均后的上下文词 词向量 与 每一个 目标词 进行 dot ,syn1存储的是每个词的模型参数
                    p = sigmoid(z)
                    g = alpha * (label - p) # 计算二分类的梯度(z的梯度是label-p,具体推导可以看我的 一篇关于bp的博文)
                    neu1e += g * syn1[target] #  此处使用梯度上升方法,目的求得 概率最大化(g*syn1[target],更新embedding)
                    syn1[target] += g * neu1  # 利用梯度上升 更新 模型参数(g * neu1更新参数矩阵)

# Update syn0
                for context_word in context:    #更新 每一个 上下文对应 的词向量矩阵
                    syn0[context_word] += neu1e # 利用梯度 上升更新 词嵌入矩阵

# Skip-gram
            else:
                for context_word in context:  # 循环上下文 词的个数 
                    # Init neu1e with zeros
                    neu1e = np.zeros(dim)

# Compute neu1e and update syn1
                    if neg > 0:
                        classifiers = [(token, 1)] + [(target, 0) for target in table.sample(neg)]
                    else:
                        classifiers = zip(vocab[token].path, vocab[token].code)
                    for target, label in classifiers:
                        z = np.dot(syn0[context_word], syn1[target])
                        p = sigmoid(z)
                        g = alpha * (label - p)
                        neu1e += g * syn1[target]              # Error to backpropagate to syn0
                        syn1[target] += g * syn0[context_word] # Update syn1   使用 上下文更新 syn1

# Update syn0
                    syn0[context_word] += neu1e

知乎: https://zhuanlan.zhihu.com/albertwang

微信公众号:AI-Research-Studio

​​

cbow和skip-gram实现关键代码解析相关推荐

  1. Rasa课程、Rasa培训、Rasa面试、Rasa实战系列之Understanding Word Embeddings CBOW and Skip Gram

    Rasa课程.Rasa培训.Rasa面试.Rasa实战系列之Understanding Word Embeddings CBOW and Skip Gram 字嵌入 从第i个字符,第i+1个字符预测第 ...

  2. 《自然语言处理学习之路》02 词向量模型Word2Vec,CBOW,Skip Gram

    本文主要是学习参考莫烦老师的教学,对老师课程的学习,记忆笔记. 原文链接 文章目录 书山有路勤为径,学海无涯苦作舟. 零.吃水不忘挖井人 一.计算机如何实现对于词语的理解 1.1 万物数字化 1.2 ...

  3. Adaptive Personalized Federated Learning 论文解读+代码解析

    论文地址点这里 一. 介绍 联邦学习强调确保本地隐私情况下,对多个客户端进行训练,客户端之间不交换数据而交换参数来进行通信.目的是聚合成一个全局的模型,使得这个模型再各个客户端上读能取得较好的成果.联 ...

  4. Wormhole资产跨链项目代码解析

    1. 引言 Wormhole支持基于Solana与多个链进行资产转移,开源代码为: https://github.com/certusone/wormhole/tree/main 实际部署配置信息参见 ...

  5. 视觉SLAM开源算法ORB-SLAM3 原理与代码解析

    来源:深蓝学院,文稿整理者:何常鑫,审核&修改:刘国庆 本文总结于上交感知与导航研究所科研助理--刘国庆关于[视觉SLAM开源算法ORB-SLAM3 原理与代码解析]的公开课. ORB-SLA ...

  6. ViBe算法原理和代码解析

    ViBe - a powerful technique for background detection and subtraction in video sequences 算法官网:http:// ...

  7. (需求实战_进阶_02)SSM集成RabbitMQ 关键代码讲解、开发、测试

    接上一篇:(企业内部需求实战_进阶_01)SSM集成RabbitMQ 关键代码讲解.开发.测试 https://gblfy.blog.csdn.net/article/details/10419730 ...

  8. java分页代码思路,记录--java 分页 思路 (hibernate关键代码)

    有时会脑袋蒙圈,记录下分页的思路 下面代码是hibernate的分页,其分页就是从第几条数据为起点,取几条数据.比如在mysql中的limit(5,10)取的就是第6条到第10条 在下面代码中的pag ...

  9. PX4代码解析(5)

    一.前言 我所讨论的PX4代码是基于v1.11版本,该版本与之前的版本有不少不同,其中一个比较大的区别在于新版本大部分用到了C++中模板,使得代码没有以前那么容易理解,因此我在后面介绍PX4的姿态估计 ...

最新文章

  1. TReader高速文本浏览器 1.0 发布
  2. SqlDataReader.GetSchemaTable
  3. AC日记——[SDOI2010]大陆争霸 洛谷 P3690
  4. 周志华《机器学习》课后习题(第五章):神经网络
  5. PHP Class中public,private,protected,static的区别
  6. linux sql 语句菜鸟,Linux安装mysql
  7. 贾跃亭致信债权人:将努力打工还债,请相信我!
  8. Wampserver_开启CURL
  9. 教你做一个最简版的倒计时计时器,新手也能秒懂。
  10. (零基础)如何使用python下载哔哩哔哩视频?
  11. [附源码]java毕业设计基于的高校学生考勤管理系统
  12. vue中搜索功能如何请求数据接口来实现关键字查询
  13. linux ubuntu系统忘记root密码的解决办法
  14. 霍尔开关传感器的选型
  15. 微软Windows Phone卷土归来
  16. 计算机芯片的主要用途,汽车电脑芯片30343的主要作用是什么??
  17. VPython三维仿真(NO.1) VPython安装与开发环境
  18. 除权除息和复权复息的内容总结
  19. 新手(小白)如何使用阿里云服务器搭建FTP服务?
  20. Linux系统上如何禁用 USB 存储

热门文章

  1. Kooboo完全介绍二:创建第一个Kooboo站点
  2. linux yum卸载安装记录
  3. awvs 超详细中文手册
  4. 重点:机器学习总结之各算法常用包和函数
  5. 如何在网页中实现快速地粘贴复制代码
  6. [jzoj 4230] 淬炼神体{ 0/1分数规划+二分答案}
  7. MySQL的锁与锁机制----锁分类
  8. Windows桌面壁纸
  9. 条条大路通罗马:mPaaS 新增体验入口
  10. Gremlin学习--图的汇总与分组统计计算