跳转至

树状数组⚓︎

树状数组(\text{Binary Indexed Tree} / \text{Fenwick Tree})用于维护前缀累积信息(如前缀和、前缀异或、频率累加等)。
若运算可逆且满足结合(如加法、异或或模下可逆的乘法),也可用于相应的区间/差分场景。

树状数组的下标

数组的下标从 1 开始编号(\text{one-based indexing}

lowbit⚓︎

lowbit 取一个正整数二进制表示中最低位的 1(最低有效的 1 对应的值),这是因为 -x 的二进制表示是 \sim x + 1,即对 x 取反后加 1,因此 x \And -x 会保留 x 的最低位的 1,其余位均为 0
lowbit 只对正整数有意义(取二进制表示中最低的 1)。x=0 时没有任何位为 1,因此从语义上说 lowbit(0) 是无意义的。

lowbit
C++
int64_t lowbit(int64_t x) { return x & (-x); }

单点更新、区间查询⚓︎

维护一个数组 tree,其中 tree[i] 存储原数组中从 i - lowbit(i) + 1i 的区间和(或其他可逆结合运算的结果)。
前缀和 sum(1, pos) 可以通过不断减去 lowbit 来累加 tree 中的值实现。
单点更新时,需要将更新值加到所有包含该位置的 tree 元素中,这可以通过不断加上 lowbit 来实现。

【模板】树状数组 1
C++
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;

struct BIT {
  explicit BIT(int64_t n) : n(n), tree(n + 1) {}

  // 单点更新 x位置加delta, 1 <= x <= n
  void point_add(int64_t x, int64_t delta) {
    for (; x <= n; x += lowbit(x)) { tree[x] += delta; }
  }

  // 查询区间和 [1,x], 1 <= x <= n
  int64_t sum(int64_t x) {
    int64_t ret = 0;
    for (; x > 0; x -= lowbit(x)) { ret += tree[x]; }
    return ret;
  }

  // 查询区间和 [x,y], 1 <= x <= y <= n
  int64_t range_sum(int64_t x, int64_t y) { return sum(y) - sum(x - 1); }

  // x & (-x) 取出 x 最右边的 1
  static int64_t lowbit(int64_t x) { return x & (-x); }

  int64_t n;             // 数组大小
  vector<int64_t> tree;  // one-based indexing
};

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  int64_t n, m;
  cin >> n >> m;
  BIT bit(n);
  for (int64_t i = 1; i <= n; i++) {
    int64_t v;
    cin >> v;
    bit.point_add(i, v);
  }

  while ((m--) != 0) {
    int64_t op, x, y;
    cin >> op >> x >> y;
    if (op == 1) {  // 将第 x 个数加上 y
      bit.point_add(x, y);
    } else {  // 查询区间和 [x,y]
      cout << bit.range_sum(x, y) << '\n';
    }
  }
  return 0;
}

区间更新、单点查询⚓︎

  1. 维护差分数组 D: D[1]=A[1], D[i]=A[i]-A[i-1] (i>1)
  2. 区间 [l,r]val 等价于 D[l]+=val, D[r+1]-=val(两次点增)
  3. 查询 A[i] 则为 \sum_{j=1}^i D[j](前缀和)
【模板】树状数组 2
C++
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;

struct BIT {
  explicit BIT(int64_t n) : n(n), tree(n + 2) {}

  // 区间更新 [x,y] 每个位置加delta, 1 <= x <= y <= n
  void range_add(int64_t x, int64_t y, int64_t delta) {
    point_add(x, delta);
    point_add(y + 1, -delta);
  }

  // 单点查询 x位置的值, 1 <= x <= n
  int64_t sum(int64_t x) {
    int64_t ret = 0;
    for (; x > 0; x -= lowbit(x)) { ret += tree[x]; }
    return ret;
  }

  static int64_t lowbit(int64_t x) { return x & (-x); }

  // 单点更新, 维护差分数组
  void point_add(int64_t index, int64_t delta) {
    // 不查n+1位置, 上界到n即可
    for (; index <= n; index += lowbit(index)) { tree[index] += delta; }
  }

  int64_t n;             // 数组大小
  vector<int64_t> tree;  // 差分树状数组, one-based indexing
};

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  int64_t n, m;
  cin >> n >> m;
  BIT bit(n);
  for (int64_t i = 1; i <= n; i++) {
    int64_t v;
    cin >> v;
    bit.range_add(i, i, v);
  }

  while ((m--) != 0) {
    int64_t op;
    cin >> op;
    if (op == 1) {  // 将区间 [x,y] 每个位置加 k
      int64_t x, y, k;
      cin >> x >> y >> k;
      bit.range_add(x, y, k);
    } else {  // 查询 x 位置的值
      int64_t x;
      cin >> x;
      cout << bit.sum(x) << '\n';
    }
  }
  return 0;
}

区间更新、区间查询⚓︎

  1. 区间更新同上,维护差分数组 D
  2. 区间查询
    • 单点的值:A[i]=\sum_{j=1}^i D[j]
    • 前缀和: \begin{aligned} \sum_{i=1}^r A[i] &= \sum_{i=1}^r \sum_{j=1}^i D[j] = rD[1] + (r-1)D[2] + \ldots + (r-(r-1))D[r] \\ &= r\sum_{i=1}^r D[i] - \sum_{i=1}^r (i-1)D[i] \end{aligned}
    • F[i]=(i-1)D[i],则 \sum_{i=l}^r A[i]=r\sum_{i=1}^r D[i]-\sum_{i=1}^r F[i]-((l-1)\sum_{i=1}^{l-1} D[i]-\sum_{i=1}^{l-1} F[i])
  3. 维护两个树状数组:一个维护 D 用于区间更新,另一个维护 F 用于区间查询,其中 F[i]=(i-1)D[i]

区间更新、区间查询

  • 区间查询常用前缀和的差值实现,因此需要将区间查询转化为前缀和的形式
  • 区间更新常用差分数组的形式实现,因此需要将区间更新转化为差分数组的形式
  • 线段树也可以实现区间更新和区间查询
【模板】线段树 1
C++
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;

struct BIT {
  explicit BIT(int64_t n) : n(n), d(n + 2), f(n + 2) {}

  // 区间更新 [x,y] 每个位置加delta, 1 <= x <= y <= n
  void range_add(int64_t x, int64_t y, int64_t delta) {
    point_add(d, x, delta);
    point_add(d, y + 1, -delta);
    point_add(f, x, delta * (x - 1));
    point_add(f, y + 1, -delta * y);
  }

  // 查询区间和 [x, y], 1 <= x <= y <= n
  int64_t range_sum(int64_t x, int64_t y) {
    return y * sum(d, y) - sum(f, y) - ((x - 1) * sum(d, x - 1) - sum(f, x - 1));
  }

  // 单点更新 x位置加delta, 1 <= x <= n
  void point_add(vector<int64_t> &tree, int64_t x, int64_t delta) const {
    // 不查n+1位置, 上界到n即可
    for (; x <= n; x += Lowbit(x)) { tree[x] += delta; }
  }

  // 单点查询 x位置的值, 1 <= x <= n
  static int64_t sum(vector<int64_t> &tree, int64_t x) {
    int64_t ret = 0;
    for (; x > 0; x -= Lowbit(x)) { ret += tree[x]; }
    return ret;
  }

  static int64_t Lowbit(int64_t x) { return x & (-x); }

  int64_t n;          // 数组大小
  vector<int64_t> d;  // 差分树状数组, one-based indexing
  vector<int64_t> f;  // 差分树状数组, one-based indexing
};

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  int64_t n, m;
  cin >> n >> m;
  BIT bit(n);
  for (int64_t i = 1; i <= n; i++) {
    int64_t v;
    cin >> v;
    bit.range_add(i, i, v);
  }

  for (int64_t i = 0; i < m; i++) {
    int64_t op;
    cin >> op;
    if (op == 1) {
      int64_t l, r, val;
      cin >> l >> r >> val;
      bit.range_add(l, r, val);
    } else {
      int64_t l, r;
      cin >> l >> r;
      cout << bit.range_sum(l, r) << '\n';
    }
  }
  return 0;
}

二维单点更新、区间查询⚓︎

二维区域和检索 - 可变

Leetcode Premium

C++
#include <cstdint>
#include <vector>
using namespace std;

struct BIT {
  explicit BIT(int64_t n, int64_t m) : n(n), m(m), tree(n + 1, vector<int64_t>(m + 1)) {}

  // 从数组初始化树状数组, nums下标从0开始
  explicit BIT(vector<vector<int64_t>> nums) : BIT(nums.size(), nums[0].size()) {
    for (int64_t i = 1; i <= n; i++) {
      for (int64_t j = 1; j <= m; j++) { point_add(i, j, nums[i - 1][j - 1]); }
    }
  }

  // 单点更新 (x,y)位置加delta, 1 <= x <= n, 1 <= y <= m
  void point_add(int64_t x, int64_t y, int64_t delta) {
    for (int64_t i = x; i <= n; i += lowbit(i)) {
      for (int64_t j = y; j <= m; j += lowbit(j)) { tree[i][j] += delta; }
    }
  }

  // 查询区间和 (x1,y1) 到 (x2,y2) , 1 <= x1 <= x2 <= n, 1 <= y1 <= y2 <= m
  int64_t range_sum(int64_t x1, int64_t y1, int64_t x2, int64_t y2) {
    return sum(x2, y2) - sum(x1 - 1, y2) - sum(x2, y1 - 1) + sum(x1 - 1, y1 - 1);
  }

  static int64_t lowbit(int64_t x) { return x & (-x); }

  // 查询区间和 (1,1) 到 (x,y) , 1 <= x <= n, 1 <= y <= m
  int64_t sum(int64_t x, int64_t y) {
    int64_t ret = 0;
    for (int64_t i = x; i > 0; i -= lowbit(i)) {
      for (int64_t j = y; j > 0; j -= lowbit(j)) { ret += tree[i][j]; }
    }
    return ret;
  }

  int64_t n, m;                  // 数组大小
  vector<vector<int64_t>> tree;  // one-based indexing
};

二维区间更新、区间查询⚓︎

维护四个二维树状数组,分别维护 D, i*D, j*D, i*j*D,其中 D 为差分数组,D[i][j]=A[i][j]-A[i-1][j]-A[i][j-1]+A[i-1][j-1]

上帝造题的七分钟
C++
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;

struct BIT {
  explicit BIT(int64_t n, int64_t m)
      : n(n),
        m(m),
        tree1(n + 1, vector<int64_t>(m + 1, 0)),
        tree2(n + 1, vector<int64_t>(m + 1, 0)),
        tree3(n + 1, vector<int64_t>(m + 1, 0)),
        tree4(n + 1, vector<int64_t>(m + 1, 0)) {}

  // 从数组初始化树状数组, nums下标从0开始
  explicit BIT(vector<vector<int64_t>> nums) : BIT(nums.size(), nums[0].size()) {
    for (int64_t i = 1; i <= n; i++) {
      for (int64_t j = 1; j <= m; j++) { point_add(i, j, nums[i - 1][j - 1]); }
    }
  }

  static int64_t lowbit(int64_t x) { return x & (-x); }

  // 单点更新 (x,y)位置加delta, 1 <= x <= n, 1 <= y <= m
  void point_add(int64_t x, int64_t y, int64_t delta) {
    int64_t delta1 = delta;
    int64_t delta2 = delta * x;
    int64_t delta3 = delta * y;
    int64_t delta4 = delta * x * y;
    for (int64_t i = x; i <= n; i += lowbit(i)) {
      for (int64_t j = y; j <= m; j += lowbit(j)) {
        tree1[i][j] += delta1;
        tree2[i][j] += delta2;
        tree3[i][j] += delta3;
        tree4[i][j] += delta4;
      }
    }
  }

  // 区间更新 (x1,y1) 到 (x2,y2) 每个位置加delta
  // 1 <= x1 <= x2 <= n, 1 <= y1 <= y2 <= m
  void range_add(int64_t x1, int64_t y1, int64_t x2, int64_t y2, int64_t delta) {
    point_add(x1, y1, delta);
    point_add(x1, y2 + 1, -delta);
    point_add(x2 + 1, y1, -delta);
    point_add(x2 + 1, y2 + 1, delta);
  }

  // 查询区间和 (1,1) 到 (x,y) , 1 <= x <= n, 1 <= y <= m
  int64_t sum(int64_t x, int64_t y) {
    int64_t ret = 0;
    for (int64_t i = x; i > 0; i -= lowbit(i)) {
      for (int64_t j = y; j > 0; j -= lowbit(j)) {
        ret += (x + 1) * (y + 1) * tree1[i][j];
        ret -= (y + 1) * tree2[i][j];
        ret -= (x + 1) * tree3[i][j];
        ret += tree4[i][j];
      }
    }
    return ret;
  }

  // 查询区间和 (x1,y1) 到 (x2,y2) , 1 <= x1 <= x2 <= n, 1 <= y1 <= y2 <= m
  int64_t range_sum(int64_t x1, int64_t y1, int64_t x2, int64_t y2) {
    return sum(x2, y2) - sum(x1 - 1, y2) - sum(x2, y1 - 1) + sum(x1 - 1, y1 - 1);
  }

 private:
  int64_t n, m;                   // 数组大小
  vector<vector<int64_t>> tree1;  // 树状数组, one-based indexing, 维护 d[i][j]
  vector<vector<int64_t>> tree2;  // 树状数组, one-based indexing, 维护 i*d[i][j]
  vector<vector<int64_t>> tree3;  // 树状数组, one-based indexing, 维护 j*d[i][j]
  vector<vector<int64_t>> tree4;  // 树状数组, one-based indexing, 维护 i*j*d[i][j]
};

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  char op;
  int64_t n, m;
  cin >> op >> n >> m;
  BIT bit(n, m);
  while (cin >> op) {
    if (op == 'L') {  // 将 (x1, y1), (x2, y2) 为顶点的矩形区域内的所有数字加上
                      // delta
      int64_t x1, y1, x2, y2, val;
      cin >> x1 >> y1 >> x2 >> y2 >> val;
      bit.range_add(x1, y1, x2, y2, val);
    } else if (op == 'k') {  // 查询 (x1, y1), (x2, y2)
                             // 为顶点的矩形区域内所有数字的和
      int64_t x1, y1, x2, y2;
      cin >> x1 >> y1 >> x2 >> y2;
      cout << bit.range_sum(x1, y1, x2, y2) << '\n';
    }
  }
  return 0;
}

评论