spark word2vec 源码详细解析

  • 简单介绍spark word2vec
    • skip-gram 层次softmax版本的源码解析
    • word2vec 的原理 只需要看层次哈弗曼树skip-gram那部分
    • skip-gram negetive sample 的版本源码解析:

简单介绍spark word2vec

Word2Vec creates vector representation of words in a text corpus.
The algorithm first constructs a vocabulary from the corpus and then learns vector representation of words in the vocabulary.
The vector representation can be used as features in natural language processing and machine learning algorithms.
We used skip-gram model in our implementation and hierarchical softmax method to train the model. The variable names in the implementation matches the original C implementation.
For original C implementation, see https://code.google.com/p/word2vec/
For research papers, see
Efficient Estimation of Word Representations in Vector Space paper1
and
Distributed Representations of Words and Phrases and their Compositionality. paper2
word2vec算法创建了关于语料库里面词的代表————词向量。
该算法首先从语料库构建词汇表,然后学习词汇表中单词的向量表示。 向量表示可用作自然语言处理和机器学习算法中的特征。 sparkMLLIB只实现了skip-gram模型,并使用分层softmax方法来训练模型。spark的代码实现参考原始word2vecC语言代码一致。原始C语言实现见:https://code.google.com/p/word2vec/。相关研究论文见:Efficient Estimation of Word Representations in Vector Space和Distributed Representations of Words and Phrases and their Compositionality。

skip-gram 层次softmax版本的源码解析

package org.apache.spark.mllib.feature
import java.lang.{Iterable => JavaIterable}
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
/**
*  Entry in vocabulary   定义词典的属性类   复习:scala的class类别
*/
private case class VocabWord(var word: String,   //词var cn: Int,        //频次var point: Array[Int],   // ARRAY 存的是这个词[叶子结点]的从根节点到叶子节点的路径经过的节点var code: Array[Int],   //记录Huffman编码var codeLen: Int        //code长度,路径长度 ,存储到达该叶子结点,要经过多少个结点
)
本文只实现skip-gram hierarchical softmax 部分,参照C语言实现的代码:https://code.google.com/p/word2vec/
参照两篇论文:Efficient Estimation of Word Representations in Vector Space & Distributed Representations of Words and Phrases and their Compositionality
@Since("1.1.0")
class Word2Vec extends Serializable with Logging {
//默认参数private var vectorSize = 100  //训练vector的长度private var learningRate = 0.025  //训练时的学习率private var numPartitions = 1   //分区数private var numIterations = 1   //迭代次数private var seed = Utils.random.nextLong()  //随机种子private var minCount = 5   //词的最小出现频次private var maxSentenceLength = 1000  //句子的长度//如果大于maxSentenceLength 句子的长度,将会截断为多个块。/*** Sets the maximum length (in words) of each sentence in the input data.* Any sentence longer than this threshold will be divided into chunks of* up to `maxSentenceLength` size (default: 1000)*/@Since("2.0.0")def setMaxSentenceLength(maxSentenceLength: Int): this.type = {require(maxSentenceLength > 0,s"Maximum length of sentences must be positive but got ${maxSentenceLength}")this.maxSentenceLength = maxSentenceLengththis}/*** Sets vector size (default: 100).*/@Since("1.1.0")def setVectorSize(vectorSize: Int): this.type = {require(vectorSize > 0,s"vector size must be positive but got ${vectorSize}")this.vectorSize = vectorSizethis}/*** Sets initial learning rate (default: 0.025).*/@Since("1.1.0")def setLearningRate(learningRate: Double): this.type = {require(learningRate > 0,s"Initial learning rate must be positive but got ${learningRate}")this.learningRate = learningRatethis}/*** Sets number of partitions (default: 1). Use a small number for accuracy. * //设置少数分区有利于准确性*/@Since("1.1.0")def setNumPartitions(numPartitions: Int): this.type = {require(numPartitions > 0,s"Number of partitions must be positive but got ${numPartitions}")this.numPartitions = numPartitionsthis}/*** Sets number of iterations (default: 1), which should be smaller than or equal to number of * partitions. *  //设置迭代次数,要小于或者等于分区数*/@Since("1.1.0")def setNumIterations(numIterations: Int): this.type = {require(numIterations >= 0,s"Number of iterations must be nonnegative but got ${numIterations}")this.numIterations = numIterationsthis}/*** Sets random seed (default: a random long integer).*/@Since("1.1.0")def setSeed(seed: Long): this.type = {this.seed = seedthis}/*** Sets the window of words (default: 5) * //根据单个文本的长度合理设置,目前针对于标题40个字,设置为5*/@Since("1.6.0")def setWindowSize(window: Int): this.type = {require(window > 0,s"Window of words must be positive but got ${window}")this.window = windowthis}/*** Sets minCount, the minimum number of times a token must appear to be included in the word2vec* model's vocabulary (default: 5).* //根据文本的词的频次分布设置,保证覆盖大多数的文本。*/@Since("1.3.0")def setMinCount(minCount: Int): this.type = {require(minCount >= 0,s"Minimum number of times must be nonnegative but got ${minCount}")this.minCount = minCountthis}private val EXP_TABLE_SIZE = 1000private val MAX_EXP = 6private val MAX_CODE_LENGTH = 40/** context words from [-window, window] */  //滑动窗口以中心词的左右各+-window选词。private var window = 5private var trainWordsCount = 0Lprivate var vocabSize = 0
*********transient 解释:
我们都知道一个对象只要实现了Serilizable接口,这个对象就可以被序列化,java的这种序列化模式为开发者提供了很多便利,我们可以不必关系具体序列化的过程,只要这个类实现了Serilizable接口,这个类的所有属性和方法都会自动序列化。
然而在实际开发过程中,我们常常会遇到这样的问题,这个类的有些属性需要序列化,而其他属性不需要被序列化,打个比方,如果一个用户有一些敏感信息(如密码,银行卡号等),为了安全起见,不希望在网络操作(主要涉及到序列化操作,本地序列化缓存也适用)中被传输,这些信息对应的变量就可以加上transient关键字。换句话说,这个字段的生命周期仅存于调用者的内存中而不会写到磁盘里持久化。
总之,java的transient关键字为我们提供了便利,你只需要实现Serilizable接口,将不需要序列化的属性前添加关键字transient,序列化对象的时候,这个属性就不会序列化到指定的目的地中。
*********transient 解释:@transient private var vocab: Array[VocabWord] = null@transient private var vocabHash = mutable.HashMap.empty[String, Int]********************************************************************************************************************
from :org.apache.spark.ml.feature.Word2Vec#fit
override def fit(dataset: Dataset[_]): Word2VecModel = {transformSchema(dataset.schema, logging = true)val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))val wordVectors = new feature.Word2Vec().setLearningRate($(stepSize)).setMinCount($(minCount)).setNumIterations($(maxIter)).setNumPartitions($(numPartitions)).setSeed($(seed)).setVectorSize($(vectorSize)).setWindowSize($(windowSize)).setMaxSentenceLength($(maxSentenceLength)).fit(input)copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
}
*********************************************************************************************************************
//dataset来自上面的input,里面是:Seq[String]private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = {  //构建每个词的类val words = dataset.flatMap(x => x)  //把所有的词压平,统计词频vocab = words.map(w => (w, 1)).reduceByKey(_ + _).filter(_._2 >= minCount)  //过滤词频大于minCount的词.map(x => VocabWord(x._1,x._2,new Array[Int](MAX_CODE_LENGTH),new Array[Int](MAX_CODE_LENGTH),0)).collect().sortWith((a, b) => a.cn > b.cn)  //按频数从大到小排序vocabSize = vocab.lengthrequire(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +"the setting of minCount, which could be large enough to remove all your words in sentences.")var a = 0while (a < vocabSize) {vocabHash += vocab(a).word -> a   //@transient private var vocabHash = mutable.HashMap.empty[String, Int],【词,词频】  生成hashMap(K:word,V:a)--> 对词典中所有元素进行映射,方便查找trainWordsCount += vocab(a).cn    //训练词的个数统计a += 1}logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")}//创建sigmoid函数查询表private def createExpTable(): Array[Float] = {val expTable = new Array[Float](EXP_TABLE_SIZE)var i = 0while (i < EXP_TABLE_SIZE) {val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)expTable(i) = (tmp / (tmp + 1.0)).toFloati += 1}expTable}//构造哈夫曼树private def createBinaryTree(): Unit = {val count = new Array[Long](vocabSize * 2 + 1)  //二叉树中所有的结点val binary = new Array[Int](vocabSize * 2 + 1)  //设置每个结点的Huffman编码:左1,右0val parentNode = new Array[Int](vocabSize * 2 + 1)  //存储每个结点的父节点val code = new Array[Int](MAX_CODE_LENGTH)  //存储每个叶子结点的Huffman编码val point = new Array[Int](MAX_CODE_LENGTH)  //存储每个叶子结点的路径(经历过哪些结点)var a = 0while (a < vocabSize) { //节点 0~vocabSize-1  赋值为该节点词的频次  左边都是叶子结点count(a) = vocab(a).cna += 1}while (a < 2 * vocabSize) {  //节点 vocabSize~2*vocabSize-1  赋值为1e9  右边都是父节点count(a) = 1e9.toInta += 1}var pos1 = vocabSize - 1var pos2 = vocabSize//min1i和min2i是左右节点var min1i = 0var min2i = 0a = 0while (a < vocabSize - 1) {if (pos1 >= 0) {if (count(pos1) < count(pos2)) {min1i = pos1pos1 -= 1} else {min1i = pos2pos2 += 1}} else {min1i = pos2pos2 += 1}if (pos1 >= 0) {if (count(pos1) < count(pos2)) {min2i = pos1pos1 -= 1} else {min2i = pos2pos2 += 1}} else {min2i = pos2pos2 += 1}count(vocabSize + a) = count(min1i) + count(min2i)   //从三个点里面找到和最小的两个点parentNode(min1i) = vocabSize + a    //父节点parentNode(min2i) = vocabSize + a    //父节点binary(min2i) = 1          //定义右子树为1a += 1}// Now assign binary code to each vocabulary wordvar i = 0a = 0while (a < vocabSize) {var b = ai = 0while (b != vocabSize * 2 - 2) {  //哈弗曼树一共有2n-1个节点,所以vocabSize*2-2指的是根节点,遍历a二叉树路径上的每个节点,除了根节点code(i) = binary(b)         //第b个结点的Huffman编码是0 or 1point(i) = b                //存储路径,经过b结点i += 1b = parentNode(b)          //按照路径去找下一个节点,遍历b的下个节点}vocab(a).codeLen = i         //存储到达叶子结点a,要经过多少个结点vocab(a).point(0) = vocabSize - 2 //每个词的point(0)都是一样的为vocabSize-2,这个是根节点,在这里哈弗曼树已经建立完成了,point记录的是叶子结点a的从根节点以来的路径,因为哈弗曼树所有词的节点是叶子结点,从根节点到叶子节点上的路径都是中间节点如图一所示的,路径里面的节点都减了vocabSize,因为中间节点是vocabSize-1个,所以又都放在0到vocabSize-1的范围了。b = 0while (b < i) {        //遍历a二叉树路径上的每个节点vocab(a).code(i - b - 1) = code(b)   //根据上一步的结果,对节点a的哈夫曼编码赋值vocab(a).point(i - b) = point(b) - vocabSize  //根据上一步的结果,对节点a的路径节点进行赋值b += 1}a += 1    //下一个词}}/*** Computes the vector representation of each word in vocabulary.* @param dataset an RDD of sentences,*                each sentence is expressed as an iterable collection of words* @return a Word2VecModel*/@Since("1.1.0")def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {learnVocab(dataset)         //构建词汇类createBinaryTree()          //构建 Huffman 树val sc = dataset.contextval expTable = sc.broadcast(createExpTable())   //广播sigmod查询表val bcVocab = sc.broadcast(vocab)               //广播词汇类val bcVocabHash = sc.broadcast(vocabHash)       //广播词 词索引try {doFit(dataset, sc, expTable, bcVocab, bcVocabHash)  } finally {expTable.destroy(blocking = false)   //销毁广播变量bcVocab.destroy(blocking = false)bcVocabHash.destroy(blocking = false)}}private def doFit[S <: Iterable[String]](dataset: RDD[S], sc: SparkContext,expTable: Broadcast[Array[Float]],bcVocab: Broadcast[Array[VocabWord]],bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {// each partition is a collection of sentences,// will be translated into arrays of Index integerval sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>  //RDD[S] S为RDD里面最细粒度的数据结构,里面装的就是这个类型S的数据// Each sentence will map to 0 or more Array[Int]sentenceIter.flatMap { sentence =>// Sentence of words, some of which map to a word indexval wordIndexes = sentence.flatMap(bcVocabHash.value.get) // flatMap对句子中每个词得到index,得到每个句子每个词的index// break wordIndexes into trunks of maxSentenceLength when has morewordIndexes.grouped(maxSentenceLength).map(_.toArray)  //如果有的句子的长度大于1000,就给它分组为1000单位,并是array | wordIndexes是个Iterable[Int]格式利用grouped函数对其分组。wordIndexes.grouped(maxSentenceLength)返回的是:Iterator[Array[Int]]}}//val newSentences = sentences.repartition(numPartitions).cache()   //按照给定的分区数,进行重分区  并且全部cache//可以改为:
//todo 更改存储方式val newSentences = sentences.repartition(numPartitions).cache()//todo 对sentence进行checkpointnewSentences.sparkContext.setCheckpointDir("hdfs://ns4/user/dd_edw/tmp.db/item_relationship/item_embedding/graph_embedding_rdwalk")newSentences.checkpoint()newSentences.countbcVocabHash.destroy(blocking = false) //TODO 用完了 需要进行释放 销毁//    val newSentences = sentences.repartition(numPartitions).persist(StorageLevel.MEMORY_AND_DISK_SER)
//    val newSentences = sentences.repartition(numPartitions).persist(StorageLevel.DISK_ONLY)val initRandom = new XORShiftRandom(seed)                         //if (vocabSize.toLong * vectorSize >= Int.MaxValue) {   //如果词汇量*词向量长度 大于或等于 INT最大值 就抛出异常throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" +" to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " +"which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.")}val syn0Global =Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)   //初始化叶子节点,分词向量随机设置初始值val syn1Global = new Array[Float](vocabSize * vectorSize)                                   //初始化非叶子结点,参数向量设置初始值为0val totalWordsCounts = numIterations * trainWordsCount + 1                                  //迭代次数*所有分词的个数 +1var alpha = learningRate                                                                    //学习率for (k <- 1 to numIterations) {   //开始迭代val bcSyn0Global = sc.broadcast(syn0Global)     val bcSyn1Global = sc.broadcast(syn1Global)val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount //已经迭代过的词数val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))val syn0Modify = new Array[Int](vocabSize)val syn1Modify = new Array[Int](vocabSize)/**def foldLeft[B](z: B)(op: (B, A) => B): B = {var result = zthis foreach (x => result = op(result, x))result}*/val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) { //{}里面是OP[具体操作],初始值是(bcSyn0Global.value, bcSyn1Global.value, 0L, 0L),然后在每个分区里面串行运行,x是case ((syn0, syn1, lastWordCount, wordCount), sentence),最终结果是:(syn0, syn1, lwc, wc) 和Z同种类型。最后的结果(syn0, syn1, lwc, wc)总是更新存在的。总是赋值给B类型。最后结果也是B,B就是(syn0, syn1, lwc, wc)类型的数据。iter每迭代一次sentence就会更新一次Bcase ((syn0, syn1, lastWordCount, wordCount), sentence) =>  //每个分区里面的每个sentencevar lwc = lastWordCount  //每次迭代的最新的var wc = wordCountif (wordCount - lastWordCount > 10000) { //当句子迭代10000个词的时候。每迭代10000词的时候就更新一下alphalwc = wordCount   //更改上次词数alpha = learningRate *(1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) /totalWordsCounts)   //随着wordCount变大,alpha变小if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001  //当小于learningRate * 0.0001时候,直接等于learningRate * 0.0001logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " +s"alpha = $alpha")}wc += sentence.length //wc就是上次的wordCount,一直累加句子的长度。var pos = 0while (pos < sentence.length) {  //开始迭代,一个sentence中的pos位置的词,pos从0开始val word = sentence(pos)val b = random.nextInt(window)  //b是window内的随机数// Train Skip-gramvar a = bwhile (a < window * 2 + 1 - b) {  //因为开始a = b ,从b开始到 window * 2 + 1 - b,也就是取pos词左右window - b 个词,迭代pos附近的窗口:window - bif (a != window) { //当a不是中心词val c = pos - window + a   //pos位置的词pos-(window - a)[真实位置]if (c >= 0 && c < sentence.length) {  //pos的左右位置迭代取值可能是负的或者超出句子长度,限定范围val lastWord = sentence(c)    //该词的indexval l1 = lastWord * vectorSize  //syn0的indexval neu1e = new Array[Float](vectorSize) //相当于公式里面的e,就是x的梯度迭代项// Hierarchical softmaxvar d = 0while (d < bcVocab.value(word).codeLen) {  //迭代中心词的路径哈夫曼二分类val inner = bcVocab.value(word).point(d)  //路径上节点indexval l2 = inner * vectorSize               //syn1对应的index// Propagate hidden -> output    blas.sdot函数解释:sdot(int n, float[] sx, int _sx_offset, int incx, float[] sy, int _sy_offset, int incy),结果是:sx .* sy,并且sx[_sx_offset,incx*n + _sx_offset],sy[_sy_offset,incy*n + _sy_offset]var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)   //向量点乘,syn0 .* syn1 , syn0[l1,l1+1*vectorSize],syn1[l2,l2+1*vectorSize]if (f > -MAX_EXP && f < MAX_EXP) {                    val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toIntf = expTable.value(ind)   //索引到sigmod函数表的值val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat   //梯度blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)   //neu1e = g * syn1 + neu1e  blas.saxpy函数解释:saxpy(int n, float sa, float[] sx, int _sx_offset, int incx, float[] sy, int _sy_offset, int incy),结果是:sy= sa*sx+sy,并且sx[_sx_offset,_sx_offset+incx*n],sy[_sy_offset,_sy_offset+incy*n]blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)   //syn1 = g * syn0 + syn1syn1Modify(inner) += 1          //记录参数向量里面的点被更新次数}d += 1}blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)  //syn0 = 1.0f * neu1e + syn0   a的情况下,迭代完成中心词pos附近的一个词的参数向量和词向量syn0Modify(lastWord) += 1}}a += 1}pos += 1   //循环到这个句子的下一个中心词}(syn0, syn1, lwc, wc)}val syn0Local = model._1   //syn0 为叶子结点向量,即分词向量val syn1Local = model._2   //syn1 为非叶子结点向量,即参数向量// Only output modified vectors.   Iterator.tabulate函数: Creates an iterator producing the values of a given function over a range of integer values starting from 0.Iterator.tabulate(vocabSize) { index =>if (syn0Modify(index) > 0) {Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))} else {None}}.flatten ++ Iterator.tabulate(vocabSize) { index =>if (syn1Modify(index) > 0) {Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))} else {None}}.flatten    //得到n个词向量的结果,n-1个中间节点的向量结果,两个结果(index,array)拼接起来,并且中间参数节点向量的index 从vocabSize开始编号}val synAgg = partial.reduceByKey { case (v1, v2) =>   //注意partial是所有分区内部的结果,按照同样的index下的array进行聚合,直接把所有分区的结果暴力累加blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)v1}.collect()var i = 0while (i < synAgg.length) {  //分别得到分词向量和参数向量val index = synAgg(i)._1if (index < vocabSize) {Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)} else {Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)}i += 1}bcSyn0Global.destroy(false)bcSyn1Global.destroy(false)}newSentences.unpersist()val wordArray = vocab.map(_.word)new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)  //得到分词向量}/*** Computes the vector representation of each word in vocabulary (Java version).* @param dataset a JavaRDD of words* @return a Word2VecModel*/@Since("1.1.0")def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {fit(dataset.rdd.map(_.asScala))}
}/**
* Word2Vec model
* @param wordIndex maps each word to an index, which can retrieve the corresponding
*                  vector from wordVectors
* @param wordVectors array of length numWords * vectorSize, vector corresponding
*                    to the word mapped with index i can be retrieved by the slice
*                    (i * vectorSize, i * vectorSize + vectorSize)
*/
@Since("1.1.0")
class Word2VecModel private[spark] (private[spark] val wordIndex: Map[String, Int],private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable {private val numWords = wordIndex.size// vectorSize: Dimension of each word's vector.private val vectorSize = wordVectors.length / numWords// wordList: Ordered list of words obtained from wordIndex.private val wordList: Array[String] = {val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzipwl.toArray}// wordVecNorms: Array of length numWords, each value being the Euclidean norm//               of the wordVector.   长度为numWords的数组,每个值都是wordVector的欧几里得范数。private val wordVecNorms: Array[Float] = {val wordVecNorms = new Array[Float](numWords)var i = 0while (i < numWords) {val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)  i += 1}wordVecNorms}@Since("1.5.0")def this(model: Map[String, Array[Float]]) = {this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))}override protected def formatVersion = "1.0"@Since("1.4.0")def save(sc: SparkContext, path: String): Unit = {Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)}/*** Transforms a word to its vector representation* @param word a word* @return vector representation of word*/@Since("1.1.0")def transform(word: String): Vector = {wordIndex.get(word) match {case Some(ind) =>val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)Vectors.dense(vec.map(_.toDouble))case None =>throw new IllegalStateException(s"$word not in vocabulary")}}/*** Find synonyms of a word; do not include the word itself in results.* @param word a word* @param num number of synonyms to find* @return array of (word, cosineSimilarity)*/@Since("1.1.0")def findSynonyms(word: String, num: Int): Array[(String, Double)] = {val vector = transform(word)findSynonyms(vector, num, Some(word))}/*** Find synonyms of the vector representation of a word, possibly* including any words in the model vocabulary whose vector respresentation* is the supplied vector.* @param vector vector representation of a word* @param num number of synonyms to find* @return array of (word, cosineSimilarity)*/@Since("1.1.0")def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {findSynonyms(vector, num, None)}/*** Find synonyms of the vector representation of a word, rejecting* words identical to the value of wordOpt, if one is supplied.* @param vector vector representation of a word* @param num number of synonyms to find* @param wordOpt optionally, a word to reject from the results list* @return array of (word, cosineSimilarity)*/private def findSynonyms(vector: Vector,   //需要找的这个词向量在所有词向量里面的相似结果num: Int,         //需要找的TOPNwordOpt: Option[String]): Array[(String, Double)] = {  //返回形式是 (词,相似度)require(num > 0, "Number of similar words should > 0")val fVector = vector.toArray.map(_.toFloat)  //由double类型变为Float类型,可以节省存储空间val cosineVec = new Array[Float](numWords)   //vector 与每个词向量直接的cosine相似度值val alpha: Float = 1val beta: Float = 0// Normalize input vector before blas.sgemv to avoid Inf value  这样对归一化后的结果避免了无穷大的异常出现val vecNorm = blas.snrm2(vectorSize, fVector, 1)  //blas.snrm2函数:SNRM2 := sqrt( x'*x ).通过函数名称返回向量的欧几里得范数if (vecNorm != 0.0f) {blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1)  //blas.sscal函数:scales a vector by a constant, uses unrolled loops for increment equal to 1,sscal(int n, float sa, float[] sx, int _sx_offset, int incx) 结果是向量所元素sx*sa,其中 sx[_sx_offset,_sx_offset+incx*n]}blas.sgemv("T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) //wordVectors * fVector   //wordVectors 500万*100维  fVector 100*1维  cosineVec 500万*1维
//sgemv(java.lang.String trans, int m, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy)  y := alpha*A*x + beta*y,   or   y := alpha*A'*x + beta*yvar i = 0while (i < numWords) {val norm = wordVecNorms(i)   //每个词向量的欧几里得范数if (norm == 0.0f) {cosineVec(i) = 0.0f} else {cosineVec(i) /= norm  //之前fVector已经除了vecNorm,后面只需各个词向量除以自己的范式就行了}i += 1}//堆排序取数val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2))var j = 0while (j < numWords) {pq += Tuple2(wordList(j), cosineVec(j))j += 1}val scored = pq.toSeq.sortBy(-_._2)val filtered = wordOpt match {case Some(w) => scored.filter(tup => w != tup._1)case None => scored}filtered.take(num).map { case (word, score) => (word, score.toDouble) }.toArray}/*** Returns a map of words to their vector representations.*/@Since("1.2.0")def getVectors: Map[String, Array[Float]] = {wordIndex.map { case (word, ind) =>(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))}}}@Since("1.4.0")
object Word2VecModel extends Loader[Word2VecModel] {private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {model.keys.zipWithIndex.toMap}private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {require(model.nonEmpty, "Word2VecMap should be non-empty")val (vectorSize, numWords) = (model.head._2.length, model.size)val wordList = model.keys.toArrayval wordVectors = new Array[Float](vectorSize * numWords)var i = 0while (i < numWords) {Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize)i += 1}wordVectors}private object SaveLoadV1_0 {val formatVersionV1_0 = "1.0"val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel"case class Data(word: String, vector: Array[Float])def load(sc: SparkContext, path: String): Word2VecModel = {val spark = SparkSession.builder().sparkContext(sc).getOrCreate()val dataFrame = spark.read.parquet(Loader.dataPath(path))// Check schema explicitly since erasure makes it hard to use match-case for checking.Loader.checkSchema[Data](dataFrame.schema)val dataArray = dataFrame.select("word", "vector").collect()val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMapnew Word2VecModel(word2VecMap)}def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {val spark = SparkSession.builder().sparkContext(sc).getOrCreate()val vectorSize = model.values.head.lengthval numWords = model.sizeval metadata = compact(render(("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))// We want to partition the model in partitions smaller than// spark.kryoserializer.buffer.maxval bufferSize = Utils.byteStringAsBytes(spark.conf.get("spark.kryoserializer.buffer.max", "64m"))// We calculate the approximate size of the model// We only calculate the array size, considering an// average string size of 15 bytes, the formula is:// (floatSize * vectorSize + 15) * numWordsval approxSize = (4L * vectorSize + 15) * numWordsval nPartitions = ((approxSize / bufferSize) + 1).toIntval dataArray = model.toSeq.map { case (w, v) => Data(w, v) }spark.createDataFrame(dataArray).repartition(nPartitions).write.parquet(Loader.dataPath(path))}}@Since("1.4.0")override def load(sc: SparkContext, path: String): Word2VecModel = {val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)implicit val formats = DefaultFormatsval expectedVectorSize = (metadata \ "vectorSize").extract[Int]val expectedNumWords = (metadata \ "numWords").extract[Int]val classNameV1_0 = SaveLoadV1_0.classNameV1_0(loadedClassName, loadedVersion) match {case (classNameV1_0, "1.0") =>val model = SaveLoadV1_0.load(sc, path)val vectorSize = model.getVectors.values.head.lengthval numWords = model.getVectors.sizerequire(expectedVectorSize == vectorSize,s"Word2VecModel requires each word to be mapped to a vector of size " +s"$expectedVectorSize, got vector of size $vectorSize")require(expectedNumWords == numWords,s"Word2VecModel requires $expectedNumWords words, but got $numWords")modelcase _ => throw new Exception(s"Word2VecModel.load did not recognize model with (className, format version):" +s"($loadedClassName, $loadedVersion).  Supported:\n" +s"  ($classNameV1_0, 1.0)")}}
}

word2vec 的原理 只需要看层次哈弗曼树skip-gram那部分

原理部分推荐链接:https://www.cnblogs.com/shixiangwan/p/7808249.html
其中Sparkword2vec使用过程中有以下问题:

  1. 当迭代次数或者分区过多的情况下,会产生Infinity的问题
  2. 训练过程中分区过多准确度会下降
  3. 内存消耗过大,全部cache了
  4. 哈夫曼树的方法时间消耗大。等问题,这些问题最近几天完善都一一解决了嘿嘿

问题1:
针对第一个问题解决方法:
基于以上的源码可以看见:
spark实现skip-gram直接复现原始word2vec-C语言版本。
spark实现得到词向量是累加所有分区和所有迭代的结果,随着迭代次数的增大和分区数增加导致词向量数值异常。采用归一化词向量迭代结果,把每次迭代和每个分区的结果累加并且归一化就可以了。
具体代码如下:
把这个代码

val synAgg = partial.reduceByKey { case (v1, v2) =>blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)v1

改为:

//修改思路就是把所有的向量结果取均值,计算每个向量的在所有分区出现的计数,然后再取均值// https://github.com/apache/spark/pull/26722// SPARK-24666: do normalization for aggregating weights from partitions.// Original Word2Vec either single-thread or multi-thread which do Hogwild-style aggregation.// Our approach needs to do extra normalization, otherwise adding weights continuously may// cause overflow on float and lead to infinity/-infinity weights.val synAgg = partial.mapPartitions { iter =>iter.map { case (id, vec) =>(id, (vec, 1))}}.reduceByKey { case ((v1, count1), (v2, count2)) =>blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)(v1, count1 + count2)}.map { case (id, (vec, count)) =>blas.sscal(vectorSize, 1.0f / count, vec, 1)(id, vec)

就可以了。完美解决向量值大的问题。
问题2:
分区即是分治的思想,把数据mapPartitions一下每个分区维护自己的一套参数,后期处理把每个分区的参数累加处理,所以在数据迭代上只是并行迭代累加并未串行按照样本依次迭代。分区过多导致每个分区的数据量过小会减少准确度,但是word2vec的效果跟分词质量和数据量的大小有这很大关系。
可以把里面的bcVocabHash.destroy(blocking = false) // 用完了 需要进行释放 销毁,提前释放

skip-gram negetive sample 的版本源码解析:

参考该代码:github SKNS的实现

/** Licensed to the Apache Software Foundation (ASF) under one or more* contributor license agreements.  See the NOTICE file distributed with* this work for additional information regarding copyright ownership.* The ASF licenses this file to You under the Apache License, Version 2.0* (the "License"); you may not use this file except in compliance with* the License.  You may obtain a copy of the License at**    http://www.apache.org/licenses/LICENSE-2.0** Unless required by applicable law or agreed to in writing, software* distributed under the License is distributed on an "AS IS" BASIS,* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.* See the License for the specific language governing permissions and* limitations under the License.*/package org.apache.spark.ml.feature.sgns// https://github.com/shubhamchopra/spark/tree/Word2VecSGNS/mllib/src/main/scala/org/apache/spark/ml/featureimport com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.feature
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandomobject Word2VecCBOWSolver extends Logging {// learning rate is updated for every batch of size batchSizeprivate val batchSize = 10000// power to raise the unigram distribution withprivate val power = 0.75//  private val EXP_TABLE_SIZE = 1000private val EXP_TABLE_SIZE = 10000private val MAX_EXP = 6case class Vocabulary(totalWordCount: Long,vocabMap: Map[String, Int],unigramTable: Array[Int],samplingTable: Array[Float])/*** This method implements Word2Vec Continuous Bag Of Words based implementation using* negative sampling optimization, using BLAS for vectorizing operations where applicable.* The algorithm is parallelized in the same way as the skip-gram based estimation.* We divide input data into N equally sized random partitions.* We then generate initial weights and broadcast them to the N partitions. This way* all the partitions start with the same initial weights. We then run N independent* estimations that each estimate a model on a partition. The weights learned* from each of the N models are averaged and rebroadcast the weights.* This process is repeated `maxIter` number of times.** @param input A RDD of strings. Each string would be considered a sentence.* @return Estimated word2vec model*/def fit[S <: Iterable[String]](
//                                spark:SparkSession, //TODOword2Vec: Word2Vec,skipGramMode: Boolean,input: RDD[S]): feature.Word2VecModel = {val negativeSamples = word2Vec.getNegativeSamples //负采样的词数目val sample = word2Vec.getSample //针对高频词的衰减系数/*** totalWordCount   Long类型的实数* vocabMap         Map[String, Int]长度为所有词的大小* unigramTable     Array[Int]   长度为设置的采样table的长度* samplingTable    Array[Float] 所有词的长度大小,采样表**/val Vocabulary(totalWordCount, vocabMap, uniTable, sampleTable) =generateVocab(input, word2Vec.getMinCount, sample, word2Vec.getUnigramTableSize)val vocabSize = vocabMap.size//TODO 确认负采样的词的数目要小于整个词汇的大小assert(negativeSamples < vocabSize, s"Vocab size ($vocabSize) cannot be smaller" +s" than negative samples($negativeSamples)")val seed = word2Vec.getSeedval initRandom = new XORShiftRandom(seed)val vectorSize = word2Vec.getVectorSizeval syn0Global = Array.fill(vocabSize * vectorSize)(initRandom.nextFloat - 0.5f)  //随机初始化的向量作为输入val syn1Global = Array.fill(vocabSize * vectorSize)(0.0f)val sc = input.context//以下是广播的数据量val vocabMapBroadcast = sc.broadcast(vocabMap)val unigramTableBroadcast = sc.broadcast(uniTable)val sampleTableBroadcast = sc.broadcast(sampleTable)val expTable = sc.broadcast(createExpTable())val windowSize = word2Vec.getWindowSize  //滑动窗口大小val maxSentenceLength = word2Vec.getMaxSentenceLength //最大的句子长度val numPartitions = word2Vec.getNumPartitions //运行时的分区数目/*    import spark.implicits._val xxx = input.map(x => x.toSeq).toDF("seq")sc.parallelize(vocabMap.map(x => (x._1,x._2)).toArray,100).toDF("skuid","index")*///就是把每个句子里的的word转换为index Int格式val digitSentences = input.flatMap { sentence =>val wordIndexes = sentence.flatMap(vocabMapBroadcast.value.get) // 针对每个句子里面的词,得到该词的index,在flatMap里面应用这个函数[vocabMapBroadcast.value.get],得到每个//grouped按照maxSentenceLength分组,把wordIndexes按照最大长度分成几个部分,每个部分的长度不超过 maxSentenceLengthwordIndexes.grouped(maxSentenceLength).map(_.toArray)}.repartition(numPartitions)  //指定分区数目.cache() //TODO cache 可以修改存储方式,减少内存digitSentences.sparkContext.setCheckpointDir("hdfs://ns4/user/dd_edw/tmp.db/item_relationship/item_embedding/graph_embedding_rdwalk")digitSentences.checkpoint()digitSentences.count //TODO actionvocabMapBroadcast.destroy()  //TODO 用完了直接销毁val learningRate = word2Vec.getStepSize  //学习率设置的是0.025Dval wordsPerPartition = totalWordCount / numPartitions   //每个partitions的数据量(以总的词数目不去重为准)logInfo(s"VocabSize: ${vocabMap.size}, TotalWordCount: $totalWordCount")val maxIter = word2Vec.getMaxIterfor {iteration <- 1 to maxIter} {   //迭代次数,每次迭代里面有若干个分组batch运行,注意是在每个partition里面的。logInfo(s"Starting iteration: $iteration")val iterationStartTime = System.nanoTime()val syn0bc = sc.broadcast(syn0Global)   //广播词参数val syn1bc = sc.broadcast(syn1Global)val partialFits = digitSentences.mapPartitionsWithIndex { case (i_, iter) =>logInfo(s"Iteration: $iteration, Partition: $i_")val random = new XORShiftRandom(seed ^ ((i_ + 1) << 16) ^ ((-iteration - 1) << 8))val contextWordPairs = iter.flatMap { s => //iter为一个分区里的所有句子迭代器  s为遍历一个分区里的句子val doSample = sample > Double.MinPositiveValue   //是否有采样系数 boolean类型/***得到的是Iterator[(Seq[Int], Int)],其中Seq[Int]是后面Int的上下窗口词集合,后者Int是中心词index*/generateContextWordPairs(s,windowSize,doSample,sampleTableBroadcast.value,random)}//把所有的中心词对应的窗口词集合分组批次,按照batchSize大小  batchSize大小默认是10000val groupedBatches = contextWordPairs.grouped(batchSize)//负采样标签,negativeSamples负采样词的个数val negLabels = 1.0f +: Array.fill(negativeSamples)(0.0f)val syn0 = syn0bc.valueval syn1 = syn1bc.valueval unigramTable = unigramTableBroadcast.value //长度大小为负采样表的长度,负采样的table,目前设置的长度为2千万// initialize intermediate arraysval contextVec = new Array[Float](vectorSize)val l2Vectors = new Array[Float](vectorSize * (negativeSamples + 1))val gb = new Array[Float](negativeSamples + 1)val neu1e = new Array[Float](vectorSize)val wordIndices = new Array[Int](negativeSamples + 1)val time = System.nanoTimevar batchTime = System.nanoTimevar idx = -1Lfor (batch <- groupedBatches) { // 一个batch就是Seq[(Seq[Int], Int)]集合idx = idx + 1 //每迭代一个batch就会idx增加1val wordRatio = //会随着idx和iteration的增大而增大,越往后迭代得到的wordRatio越小idx.toFloat * batchSize / (maxIter * (wordsPerPartition.toFloat + 1)) +((iteration - 1).toFloat / maxIter)//学习率会随着wordRatio的增大而减小,但是不会小于设置的learningRate * 0.0001,越往后迭代学习率会越小val alpha = math.max(learningRate * 0.0001, learningRate * (1 - wordRatio)).toFloatif(idx % 10 == 0 && idx > 0) { //TODO 每迭代10个batch会做一次汇总,对当前运行过得batch的时间进行统计。logInfo(s"Partition: $i_, wordRatio = $wordRatio, alpha = $alpha") //打印各个分区index,学习率等数据val wordCount = batchSize * idx   //本分区总共已经迭代的中心词个数val timeTaken = (System.nanoTime - time) / 1e6  // 对每个分区定时器,计算到这步所用的时间val batchWordCount = 10 * batchSize  //val currentBatchTime = (System.nanoTime - batchTime) / 1e6  // 对每个分区定时器,计算到这步所用的时间batchTime = System.nanoTimelogDebug(s"Partition: $i_, Batch time: $currentBatchTime ms, batch speed: " +s"${batchWordCount / currentBatchTime * 1000} words/s")logDebug(s"Partition: $i_, Cumulative time: $timeTaken ms, cumulative speed: " +s"${wordCount / timeTaken * 1000} words/s")}val errors = for ((ids, word) <- batch) yield {  //遍历每个batch里面的内容val contexts = if (skipGramMode) { //如果是sg-ns模型ids.map(i => Seq(i))  //把每个上下窗口的词变成Seq集合} else {Seq(ids)  //Seq[Seq[int]]}val errs = for (contextIds <- contexts) yield {// initialize vectors to 0zeroVector(contextVec)zeroVector(l2Vectors)zeroVector(gb)zeroVector(neu1e)val scale = 1.0f / contextIds.length  //上下SKU个数 如果是sg-ns的话,一个词就是一个Seq,如果是CNOW-ns的话上下SKU集合是一个Seq// feed forward   前馈contextIds.foreach { c =>//blas.saxpy函数解释:saxpy(int n, float sa, float[] sx, int _sx_offset, int incx, float[] sy, int _sy_offset, int incy),// 结果是:sy= sa*sx+sy,并且sx[_sx_offset , _sx_offset + incx*n],sy[ _sy_offset , _sy_offset + incy*n],// 本语句的意思是:contextVec = scale*syn0 + contextVec 其中的位移范围是一个vectorSize的长度blas.saxpy(vectorSize, scale, syn0, c * vectorSize, 1, contextVec, 0, 1)}//word是指当前中心词,针对每个中心词采样negativeSamples个词generateNegativeSamples(random, word, unigramTable, negativeSamples, wordIndices)Iterator.range(0, wordIndices.length).foreach { i =>// copy(src: AnyRef, srcPos: Int, dest: AnyRef, destPos: Int, length: Int)// 把syn1复制到l2Vectors  l2Vectors是负采样词向量表Array.copy(syn1, vectorSize * wordIndices(i), l2Vectors, vectorSize * i, vectorSize)}// propagating hidden to output in batch  传播隐藏层到output层val rows = negativeSamples + 1val cols = vectorSize//sgemv(trans: String, m: Int, n: Int, alpha: Float, a: Array[Float], _a_offset: Int, lda: Int, x: Array[Float],// _x_offset: Int, incx: Int, beta: Float, y: Array[Float], _y_offset: Int, incy: Int)// y := alpha*A*x + beta*y,   or   y := alpha*A'*x + beta*y  |||||||====|||||||  gb = 1.0f * l2Vectors * contextVec + 0.0f * gbblas.sgemv("T", cols, rows, 1.0f, l2Vectors, 0, cols, contextVec, 0, 1, 0.0f, gb, 0, 1)Iterator.range(0, negativeSamples + 1).foreach { i =>if (gb(i) > -MAX_EXP && gb(i) < MAX_EXP) {val ind = ((gb(i) + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt  //TODO 修改sigmod函数的调用采用查表方式val v = expTable.value(ind)//                  val v = 1.0f / (1 + math.exp(-gb(i)).toFloat)  //sigmod计算,其实可以进行查表运算// computing error gradientval err = (negLabels(i) - v) * alpha  //梯度// update hidden -> output layer, syn1// syn1 = err * contextVec + syn1blas.saxpy(vectorSize, err, contextVec, 0, 1, syn1, wordIndices(i) * vectorSize, 1)// update for word vectors// neu1e = err * l2Vectors + neu1eblas.saxpy(vectorSize, err, l2Vectors, i * vectorSize, 1, neu1e, 0, 1)gb.update(i, err)} else {gb.update(i, 0.0f)}}// update input -> hidden layer, syn0contextIds.foreach { i =>// syn0 = 1.0f * neu1e + syn0blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, i * vectorSize, 1)}gb.map(math.abs).sum / alpha}errs.sum  //每个中心词迭代后的误差}logInfo(s"Partition: $i_, Average Batch Error = ${errors.sum / batchSize}")}Iterator.tabulate(vocabSize) { index =>(index, syn0.slice(index * vectorSize, (index + 1) * vectorSize))} ++ Iterator.tabulate(vocabSize) { index =>(vocabSize + index, syn1.slice(index * vectorSize, (index + 1) * vectorSize))}}val aggedMatrices = partialFits.reduceByKey { case (v1, v2) =>blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)v1/*val aggedMatrices =  partialFits.mapPartitions { iter =>iter.map { case (id, vec) =>(id, (vec, 1))}}.reduceByKey { case ((v1, count1), (v2, count2)) =>blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)(v1, count1 + count2)}.map { case (id, (vec, count)) =>blas.sscal(vectorSize, 1.0f / count, vec, 1)(id, vec)*/}.collect()val norm = 1.0f / numPartitionsaggedMatrices.foreach {case (index, v) =>blas.sscal(v.length, norm, v, 0, 1)if (index < vocabSize) {Array.copy(v, 0, syn0Global, index * vectorSize, vectorSize)} else {Array.copy(v, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)}}syn0bc.destroy(false)syn1bc.destroy(false)val timePerIteration = (System.nanoTime() - iterationStartTime) / 1e6logInfo(s"Total time taken per iteration: ${timePerIteration} ms")}digitSentences.unpersist()
//    vocabMapBroadcast.destroy()unigramTableBroadcast.destroy()sampleTableBroadcast.destroy()new feature.Word2VecModel(vocabMap, syn0Global)}/*** Similar to InitUnigramTable in the original code.   跟源码一样的操作*/private def generateUnigramTable(normalizedWeights: Array[Double], tableSize: Int): Array[Int] = {val table = new Array[Int](tableSize)var index = 0var wordId = 0while (index < table.length) { //遍历tabletable.update(index, wordId)//[index.toFloat / table.length]这个值最大值是1,normalizedWeights数组最大值也是1,强制把normalizedWeights分成tableSize个区间,// 按照table索引的刻度进行分割,最终得到的table里面的每个元素是相邻的是wordID正是table一个刻度所包含的wordif (index.toFloat / table.length >= normalizedWeights(wordId)) {wordId = math.min(normalizedWeights.length - 1, wordId + 1)}index += 1}table}private def createExpTable(): Array[Float] = {val expTable = new Array[Float](EXP_TABLE_SIZE)var i = 0while (i < EXP_TABLE_SIZE) {val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)expTable(i) = (tmp / (tmp + 1.0)).toFloati += 1}expTable}/**** @param input* @param minCount  最小的频次限制* @param sample   sample的系数设置* @param unigramTableSize  采样的表格大小* @tparam S* @return*/private def generateVocab[S <: Iterable[String]](input: RDD[S],minCount: Int,sample: Double,unigramTableSize: Int): Vocabulary = {val sc = input.contextval words = input.flatMap(x => x)//按照词的频次进行排序,并且zip上indexval sortedWordCounts = words.map(w => (w, 1L)).reduceByKey(_ + _).filter{case (w, c) => c >= minCount}.collect().sortWith{case ((w1, c1), (w2, c2)) => c1 > c2}.zipWithIndexval totalWordCount = sortedWordCounts.map(_._1._2).sum   //所有词的总和//每个词的对应的index索引val vocabMap = sortedWordCounts.map{case ((w, c), i) =>w -> i}.toMap//所有词大小的 抽样表val samplingTable = new Array[Float](vocabMap.size)if (sample > Double.MinPositiveValue) {  // sample 大于double最小值sortedWordCounts.foreach { case ((w, c), i) =>val samplingRatio = sample * totalWordCount / c  //采样的概率samplingTable.update(i, (math.sqrt(samplingRatio) + samplingRatio).toFloat)}}val weights = sortedWordCounts.map{ case((_, x), _) => scala.math.pow(x, power)} //对每个词的频次进行 f^0.75次方val totalWeight = weights.sum   //所有的权重和//scanLeft:扫描,即对某个集合的所有元素做fold操作,但是会把产生的所有中间结果放置于一个集合中保存 ,跟foldLeft还是有区别的,foldLeft不存储中间结果//TODO normalizedCumWeights数组的长度为所有词的个数大小val normalizedCumWeights = weights.scanLeft(0.0)(_ + _).tail.map(x => x / totalWeight) //数组的tail操作,除了头全是尾部//Unigram table size. The unigram table is used to generate negative samples.    本程序设置的2千万val unigramTable = generateUnigramTable(normalizedCumWeights, unigramTableSize)/*** totalWordCount   Long类型的实数* vocabMap         Map[String, Int]长度为所有词的大小* unigramTable     Array[Int]   长度为设置的采样table的长度* samplingTable    Array[Float] 所有词的长度大小,采样表**/Vocabulary(totalWordCount, vocabMap, unigramTable, samplingTable)}private def zeroVector(v: Array[Float]): Unit = {var i = 0while(i < v.length) {v.update(i, 0.0f)i+= 1}}/**** @param sentence* @param window* @param doSample* @param samplingTable* @param random* @return   生成中心词的上下文 词对*/private def generateContextWordPairs(sentence: Array[Int],window: Int,doSample: Boolean,samplingTable: Array[Float],random: XORShiftRandom): Iterator[(Seq[Int], Int)] = {val reducedSentence = if (doSample) {sentence.filter(i => samplingTable(i) > random.nextFloat)  //每个句子里,随机选取一些词} else {sentence}val sentenceLength = reducedSentence.lengthIterator.range(0, sentenceLength)//该句子的长度.map { i =>val b = window - random.nextInt(window) // (window - a) in original code// pick b words around the current word indexval start = math.max(0, i - b) // c in original code, floor ar 0val end = math.min(sentenceLength, i + b + 1) // cap at sentence length// make sure current word is not a part of the contextval contextIds = Iterator.range(start, end).filter(_ != i).map(reducedSentence(_)) //得到start,end范围内的词的索引val word = reducedSentence(i)(contextIds.toSeq, word)}}/**** @param random* @param word* @param unigramTable* @param numSamples* @param arr  最终返回的是arr,里面第一个是word中心词,后面依次是numSamples个采样词*/// This essentially helps translate from uniform distribution to a distribution// resembling uni-gram frequency distribution.private def generateNegativeSamples(random: XORShiftRandom,word: Int,unigramTable: Array[Int],numSamples: Int,arr: Array[Int]): Unit = {assert(numSamples + 1 == arr.length,s"Input array should be large enough to hold ${numSamples} negative samples")arr.update(0, word)  //arr的第一个元素是word(中心词Word)var i = 1while (i <= numSamples) { //迭代随机选取样本val negSample = unigramTable(random.nextInt(unigramTable.length))if(negSample != word) {arr.update(i, negSample)i += 1}}}
}

spark word2vec 源码详细解析相关推荐

  1. 20行Python代码爬取2W多条音频文件素材【内附源码+详细解析】新媒体创作必备

    大家好,我是辣条. 今天的内容稍显简单,不过对于新媒体创作的朋友们还是很有帮助的,你能用上的话记得给辣条三连! 爬取目标 网站:站长素材 工具使用 开发工具:pycharm 开发环境:python3. ...

  2. JAVA8 LinkedList 链表源码详细解析

    今天突发奇想看了一下LinkedList的源码,还挺有趣的,话不多说,show me the code. 我使用的是IDEA,快捷键仅供参考. 按住Ctrl再点击类名可以进入类的源码,随便写一个含有L ...

  3. ArrayList源码详细解析(一)

    Java ArrayList源码解析(基于JDK 12,对比JDK 8) 自从接触到ArrayList以来,一直觉得很方便,但是从来没有系统.全面的学习了解过ArraryList的实现原理.最近学习了 ...

  4. Hadoop HDFS创建文件/写数据流程、源码详细解析

    HDFS创建文件/写数据源码解析 HDFS HDFS写流程 创建文件源码 客户端 DistributedFileSystem DFSClient DFSOutputStream 客户端/Namenod ...

  5. golang mutex源码详细解析

    目前golang的版本是1.12,其中的mutex是增加了普通模式和饥饿模式切换的优化版本,为了便于理解,这里先从上一个版本1.7版本的mutex开始分析,以后再对优化版本进行说明. Mutex结构说 ...

  6. Faster_R_CNN源码详细解析

    Faster R-CNN整体架构 首先使用共享卷积层为全图提取特征feature maps 将得到的feature maps送入RPN,RPN会产生接近两千个候选框proposals RoI Pool ...

  7. 火车轨道铁路轨道检测识别(附带Python源码+详细解析)

    现在的网络上,铁轨检测的源码几乎没有,所以自己参照着一篇汽车车道线检测的方法,然后调节参数,实现了铁轨的轨道检测,但现在只能检测直线,弯曲的铁轨检测下一步会实现,实现之后会更新的,敬请期待. 弯轨检测 ...

  8. HandlerThread原理、使用实例、源码详细解析

    0.目录 一.HandlerThread简介 二.HandlerThread原理 三.HandlerThread使用实例 四.HandlerThread源码分析 五.总结 一.HandlerThrea ...

  9. MJRefresh 源码详细解析

    MJRefresh是李明杰老师的作品,到现在已经有9800多颗star了,是一个简单实用,功能强大的iOS下拉刷新(也支持上拉加载更多)控件.它的可定制性很高,几乎可以满足大部分下拉刷新的设计需求,值 ...

最新文章

  1. 计算机改变了我们的生活英语作文带翻译,一件事改变了我的生活的英语作文,带翻译,求帮忙,最好是符合初中三年级的英语作文...
  2. No.6 PHP的基本配置与优化
  3. ASP.NET中的加密方法介绍
  4. Ext4 MVC CRUD操作
  5. 计量经济学建模_一分钟看完计量经济学
  6. Spring 笔记
  7. 【Kafka】Kafka使用代码设置offset值
  8. Helm 3 完整教程(十二):Helm 函数讲解(6)字典函数
  9. 蝰蛇音效音效包_用数据科学分析音效迷艺术
  10. 神舟t6ti-x5刷蓝天新版BIOS教程 可提升0.5G睿频 模具N85xHKx/HJx
  11. 【数学建模】层次分析法(AHP)+Matlab实现
  12. python 携程订单接口_携程api开发
  13. [iOS,mac]Coding.Net(码市)进行代码管理
  14. 手机抓包+注入黑科技HttpCanary——最强大的Android抓包注入工具
  15. (计算几何+二分+网络流)P4048 [JSOI2010]冷冻波
  16. hdu2075 A|B?(C语言)
  17. 新加坡国际学校IB成绩亮眼,全球40%满分考生来自新加坡
  18. 等保三级安全要求简要攻略-安全物理环境
  19. “难产”的恒驰5,前途堪忧
  20. 在VS中编译并运行lua文件

热门文章

  1. 地铁fas系统需要服务器吗,地铁FAS系统简介及操作
  2. [iOS开发]修改代码问题记录
  3. 水池水位无线自动控制系统
  4. error: Failed dependencies: cloog-ppl = 0.15 is needed by gcc-4.4.7-4.el6.x86_64 cpp = 4.4.7-4.el6
  5. Number()函数
  6. gdb调试 出现value optimized out解决方法
  7. 中医临床试验数据交换标准研究
  8. 新手入门-LINUX(转)
  9. 特征选择—相关性过滤
  10. 2021河南卫生副高考试成绩查询,中国卫生人才网 中国卫生人才网:河南卫生资格成绩查询入口2021...