CloudySky

纵使世界万般残酷

总有温暖值得守护

树分治 学习笔记

点分治

点分治是用来解决树上路径问题的一种思想。

和序列上的分治思想类似,只不过是向下递归时把序列的中点换成了子树的重心。一般要搭配双指针或者树状数组

考虑这样的问题

  1. 求树上是否存在距离恰好为 kk 的路径。
  2. 求树上距离小于等于 kk 的路径有几条。

这个时候就可以用点分治解决问题。

具体实现:

考虑选定一个点 xx,那么树上的所有路径就可以分成两类:

  1. 经过 xx,即两个端点分别位于 xx两个子树上。
  2. 不经过 xx,即两个端点在 xx 的同一个子树上。

对于第二类点,直接递归处理就好,对于第一类点,分别求出 xx 子树上所有点距离 xxdisdis 值,那么任意两个点 y,zy, z 之间的路径长一定为 len=disy+diszlen = dis_y + dis_z 。那么就可以边求距离边将所有的距离放入一个数组里,排一遍序,双指针扫一下即可。或者开一个桶树状数组维护一下。

但这样会出现重复,有两种解决方案:

  1. 对树进行染色,不同的子树染上不同的颜色。
    这种方法对于只找特殊点的 1 类问题比较适用,常数较小。
  2. 进行容斥,对于每个点先计算答案,再分别计算只有一棵子树上的点时的答案,将同子树贡献剪掉。
    这种方法可以较好地处理范围性 2 类问题,缺点就是常数很大。

还要注意每次递归子树时,都要再找一遍重心,而不是直接选择与当前节点相连的点。

代码实现(1 类问题):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// 点分治 学习笔记
int rt, siz[Maxn], mx[Maxn], vis[Maxn];
vector <pair<int, int> > id;

// 计算距离并染色
void dfs(int x, int f, int d, int c) {
id.push_back({d, c});
for (int i = hd[x]; i; i = e[i].nxt) {
int y = e[i].v;
if (y == f || vis[y]) continue;
dfs(y, x, d + e[i].t, c);
}
}nt y = e[i].v;
if (y == f || vis[y]) continue;
getrt(y, x, s), siz[x] += siz[y];
mx[x] = max(mx[x], siz[y]);
}
mx[x] = max(mx[x], s - siz[x]);
if (!rt || mx[x] < mx[rt]) rt = x;
}

// 分治
void solve(int x) {
vis[x] = 1, calc(x);
for (int i = hd[x]; i; i = e[i].nxt) {
int y = e[i].v, s = siz[y];
if (vis[y]) continue;
rt = 0, getrt(y, x, s), getrt(y = rt, x, s), solve(y);
}
}

点分树

点分树是点分治的升级版,支持动态进行点分治可进行的操作。

具体实现:

原理是将点分治每次的分治重心相连,构建成一颗辅助树,把点分治需要用的东西用数据结构动态维护起来。在这棵树上进行的一些看似复杂度不对的东西因为有了层数的限制都可以被接受。同时还需要配一个快速求出两点在原树上的距离的数据结构,大多数情况下树剖就行了。

P6329 【模板】点分树 | 震波

代码实现:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// 树分治 学习笔记 || 点分树的具体实现
// Code By CloudySky
#include <bits/stdc++.h>
// #define int long long
namespace IO {
inline int read() {
int x = 0, f = 1; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); }
while (c >= '0' && c <= '9') { x = x * 10 + (c ^ 48); c = getchar(); }
return x * f;
}
void print_n(int x) {
if (x > 9) print_n(x / 10);
putchar(x % 10 + '0');
}
inline void print(int x, char s = '\n') {
if (x < 0) putchar('-'), x = -x;
print_n(x), putchar(s);
}
} // namespace IO
using namespace IO;
const int Maxn = 1e5 + 10;
using namespace std;

struct BIT {
int N; vector <int> v;
void init(int n) {N = n + 1, v.resize (N + 1);}
void add(int x, int y) {for (x = x + 1; x <= N; x += x & -x) v[x] += y;}
int ask(int x, int ans = 0) {
for (x = min (x + 1, N); x; x -= x & -x) {ans += v[x];} return ans;
}
} T1[Maxn], T2[Maxn];

struct edge {int v, nxt;} e[Maxn << 1];
int hd[Maxn], cnt;
void add(int u, int v) {e[++cnt] = (edge) {v, hd[u]}; hd[u] = cnt;}

int dep[Maxn], fa[Maxn], siz[Maxn], son[Maxn];

void dfs1(int x, int f) {
dep[x] = dep[f] + 1, fa[x] = f, siz[x] = 1;
for (int i = hd[x]; i; i = e[i].nxt) {
int y = e[i].v; if (y == fa[x]) continue;
dfs1(y, x), siz[x] += siz[y];
if (siz[y] > siz[son[x]]) son[x] = y;
}
}
int top[Maxn];
void dfs2(int x, int f) {
top[x] = f;
if (son[x]) dfs2(son[x], f);
for (int i = hd[x]; i; i = e[i].nxt) {
int y = e[i].v;
if (y != fa[x] && y != son[x]) dfs2(y, y);
}
}

int lca(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}

int dis(int x, int y) {return dep[x] + dep[y] - 2 * dep[lca(x, y)];}

void calc(int x) {
siz[fa[x]] -= siz[x], siz[x] += siz[fa[x]];
}

int mx[Maxn], vis[Maxn], rt;
void getrt(int x, int f, int Sz) {
siz[x] = 1, mx[x] = 0;
for (int i = hd[x]; i; i = e[i].nxt) {
int y = e[i].v; if (y == f || vis[y]) continue;
getrt(y, x, Sz);
siz[x] += siz[y], mx[x] = max(mx[x], siz[y]);
}
mx[x] = max(mx[x], Sz - siz[x]);
if (!rt || mx[x] < mx[rt]) rt = x;
}

int p[Maxn];

void build(int x, int f, int Sz) {
p[x] = f, vis[x] = 1;
T1[x].init(Sz), T2[x].init(Sz);
for (int i = hd[x]; i; i = e[i].nxt) {
int y = e[i].v, s = siz[y];
if (vis[y]) continue;
rt = 0, getrt(y, x, s), getrt(rt, x, s),
build(rt, x, s);
}
}

void change(int x, int k) {
T1[x].add(0, k);
for (int y = x; p[y]; y = p[y]) {
int d = dis(x, p[y]);
T1[p[y]].add(d, k), T2[y].add(d, k);
}
}

int ask(int x, int k) {
int ans = T1[x].ask(k);
for (int y = x; p[y]; y = p[y]) {
int d = k - dis(x, p[y]);
if (d < 0) continue;
ans += T1[p[y]].ask(d) - T2[y].ask(d);
}
return ans;
}

int a[Maxn];

signed main() {
int n = read(), m = read();
for (int i = 1; i <= n; ++i) a[i] = read();
for (int i = 1, u, v; i < n; ++i) {u = read(), v = read(), add(u, v), add(v, u);}
dfs1(1, 1), dfs2(1, 1);
getrt(1, 0, n), getrt(rt, 0, n), build(rt, 0, n);
for (int i = 1; i <= n; ++i)
change(i, a[i]);
for (int i = 1, lst = 0; i <= m; ++i) {
int op = read(), x = read() ^ lst, y = read() ^ lst;
if (op == 0) print(lst = ask(x, y));
if (op == 1) change(x, y - a[x]), a[x] = y;
}
return 0;
}

本文作者:CloudySky
写作时间: 2022-01-20
最后更新时间: 2022-01-20
遵循协议: BY-NC-SA

Top