跳转至

ST表⚓︎

ST 表(\text{Sparse Table})是一种用于在静态数组上高效处理区间查询的算法数据结构。
ST 表适用于幂等(\text{idempotent})操作,而区间和(sum)等非幂等操作则不适用。
ST 表的核心思想是预处理数组的多个区间,使得任何查询都可以通过合并预处理的区间来快速得到结果。
ST 表是一种典型的倍增(\text{Doubling})技术应用。

应用场景

  • 区间最值查询(\text{Range Minimum/Maximum Query}
  • 区间最大公约数查询(\text{Range GCD Query}
  • 其他满足结合律的幂等操作(如 \min\max\gcd\text{bitwise\_and}\text{bitwise\_or} 等)

ST 表模板

C++
struct ST {
  int n, m;

  vector<int> log2;  // floor(log2(i)), 也可以用 31 - __builtin_clz(i)
  vector<vector<int>> st;
  std::function<int(int, int)> func;  // 结合律运算函数, 如 min, max, gcd

  ST(const vector<int> &nums, const std::function<int(int, int)> &f)
      : n(nums.size()), log2(n + 1), func(f) {
    for (int i = 2; i <= n; i++) { log2[i] = log2[i / 2] + 1; }
    m = log2[n] + 1;  // m = log2[n] + 1 = 32 - __builtin_clz(n) (1)
    st = vector<vector<int>>(n, vector<int>(m));
    // 预处理 ST 表, O(nlogn)
    for (int i = 0; i < n; i++) { st[i][0] = nums[i]; }
    for (int j = 1; j < m; j++) {
      for (int i = 0; i + (1 << j) - 1 < n; i++) {  // 确保区间右端点不会越界 (2) (3)
        st[i][j] = func(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
      }
    }
  }

  // 查询区间 [l, r] 的值,0 <= l <= r < n (4)
  int query(int l, int r) {
    int j = log2[r - l + 1];  // 计算不超过区间长度的最大2的幂次 (5)
    return func(st[l][j], st[r - (1 << j) + 1][j]);
  }
};
  1. 需要的列数 mn 决定,最多需要 \lfloor log_{2}(n) \rfloor + 1 列 : 2^m = 2^{\lfloor log_{2}(n) \rfloor + 1} > n
  2. 整个区间 [i, i + 2^{j} - 1] 可以分解为两个子区间:
    [i, i + 2^{j - 1})[i + 2^{j - 1}, i + 2^{j})
    这两个子区间的长度均为 2^{j - 1}
  3. 另一种常见处理方法,根据前进 2^{j-1} 步再前进 2^{j-1} 步来合并, 见LCA
    for (int i = 0; i < n; i++) { st[i][j] = st[st[i][j - 1]][j - 1]; }
  4. 区间 [l, r] 可以分解为两个子区间:
    [l, l + 2^{j} - 1][r - 2^{j} + 1, r]
    这两个子区间覆盖了整个区间 [l, r] (可能重叠,但是幂等性不会影响结果),并且长度均为 2^{j}
  5. 也可以用 j = 31 - __builtin_clz(r - l + 1);
Balanced Lineup G

给定一个长度为 n 的整数序列 a_1, a_2, \ldots, a_n,以及 m 个查询,每个查询包含两个整数 lr,要求你在子序列 a_l, a_{l+1}, \ldots, a_r 中找到最大值与最小值的差值。

C++
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;

struct ST {
  int n, m;
  vector<vector<int>> st_max;
  vector<vector<int>> st_min;

  explicit ST(const vector<int> &nums)
      : n(nums.size()),
        m(32 - __builtin_clz(n + 1)),
        st_max(n + 1, vector<int>(m)),
        st_min(n + 1, vector<int>(m)) {
    // 预处理 ST 表, O(nlogn)
    for (int i = 1; i <= n; i++) {
      st_max[i][0] = nums[i - 1];
      st_min[i][0] = nums[i - 1];
    }
    for (int j = 1; j < m; j++) {
      for (int i = 1; i + (1 << j) - 1 <= n; i++) {  // 确保不会越界
        st_max[i][j] = max(st_max[i][j - 1], st_max[i + (1 << (j - 1))][j - 1]);
        st_min[i][j] = min(st_min[i][j - 1], st_min[i + (1 << (j - 1))][j - 1]);
      }
    }
  }

  int query(int l, int r) {
    int j = 32 - __builtin_clz(r - l + 1) - 1;  // 计算不超过区间长度的最大2的幂次
    return max(st_max[l][j], st_max[r - (1 << j) + 1][j])
         - min(st_min[l][j], st_min[r - (1 << j) + 1][j]);
  }
};

int main() {
  int n, q;
  cin >> n >> q;
  vector<int> nums(n);
  for (int i = 0; i < n; i++) { cin >> nums[i]; }
  ST st(nums);
  while ((q--) != 0) {
    int l, r;
    cin >> l >> r;
    cout << st.query(l, r) << '\n';
  }
  return 0;
}

评论