openmp官方源码_Faiss 源码解析
Faiss 源码解析
faiss
是 facebook
开源的一个专门用于做高维向量的相似性搜索的库,有 c++
和 python
的接口;目前项目地址在 https://github.com/facebookresearch/faiss。本文主要结合 faiss
的官方示例,介绍如何使用 faiss
以及 暴力/IVF/IVFPQ
检索算法在 faiss
的具体实现。
检索算法介绍
检索算法的介绍可以参考 科普,本文主要关注3种检索算法:
- 暴力搜索:顾名思义,
query
和base
一一比对,选择最近的 IVF
:首先在具有代表性的数据上训练聚类中心,然后将base
加入到最近的聚类中心的桶里,在search
的时候,query
先和聚类中心比对,再在一定数目的桶里做暴力搜索IVFPQ
:在IVF
的基础上,将base
做PQ
量化,加速比对
faiss 的编译与安装
可以参考官方给出的编译方法,这里我没有安装 cuda,所以采用的命令是
./configure --without-cuda && make
在编译完 faiss
之后,我们对官方提供的示例也进行编译,路径在 ./tutorial/cpp
下,cd
到目录下直接 make
就可以了
如何使用 faiss
官方总共提供了五个示例,其中有两个是 gpu
版本的,三个是 cpu
版本的,我们这里主要关注 cpu
的,分别是 1-Flat.cpp
,2-IVFFLAT.cpp
,3-IVFPQ.cpp
,分别对应着暴力算法检索,IVF
算法检索,IVFPQ
算法检索。不同的算法在用户侧代码基本一致,我们选取 IVFPQ
做简单介绍。
#include <cstdio>
#include <cstdlib>#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>int main() {int d = 64; // 特征维度int nb = 100000; // base 样本数量int nq = 10000; // query 样本数量float *xb = new float[d * nb];float *xq = new float[d * nq];for(int i = 0; i < nb; i++) {for(int j = 0; j < d; j++)xb[d * i + j] = drand48();xb[d * i] += i / 1000.;} // 随机初始化 base 数据for(int i = 0; i < nq; i++) {for(int j = 0; j < d; j++)xq[d * i + j] = drand48();xq[d * i] += i / 1000.;} // 随机初始化 query 数据int nlist = 100; // 聚类中心个数int k = 4;int m = 8; // bytes per vectorfaiss::IndexFlatL2 quantizer(d); // 初始化用 L2 暴力 search 的 indexfaiss::IndexIVFPQ index(&quantizer, d, nlist, m, 8); // 初始化 ivfpq 的 index,用 L2 暴力 search 的 index 初始化index.train(nb, xb); // 训练 indexindex.add(nb, xb); // 将 base 数据加入到 index 中,用于之后的搜索{ // search xqlong *I = new long[k * nq];float *D = new float[k * nq];index.nprobe = 10; // 搜索 10 个中心点index.search(nq, xq, k, D, I);printf("I=n");for(int i = nq - 5; i < nq; i++) {for(int j = 0; j < k; j++)printf("%5ld ", I[i * k + j]);printf("n");}delete [] I;delete [] D;}delete [] xb;delete [] xq;return 0;
}
这段代码主要包括了四个部分,分别是
- 初始化
base/query
数据和index
- 训练
index
- 加入
base
到index
- 对
query
做search
其中,使用 faiss
主要包含了三步。初始化数据准备不用多说,faiss
中要求的数据格式都是 n * d
的矩阵格式,然后被展平到一维 float
数组中。剩下的两步,都是对 index
进行操作。
源码解析
检索流程
参考官方给的例子,检索分为三步:train
,add
,search
,不同的检索算法,体现在使用不同的 index
进行这三步上
train
:选取有代表性的数据,训练index
add
:将base
数据加入到index
中search
:对于给定的query
,返回其对应的在底库中的topk
重要类
Index
index
的基类,后续各种各样的检索算法,都会继承这个基类或者这个类的派生类,然后实现具体的方法,在这个类中,有如下的数据成员:
d
:维度,每个向量的维度ntotal
:索引的向量的数目,可以理解成检索时的base
数目metric_type
:检索时使用的metric
类型,比如L2
,内积等
IndexFlat
用于做暴力搜索的 index
类,直接继承 index
。暴力搜索思路很简单,无需 train
,add
的所有 base 都被存储起来,然后在 search
的时候把 query
和所有 base
进行比对,选取最近的。我们看下具体实现。
add
add
就是把所有的 base
都存储起来
void IndexFlat::add (idx_t n, const float *x) {xb.insert(xb.end(), x, x + n * d);ntotal += n;
}
Search
Search
的时候,根据 metric type
的不同,返回 query
的 topk
。具体计算时采用了 openmp
和sse
/avx
优化
void IndexFlat::search (idx_t n, const float *x, idx_t k,float *distances, idx_t *labels) const
{// we see the distances and labels as heapsif (metric_type == METRIC_INNER_PRODUCT) {float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};knn_inner_product (x, xb.data(), d, n, ntotal, &res); //函数内部有并行优化} else if (metric_type == METRIC_L2) {float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};knn_L2sqr (x, xb.data(), d, n, ntotal, &res);} else {float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};knn_extra_metrics (x, xb.data(), d, n, ntotal,metric_type, metric_arg,&res);}
}
Clustering
实现 K-means
聚类的类,提供train
,需要训练数据和 index
(用于 search
最近的向量),结果得到训练数据的类中心向量,如果是量化的向量,那么还需要提供量化使用的 index codec
,我们去除量化的部分,只看 float
数据
核心代码如下,包括如下部分:
search
过程,将聚类中心作为底库加入到index
中,并对训练数据做search
,得到assign
- 计算新的聚类中心,计算新的聚类中心的代码在
compute_centroids
中,具体就是对于相同的类别的向量,将向量的均值作为新的中心,在实现上,利用 openmp 进行了并行优化
重复以上两步,就可以得到最优的聚类中心
void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,const Index * codec, Index & index,const float *weights) {// 前处理省略 for (int redo = 0; redo < nredo; redo++) {if (verbose && nredo > 1) {printf("Outer iteration %d / %dn", redo, nredo);}// initialize (remaining) centroids with random points from the datasetcentroids.resize (d * k);std::vector<int> perm (nx);rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);for (int i = n_input_centroids; i < k ; i++) {memcpy (¢roids[i * d], x + perm[i] * line_size, line_size);}post_process_centroids ();// prepare the indexif (index.ntotal != 0) {index.reset();}index.add (k, centroids.data());// k-means iterationsfloat err = 0;for (int i = 0; i < niter; i++) {double t0s = getmillisecs();index.search (nx, reinterpret_cast<const float *>(x), 1,dis.get(), assign.get());InterruptCallback::check();t_search_tot += getmillisecs() - t0s;// accumulate errorerr = 0;for (int j = 0; j < nx; j++) {err += dis[j];}// update the centroidsstd::vector<float> hassign (k);size_t k_frozen = frozen_centroids ? n_input_centroids : 0;compute_centroids (d, k, nx, k_frozen,x, codec, assign.get(), weights,hassign.data(), centroids.data());index.reset ();if (update_index) {index.train (k, centroids.data());}index.add (k, centroids.data());InterruptCallback::check ();}}//保存最优聚类中心if (nredo > 1) {centroids = best_centroids;iteration_stats = best_obj;index.reset();index.add(k, best_centroids.data());}}void compute_centroids (size_t d, size_t k, size_t n,size_t k_frozen,const uint8_t * x, const Index *codec,const int64_t * assign,const float * weights,float * hassign,float * centroids)
{k -= k_frozen;centroids += k_frozen * d;memset (centroids, 0, sizeof(*centroids) * d * k);size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);#pragma omp parallel{int nt = omp_get_num_threads();int rank = omp_get_thread_num();// this thread is taking care of centroids c0:c1size_t c0 = (k * rank) / nt;size_t c1 = (k * (rank + 1)) / nt;std::vector<float> decode_buffer (d);for (size_t i = 0; i < n; i++) {int64_t ci = assign[i];assert (ci >= 0 && ci < k + k_frozen);ci -= k_frozen;if (ci >= c0 && ci < c1) {float * c = centroids + ci * d;const float * xi;if (!codec) {xi = reinterpret_cast<const float*>(x + i * line_size);} else {float *xif = decode_buffer.data();codec->sa_decode (1, x + i * line_size, xif);xi = xif;}if (weights) {float w = weights[i];hassign[ci] += w;for (size_t j = 0; j < d; j++) {c[j] += xi[j] * w;}} else {hassign[ci] += 1.0;for (size_t j = 0; j < d; j++) {c[j] += xi[j];}}}}}#pragma omp parallel forfor (size_t ci = 0; ci < k; ci++) {if (hassign[ci] == 0) {continue;}float norm = 1 / hassign[ci];float * c = centroids + ci * d;for (size_t j = 0; j < d; j++) {c[j] *= norm;}}}
IndexIVF
用于做 IVF 搜索的 index 类。
train
ivf
算法会把给定的数据进行聚类,得到固定数目的聚类中心。具体的,就是 train_q1
的过程,train_residual
在 ivf 中是一个空函数
void IndexIVF::train (idx_t n, const float *x)
{train_q1 (n, x, verbose, metric_type);train_residual (n, x);is_trained = true;
}void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {if (verbose)printf("IndexIVF: no residual trainingn");// does nothing by default
}
train_q1
用的是 Level1Quantizer
的具体实现,如下,对训练数据进行聚类,得到聚类中心并保存下来
void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
{// 省略无关代码Clustering clus (d, nlist, cp);quantizer->reset();if (clustering_index) {clus.train (n, x, *clustering_index);quantizer->add (nlist, clus.centroids.data());} else {clus.train (n, x, *quantizer);}quantizer->is_trained = true;
}
add
- 分片。根据输入的大小,按照固定的大小依次进行
add
- 建立
invlists
。根据train
得到的聚类中心(保存在quantizer
中),每一个类中心对应invlists
中的一个桶。 - 向
invlists
的桶里加入base
。利用了openmp
进行了并行加速
- 分片。根据输入的大小,按照固定的大小依次进行
void IndexIVF::add (idx_t n, const float * x)
{add_with_ids (n, x, nullptr);
}void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
{// do some blocking to avoid excessive allocsidx_t bs = 65536;if (n > bs) {for (idx_t i0 = 0; i0 < n; i0 += bs) {idx_t i1 = std::min (n, i0 + bs);if (verbose) {printf(" IndexIVF::add_with_ids %ld:%ldn", i0, i1);}add_with_ids (i1 - i0, x + i0 * d,xids ? xids + i0 : nullptr);}return;}std::unique_ptr<idx_t []> idx(new idx_t[n]);quantizer->assign (n, x, idx.get());size_t nadd = 0, nminus1 = 0;#pragma omp parallel reduction(+: nadd){int nt = omp_get_num_threads();int rank = omp_get_thread_num();// each thread takes care of a subset of listsfor (size_t i = 0; i < n; i++) {idx_t list_no = idx [i];if (list_no >= 0 && list_no % nt == rank) {idx_t id = xids ? xids[i] : ntotal + i;size_t ofs = invlists->add_entry (list_no, id,flat_codes.get() + i * code_size);dm_adder.add (i, list_no, ofs);nadd++;} else if (rank == 0 && list_no == -1) {dm_adder.add (i, -1, 0);}}}ntotal += n;
}
search
Search corse_dis
。搜索离query
最近的聚类中心Search invlists
。在最近的nprobe
个聚类中心对应的invlists
中进行暴力heap
搜索,得到topk
void IndexIVF::search (idx_t n, const float *x, idx_t k,float *distances, idx_t *labels) const
{std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);double t0 = getmillisecs();quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());indexIVF_stats.quantization_time += getmillisecs() - t0;t0 = getmillisecs();invlists->prefetch_lists (idx.get(), n * nprobe);search_preassigned (n, x, k, idx.get(), coarse_dis.get(),distances, labels, false);indexIVF_stats.search_time += getmillisecs() - t0;
}
ProductQuantizer
用来做 PQ
量化算法的类,关于 PQ
量化算法,可以参考 pq算法。简单来说,我们需要得到用来量化的码本,然后我们可以对输入的向量进行解码和编码。得到码本的过程在 ProductQuantizer::train
中,包含
- 将输入向量按照维度切分成
PQ
段,每段的维度是dsub
- 得到每段的聚类中心,这就是码本
编码和解码的过程就是将输入向量转化为码本里的 idx
,可以看出,量化是存在一定的误差,其中,PQ
越大,误差越小
void ProductQuantizer::train (int n, const float * x)
{if (train_type != Train_shared) {train_type_t final_train_type;final_train_type = train_type;if (train_type == Train_hypercube ||train_type == Train_hypercube_pca) {if (dsub < nbits) {final_train_type = Train_default;printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)n",nbits, dsub);}}float * xslice = new float[n * dsub];ScopeDeleter<float> del (xslice);for (int m = 0; m < M; m++) {for (int j = 0; j < n; j++)memcpy (xslice + j * dsub,x + j * d + m * dsub,dsub * sizeof(float));Clustering clus (dsub, ksub, cp);// we have some initialization for the centroidsif (final_train_type != Train_default) {clus.centroids.resize (dsub * ksub);}switch (final_train_type) {case Train_hypercube:init_hypercube (dsub, nbits, n, xslice,clus.centroids.data ());break;case Train_hypercube_pca:init_hypercube_pca (dsub, nbits, n, xslice,clus.centroids.data ());break;case Train_hot_start:memcpy (clus.centroids.data(),get_centroids (m, 0),dsub * ksub * sizeof (float));break;default: ;}if(verbose) {clus.verbose = true;printf ("Training PQ slice %d/%zdn", m, M);}IndexFlatL2 index (dsub);clus.train (n, xslice, assign_index ? *assign_index : index);set_params (clus.centroids.data(), m);}} else {Clustering clus (dsub, ksub, cp);if(verbose) {clus.verbose = true;printf ("Training all PQ slices at oncen");}IndexFlatL2 index (dsub);clus.train (n * M, x, assign_index ? *assign_index : index);for (int m = 0; m < M; m++) {set_params (clus.centroids.data(), m);}}
}
IndexIVFPQ
ivfpq
算法在 ivf
的基础上,对 base
做 pq
。大家可以自行参考代码
openmp官方源码_Faiss 源码解析相关推荐
- Android Fragment 从源码的角度去解析(上)
###1.概述 本来想着昨天星期五可以早点休息,今天可以早点起来跑步,可没想到事情那么的多,晚上有人问我主页怎么做到点击才去加载Fragment数据,而不是一进入主页就去加载所有的数据,在这里自己就对 ...
- 【Android源码】源码分析深度好文+精编内核解析分享
阅读Android源码的好处有很多,比如:可以加深我们对系统的了解:可以参考牛人优雅的代码实现:可以从根本上找出一些bug的原因-我们应该庆幸Android是开源的,所有的功能都可以看到实现,所有的b ...
- MyBatis 源码分析 - 映射文件解析过程
1.简介 在上一篇文章中,我详细分析了 MyBatis 配置文件的解析过程.由于上一篇文章的篇幅比较大,加之映射文件解析过程也比较复杂的原因.所以我将映射文件解析过程的分析内容从上一篇文章中抽取出来, ...
- 【flink】Flink 1.12.2 源码浅析 : yarn-per-job模式解析 TaskMasger 启动
1.概述 转载:Flink 1.12.2 源码浅析 : yarn-per-job模式解析 [四] 上一篇: [flink]Flink 1.12.2 源码浅析 : yarn-per-job模式解析 Jo ...
- 【flink】Flink 1.12.2 源码浅析 : yarn-per-job模式解析 JobMasger启动 YarnJobClusterEntrypoint
1.概述 转载:Flink 1.12.2 源码浅析 : yarn-per-job模式解析 [三] 上一章:[flink]Flink 1.12.2 源码浅析 : yarn-per-job模式解析 yar ...
- 【flink】Flink 1.12.2 源码浅析 : yarn-per-job模式解析 yarn 提交过程解析
1.概述 转载:Flink 1.12.2 源码浅析 : yarn-per-job模式解析 [二] 请大家看原文去. 接上文Flink 1.12.2 源码分析 : yarn-per-job模式浅析 [一 ...
- 【flink】Flink 1.12.2 源码浅析 : yarn-per-job模式解析 从脚本到主类
1.概述 转载:Flink 1.12.2 源码浅析 : yarn-per-job模式解析 [一] 可以去看原文.这里是补充专栏.请看原文 2. 前言 主要针对yarn-per-job模式进行代码分析. ...
- [darknet源码系列-2] darknet源码中的cfg解析
[darknet源码系列-2] darknet源码中的cfg解析 FesianXu 20201118 at UESTC 前言 笔者在[1]一文中简单介绍了在darknet中常见的数据结构,本文继续上文 ...
- UE4官方滚球项目源码笔记
UE4官方滚球项目源码笔记 我的项目名称:test_0511,读者请根据自己的项目名称自行查找(YourProgramNameBall.h/YourProgramNameBall.cpp) 笔者是UE ...
最新文章
- “大型票务系统”和“实物电商系统”的数据库选型
- 730版本去掉恼人的提示信息
- libsvm java下载_java-libsvm 版 结合已有数据集的demo,方便初学者使用 Develop 238万源代码下载- www.pudn.com...
- dns的主从服务器的简单配置
- DELL服务器T410进行系统修复,ibm T410 BIOS修复过程-BIOS维修网站www.biosrepair.com
- CentOS6.8下安装memcached并设置开机自启动
- 计算机学生的高职英语课程,高职计算机英语课程教学方法探索
- Xmind 8 Pro破解版安装激活教程(Windows版)
- HDU - 1598
- 企业邮箱注册申请入口,公司邮箱申请哪个好?
- chrome插件莫名消失【已解决】
- 007需求分析中的重要知识点(马斯洛需求层次理论+KANO优先级筛选模型+金字塔模型)
- 信息系统项目管理师必背核心考点(四十九)合同法
- HIT CSAPP大作业论文
- r语言c()函数格式,R语言基本操作函数
- windos10本地安装git工具并使用
- 少不读水浒——揭秘水浒传
- 王桂林 C++基础与提高 练习题——string数组
- 无线网络中AP及AC的概念及作用:
- 数据库优化的八种方式
热门文章
- Log4j2基本使用入门
- 阿里首推的“SpringBoot+Vue全栈项目”有多牛X?
- 排序 -> 快速排序
- 理解交换机通过逆向自学习算法建立地址转发表的过程_交换机与 VLAN 到底是怎么来的...
- 克隆对象和对象的继承
- Spring Boot系列(十二)Spring Boot整合ActiveQ实现消息收发和订阅
- 自定义格式字符串随笔(IFormattable,IFormatProvider,ICustomFormatter三接口的实现)
- python 判断当前系统的Python编译器类型
- HTML5 history新特性pushState、replaceState,popstate
- Optimizing regular expressions in Java