并查集模板

1
2
3
4
5
6
7
8
# 每个点初始化的时候指向自己
p = list(range(n))
def find(x):
# 当前值并不指向其根
if x != p[x]:
# 不断递归直到找到自己的根
p[x] = find(p[x])
return p[x]

经典例题总结
难度值:1 对于模板的基本应用
合并集合
连通块内点的数量
格子游戏 (二维并查集)
难度值:2 结合一些基本场景
省份数量
冗余链接
连通网络的操作次数
难度值:3 对实际问题进行解读转换 但还只涉及并查集这一种算法
搭配购买 (并查集结合01背包问题)
好的路径数量
删除操作后的最大字段和

例1

acwing836 合并集合 模板的直接应用

难度值: 1
原题链接:
https://www.acwing.com/problem/content/description/838/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
n,m = map(int,input().split())
p = list(range(n + 1))

def find(x):
if p[x]!=x:
p[x] = find(p[x])
return p[x]

while m:
m = m - 1
arr = input().split()
op,a,b = arr[0],int(arr[1]),int(arr[2])
a = find(a)
b = find(b)
if op == "M":
p[a] = b
elif op == "Q":
if a == b:
print("Yes")
else:
print("No")

例2

连通块内点的数量 acwing837
难度值: 1
原题链接:
https://www.acwing.com/problem/content/839/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
n,m = map(int,input().split())
p = list(range(n + 1))
cnt = (n + 1) * [1]
cnt[0] = 0

def find(x):
if x != p[x]:
p[x] = find(p[x])
return p[x]

while m:
m = m - 1
arr = input().split()
if len(arr) == 3:
op,a,b = arr[0],int(arr[1]),int(arr[2])
a,b = find(a),find(b)
if op == "C":
# 这里要注意下
# 算的是点的数量 所以如果合并就不能再次合并了
# 否则cnt重复计算结果就会出错
if a == b:
continue
p[a] = b
cnt[b] = cnt[b] + cnt[a]
elif op == "Q1":
if a == b:
print("Yes")
else:
print("No")
else:
_,a = arr[0],int(arr[1])
a = find(a)
print(cnt[a])

例3

好路径的数目
难度值: 3
原题链接:
https://leetcode.cn/problems/number-of-good-paths/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution:
def numberOfGoodPaths(self, vals: List[int], edges: List[List[int]]) -> int:
n = len(vals)
g = [ [] for _ in range(n)]

# 建图
for x,y in edges:
# 无向图
g[x].append(y)
g[y].append(x)
p = list(range(n))
def find(x):
if p[x]!=x:
p[x] = find(p[x])
return p[x]

# 记录每个符合条件的路径中点的个数
cnt = [1] * n
# 单个点也算一条路径 故ans 初始值为n
ans = n
# 排序以后从小到大 遍历
# 从而保证当前val_x 就是连通路径中最大的值
# 故只要节点x 的邻接点 小于val_x 即可进行合并操作
for val_x,x in sorted(zip(vals,range(n))):
fx = find(x)
for y in g[x]:
fy = find(y)
val_y = vals[fy]
# 如果已经处于同一集合 则无需合并
# 或者其邻接点的值 比val_x 大则放到后续再进行合并
if fy == fx or val_y > val_x:
continue
# 等于最大值 则可以从两边加入集合 更新新的好路径
# 故更新答案
if val_y == val_x:
ans = ans + cnt[fx] * cnt[fy]
cnt[fx] = cnt[fx] + cnt[fy]
# 否则不等于最大值 不更新新的好路径 但可以并入集合 其新集合的最大值 为原集合的最大值
p[fy] = fx

return ans

例4

删除操作后的最大字段和
难度值: 3
原题链接:
https://leetcode.cn/problems/maximum-segment-sum-after-removals/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution:
def maximumSegmentSum(self, nums: List[int], removeQueries: List[int]) -> List[int]:
n,p = len(nums),list(range(len(nums)))
def find(x):
if p[x] != x:
p[x] = find(p[x])
return p[x]

cnt = n * [0]
ans = [0]
# 维护当前数组最大子段和
m = 0
# 按删除顺序反向操作 在复原过程中获得字段和
# 用并查集维护连续性信息
# 加入最后一个点时 全段长度此时默认为起始状态为0 不需考虑
for i in range(n-1,0,-1):
idx = removeQueries[i]
# 先将当前点复原
cnt[idx] = nums[idx]
m = max(m,nums[idx])
fi = find(idx)
# 然后判断该点左右集合是连续
# 如连续则合并集合
for j in (idx - 1,idx + 1):
# 越界 或者该点还没填入元素 则必然无法连续跳过
if j < 0 or j == n or not cnt[j]:
continue
fj = find(j)
p[fj] = fi
cnt[fi] = cnt[fi] + cnt[fj]
m = max(cnt[fi],m)
ans.append(m)

ans.reverse()
return ans

例5

省份数量
原题链接:
https://leetcode.cn/problems/number-of-provinces/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution:
"""
就是计算连通块的数量
"""
def findCircleNum(self, isConnected: List[List[int]]) -> int:
n = len(isConnected)
p = list(range(n))

def find(x):
if x!=p[x]:
p[x] = find(p[x])
return p[x]

for i in range(n):
for j in range(i + 1,n):
if isConnected[i][j]:
a = find(i)
b = find(j)
if a == b:
continue
p[a] = b

s = set()

for v in p:
# 这里一定要注意 并查集判断该点在哪个集合都是find(root)
s.add(find(v))
return len(s)

例6

冗余连接
原题链接:
https://leetcode.cn/problems/redundant-connection/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
n = len(edges)
p = list(range(n + 1))
def find(x):
if x != p[x]:
p[x] = find(p[x])
return p[x]
ans = []
for x,y in edges:
a = find(x)
b = find(y)
if a == b:
ans.append([x,y])
continue
if a!=b:
p[a] = b

# 最后的加进去的附加边长 就是在edges中最后出现的边
return ans[-1]

例7

连通网络的操作次数
原题链接:
https://leetcode.cn/problems/number-of-operations-to-make-network-connected/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
def makeConnected(self, n: int, connections: List[List[int]]) -> int:
# 线的数量
m = len(connections)
# n个点连通最少需要n - 1条边
if m < n - 1:
return -1
p = list(range(n))
def find(x):
if p[x]!=x:
p[x] = find(p[x])
return p[x]


# 只有一台计算机 连通分量值就是1
cur = 1
for x,y in connections:
a = find(x)
b = find(y)
if a == b:
continue
p[a] = b
cur = cur + 1

# 最少只需将剩余连通分量连上即可
return n-cur

例8

格子游戏
原题链接:
https://www.acwing.com/problem/content/1252/
收录这题的主要原因是这题是二维并查集 怎么说呢 倒也不算难 不过之前如果没有经历过一次我觉得就很难做出来

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package main

import "fmt"

const N int = 210

var p = make([]int, N*N)

func find(x int) int {
if p[x] != x {
p[x] = find(p[x])
}
return p[x]
}

func main() {
n, m := 0, 0
fmt.Scanf("%d %d",&n,&m)
// 这里已经转换为二维的并查集了 所以要初始化的值为n*n
for i := 0; i < n * n; i++ {
p[i] = i
}

get := func(x, y int) int {
return x*n + y
}

flag := false

for i := 1; i <= m; i++ {
var x, y int
var d string
fmt.Scanf("%d %d %s", &x, &y, &d)
// 将下标映射到0开始
x--
y--
//fmt.Println(x,y,d)
if d == "D" {
a := get(x, y)
b := get(x+1, y)
if find(a) == find(b) {
fmt.Println(i)
flag = true
break
} else {
p[b] = find(a)
}
} else {
a := get(x, y)
b := get(x, y+1)
if find(a) == find(b) {
fmt.Println(i)
flag = true
break
} else {
p[b] = find(a)
}
}
}

if flag == false {
fmt.Println("draw")
}

}

例9

搭配购买
原题链接:
https://www.acwing.com/problem/content/1254/
并查集结合01背包问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package main

import "fmt"

const N = 10010

var f = make([]int, N)
var w = make([]int, N)
var v = make([]int, N)
var p = make([]int, N)

func find(x int) int {
if x != p[x] {
p[x] = find(p[x])
}
return p[x]
}

func main() {
var n, m, money int
fmt.Scanf("%d %d %d", &n, &m, &money)

for i := 1; i <= n; i++ {
fmt.Scanf("%d %d", &v[i], &w[i])
p[i] = i
}

max := func(a, b int) int {
if a > b {
return a
}
return b
}

for i := 0; i < m; i++ {
var a, b int
fmt.Scanf("%d %d", &a, &b)
pa := find(a)
pb := find(b)
if pa == pb {
continue
}
v[pb] += v[pa]
w[pb] += w[pa]
p[pa] = pb
}

for i := 1; i <= n; i++ {
if p[i] == i {
for j := money; j >= v[i]; j-- {
f[j] = max(f[j], f[j-v[i]]+w[i])
}
}
}

fmt.Println(f[money])
}