dsu on tree

dsu on tree

一月 01, 2022

dsu on tree

前置知识

dfs建树 轻重链剖分

核心思想

对于以 u 为根的子树

①. 先统计它轻子树(轻儿子为根的子树)的答案,统计完后删除信息

②. 再统计它重子树(重儿子为根的子树)的答案 ,统计完后保留信息

③. 然后再将重子树的信息合并到 u上

④. 再去遍历 u 的轻子树,然后把轻子树的信息合并到 u 上

⑤. 判断 u 的信息是否需要传 递给它的父节点(u 是否是它父节点的重儿子)

首先我们进行 $\mathrm{dfs}$ 找出节点的重儿子

1
2
3
4
5
6
7
8
9
10
11
inline void dfs(int u, int f){
fa[u] = f;
siz[u] = 1;
for(int i = head[u]; i; i = e[i].nxt){
int v = e[i].to;
if(v == f) continue;
dfs(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}

然后我们进行统计答案的操作

首先标记好是要进行添加还是删除

然后我们像上面说的一样,先统计完轻字数的答案,然后删除信息,然后再统计答案

最后放出总代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/*
BlackPink is the Revolution
light up the sky
Blackpink in your area
*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<cctype>
#include<bitset>
#include<vector>
#include<ctime>
#include<map>
#include<set>
#include<iomanip>
#define int long long
#define rep(i, a, b) for(int i = (a); (i) <= (b); ++i)
#define per(i, a, b) for(int i = (a); (i) >= (b); --i)
#define whlie while
using namespace std;
#define rep(i, l, r) for(int (i) = (l); (i) <= (r); (i)++)
#define pre(i, l, r) for(int (i) = (l); (i) >= (r); (i)--)
const int N=4e5+10;
typedef long long ll;
typedef pair<int, int> P;

struct edge{
int to, nxt, val;
}e[N<<1];
int cnt, head[N], dis[N], siz[N], son[N], dep[N], top[N], tot[N], ans[N], c[N];

inline void add(int u, int v, int w){
e[++cnt] = (edge){v, head[u], w}, head[u] = cnt;
}

template <typename T> inline void read(T &x){
x=0; int f=0; char c = getchar();
for(; !isdigit(c); c = getchar()) f |= (c == '-');
for(; isdigit(c); c = getchar()) x = x * 10+(c ^ 48);
if(f) x = -x;
}

template <typename T> inline void write(T &x, char ch){
if(x<0) putchar('-'), x = -x;
static short st[30], tp = 0;
do st[++tp] = x % 10, x /= 10; while(x);
while(tp) putchar(st[tp--] | 48);
putchar(ch);
}

int n, m, res;
int maxn, minn;

inline void dfs(int u, int f){
siz[u] = 1;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v == f) continue;
dfs(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}

inline void calc(int u, int f, int val, int sonu) {
if (val == 1) {
tot[c[u]] ++;
if (tot[c[u]] > maxn) maxn = tot[c[u]], res = c[u];
else if (tot[c[u]] == maxn) res += c[u];
}
else tot[c[u]]--;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v == f || v == sonu) continue;
calc(v, u, val, sonu);
}
}

inline void dsu(int u, int f, int val) { //val为1表示是重儿子,不会进行二次递归,val为0表示是轻儿子,第一次统计信息并删除,第二次统计信息
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v == f || v == son[u]) continue;
dsu(v, u, 0);
}
if (son[u]) dsu(son[u], u, 1);
calc(u, f, 1, son[u]);
ans[u] = res;
if (!val) calc(u, f, -1, 0), maxn = 0, res = 0;
}

signed main(){
read(n);
rep (i, 1, n) read(c[i]);
rep (i, 2, n) {
int u, v;
read(u), read(v);
add(u, v, 0), add(v, u, 0);
}
dfs(1, 0);
dsu(1, 0, 0);
rep (i, 1, n) write(ans[i], ' ');
return 0;
}