相比普通并查集,带权并查集记录了结点到根的距离,可以用来解决某些求路径长度的问题
问题
食物链(POJ-1182)
Description
动物王国中有三类动物A,B,C,这三类动物的食物链构成了有趣的环形。A吃B, B吃C,C吃A。
现有N个动物,以1-N编号。每个动物都是A,B,C中的一种,但是我们并不知道它到底是哪一种。
有人用两种说法对这N个动物所构成的食物链关系进行描述:
第一种说法是”1 X Y”,表示X和Y是同类。
第二种说法是”2 X Y”,表示X吃Y。
此人对N个动物,用上述两种说法,一句接一句地说出K句话,这K句话有的是真的,有的是假的。当一句话满足下列三条之一时,这句话就是假话,否则就是真话。
1) 当前的话与前面的某些真的话冲突,就是假话;
2) 当前的话中X或Y比N大,就是假话;
3) 当前的话表示X吃X,就是假话。
你的任务是根据给定的N(1 <= N <= 50,000)和K句话(0 <= K <= 100,000),输出假话的总数。
Input
第一行是两个整数N和K,以一个空格分隔。
以下K行每行是三个正整数 D,X,Y,两数之间用一个空格隔开,其中D表示说法的种类。
若D=1,则表示X和Y是同类。
若D=2,则表示X吃Y。
Output
只有一个整数,表示假话的数目。
Sample Input
1 | 100 7 |
Sample Output
1 | 3 |
蒟蒻的我看到描述的话应该会想到,首先建一个图,然后根据图上两点的距离判断两者的关系。但又想到万一这么一种情况:
那判断2和3的关系或者4和5的关系就比较困难了。
一个想法是把这个图分层:
所以需要有一个数据结构来记录图的层次,然后嫩快速地计算出两点的关系。
带权并查集
带权并查集就有这么一个作用,它可以记录下每个结点相对于根结点的距离。
普通并查集,通过一个fa
数组储存其父结点,经过路径压缩可以让同一个集合里的结点指向同一个根结点:1
2
3
4int fa[MAXN];
int find(int x) {
return fa[x] = (fa[x] == x) ? x : find(fa[x]);
}
而带权并查集加入了d
数组记录距离:
1 | int fa[MAXN],d[MAXN]; |
查找函数
1 | int find(int x) { |
用下图做示例,将 $3$ 到 $1$ 的路径压缩
- 记录 $3$ 的父结点 : $oldFa \gets 2$
- 更改 $3$ 的父结点为根结点 $fa[3] \gets 1$
- $3$ 到现在的父结点的距离就为 $3$ 到 $oldFa(2)$ 的距离加上 $oldFa$ 到根结点的距离: $d[3] \gets 3 + 4$。
合并函数
1 | void merge(int x,int y,int w) { |
这个函数的作用,是将 x
所在分支合并到 y
所在分支。
以下图为例,$1$、$2$ 为一个分支,$3$、$4$ 为一个分支,已知 $2$ 和 $4$ 的距离是 3 目的是将 $4$ 所在分支合并到 $2$ 所在分支。。
- 将 $3$ 的父结点设为 $1$ : $fa[3] \gets 1$
- 计算 $3$ 到 $1$ 的距离
如何计算$3$ 到 $1$ 的距离呢。可以设想一下合并后的图形,应有如下规律
$dist(x,y)$ 表示 $x$ 到 $y$ 的距离。那么可以得出
这样就合并完了。之后在查找的时候会自动路径压缩.
计算距离
1 | int dist(int x,int y) |
只要将两者到根结点的距离相减就得到啦
完整模板
1 |
|
题解
带权并查集可以很容易地知道两点的距离,不过依照题意,只有 $A$、$B$、$C$三种动物,所以距离只能在 0,1,2 里取,只要将模板里的距离都 $ \mod 3 $ 就可以了。
比如在 $A \to B \to C \to A$,这个食物链里,$dist(B,A) = 1$ 表示 $A$ 捕食 $B$,$dist(C,A) = 2$ 表示 $C$ 捕食 $A$,如果算出来距离等于0的话,就表示两者是同类。
代码1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
int fa[MAXN],d[MAXN];
int find(int x) {
if(fa[x] == x) return x;
else {
int oldFa = fa[x];
fa[x] = find(oldFa);
d[x] = (d[x] + d[oldFa]) % 3;
return fa[x];
}
}
void merge(int x,int y,int w) {
int fax = find(x),fay = find(y);
if(fax == fay) return;
fa[fax] = fay;
d[fax] = (-d[x] + d[y] + w + 3) % 3;
}
int dist(int x,int y) {
int fax = find(x),fay = find(y);
if(fax != fay) return -1;
else return (d[x] - d[y] + 3) % 3;
}
int main() {
int n,k,ans = 0;
scanf("%d%d",&n,&k);
for(int i = 1 ;i <= n;i++) fa[i] = i;
for(int i = 0;i < k ;i++) {
int o,x,y;
scanf("%d%d%d",&o,&x,&y);
if(x > n || y > n) {ans++;continue;}
if(o == 1) {
int fax = find(x),fay = find(y);
if(fax != fay) merge(x,y,0);
else if(dist(x,y) != 0) ans++;
}
else {
if(x == y) {ans++;continue;}
int fax = find(x),fay = find(y);
if(fax != fay) merge(x,y,1);
else if(dist(x,y) != 1) ans++;
}
}
printf("%d",ans);
return 0;
}