用优先队列解决TopM和Multiway问题

TopM 问题描述

从 N 个输入中找到最大/小的 M 个元素。

使用基于堆的优先队列解决 TopM 问题

算法描述

解决此问题的算法不止一个,但这里只关注下面这个:
将前 M 个输入读入一个大小为 M 的集合 S,并使集合 S 中的最小元素位于位置 k。为了方便描述,将集合 S 中的最小元素记为 $E_k$。当再读入一个新的输入时,将其与 $E_k$ 进行比较,若新的输入较大,就从集合 S 中删除当前的 $E_k$,并将新的输入添加至集合 S,然后找出新的 $E_k$。处理完所有输入后,就得到最大的 M 个输入。

算法实现

创建一个大小为 M+1 的 MinPQ,将 N 个输入插入此优先队列,每次插入完成后就检查队列大小是否大于 M,如果大于就删除最小元素。

下面的示例代码找出交易金额最大的 M 笔交易:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

public static void main(String[] args) {
int M = Integer.valueOf(args[0]);
MinPQ<Transaction> pq = new MinPQ<Transaction>(M+1);
while (StdIn.hasNextLine()) {
pq.insert(new Transaction(StdIn.readLine()));
if (pq.size() > M) {
pq.delMin();
}
} // 最大的 M 个元素都在优先队列中

Stack<Transaction> stack = new Stack<Transaction>();
while (!pq.isEmpty()) {
stack.push(pq.delMin());
}
for (Transaction t : stack) {
StdOut.println(t);
}
}

性能评估

从 N 个输入中找到最大的 M 个元素所需的时间成本:

解决办法 时间的增长数量级 空间的增长数量级
使用排序算法 $NlogN$ N
使用初级实现的优先队列 NM M
使用基于堆的优先队列 $NlogM$ M

Multiway 问题描述

将多个有序的输入流归并成一个有序的输入流。也就是常说的多路/向归并问题。

使用基于堆的索引优先队列解决多路归并问题

算法描述

算法思路如图所示,其中的前提是每个流都是单独升序有序的:

将多个升序有序的流归并

使用索引优先队列实现算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

public class Multiway {

// Merge together the sorted input streams
// and write the sorted result to standard output
public static void merge(In[] streams) {
int n = streams.length;
IndexedMinPQ<String> pq = new IndexedMinPQ<String>(n);

for (int i = 0; i < n; i++) {
if (!streams[i].isEmpty()) {
pq.insert(i, streams[i].readString());
}
}

while (!pq.isEmpty()) {
StdOut.print(pq.min() + " ");
int i = pq.delMin();
if (!streams[i].isEmpty()) {
pq.insert(i, streams[i].readString());
}
}
}

public static void main(String[] args) {
int n = args.length;
In[] streams = new In[n];
for (int i = 0; i < n; i++) {
streams[i] = new In(args[i]);
}
merge(streams);
}
}

索引优先队列

API

定义 描述
void insert(int i, Item item) 插入一个元素,将它和索引 i 关联
void change(int i, Item item) 将与索引 i 关联的元素设置为 item
bool contains(int i) 是否存在索引为 i 的元素
void delete(int i) 删除索引 i 及其关联的元素
void min() 返回最小元素
int minIndex() 返回最小元素的索引
int delMin() 删除最小元素并返回它的索引
bool isEmpty() 优先队列是否为空
int size() 优先队列中的元素数量

基于堆的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

public class IndexedMinPQ<E extends Comparable<E>> implements Iterable<Integer> {

private int maxN; // 优先队列可容纳的最大元素个数
private int n; // 当前优先队列上的元素个数

// 保存用户输入的元素,下标就是对应元素的索引
private E[] items;

// 二叉堆的数组表示形式。节点之间的父子关系仍使用下标表示,
// 但下标对应的元素不是输入元素,而是输入元素的索引
private int[] pq;

// 保存输入元素的索引在二叉堆中的位置,下标就是输入元素的索引
private int[] qp;

public IndexedMinPQ(int maxN) {
if (maxN < 0) {
throw new IllegalArgumentException("maxN < 0");
}
this.maxN = maxN;
n = 0;
items = (E[]) new Comparable[maxN+1];
pq = new int[maxN+1];
qp = new int[maxN+1];
for (int i = 0; i <= maxN; i++) {
qp[i] = -1;
}
}

/**
* 插入一个索引为 i 的元素
* @param i
* @param item
*/
public void insert(int i, E item) {
if (i < 0 || i >= maxN) {
throw new IndexOutOfBoundsException("i < 0 or i >= maxN");
}
if (contains(i)) {
throw new IllegalArgumentException("Index i already exists");
}

items[i] = item;
pq[++n] = i;
qp[i] = n;
percolateUp(n);
}

/**
* 将索引为 i 的元素设置为 item
* @param i
* @param item
*/
public void change(int i, E item) {
if (i < 0 || i >= maxN) {
throw new IndexOutOfBoundsException("k < 0 or k >= maxN");
}
if (!contains(i)) {
throw new NoSuchElementException("Index i is absent");
}

E oldItem = items[i];
items[i] = item;
if (item.compareTo(oldItem) > 0) {
percolateDown(qp[i]);
} else {
percolateUp(qp[i]);
}
}

/**
* 是否存在索引为 i 的元素
*/
public boolean contains(int i) {
return qp[i] != -1;
}

/**
* 删除索引为 i 的元素
*/
public void delete(int i) {
if (i < 0 || i >= maxN) {
throw new IndexOutOfBoundsException("k < 0 or k >= maxN");
}
if (!contains(i)) {
throw new NoSuchElementException("Index i does not exist");
}

int heapIndex = qp[i];
exch(heapIndex, n--);
percolateUp(heapIndex);
percolateDown(heapIndex);
items[i] = null;
pq[n+1] = -1; // 重置为初始值,可选
qp[i] = -1;
}

/**
* 返回最小元素
*/
public E min() {
if (isEmpty()) throw new NoSuchElementException("Priority queue underflow");
return items[pq[1]];
}

/**
* 返回最小元素的索引
*/
public int minIndex() {
if (isEmpty()) throw new NoSuchElementException("Priority queue underflow");
return pq[1];
}

/**
* 删除最小元素并返回它的索引
*/
public int delMin() {
if (isEmpty()) throw new NoSuchElementException("Priority queue underflow");
int indexOfMin = pq[1];
exch(1, n--);
percolateDown(1);
assert indexOfMin == pq[n+1];
items[indexOfMin] = null;
pq[n+1] = -1; // 重置为初始值,可选
qp[indexOfMin] = -1;
return indexOfMin;
}

/**
* 减小与索引 i 关联的元素的键值(key value)
* @param i 键值要减小的元素的索引
* @param item 键值减小后的元素
*/
public void decreaseKey(int i, E item) {
if (i < 0 || i >= maxN) throw new IndexOutOfBoundsException();
if (!contains(i)) throw new NoSuchElementException("index i is absent");
if (items[i].compareTo(item) <= 0)
throw new IllegalArgumentException(
"Calling decreaseKey() with given argument would not strictly decrease the key");
items[i] = item;
percolateUp(qp[i]);
}

/**
* 增大与索引 i 关联的元素的键值(key value)
* @param i 键值要增大的元素的索引
* @param item 键值增大后的元素
*/
public void increaseKey(int i, E item) {
if (i < 0 || i >= maxN) throw new IndexOutOfBoundsException();
if (!contains(i)) throw new NoSuchElementException("index i is absent");
if (items[i].compareTo(item) >= 0)
throw new IllegalArgumentException(
"Calling increaseKey() with given argument would not strictly increase the key");
items[i] = item;
percolateDown(qp[i]);
}

public boolean isEmpty() {
return n == 0;
}

public int size() {
return n;
}

/*******************************
* General helper functions.
*******************************/

/**
* i 索引的元素是否大于 j 索引的
* @param i 第一个元素的索引在二叉堆中的位置
* @param j 第二个元素的索引在二叉堆中的位置
* @return
*/
private boolean greater(int i, int j) {
return items[pq[i]].compareTo(items[pq[j]]) > 0;
}

/**
* 第一个索引对应的元素是否大于第二个
* @param i1 第一个元素的索引
* @param i2 第二个元素的索引
* @return
*/
private boolean greater2(int i1, int i2) {
return items[i1].compareTo(items[i2]) > 0;
}

private void exch(int i, int j) {
int tmp = pq[i];
pq[i] = pq[j];
pq[j] = tmp;
qp[pq[i]] = i;
qp[pq[j]] = j;
}

/****************************
* Heap helper functions.
****************************/
private void percolateUp(int i) {
int tmp = pq[i];
int hole = i;
while (hole > 1 && greater2(pq[hole/2], tmp)) {
pq[hole] = pq[hole/2];
qp[pq[hole]] = hole;
hole = hole/2;
}
pq[hole] = tmp;
qp[tmp] = hole;
}

private void percolateDown(int i) {
int tmp = pq[i];
int hole = i;
while (2*hole <= n) {
int child = 2*hole;
// 找出较小的儿子
if (child+1 <= n && greater(child, child+1)) {
child++;
}
if (greater2(tmp, pq[child])) {
pq[hole] = pq[child];
qp[pq[hole]] = hole;
hole = child;
} else {
break;
}
}
pq[hole] = tmp;
qp[tmp] = hole;
}

/****************************
* Iterators.
****************************/

public Iterator<Integer> iterator() {
return new HeapIterator();
}

private class HeapIterator implements Iterator<Integer> {

private IndexedMinPQ<E> copy;

public HeapIterator() {
// pq.length-1 = maxN
copy = new IndexedMinPQ<E>(pq.length-1);
// Adding all elements to copy of heap
// takes linear time since already in heap order so no elements move
for (int i = 1; i <= n; i++) {
copy.insert(pq[i], items[pq[i]]);
}

// Bad performance: NlogN. Should not use!
/*for (int i = 0; i < n; i++) {
copy.insert(i, items[i]);
}*/
}

public boolean hasNext() {
return !copy.isEmpty();
}

public Integer next() {
if (!hasNext()) throw new NoSuchElementException();
return copy.delMin();
}

public void remove() {
throw new UnsupportedOperationException();
}
}

/**
* Unit tests the <tt>IndexedMinPQ</tt> data type.
*/
public static void main(String[] args) {
// insert a bunch of strings
String[] strings = { "it", "was", "the", "best", "of", "times", "it", "was", "the", "worst" };

IndexedMinPQ<String> pq = new IndexedMinPQ<String>(strings.length);
for (int i = 0; i < strings.length; i++) {
pq.insert(i, strings[i]);
}

// delete and print each key
while (!pq.isEmpty()) {
int i = pq.delMin();
StdOut.println(i + " " + strings[i]);
}
StdOut.println();

// reinsert the same strings
for (int i = 0; i < strings.length; i++) {
pq.insert(i, strings[i]);
}

// print each key using the iterator
for (int i : pq) {
StdOut.println(i + " " + strings[i]);
}
while (!pq.isEmpty()) {
pq.delMin();
}
}
}

性能评估

在一个大小为 N 的索引优先队列中,插入(insert)、改变优先级(change)、删除(delete)和删除最小元素(delMin)操作所需的比较次数和 $logN$ 成正比。

含有 N 个元素的基于堆的索引优先队列所有操作在最坏情况下的时间成本:

操作 比较次数的增长数量级
insert() $logN$
change() $logN$
contains() 1
delete() $logN$
min() 1
minIndex() 1
delMin() $logN$

参考资源

《算法(第四版)》Robert Sedgewick 等著,第 2 章-排序

0%