树上差分⚓︎
树上差分是指在树的路径进行差分操作,以便高效地处理路径更新和查询问题。树上差分可分为点差分和边差分两种。
点差分⚓︎
树上点差分基本思想是通过差分操作记录路径上的修改值(针对节点),最后通过一次深度优先搜索(\text{DFS})将差分值累加,从而得到每个节点的最终值。具体步骤如下:
-
使用树上倍增或者 \text{Tarjan} 算法预处理每条查询中路径两端点的最低公共祖先\text{LCA}(1)
- 见 最近公共祖先
-
对于每条路径 u \longleftrightarrow v, 找到端点的最低公共祖先 \text{LCA}
- 在 u 和 v 上加上修改值
- 在 \text{LCA} 上减去修改值
- 如果 \text{LCA} 不是根节点,那么在 \text{LCA} 的父节点上也减去修改值
- 最后对树进行一次 \text{DFS}, 将差分值进行累加, 即可得到每个节点被路径覆盖后的最终点权
Max Flow P
给定一棵有 n 个节点的树和 k 条路径,每条路径连接树上的两个节点。初始时每个节点的值为 0。对于每条路径,路径上所有节点的值都加 1。请你计算经过所有路径操作后,树上节点的最大值。
#include <iostream>
#include <vector>
using namespace std;
int main() {
int n, k;
cin >> n >> k;
vector<vector<int>> tree(n + 1);
for (int i = 1; i < n; ++i) {
int x, y;
cin >> x >> y;
tree[x].push_back(y);
tree[y].push_back(x);
}
int m = 32 - __builtin_clz(n + 1);
vector<vector<int>> st(n + 1, vector<int>(m, -1));
vector<int> depth(n + 1);
vector<int> parent(n + 1, -1);
int root = 1;
// 预处理倍增数组
{
auto dfs = [&](auto &&self, int x, int from) -> void {
st[x][0] = from;
for (int y : tree[x]) {
if (y != from) {
depth[y] = depth[x] + 1;
parent[y] = x;
self(self, y, x);
}
}
};
dfs(dfs, root, -1);
for (int j = 1; j < m; ++j) {
for (int i = 1; i <= n; ++i) {
if (st[i][j - 1] != -1) { st[i][j] = st[st[i][j - 1]][j - 1]; }
}
}
}
// 获取节点x的第k个祖先
auto get_kth_ancestor = [&](int node, int k) -> int {
for (; (k != 0) && (node != -1); k &= k - 1) { node = st[node][__builtin_ctz(k)]; }
return node;
};
// 获取节点x和节点y的最近公共祖先
auto get_lca = [&](int x, int y) -> int {
if (depth[x] > depth[y]) { swap(x, y); }
y = get_kth_ancestor(y, depth[y] - depth[x]);
if (y == x) { return x; }
for (int i = m - 1; i >= 0; --i) {
int px = st[x][i];
int py = st[y][i];
if (px != py) {
x = px;
y = py;
}
}
return st[x][0];
};
vector<int> values(n + 1, 0); // 树上点差分数组, 如果有点权值,可以先将点权值加到values中
// 处理每条路径: 点差分
for (int i = 0; i < k; ++i) {
int x, y;
cin >> x >> y;
int lca = get_lca(x, y);
values[x] += 1;
values[y] += 1;
values[lca] -= 1;
if (lca != root) { values[parent[lca]] -= 1; }
}
// dfs累加差分值
auto dfs = [&](auto &&self, int x, int from) -> void {
for (int y : tree[x]) {
if (y != from) {
self(self, y, x);
values[x] += values[y]; // 累加子节点的差分值
}
}
};
dfs(dfs, root, -1);
// values数组即为每个节点的最终值
cout << *max_element(values.begin() + 1, values.end()) << '\n';
return 0;
}
边差分⚓︎
树上边差分基本思想是通过差分操作记录路径上的修改值(针对边),最后通过一次深度优先搜索(\text{DFS})将差分值累加,从而得到每条边的最终值。具体步骤如下:
-
使用树上倍增或者 \text{Tarjan} 算法预处理每条查询中路径两端点的最低公共祖先\text{LCA}(1)
- 见 最近公共祖先
-
对于每条路径 u \longleftrightarrow v, 找到端点的最低公共祖先 \text{LCA}
- 在 u 和 v 上加上修改值
- 在 \text{LCA} 上减去两倍的修改值
- 最后对树进行一次 \text{DFS}, 假设边 e 从父节点 u 连向子节点 v,令 \text{weight}[e] += \text{num}[v] 即可得到每条边的最终值。然后将点权累加到父节点上。
Network
给定一棵含 N 个节点的树,和 M 条新增边。每次删除一条原树边和一条新增边,求使网络变为不连通的方案数。
Hint
每条新增边在树上对应一条路径。若某树边被 k 条新增边跨过,则删掉它后仍连通的新边有 k 条。
用树上差分 + \text{LCA} 统计每条树边被跨过的次数 k, 答案为对每条树边求贡献:
- k = 0 \Rightarrow +M
- k = 1 \Rightarrow +1
- k > 1 \Rightarrow +0
数据范围
1 \leq N \leq 10^5, 1 \leq M \leq 10^5
卡常严重?
似乎vector会TLE,改用链式前向星;cin/cout也会TLE,改用scanf/printf
// This file is written in C++98 standard.
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int MAXN = 1e5 + 5;
const int MAXM = 2e5 + 5; // 2*(N-1)
const int LOGN = 18; // log2(1e5) ≈ 17
struct Edge {
int to, next;
} edges[MAXM];
int head[MAXN], tot;
// Tree & LCA
int N, M;
int st[MAXN][LOGN];
int depth[MAXN];
int values[MAXN];
int parent[MAXN];
// for add_edge
inline void add_edge(int u, int v) {
edges[++tot].to = v;
edges[tot].next = head[u];
head[u] = tot;
}
void dfs_lca(int x, int fa) {
st[x][0] = fa;
for (int j = 1; j < LOGN; ++j) {
if (st[x][j - 1] != -1) {
st[x][j] = st[st[x][j - 1]][j - 1];
} else {
st[x][j] = -1;
}
}
for (int i = head[x]; i != 0; i = edges[i].next) {
int y = edges[i].to;
if (y == fa) { continue; }
depth[y] = depth[x] + 1;
dfs_lca(y, x);
}
}
int get_kth_ancestor(int node, int k) {
for (int i = 0; i < LOGN && node != -1; ++i) {
if ((k & (1 << i)) != 0) { node = st[node][i]; }
}
return node;
}
int get_lca(int x, int y) {
if (depth[x] < depth[y]) { swap(x, y); }
int diff = depth[x] - depth[y];
for (int i = 0; i < LOGN; ++i) {
if ((diff & (1 << i)) != 0) { x = st[x][i]; }
}
if (x == y) { return x; }
for (int i = LOGN - 1; i >= 0; --i) {
if (st[x][i] != st[y][i]) {
x = st[x][i];
y = st[y][i];
}
}
return st[x][0];
}
int ans;
int dfs_diff(int x, int fa) {
for (int i = head[x]; i != 0; i = edges[i].next) {
int y = edges[i].to;
if (y == fa) { continue; }
dfs_diff(y, x);
int weight = values[y];
if (weight == 0) {
ans += M;
} else if (weight == 1) {
ans += 1;
}
values[x] += values[y];
}
return values[x];
}
int main() {
scanf("%d%d", &N, &M);
memset(head, 0, sizeof(head));
memset(st, -1, sizeof(st));
tot = 0;
for (int i = 1; i < N; ++i) {
int u, v;
scanf("%d%d", &u, &v);
add_edge(u, v);
add_edge(v, u);
}
depth[1] = 0;
dfs_lca(1, -1);
for (int i = 0; i < M; ++i) {
int u, v;
scanf("%d%d", &u, &v);
int lca = get_lca(u, v);
values[u] += 1;
values[v] += 1;
values[lca] -= 2;
}
ans = 0;
dfs_diff(1, -1);
printf("%d\n", ans);
return 0;
}