Faiss 源码解析

faissfacebook 开源的一个专门用于做高维向量的相似性搜索的库,有 c++python 的接口;目前项目地址在 https://github.com/facebookresearch/faiss。本文主要结合 faiss 的官方示例,介绍如何使用 faiss 以及 暴力/IVF/IVFPQ 检索算法在 faiss 的具体实现。

检索算法介绍

检索算法的介绍可以参考 科普,本文主要关注3种检索算法:

  1. 暴力搜索:顾名思义,querybase 一一比对,选择最近的
  2. IVF:首先在具有代表性的数据上训练聚类中心,然后将 base 加入到最近的聚类中心的桶里,在 search 的时候,query 先和聚类中心比对,再在一定数目的桶里做暴力搜索
  3. IVFPQ:在 IVF 的基础上,将 basePQ 量化,加速比对

faiss 的编译与安装

可以参考官方给出的编译方法,这里我没有安装 cuda,所以采用的命令是

./configure --without-cuda && make

在编译完 faiss 之后,我们对官方提供的示例也进行编译,路径在 ./tutorial/cpp 下,cd到目录下直接 make 就可以了

如何使用 faiss

官方总共提供了五个示例,其中有两个是 gpu 版本的,三个是 cpu 版本的,我们这里主要关注 cpu 的,分别是 1-Flat.cpp2-IVFFLAT.cpp3-IVFPQ.cpp,分别对应着暴力算法检索,IVF 算法检索,IVFPQ 算法检索。不同的算法在用户侧代码基本一致,我们选取 IVFPQ 做简单介绍。

#include <cstdio>
#include <cstdlib>#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>int main() {int d = 64;                            // 特征维度int nb = 100000;                       // base 样本数量int nq = 10000;                        // query 样本数量float *xb = new float[d * nb];float *xq = new float[d * nq];for(int i = 0; i < nb; i++) {for(int j = 0; j < d; j++)xb[d * i + j] = drand48();xb[d * i] += i / 1000.;} // 随机初始化 base 数据for(int i = 0; i < nq; i++) {for(int j = 0; j < d; j++)xq[d * i + j] = drand48();xq[d * i] += i / 1000.;}    // 随机初始化 query 数据int nlist = 100;  // 聚类中心个数int k = 4;int m = 8;                             // bytes per vectorfaiss::IndexFlatL2 quantizer(d);       // 初始化用 L2 暴力 search 的 indexfaiss::IndexIVFPQ index(&quantizer, d, nlist, m, 8); // 初始化 ivfpq 的 index,用 L2 暴力 search 的 index 初始化index.train(nb, xb); // 训练 indexindex.add(nb, xb); // 将 base 数据加入到 index 中,用于之后的搜索{       // search xqlong *I = new long[k * nq];float *D = new float[k * nq];index.nprobe = 10; // 搜索 10 个中心点index.search(nq, xq, k, D, I);printf("I=n");for(int i = nq - 5; i < nq; i++) {for(int j = 0; j < k; j++)printf("%5ld ", I[i * k + j]);printf("n");}delete [] I;delete [] D;}delete [] xb;delete [] xq;return 0;
}

这段代码主要包括了四个部分,分别是

  1. 初始化 base/query 数据和 index
  2. 训练 index
  3. 加入baseindex
  4. querysearch

其中,使用 faiss 主要包含了三步。初始化数据准备不用多说,faiss 中要求的数据格式都是 n * d 的矩阵格式,然后被展平到一维 float 数组中。剩下的两步,都是对 index 进行操作。

源码解析

检索流程

参考官方给的例子,检索分为三步:trainaddsearch,不同的检索算法,体现在使用不同的 index 进行这三步上

  1. train:选取有代表性的数据,训练 index
  2. add:将 base 数据加入到 index
  3. search:对于给定的 query,返回其对应的在底库中的 topk

重要类

Index

index 的基类,后续各种各样的检索算法,都会继承这个基类或者这个类的派生类,然后实现具体的方法,在这个类中,有如下的数据成员:

  • d:维度,每个向量的维度
  • ntotal:索引的向量的数目,可以理解成检索时的 base 数目
  • metric_type:检索时使用的 metric 类型,比如 L2,内积等

IndexFlat

用于做暴力搜索的 index 类,直接继承 index。暴力搜索思路很简单,无需 trainadd 的所有 base 都被存储起来,然后在 search 的时候把 query 和所有 base 进行比对,选取最近的。我们看下具体实现。

  • add

add 就是把所有的 base 都存储起来

void IndexFlat::add (idx_t n, const float *x) {xb.insert(xb.end(), x, x + n * d);ntotal += n;
}

  • Search

Search 的时候,根据 metric type 的不同,返回 querytopk。具体计算时采用了 openmpsse/avx 优化

void IndexFlat::search (idx_t n, const float *x, idx_t k,float *distances, idx_t *labels) const
{// we see the distances and labels as heapsif (metric_type == METRIC_INNER_PRODUCT) {float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};knn_inner_product (x, xb.data(), d, n, ntotal, &res); //函数内部有并行优化} else if (metric_type == METRIC_L2) {float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};knn_L2sqr (x, xb.data(), d, n, ntotal, &res);} else {float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};knn_extra_metrics (x, xb.data(), d, n, ntotal,metric_type, metric_arg,&res);}
}

Clustering

实现 K-means 聚类的类,提供train ,需要训练数据和 index(用于 search 最近的向量),结果得到训练数据的类中心向量,如果是量化的向量,那么还需要提供量化使用的 index codec,我们去除量化的部分,只看 float 数据

核心代码如下,包括如下部分:

  • search过程,将聚类中心作为底库加入到 index 中,并对训练数据做 search,得到 assign
  • 计算新的聚类中心,计算新的聚类中心的代码在 compute_centroids中,具体就是对于相同的类别的向量,将向量的均值作为新的中心,在实现上,利用 openmp 进行了并行优化

重复以上两步,就可以得到最优的聚类中心

void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,const Index * codec, Index & index,const float *weights) {// 前处理省略  for (int redo = 0; redo < nredo; redo++) {if (verbose && nredo > 1) {printf("Outer iteration %d / %dn", redo, nredo);}// initialize (remaining) centroids with random points from the datasetcentroids.resize (d * k);std::vector<int> perm (nx);rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);for (int i = n_input_centroids; i < k ; i++) {memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);}post_process_centroids ();// prepare the indexif (index.ntotal != 0) {index.reset();}index.add (k, centroids.data());// k-means iterationsfloat err = 0;for (int i = 0; i < niter; i++) {double t0s = getmillisecs();index.search (nx, reinterpret_cast<const float *>(x), 1,dis.get(), assign.get());InterruptCallback::check();t_search_tot += getmillisecs() - t0s;// accumulate errorerr = 0;for (int j = 0; j < nx; j++) {err += dis[j];}// update the centroidsstd::vector<float> hassign (k);size_t k_frozen = frozen_centroids ? n_input_centroids : 0;compute_centroids (d, k, nx, k_frozen,x, codec, assign.get(), weights,hassign.data(), centroids.data());index.reset ();if (update_index) {index.train (k, centroids.data());}index.add (k, centroids.data());InterruptCallback::check ();}}//保存最优聚类中心if (nredo > 1) {centroids = best_centroids;iteration_stats = best_obj;index.reset();index.add(k, best_centroids.data());}}void compute_centroids (size_t d, size_t k, size_t n,size_t k_frozen,const uint8_t * x, const Index *codec,const int64_t * assign,const float * weights,float * hassign,float * centroids)
{k -= k_frozen;centroids += k_frozen * d;memset (centroids, 0, sizeof(*centroids) * d * k);size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);#pragma omp parallel{int nt = omp_get_num_threads();int rank = omp_get_thread_num();// this thread is taking care of centroids c0:c1size_t c0 = (k * rank) / nt;size_t c1 = (k * (rank + 1)) / nt;std::vector<float> decode_buffer (d);for (size_t i = 0; i < n; i++) {int64_t ci = assign[i];assert (ci >= 0 && ci < k + k_frozen);ci -= k_frozen;if (ci >= c0 && ci < c1)  {float * c = centroids + ci * d;const float * xi;if (!codec) {xi = reinterpret_cast<const float*>(x + i * line_size);} else {float *xif = decode_buffer.data();codec->sa_decode (1, x + i * line_size, xif);xi = xif;}if (weights) {float w = weights[i];hassign[ci] += w;for (size_t j = 0; j < d; j++) {c[j] += xi[j] * w;}} else {hassign[ci] += 1.0;for (size_t j = 0; j < d; j++) {c[j] += xi[j];}}}}}#pragma omp parallel forfor (size_t ci = 0; ci < k; ci++) {if (hassign[ci] == 0) {continue;}float norm = 1 / hassign[ci];float * c = centroids + ci * d;for (size_t j = 0; j < d; j++) {c[j] *= norm;}}}

IndexIVF

用于做 IVF 搜索的 index 类。

  • train

ivf 算法会把给定的数据进行聚类,得到固定数目的聚类中心。具体的,就是 train_q1​ 的过程,train_residual 在 ivf 中是一个空函数

void IndexIVF::train (idx_t n, const float *x)
{train_q1 (n, x, verbose, metric_type);train_residual (n, x);is_trained = true;
}void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {if (verbose)printf("IndexIVF: no residual trainingn");// does nothing by default
}

train_q1用的是 Level1Quantizer 的具体实现,如下,对训练数据进行聚类,得到聚类中心并保存下来

void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
{// 省略无关代码Clustering clus (d, nlist, cp);quantizer->reset();if (clustering_index) {clus.train (n, x, *clustering_index);quantizer->add (nlist, clus.centroids.data());} else {clus.train (n, x, *quantizer);}quantizer->is_trained = true;
}

  • add

    • 分片。根据输入的大小,按照固定的大小依次进行 add
    • 建立 invlists。根据 train得到的聚类中心(保存在 quantizer 中),每一个类中心对应 invlists 中的一个桶。
    • invlists 的桶里加入 base。利用了 openmp 进行了并行加速
void IndexIVF::add (idx_t n, const float * x)
{add_with_ids (n, x, nullptr);
}void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
{// do some blocking to avoid excessive allocsidx_t bs = 65536;if (n > bs) {for (idx_t i0 = 0; i0 < n; i0 += bs) {idx_t i1 = std::min (n, i0 + bs);if (verbose) {printf("   IndexIVF::add_with_ids %ld:%ldn", i0, i1);}add_with_ids (i1 - i0, x + i0 * d,xids ? xids + i0 : nullptr);}return;}std::unique_ptr<idx_t []> idx(new idx_t[n]);quantizer->assign (n, x, idx.get());size_t nadd = 0, nminus1 = 0;#pragma omp parallel reduction(+: nadd){int nt = omp_get_num_threads();int rank = omp_get_thread_num();// each thread takes care of a subset of listsfor (size_t i = 0; i < n; i++) {idx_t list_no = idx [i];if (list_no >= 0 && list_no % nt == rank) {idx_t id = xids ? xids[i] : ntotal + i;size_t ofs = invlists->add_entry (list_no, id,flat_codes.get() + i * code_size);dm_adder.add (i, list_no, ofs);nadd++;} else if (rank == 0 && list_no == -1) {dm_adder.add (i, -1, 0);}}}ntotal += n;
}

  • search

    • Search corse_dis。搜索离 query 最近的聚类中心
    • Search invlists。在最近的 nprobe 个聚类中心对应的 invlists 中进行暴力 heap 搜索,得到 topk
void IndexIVF::search (idx_t n, const float *x, idx_t k,float *distances, idx_t *labels) const
{std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);double t0 = getmillisecs();quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());indexIVF_stats.quantization_time += getmillisecs() - t0;t0 = getmillisecs();invlists->prefetch_lists (idx.get(), n * nprobe);search_preassigned (n, x, k, idx.get(), coarse_dis.get(),distances, labels, false);indexIVF_stats.search_time += getmillisecs() - t0;
}

ProductQuantizer

用来做 PQ 量化算法的类,关于 PQ 量化算法,可以参考 pq算法。简单来说,我们需要得到用来量化的码本,然后我们可以对输入的向量进行解码和编码。得到码本的过程在 ProductQuantizer::train 中,包含

  • 将输入向量按照维度切分成 PQ 段,每段的维度是 dsub
  • 得到每段的聚类中心,这就是码本

编码和解码的过程就是将输入向量转化为码本里的 idx,可以看出,量化是存在一定的误差,其中,PQ 越大,误差越小

void ProductQuantizer::train (int n, const float * x)
{if (train_type != Train_shared) {train_type_t final_train_type;final_train_type = train_type;if (train_type == Train_hypercube ||train_type == Train_hypercube_pca) {if (dsub < nbits) {final_train_type = Train_default;printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)n",nbits, dsub);}}float * xslice = new float[n * dsub];ScopeDeleter<float> del (xslice);for (int m = 0; m < M; m++) {for (int j = 0; j < n; j++)memcpy (xslice + j * dsub,x + j * d + m * dsub,dsub * sizeof(float));Clustering clus (dsub, ksub, cp);// we have some initialization for the centroidsif (final_train_type != Train_default) {clus.centroids.resize (dsub * ksub);}switch (final_train_type) {case Train_hypercube:init_hypercube (dsub, nbits, n, xslice,clus.centroids.data ());break;case  Train_hypercube_pca:init_hypercube_pca (dsub, nbits, n, xslice,clus.centroids.data ());break;case  Train_hot_start:memcpy (clus.centroids.data(),get_centroids (m, 0),dsub * ksub * sizeof (float));break;default: ;}if(verbose) {clus.verbose = true;printf ("Training PQ slice %d/%zdn", m, M);}IndexFlatL2 index (dsub);clus.train (n, xslice, assign_index ? *assign_index : index);set_params (clus.centroids.data(), m);}} else {Clustering clus (dsub, ksub, cp);if(verbose) {clus.verbose = true;printf ("Training all PQ slices at oncen");}IndexFlatL2 index (dsub);clus.train (n * M, x, assign_index ? *assign_index : index);for (int m = 0; m < M; m++) {set_params (clus.centroids.data(), m);}}
}

IndexIVFPQ

ivfpq 算法在 ivf 的基础上,对 basepq。大家可以自行参考代码

openmp官方源码_Faiss 源码解析相关推荐

  1. Android Fragment 从源码的角度去解析(上)

    ###1.概述 本来想着昨天星期五可以早点休息,今天可以早点起来跑步,可没想到事情那么的多,晚上有人问我主页怎么做到点击才去加载Fragment数据,而不是一进入主页就去加载所有的数据,在这里自己就对 ...

  2. 【Android源码】源码分析深度好文+精编内核解析分享

    阅读Android源码的好处有很多,比如:可以加深我们对系统的了解:可以参考牛人优雅的代码实现:可以从根本上找出一些bug的原因-我们应该庆幸Android是开源的,所有的功能都可以看到实现,所有的b ...

  3. MyBatis 源码分析 - 映射文件解析过程

    1.简介 在上一篇文章中,我详细分析了 MyBatis 配置文件的解析过程.由于上一篇文章的篇幅比较大,加之映射文件解析过程也比较复杂的原因.所以我将映射文件解析过程的分析内容从上一篇文章中抽取出来, ...

  4. 【flink】Flink 1.12.2 源码浅析 : yarn-per-job模式解析 TaskMasger 启动

    1.概述 转载:Flink 1.12.2 源码浅析 : yarn-per-job模式解析 [四] 上一篇: [flink]Flink 1.12.2 源码浅析 : yarn-per-job模式解析 Jo ...

  5. 【flink】Flink 1.12.2 源码浅析 : yarn-per-job模式解析 JobMasger启动 YarnJobClusterEntrypoint

    1.概述 转载:Flink 1.12.2 源码浅析 : yarn-per-job模式解析 [三] 上一章:[flink]Flink 1.12.2 源码浅析 : yarn-per-job模式解析 yar ...

  6. 【flink】Flink 1.12.2 源码浅析 : yarn-per-job模式解析 yarn 提交过程解析

    1.概述 转载:Flink 1.12.2 源码浅析 : yarn-per-job模式解析 [二] 请大家看原文去. 接上文Flink 1.12.2 源码分析 : yarn-per-job模式浅析 [一 ...

  7. 【flink】Flink 1.12.2 源码浅析 : yarn-per-job模式解析 从脚本到主类

    1.概述 转载:Flink 1.12.2 源码浅析 : yarn-per-job模式解析 [一] 可以去看原文.这里是补充专栏.请看原文 2. 前言 主要针对yarn-per-job模式进行代码分析. ...

  8. [darknet源码系列-2] darknet源码中的cfg解析

    [darknet源码系列-2] darknet源码中的cfg解析 FesianXu 20201118 at UESTC 前言 笔者在[1]一文中简单介绍了在darknet中常见的数据结构,本文继续上文 ...

  9. UE4官方滚球项目源码笔记

    UE4官方滚球项目源码笔记 我的项目名称:test_0511,读者请根据自己的项目名称自行查找(YourProgramNameBall.h/YourProgramNameBall.cpp) 笔者是UE ...

最新文章

  1. “大型票务系统”和“实物电商系统”的数据库选型
  2. 730版本去掉恼人的提示信息
  3. libsvm java下载_java-libsvm 版 结合已有数据集的demo,方便初学者使用 Develop 238万源代码下载- www.pudn.com...
  4. dns的主从服务器的简单配置
  5. DELL服务器T410进行系统修复,ibm T410 BIOS修复过程-BIOS维修网站www.biosrepair.com
  6. CentOS6.8下安装memcached并设置开机自启动
  7. 计算机学生的高职英语课程,高职计算机英语课程教学方法探索
  8. Xmind 8 Pro破解版安装激活教程(Windows版)
  9. HDU - 1598
  10. 企业邮箱注册申请入口,公司邮箱申请哪个好?
  11. chrome插件莫名消失【已解决】
  12. 007需求分析中的重要知识点(马斯洛需求层次理论+KANO优先级筛选模型+金字塔模型)
  13. 信息系统项目管理师必背核心考点(四十九)合同法
  14. HIT CSAPP大作业论文
  15. r语言c()函数格式,R语言基本操作函数
  16. windos10本地安装git工具并使用
  17. 少不读水浒——揭秘水浒传
  18. 王桂林 C++基础与提高 练习题——string数组
  19. 无线网络中AP及AC的概念及作用:
  20. 数据库优化的八种方式

热门文章

  1. Log4j2基本使用入门
  2. 阿里首推的“SpringBoot+Vue全栈项目”有多牛X?
  3. 排序 -> 快速排序
  4. 理解交换机通过逆向自学习算法建立地址转发表的过程_交换机与 VLAN 到底是怎么来的...
  5. 克隆对象和对象的继承
  6. Spring Boot系列(十二)Spring Boot整合ActiveQ实现消息收发和订阅
  7. 自定义格式字符串随笔(IFormattable,IFormatProvider,ICustomFormatter三接口的实现)
  8. python 判断当前系统的Python编译器类型
  9. HTML5 history新特性pushState、replaceState,popstate
  10. Optimizing regular expressions in Java