从下面分析可以看出,是先做了hash计算,然后使用hash join table来讲hash值相等的数据合并在一起。然后再使用udf计算距离,最后再filter出满足阈值的数据:

  /*** Join two datasets to approximately find all pairs of rows whose distance are smaller than* the threshold. If the [[outputCol]] is missing, the method will transform the data; if the* [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed* data when necessary.** @param datasetA One of the datasets to join.* @param datasetB Another dataset to join.* @param threshold The threshold for the distance of row pairs.* @param distCol Output column for storing the distance between each pair of rows.* @return A joined dataset containing pairs of rows. The original rows are in columns*         "datasetA" and "datasetB", and a column "distCol" is added to show the distance*         between each pair.*/def approxSimilarityJoin(datasetA: Dataset[_],datasetB: Dataset[_],threshold: Double,distCol: String): Dataset[_] = {val leftColName = "datasetA"val rightColName = "datasetB"val explodeCols = Seq("entry", "hashValue")val explodedA = processDataset(datasetA, leftColName, explodeCols)// If this is a self join, we need to recreate the inputCol of datasetB to avoid ambiguity.// TODO: Remove recreateCol logic once SPARK-17154 is resolved.val explodedB = if (datasetA != datasetB) {processDataset(datasetB, rightColName, explodeCols)} else {val recreatedB = recreateCol(datasetB, $(inputCol), s"${$(inputCol)}#${Random.nextString(5)}")processDataset(recreatedB, rightColName, explodeCols)}// Do a hash join on where the exploded hash values are equal.val joinedDataset = explodedA.join(explodedB, explodeCols).drop(explodeCols: _*).distinct()// Add a new column to store the distance of the two rows.val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType)val joinedDatasetWithDist = joinedDataset.select(col("*"),distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol))// Filter the joined datasets where the distance are smaller than the threshold.joinedDatasetWithDist.filter(col(distCol) < threshold)}


sql join 算法 时间复杂度

SELECT  T1.name, T2.date
FROM    T1, T2
WHERE T1.id=T2.id AND T1.color='red' AND T2.type='CAR'

假设T1有m行,T2有n行,那么,普通情况下,应该要遍历T1的每一行的id(m),然后在遍历T2(n)中找出T2.id = T1.id的行进行join。时间复杂度应该是O(m*n)

如果没有索引的话,engine会选择hash join或者merge join进行优化。

hash join是这样的:

  1. 选择被哈希的表,通常是小一点的表。让我们愉快地假定是T1更小吧。
  2. T1所有的记录都被遍历。如果记录符合color=’red’,这条记录就会进去哈希表,以id为key,以name为value。
  3. T2所有的记录被遍历。如果记录符合type=’CAR’,使用这条记录的id去搜索哈希表,所有命中的记录的name的值,都被返回,还带上了当前记录的date的值,这样就可以把两者join起来了。


merge join是这样的:

1.复制T1(id, name),根据id排序。
2.复制T2(id, date),根据id排序。

    >1 2<2 32 43 5


>1  2<  - 不match, 左指针小,左指针++ 2 3 2 4 3 5 1 2< - match, 返回记录,两个指针都++ >2 3 2 4 3 5 1 2 - match, 返回记录,两个指针都++ 2 3< 2 4 >3 5 1 2 - 左指针越界,查询结束。 2 3 2 4< 3 5 >





SELECT  T1.name, T2.date
FROM    T1, T2




可以看到 hashFunction 涉及到indices 字段下表的计算。另外的distance计算使用了jaccard相似度。


/*** :: Experimental ::** Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function* is picked from the following family of hash functions, where a_i and b_i are randomly chosen* integers less than prime:*    `h_i(x) = ((x \cdot a_i + b_i) \mod prime)`** This hash family is approximately min-wise independent according to the reference.** Reference:* Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear permutations."* Electronic Journal of Combinatorics 7 (2000): R26.** @param randCoefficients Pairs of random coefficients. Each pair is used by one hash function.*/
class MinHashLSHModel private[ml](override val uid: String,private[ml] val randCoefficients: Array[(Int, Int)])extends LSHModel[MinHashLSHModel] {/** @group setParam */@Since("2.4.0")override def setInputCol(value: String): this.type = super.set(inputCol, value)/** @group setParam */@Since("2.4.0")override def setOutputCol(value: String): this.type = super.set(outputCol, value)@Since("2.1.0")override protected[ml] def hashFunction(elems: Vector): Array[Vector] = {require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")val elemsList = elems.toSparse.indices.toListval hashValues = randCoefficients.map { case (a, b) =>elemsList.map { elem: Int =>((1L + elem) * a + b) % MinHashLSH.HASH_PRIME}.min.toDouble}// TODO: Output vectors of dimension numHashFunctions in SPARK-18450hashValues.map(Vectors.dense(_))}@Since("2.1.0")override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {val xSet = x.toSparse.indices.toSetval ySet = y.toSparse.indices.toSetval intersectionSize = xSet.intersect(ySet).size.toDoubleval unionSize = xSet.size + ySet.size - intersectionSizeassert(unionSize > 0, "The union of two input sets must have at least 1 elements")1 - intersectionSize / unionSize}@Since("2.1.0")override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {// Since it's generated by hashing, it will be a pair of dense vectors.// TODO: This hashDistance function requires more discussion in SPARK-18454x.zip(y).map(vectorPair =>vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)).min}@Since("2.1.0")override def copy(extra: ParamMap): MinHashLSHModel = {val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent)copyValues(copied, extra)}@Since("2.1.0")override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)



