跳转至

树上差分⚓︎

树上差分是指在树的路径进行差分操作,以便高效地处理路径更新和查询问题。树上差分可分为点差分和边差分两种。

点差分⚓︎

树上点差分基本思想是通过差分操作记录路径上的修改值(针对节点),最后通过一次深度优先搜索(\text{DFS})将差分值累加,从而得到每个节点的最终值。具体步骤如下:

  1. 使用树上倍增或者 \text{Tarjan} 算法预处理每条查询中路径两端点的最低公共祖先\text{LCA}(1)

    1. 最近公共祖先
  2. 对于每条路径 u \longleftrightarrow v, 找到端点的最低公共祖先 \text{LCA}

    • uv 上加上修改值
    • \text{LCA} 上减去修改值
    • 如果 \text{LCA} 不是根节点,那么在 \text{LCA} 的父节点上也减去修改值
  3. 最后对树进行一次 \text{DFS}, 将差分值进行累加, 即可得到每个节点被路径覆盖后的最终点权
Max Flow P

给定一棵有 n 个节点的树和 k 条路径,每条路径连接树上的两个节点。初始时每个节点的值为 0。对于每条路径,路径上所有节点的值都加 1。请你计算经过所有路径操作后,树上节点的最大值。

C++
#include <iostream>
#include <vector>
using namespace std;

int main() {
  int n, k;
  cin >> n >> k;
  vector<vector<int>> tree(n + 1);
  for (int i = 1; i < n; ++i) {
    int x, y;
    cin >> x >> y;
    tree[x].push_back(y);
    tree[y].push_back(x);
  }

  int m = 32 - __builtin_clz(n + 1);
  vector<vector<int>> st(n + 1, vector<int>(m, -1));
  vector<int> depth(n + 1);
  vector<int> parent(n + 1, -1);
  int root = 1;
  // 预处理倍增数组
  {
    auto dfs = [&](auto &&self, int x, int from) -> void {
      st[x][0] = from;
      for (int y : tree[x]) {
        if (y != from) {
          depth[y]  = depth[x] + 1;
          parent[y] = x;
          self(self, y, x);
        }
      }
    };
    dfs(dfs, root, -1);
    for (int j = 1; j < m; ++j) {
      for (int i = 1; i <= n; ++i) {
        if (st[i][j - 1] != -1) { st[i][j] = st[st[i][j - 1]][j - 1]; }
      }
    }
  }
  // 获取节点x的第k个祖先
  auto get_kth_ancestor = [&](int node, int k) -> int {
    for (; (k != 0) && (node != -1); k &= k - 1) { node = st[node][__builtin_ctz(k)]; }
    return node;
  };
  // 获取节点x和节点y的最近公共祖先
  auto get_lca = [&](int x, int y) -> int {
    if (depth[x] > depth[y]) { swap(x, y); }
    y = get_kth_ancestor(y, depth[y] - depth[x]);
    if (y == x) { return x; }
    for (int i = m - 1; i >= 0; --i) {
      int px = st[x][i];
      int py = st[y][i];
      if (px != py) {
        x = px;
        y = py;
      }
    }
    return st[x][0];
  };

  vector<int> values(n + 1, 0);  // 树上点差分数组, 如果有点权值,可以先将点权值加到values中
  // 处理每条路径: 点差分
  for (int i = 0; i < k; ++i) {
    int x, y;
    cin >> x >> y;
    int lca      = get_lca(x, y);
    values[x]   += 1;
    values[y]   += 1;
    values[lca] -= 1;
    if (lca != root) { values[parent[lca]] -= 1; }
  }

  // dfs累加差分值
  auto dfs = [&](auto &&self, int x, int from) -> void {
    for (int y : tree[x]) {
      if (y != from) {
        self(self, y, x);
        values[x] += values[y];  // 累加子节点的差分值
      }
    }
  };
  dfs(dfs, root, -1);

  // values数组即为每个节点的最终值
  cout << *max_element(values.begin() + 1, values.end()) << '\n';
  return 0;
}

边差分⚓︎

树上边差分基本思想是通过差分操作记录路径上的修改值(针对边),最后通过一次深度优先搜索(\text{DFS})将差分值累加,从而得到每条边的最终值。具体步骤如下:

  1. 使用树上倍增或者 \text{Tarjan} 算法预处理每条查询中路径两端点的最低公共祖先\text{LCA}(1)

    1. 最近公共祖先
  2. 对于每条路径 u \longleftrightarrow v, 找到端点的最低公共祖先 \text{LCA}

    • uv 上加上修改值
    • \text{LCA} 上减去两倍的修改值
  3. 最后对树进行一次 \text{DFS}, 假设边 e 从父节点 u 连向子节点 v,令 \text{weight}[e] += \text{num}[v] 即可得到每条边的最终值。然后将点权累加到父节点上。
Network

给定一棵含 N 个节点的树,和 M 条新增边。每次删除一条原树边和一条新增边,求使网络变为不连通的方案数。

Hint

每条新增边在树上对应一条路径。若某树边被 k 条新增边跨过,则删掉它后仍连通的新边有 k 条。

用树上差分 + \text{LCA} 统计每条树边被跨过的次数 k, 答案为对每条树边求贡献:

  1. k = 0 \Rightarrow +M
  2. k = 1 \Rightarrow +1
  3. k > 1 \Rightarrow +0

数据范围

1 \leq N \leq 10^5, 1 \leq M \leq 10^5

卡常严重?

似乎vector会TLE,改用链式前向星;cin/cout也会TLE,改用scanf/printf

C++
// This file is written in C++98 standard.

#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;

const int MAXN = 1e5 + 5;
const int MAXM = 2e5 + 5;  // 2*(N-1)
const int LOGN = 18;       // log2(1e5) ≈ 17

struct Edge {
  int to, next;
} edges[MAXM];

int head[MAXN], tot;

// Tree & LCA
int N, M;
int st[MAXN][LOGN];
int depth[MAXN];
int values[MAXN];
int parent[MAXN];

// for add_edge
inline void add_edge(int u, int v) {
  edges[++tot].to = v;
  edges[tot].next = head[u];
  head[u]         = tot;
}

void dfs_lca(int x, int fa) {
  st[x][0] = fa;
  for (int j = 1; j < LOGN; ++j) {
    if (st[x][j - 1] != -1) {
      st[x][j] = st[st[x][j - 1]][j - 1];
    } else {
      st[x][j] = -1;
    }
  }
  for (int i = head[x]; i != 0; i = edges[i].next) {
    int y = edges[i].to;
    if (y == fa) { continue; }
    depth[y] = depth[x] + 1;
    dfs_lca(y, x);
  }
}

int get_kth_ancestor(int node, int k) {
  for (int i = 0; i < LOGN && node != -1; ++i) {
    if ((k & (1 << i)) != 0) { node = st[node][i]; }
  }
  return node;
}

int get_lca(int x, int y) {
  if (depth[x] < depth[y]) { swap(x, y); }
  int diff = depth[x] - depth[y];
  for (int i = 0; i < LOGN; ++i) {
    if ((diff & (1 << i)) != 0) { x = st[x][i]; }
  }
  if (x == y) { return x; }
  for (int i = LOGN - 1; i >= 0; --i) {
    if (st[x][i] != st[y][i]) {
      x = st[x][i];
      y = st[y][i];
    }
  }
  return st[x][0];
}

int ans;

int dfs_diff(int x, int fa) {
  for (int i = head[x]; i != 0; i = edges[i].next) {
    int y = edges[i].to;
    if (y == fa) { continue; }
    dfs_diff(y, x);
    int weight = values[y];
    if (weight == 0) {
      ans += M;
    } else if (weight == 1) {
      ans += 1;
    }
    values[x] += values[y];
  }
  return values[x];
}

int main() {
  scanf("%d%d", &N, &M);
  memset(head, 0, sizeof(head));
  memset(st, -1, sizeof(st));
  tot = 0;

  for (int i = 1; i < N; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    add_edge(u, v);
    add_edge(v, u);
  }

  depth[1] = 0;
  dfs_lca(1, -1);

  for (int i = 0; i < M; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    int lca      = get_lca(u, v);
    values[u]   += 1;
    values[v]   += 1;
    values[lca] -= 2;
  }

  ans = 0;
  dfs_diff(1, -1);
  printf("%d\n", ans);
  return 0;
}

评论