蓄水池抽样算法(Reservoir Sampling)

许多年以后,当听说蓄水池抽样算法时,邱simple将会想起,那个小学数学老师带他做“小明对水池边加水边放水,求何时能加满水”应用题的下午。

一、问题

我是在一次失败的面试经历中听说蓄水池算法的。之后上网搜了搜,知道是一个数据抽样算法,寥寥几行,却暗藏玄机。主要用来解决如下问题。

给定一个数据流,数据流长度N很大,且N直到处理完所有数据之前都不可知,请问如何在只遍历一遍数据(O(N))的情况下,能够随机选取出m个不重复的数据。

这个场景强调了3件事:

数据流长度N很大且不可知,所以不能一次性存入内存。

时间复杂度为O(N)。

随机选取m个数,每个数被选中的概率为m/N。

第1点限制了不能直接取N内的m个随机数,然后按索引取出数据。第2点限制了不能先遍历一遍,然后分块存储数据,再随机选取。第3点是数据选取绝对随机的保证。讲真,在不知道蓄水池算法前,我想破脑袋也不知道该题做何解。

二、核心代码及原理

蓄水池抽样算法的核心如下:

int[] reservoir = new int[m];

// init

for (int i = 0; i < reservoir.length; i++)

{

reservoir[i] = dataStream[i];

}

for (int i = m; i < dataStream.length; i++)

{

// 随机获得一个[0, i]内的随机整数

int d = rand.nextInt(i + 1);

// 如果随机整数落在[0, m-1]范围内,则替换蓄水池中的元素

if (d < m)

{

reservoir[d] = dataStream[i];

}

}

注:这里使用已知长度的数组dataStream来表示未知长度的数据流,并假设数据流长度大于蓄水池容量m。

算法思路大致如下:

如果接收的数据量小于m,则依次放入蓄水池。

当接收到第i个数据时,i >= m,在[0, i]范围内取以随机数d,若d的落在[0, m-1]范围内,则用接收到的第i个数据替换蓄水池中的第d个数据。

重复步骤2。

算法的精妙之处在于:当处理完所有的数据时,蓄水池中的每个数据都是以m/N的概率获得的。

下面用白话文推导验证该算法。假设数据开始编号为1.

第i个接收到的数据最后能够留在蓄水池中的概率=第i个数据进入过蓄水池的概率*之后第i个数据不被替换的概率(第i+1到第N次处理数据都不会被替换)。

当i<=m时,数据直接放进蓄水池,所以第i个数据进入过蓄水池的概率=1。

当i>m时,在[1,i]内选取随机数d,如果d<=m,则使用第i个数据替换蓄水池中第d个数据,因此第i个数据进入过蓄水池的概率=m/i。

当i<=m时,程序从接收到第m+1个数据时开始执行替换操作,第m+1次处理会替换池中数据的为m/(m+1),会替换掉第i个数据的概率为1/m,则第m+1次处理替换掉第i个数据的概率为(m/(m+1))*(1/m)=1/(m+1),不被替换的概率为1-1/(m+1)=m/(m+1)。依次,第m+2次处理不替换掉第i个数据概率为(m+1)/(m+2)...第N次处理不替换掉第i个数据的概率为(N-1)/N。所以,之后第i个数据不被替换的概率=m/(m+1)*(m+1)/(m+2)*...*(N-1)/N=m/N。

当i>m时,程序从接收到第i+1个数据时开始有可能替换第i个数据。则参考上述第3点,之后第i个数据不被替换的概率=i/N。

结合第1点和第3点可知,当i<=m时,第i个接收到的数据最后留在蓄水池中的概率=1*m/N=m/N。结合第2点和第4点可知,当i>m时,第i个接收到的数据留在蓄水池中的概率=m/i*i/N=m/N。综上可知,每个数据最后被选中留在蓄水池中的概率为m/N。

这个算法建立在统计学基础上,很巧妙地获得了“m/N”这个概率。

三、深入一些——分布式蓄水池抽样(Distributed/Parallel Reservoir Sampling)

一块CPU的计算能力再强,也总有内存和磁盘IO拖他的后腿。因此为提高数据吞吐量,分布式的硬件搭配软件是现在的主流。

如果遇到超大的数据量,即使是O(N)的时间复杂度,蓄水池抽样程序完成抽样任务也将耗时很久。因此分布式的蓄水池抽样算法应运而生。运作原理如下:

假设有K台机器,将大数据集分成K个数据流,每台机器使用单机版蓄水池抽样处理一个数据流,抽样m个数据,并最后记录处理的数据量为N1, N2, ..., Nk, ..., NK(假设m

取[1, N]一个随机数d,若d

m/N的概率验证如下:

第k台机器中的蓄水池数据被选取的概率为m/Nk。

从第k台机器的蓄水池中选取一个数据放进最终蓄水池的概率为Nk/N。

第k台机器蓄水池的一个数据被选中的概率为1/m。(不放回选取时等概率的)

重复m次选取,则每个数据被选中的概率为m*(m/Nk*Nk/N*1/m)=m/N。

四、算法验证

写一份完整的代码,用来验证蓄水池抽样的随机性。数据集大小为1000,蓄水池容量为10,做10_0000次抽样。如果程序正确,那么每个数被抽中的次数接近1000次。

package cn.edu.njupt.qyz;

import java.util.ArrayList;

import java.util.Arrays;

import java.util.LinkedList;

import java.util.List;

import java.util.Random;

import java.util.Set;

import java.util.TreeSet;

import java.util.concurrent.Callable;

import java.util.concurrent.ExecutionException;

import java.util.concurrent.ExecutorService;

import java.util.concurrent.Executors;

import java.util.concurrent.Future;

public class ReservoirSampling {

static ExecutorService exec = Executors.newFixedThreadPool(4);

// 抽样任务,用作模拟并行抽样

private static class SampleTask implements Callable

{

// 输入该任务的数据

private int[] innerData;

// 蓄水池容量

private int m;

SampleTask (int m, int[] innerData)

{

this.innerData = innerData;

this.m = m;

}

@Override

public int[] call() throws Exception

{

int[] reservoir = sample(this.m, this.innerData);

return reservoir;

}

}

// 并行抽样

public static int[] mutiSample(int m, int[] dataStream) throws InterruptedException, ExecutionException

{

Random rand = new Random();

int[] reservoir = initReservoir(m, dataStream);

// 生成3个范围内随机数,将数据切成4份

List list = getRandInt(rand, dataStream.length);

int s1 = list.get(0);

int s2 = list.get(1);

int s3 = list.get(2);

// 每个任务处理的数据量

double n1 = s1 - 0;

double n2 = s2 - s1;

double n3 = s3 - s2;

double n4 = dataStream.length - s3;

// 并行抽样

Future f1 = exec.submit(new SampleTask(m, Arrays.copyOfRange(dataStream, 0, s1)));

Future f2 = exec.submit(new SampleTask(m, Arrays.copyOfRange(dataStream, s1, s2)));

Future f3 = exec.submit(new SampleTask(m, Arrays.copyOfRange(dataStream, s2, s3)));

Future f4 = exec.submit(new SampleTask(m, Arrays.copyOfRange(dataStream, s3, dataStream.length)));

List r1 = getList(f1.get());

List r2 = getList(f2.get());

List r3 = getList(f3.get());

List r4 = getList(f4.get());

// 进行m次抽样

for (int i = 0; i < m; i++)

{

int p = rand.nextInt(dataStream.length);

// 根据随机数落在的范围选择元素

if (p < s1)

{

reservoir[i] = getRandEle(r1, rand.nextInt(r1.size()));

}

else if (p < s2)

{

reservoir[i] = getRandEle(r2, rand.nextInt(r2.size()));

}

else if (p < s3)

{

reservoir[i] = getRandEle(r3, rand.nextInt(r3.size()));

}

else

{

reservoir[i] = getRandEle(r4, rand.nextInt(r4.size()));

}

}

return reservoir;

}

// 根据输入返回随机位置的元素,并且删除该元素,模拟不放回

private static int getRandEle(List list, int idx)

{

return list.remove(idx);

}

// 获取bound范围内的3个随机数,用来分割数据集

private static List getRandInt(Random rand, int bound)

{

Set set = new TreeSet<>();

List list = new ArrayList<>();

while (set.size() < 3)

{

set.add(rand.nextInt(bound));

}

for (int e: set)

{

list.add(e);

}

return list;

}

// 数据转换成List

private static List getList(int[] arr)

{

List list = new LinkedList<>();

for (int a : arr)

{

list.add(a);

}

return list;

}

// 单机版蓄水池抽样

public static int[] sample(int m, int[] dataStream)

{

// 随机数生成器,以系统当前nano时间作为种子

Random rand = new Random();

int[] reservoir = initReservoir(m, dataStream);

// init

for (int i = 0; i < reservoir.length; i++)

{

reservoir[i] = dataStream[i];

}

for (int i = m; i < dataStream.length; i++)

{

// 随机获得一个[0, i]内的随机整数

int d = rand.nextInt(i + 1);

// 如果随机整数在[0, m-1]范围内,则替换蓄水池中的元素

if (d < m)

{

reservoir[d] = dataStream[i];

}

}

return reservoir;

}

private static int[] initReservoir (int m, int[] dataStream)

{

int[] reservoir;

if (m > dataStream.length)

{

reservoir = new int[dataStream.length];

}

else

{

reservoir = new int[m];

}

return reservoir;

}

// 单机版测试

public void test()

{

// 样本长度

int len = 1000;

// 蓄水池容量

int m = 10;

// 抽样次数,用作验证抽样的随机性

int iterTime = 100000;

// 每个数字被抽到的次数

int[] freq = new int[len];

// 样本

int[] dataStream = new int[len];

// init dataStream

for (int i = 0; i < dataStream.length; i++)

{

dataStream[i] = i;

}

// count freq

for (int k = 0; k < iterTime; k++)

{

// 进行抽样

int[] reservoir = sample(m, dataStream);

// 计算出现次数

for (int i = 0; i < reservoir.length; i++)

{

int ele = reservoir[i];

freq[ele] += 1;

}

}

printStaticInfo(freq);

}

// 测试并行抽样

public void mutiTest() throws InterruptedException, ExecutionException

{

// 样本长度

int len = 1000;

// 蓄水池容量

int m = 10;

// 抽样次数,用作验证抽样的随机性

int iterTime = 10_0000;

// 每个数字被抽到的次数

int[] freq = new int[len];

// 样本

int[] dataStream = new int[len];

// init dataStream

for (int i = 0; i < dataStream.length; i++)

{

dataStream[i] = i;

}

// count freq

for (int k = 0; k < iterTime; k++)

{

// 进行抽样

int[] reservoir = mutiSample(m, dataStream);

// 计算出现次数

for (int i = 0; i < reservoir.length; i++)

{

int ele = reservoir[i];

freq[ele] += 1;

}

}

printStaticInfo(freq);

}

// 打印统计信息

private void printStaticInfo (int[] freq)

{

// 期望、方差和标准差

double avg = 0;

double var = 0;

double sigma = 0;

// print

for (int i = 0; i < freq.length; i++)

{

if (i % 10 == 9) System.out.println();

System.out.print(freq[i] + ", ");

avg += ((double)(freq[i]) / freq.length);

var += (double)(freq[i] * freq[i]) / freq.length;

}

// 输出统计信息

System.out.println("\n===============================");

var = var - avg * avg;

sigma = Math.sqrt(var);

System.out.println("Average: " + avg);

System.out.println("Variance: " + var);

System.out.println("Standard deviation: " + sigma);

}

public static void main (String[] args) throws InterruptedException, ExecutionException

{

ReservoirSampling rs = new ReservoirSampling();

rs.mutiTest();

}

}

单机版输出和并行版的输出类似,截取片段如下:

948, 1006, 1014, 1019, 1033, 1040, 948, 1014, 1000, 951,

1014, 987, 1049, 1043, 1034, 983, 1006, 974, 1060, 1009,

986, 1021, 1024, 963, 1041, 1028, 988, 1011, 975, 980,

1055, 1017, 1010, 1018, 1013, 983, 942, 1056, 1003, 1063,

1004, 1004, 999, 976, 957, 935, 1061, 1018, 1002, 1018,

1019, 946, 985, 1057, 1012, 965, 978, 1040, 1026, 1064,

1026, 1018, 980, 996, 1025, 1028, 1006, 944, 986, 981,

923, 1015, 991, 1019, 1024, 1143, 989, 985, 1022, 1019,

1004, 1000, 989, 972, 1041, 988, 1050, 932, 975, 1037,

1016, 983, 1051, 1003, 983, 986, 1017, 1009, 936, 993,

965, 976, 1001, 1000, 988, 1030, 1050, 1024, 981, 985,

935, 1023, 996, 1007, 1013, 1046, 1003, 1006, 973, 989,

943,

===============================

Average: 1000.0000000000002

Variance: 1011.8799999983748

Standard deviation: 31.81006130139291

此外,为了对比单机版与并行版(4线程)的性能差异,使用10_0000大小的数据集,蓄水池容量10,进行100_0000次重复抽样,对比两者的运行时间。结果如下

---------单机版----------

===============================

Average: 100.00000000000125

Variance: 100.31497999751264

Standard deviation: 10.015736617818613

---------并行版----------

===============================

Average: 100.00000000000169

Variance: 100.63045999737915

Standard deviation: 10.031473470900432

单机版耗时:2006s

并行版耗时:1265s

从输出结果可以看出,算法保证了数据选取的随机性。且并行版算法能够有效提高数据吞吐量。

五、应用场景

蓄水池抽样的O(N)时间复杂度,O(m)空间复杂度令其适用于对流数据、大数据集的等概率抽样。比如一个大文本数据,随机输出其中的几行。

六、总结

象征性总结:优雅巧妙的算法——蓄水池抽样。

七、参考文献

蓄水池采样算法的python实现_蓄水池抽样算法(Reservoir Sampling)相关推荐

  1. 蓄水池采样算法的python实现_蓄水池采样算法的python实现_蓄水池抽样及实现

    蓄水池抽样(Reservoir Sampling)是一个很有趣的问题,它能够在o(n)时间内对n个数据进行等概率随机抽取,例如:从1000个数据中等概率随机抽取出100个.另外,如果数据集合的量特别大 ...

  2. 蓄水池采样算法的python实现_蓄水池抽样及实现

    蓄水池抽样(Reservoir Sampling)是一个很有趣的问题,它能够在o(n)时间内对n个数据进行等概率随机抽取,例如:从1000个数据中等概率随机抽取出100个.另外,如果数据集合的量特别大 ...

  3. Python+Django+Mysql简单在线课程推荐系统 基于用户、项目、内容的协同过滤推荐算法 SimpleOnlineCourseCFRSPyth python实现协同过滤推荐算法实现源代码下载

    Python+Django+Mysql简单在线课程推荐系统 基于用户.项目.内容的协同过滤推荐算法 SimpleOnlineCourseCFRSPyth python实现协同过滤推荐算法实现源代码下载 ...

  4. Python+Django+Mysql个性化二手车推荐系统 汽车推荐系统 基于用户、项目、内容的协同过滤推荐算法 WebCarCFRSPython python实现协同过滤推荐算法实现源代码下载

    Python+Django+Mysql个性化二手车推荐系统 汽车推荐系统 基于用户.项目.内容的协同过滤推荐算法 WebCarCFRSPython python实现协同过滤推荐算法实现源代码下载 一. ...

  5. 蓄水池采样算法的python实现_常用算法-蓄水池抽样算法

    Leetcode上遇到一道题,题目是这样的: 这道题的关键是链表的长度不知道,但是要使随机返回每个元素的概率相等,这一下就难倒我了,如果知道链表的长度k,从0到k中随机选择一个整数就好了呀,可现在不知 ...

  6. 人工鱼群算法python代码_人工鱼群算法python_鱼群算法 - Brillou的个人空间 - OSCHINA - 中文开源技术交流社区......

    本算法是参照李晓磊博士的论文实现的,详细的算法原理可阅读<一种新型的智能优化方法_人工鱼群算法_李晓磊> 算法基于鱼群的生存行为:在一片水域中,鱼存在的数目最多的地方就是本水域中富含营养物 ...

  7. 正向最大匹配算法 python代码_中文分词算法之最大正向匹配算法(Python版)

    最大匹配算法是自然语言处理中的中文匹配算法中最基础的算法,分为正向和逆向,原理都是一样的. 正向最大匹配算法,故名思意,从左向右扫描寻找词的最大匹配. 首先我们可以规定一个词的最大长度,每次扫描的时候 ...

  8. matlab算法用python做_机器学习笔记—朴素贝叶斯算法实现(matlab/python)

    原理知道一百遍不如自己动手写一遍,当然,现在基本上不需要自己来写算法的底层code了,各路大神们已经为我等凡夫俗子写好了,直接调用就行. 这里介绍在MATLAB中和Python中应用贝叶斯算法的小例子 ...

  9. 排序算法python实现_合并排序算法– Java,C和Python实现

    排序算法python实现 Merge sort is one of the most efficient sorting algorithms. It works on the principle o ...

最新文章

  1. OC__part11
  2. php并行运算,php多进程并行执行脚本的代码
  3. python玩转android_如何用python玩跳一跳 ?(安卓版)
  4. 大学本科 java教材,大学本科自学java之路——IO
  5. 【Sigmoid】操作对象(对数据结果进行验证)
  6. [codeblocks下实现多个源文件2.0]类声明和成员函数定义的分离实战操作实现
  7. hevc参考代码matlab版,HEVC部分源码剖析
  8. 30种EMC标准电路
  9. linux wamp,常见的WAMP集成环境
  10. 小马激活工具出现Cannot open file k:\OEMSF 的解决方法
  11. php faker,PHP Faker 教程
  12. Deferred Shading介绍
  13. 菜鸟渗透日记29---python渗透测试编程之信息收集1-主机发现
  14. 大数据分析案例-对电信客户流失分析预警预测
  15. 最大流、最小费用最大流【模板】
  16. js中for(i in array)和for(i=0;i<array.length;i++)之间的坑
  17. 了解常见的模拟器及交换机的基本配置
  18. mysql - database/数据库
  19. 向量中断、中断向量、向量地址3个概念是什么关系?
  20. 论「能写代码」的程序员与「会写代码」的程序员

热门文章

  1. 台式计算机快捷键大全,电脑常用快捷键有哪些
  2. vpc经典网络区别_网络运维与管理 经典网络还是VPC如何选择
  3. 轻言壁纸微信小程序源码
  4. 使用阿里云云服务器一年多的感受
  5. restore,recover的区别
  6. LOESS与LOWESS
  7. 【软考】下午题 解题思路总结
  8. Tomcat网站上的core和deployer的区别
  9. Android的公交线路查询系统
  10. 经典查找算法 --- B+树