蓄水池抽样算法(Reservoir Sampling)

前言

什么是抽样问题?抽样问题就是从一个大数据量 N 中抽取出 M 个不重复的数据。例如:

  1. 从 100000 份调查报告中抽取 1000 份进行统计
  2. 从一本很厚的电话簿中抽取 1000 人进行统计

抽样问题最重要的是做到公平,也就是保证每个元素被抽到的概率是相同的。最容易想到的解决方法是生成随机数,例如对于问题 1,可以通过算法生成 [0, 100000-1] 间的随机数 1000 个,并且保证这 1000 个数字不重复,然后取出相应的元素即可。

但是对于问题 2 就不一样了,我们事先不知道数据的规模有多大,虽然可以先对数据进行一次遍历,计算出数据的数量,但这样在数据量很大时可能会很浪费时间。

这时就轮到蓄水池抽样算法出场了,它能够在只遍历一次数据的情况下,随机抽取出指定数量的不重复数据

算法过程

假设数据序列的规模为 n,需要抽取的数量为 m。

  1. 首先构建一个容量为 m 的蓄水池数组,将序列的前 m 个元素放入蓄水池中。
  2. 对于之后的元素,假设该元素为第 i 个(i >= m),在 [0, i] 中取得随机数 a,若 a 落在 [0, m-1] 范围内,那么该元素替换掉原来蓄水池中的第 a 个元素。
  3. 在遍历完数据序列后,蓄水池中剩下的元素即为要抽取的样本

算法证明

假设数据序列的规模为 n,需要抽取的数量为 m,对于第 i 个元素(i 从 0 开始):

  1. 当 i < m 时,元素进入水池的概率为 1,这个不难理解,下面看元素不被替换的概率,当遍历到第 m 个元素时,替换池中元素的概率为 m/(m+1),池中第 i 个元素被替换的概率为 1/m,所以总的来说,遍历到第 m 个元素时,第 i 个元素不被替换的概率为 1 - (m/(m+1) 1/m) = m/(m+1)。依次类推,当遍历到第 j 个元素时(j >= m),第 i 个元素不被替换的概率为 j/(j+1)。所以当遍历完 n 个元素时,第 i 个元素不被替换的概率为 m/(m+1) (m+1)/(m+2) … (n-1)/n = m/n。总的来说,当 i < m 时,第 i 个元素被抽取到的概率为 1 m/n = m/n
  2. 当 i >= m 时,在 [0, i] 中抽取随机数 d,如果 d < m,则替换掉池中的第 d 个元素,所以此时元素可以进入水池的概率为 m/(i+1)。元素进入到蓄水池后,由上面的证明可以得知,元素不被替换掉的概率为 (i+1)/n。所以总的来说,当 i >= m 时,第 i 个元素被抽取到的概率为 m/(i+1) * (i+1)/n = m/n

综上,可知遍历完一遍后,所以元素被抽取到的概率都为 m/n,即抽样是公平的。

代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
public class ReservoirSamplingDemo {
public static void main(String[] args) {
// 假设要从 data 数组中随机抽取 1000 个元素
int[] data = new int[100000];
for (int i = 0; i < data.length; i++) {
data[i] = i;
}
// 进行蓄水池抽样
int[] res = new ReservoirSamplingDemo().sampling(data, 1000);
// 打印结果
System.out.println(Arrays.toString(res));
}

/**
* 从 data 数组中随机抽取 m 个元素
*
* @param data
* @param m
* @return
*/
private int[] sampling(int[] data, int m) {
int[] reservoir = new int[m]; // 蓄水池
int n = data.length; // 数据序列的长度
// 先将前 m 个元素放入蓄水池
for (int i = 0; i < m; i++) {
reservoir[i] = data[i];
}
// 遍历之后的元素
Random random = new Random();
for (int i = m; i < n; i++) {
// 获得 [0, i] 中的一个随机数
int d = random.nextInt(i+1);
// 如果随机数 d 落在 [0, m) 的范围内,则替换掉第 d 个元素
if (d < m) {
reservoir[d] = data[i];
}
}

return reservoir;
}
}

应用:LeetCode 382 链表随机节点

题目描述

给定一个单链表,随机选择链表的一个节点,并返回相应的节点值。保证每个节点被选的概率一样。

进阶:
如果链表十分大且长度未知,如何解决这个问题?你能否使用常数级空间复杂度实现?

示例

1
2
3
4
5
6
7
8
// 初始化一个单链表 [1,2,3].
ListNode head = new ListNode(1);
head.next = new ListNode(2);
head.next.next = new ListNode(3);
Solution solution = new Solution(head);

// getRandom()方法应随机返回1,2,3中的一个,保证每个元素被返回的概率相等。
solution.getRandom();

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public class Solution {

private ListNode head;

/** @param head The linked list's head.
Note that the head is guaranteed to be not null, so it contains at least one node. */
public Solution(ListNode head) {
this.head = head;
}

/** Returns a random node's value. */
public int getRandom() {
int res = head.val; // 头结点进入蓄水池
ListNode temp = head.next;
int index = 1; // 记录当前节点(temp)的索引
Random random = new Random();
while (temp != null) {
// 获得 [0, index] 中的随机数
int rand = random.nextInt(index+1);
// 当获得的随机数为 0 时,替换掉蓄水池中的节点
if (rand == 0) {
res = temp.val;
}
// 下一节点
index++;
temp = temp.next;
}

return res;
}
}

参考

-------------    本文到此结束  感谢您的阅读    -------------
0%