线段树⚓︎
线段树(\text{Segment Tree})是一种用于高效处理区间查询和区间更新的树形数据结构。它将一个数组划分为若干个区间,并在每个节点中存储这些区间的信息,从而支持快速的查询和更新操作。
基本操作⚓︎
线段树的基本操作包括:
- 单点修改
- 区间加法
- 区间赋值
- 区间求和
其中区间操作通常使用懒标记(\text{Lazy Propagation})来优化性能。
区间加法与区间求和⚓︎
【模板】线段树 1
单点修改、区间加法与区间求和
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree { // (1)!
vector<int64_t> sum; // 区间和 (2)
vector<int64_t> tag_add; // 区间加法懒标记
explicit segment_tree(int64_t n) : sum(n * 4), tag_add(n * 4) {}
void push_up(int64_t i) { sum[i] = sum[2 * i] + sum[2 * i + 1]; } // (3)!
// 构建线段树
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) { // 叶子节点,进行初始化
sum[i] = nums[left];
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
push_up(i);
}
// 单点修改: point_set(x, val, 1, 1, n) 将下标 x 的值修改为 val (4)
void point_set(int64_t index, int64_t val, int64_t i, int64_t left, int64_t right) {
if (left == index && right == index) { // 到叶子,直接修改数组中的值
sum[i] = val;
return;
}
int64_t mid = left + ((right - left) / 2);
if (index <= mid) { // 继续往下找
point_set(index, val, 2 * i, left, mid);
} else {
point_set(index, val, 2 * i + 1, mid + 1, right);
}
push_up(i); // 更新当前节点的值
}
void lazy_add(int64_t i, int64_t val, int64_t count) {
sum[i] += count * val;
tag_add[i] += val;
}
// 向下传递懒标记
void push_down(int64_t i, int64_t left_count, int64_t right_count) {
if (tag_add[i] != 0) { // 将加法标记传递给子节点
lazy_add(2 * i, tag_add[i], left_count);
lazy_add(2 * i + 1, tag_add[i], right_count);
tag_add[i] = 0; // 清空根节点加法标记
}
}
// 区间加法: range_add(x, y, val, 1, 1, n) 将区间 [x,y] 的值加上 val
void range_add(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { // 区间覆盖, 直接更新
lazy_add(i, val, r - l + 1);
return;
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_add(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_add(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
// 区间求和: range_sum(x, y, 1, 1, n) 查询区间 [x,y] 的和
int64_t range_sum(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return sum[i]; } // 区间覆盖,直接返回
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
// 汇总结果
int64_t res = 0;
if (ql <= mid) { res += range_sum(ql, qr, 2 * i, l, mid); }
if (qr > mid) { res += range_sum(ql, qr, 2 * i + 1, mid + 1, r); }
return res;
}
};
int main() {
int64_t n, m;
cin >> n >> m;
vector<int64_t> nums(n + 1);
for (int64_t i = 1; i <= n; i++) { cin >> nums[i]; }
segment_tree seg(n);
seg.build(1, 1, n, nums);
for (int64_t i = 0; i < m; i++) {
int64_t op;
cin >> op;
if (op == 1) { // 区间加法
int64_t l, r, val;
cin >> l >> r >> val;
seg.range_add(l, r, val, 1, 1, n);
} else { // 区间求和
int64_t l, r;
cin >> l >> r;
cout << seg.range_sum(l, r, 1, 1, n) << '\n';
}
}
return 0;
}
- 闭区间,下标从 1 开始 [1,...,n],根节点编号为 1
- 需要维护的一些性质:如 sum、max、min、gcd、lcm、均值、方差 gcd、lcm一般只支持单点修改、区间查询,不支持区间修改
- 向上更新,如果是其他性质也需要一并修改
- 单点修改直接到叶节点,无需懒标记 单点操作都可以用区间操作来代替,让区间左右端点相等即可
区间赋值⚓︎
懒标记优先级
- 对同一类可组合懒标记而言,区间赋值通常覆盖区间乘法与区间加法,区间乘法会影响已有的区间加法标记。
- 但区间开方通常不是常规懒标记叠加问题,而是依赖值域收敛和剪枝递归单独处理,不能简单与赋值、乘法、加法并列看作统一优先级。
扶苏的问题
区间赋值、区间加法、区间最大值、区间求和
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> sum; // 区间和
vector<int64_t> maximum; // 区间最大值
vector<int64_t> tag_add; // 区间加法懒标记
vector<int64_t> tag_set; // 区间赋值懒标记
vector<int64_t> tag_set_val; // 区间赋值懒标记值, 只有tag_set为true时该值才有意义
explicit segment_tree(int64_t n)
: sum(n * 4), maximum(n * 4), tag_add(n * 4), tag_set(n * 4), tag_set_val(n * 4) {}
void push_up(int64_t i) {
sum[i] = sum[2 * i] + sum[2 * i + 1];
maximum[i] = max(maximum[2 * i], maximum[2 * i + 1]);
}
// 构建线段树
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) { // 叶子节点,进行初始化
sum[i] = maximum[i] = nums[left];
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
push_up(i);
}
void lazy_add(int64_t i, int64_t val, int64_t count) {
sum[i] += count * val;
maximum[i] += val;
tag_add[i] += val;
}
void lazy_set(int64_t i, int64_t val, int64_t count) {
sum[i] = count * val;
maximum[i] = val;
tag_set[i] = 1;
tag_set_val[i] = val;
tag_add[i] = 0; // 清空加法标记
}
// 向下传递懒标记
void push_down(int64_t i, int64_t left_count, int64_t right_count) {
if (tag_set[i] != 0) { // 处理赋值
lazy_set(2 * i, tag_set_val[i], left_count);
lazy_set(2 * i + 1, tag_set_val[i], right_count);
tag_set[i] = 0; // 清空根节点赋值标记
}
if (tag_add[i] != 0) { // 将加法标记传递给子节点
lazy_add(2 * i, tag_add[i], left_count);
lazy_add(2 * i + 1, tag_add[i], right_count);
tag_add[i] = 0; // 清空根节点加法标记
}
}
// 区间加法: range_add(x, y, val, 1, 1, n) 将区间 [x,y] 的值加上 val
void range_add(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { // 区间覆盖, 直接更新
lazy_add(i, val, r - l + 1);
return;
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_add(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_add(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
// 区间赋值: range_set(x, y, val, 1, 1, n) 将区间 [x,y] 的值修改为 val
void range_set(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { // 区间覆盖, 直接更新
lazy_set(i, val, r - l + 1);
return;
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_set(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_set(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
// 区间求和: range_sum(x, y, 1, 1, n) 查询区间 [x,y] 的和
int64_t range_sum(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return sum[i]; } // 区间覆盖,直接返回
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
// 汇总结果
int64_t res = 0;
if (ql <= mid) { res += range_sum(ql, qr, 2 * i, l, mid); }
if (qr > mid) { res += range_sum(ql, qr, 2 * i + 1, mid + 1, r); }
return res;
}
// 区间求最大值: range_max(x, y, 1, 1, n) 查询区间 [x,y] 的最大值
int64_t range_max(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return maximum[i]; } // 区间覆盖,直接返回
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
// 汇总结果
int64_t res = INT64_MIN;
if (ql <= mid) { res = max(res, range_max(ql, qr, 2 * i, l, mid)); }
if (qr > mid) { res = max(res, range_max(ql, qr, 2 * i + 1, mid + 1, r)); }
return res;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int64_t n, q;
cin >> n >> q;
vector<int64_t> nums(n + 1);
for (int64_t i = 1; i <= n; i++) { cin >> nums[i]; }
segment_tree seg(n);
seg.build(1, 1, n, nums);
for (int64_t i = 0; i < q; i++) {
int64_t op;
cin >> op;
if (op == 1) { // 区间赋值
int64_t l, r, x;
cin >> l >> r >> x;
seg.range_set(l, r, x, 1, 1, n);
} else if (op == 2) { // 区间加法
int64_t l, r, x;
cin >> l >> r >> x;
seg.range_add(l, r, x, 1, 1, n);
} else if (op == 3) { // 区间最大值
int64_t l, r;
cin >> l >> r;
cout << seg.range_max(l, r, 1, 1, n) << '\n';
}
}
return 0;
}
更复杂的区间操作⚓︎
乘法⚓︎
区间乘法
区间乘法优先级高于区间加法,并且每次乘法操作会影响之前的加法标记,需要将加法标记也乘上相应的值
【模板】线段树 2(区间乘法)
给定一个长度为 n 的数组,支持以下三种操作:
- 对区间 [l,r] 乘以一个整数 val。
- 对区间 [l,r] 加上一个整数 val。
- 查询区间 [l,r] 的和。
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> sum; // 区间和
vector<int64_t> tag_add; // 区间加法懒标记
vector<int64_t> tag_mul; // 区间乘法懒标记
int64_t m; // 取模
explicit segment_tree(int64_t n, int64_t m)
: sum(n * 4), tag_add(n * 4), tag_mul(n * 4, 1), m(m) {}
void push_up(int64_t i) { sum[i] = (sum[2 * i] + sum[2 * i + 1]) % m; }
// 构建线段树
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) { // 叶子节点,进行初始化
sum[i] = nums[left] % m;
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
push_up(i);
}
void lazy_add(int64_t i, int64_t val, int64_t count) {
sum[i] = (sum[i] + count * val) % m;
tag_add[i] = (tag_add[i] + val) % m;
}
void lazy_mul(int64_t i, int64_t val, int64_t count) {
sum[i] = (sum[i] * val) % m;
tag_mul[i] = (tag_mul[i] * val) % m;
tag_add[i] = (tag_add[i] * val) % m; // 乘法会影响加法标记
}
// 向下传递懒标记
void push_down(int64_t i, int64_t left_count, int64_t right_count) {
if (tag_mul[i] != 1) { // 将乘法标记传递给子节点
lazy_mul(2 * i, tag_mul[i], left_count);
lazy_mul(2 * i + 1, tag_mul[i], right_count);
tag_mul[i] = 1; // 清空根节点乘法标记
}
if (tag_add[i] != 0) { // 将加法标记传递给子节点
lazy_add(2 * i, tag_add[i], left_count);
lazy_add(2 * i + 1, tag_add[i], right_count);
tag_add[i] = 0; // 清空根节点加法标记
}
}
// 区间加法: range_add(x, y, val, 1, 1, n) 将区间 [x,y] 的值加上 val
void range_add(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { // 区间覆盖, 直接更新
lazy_add(i, val, r - l + 1);
return;
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_add(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_add(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
// 区间乘法: range_mul(x, y, val, 1, 1, n) 将区间 [x,y] 的值乘以 val
void range_mul(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { // 区间覆盖, 直接更新
lazy_mul(i, val, r - l + 1);
return;
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_mul(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_mul(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
// 区间求和: range_sum(x, y, 1, 1, n) 查询区间 [x,y] 的和
int64_t range_sum(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return sum[i]; } // 区间覆盖,直接返回
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
// 汇总结果
int64_t res = 0;
if (ql <= mid) { res = (res + range_sum(ql, qr, 2 * i, l, mid)) % m; }
if (qr > mid) { res = (res + range_sum(ql, qr, 2 * i + 1, mid + 1, r)) % m; }
return res;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int64_t n, q, m;
cin >> n >> q >> m;
vector<int64_t> nums(n + 1);
for (int64_t i = 1; i <= n; i++) { cin >> nums[i]; }
segment_tree seg(n, m);
seg.build(1, 1, n, nums);
for (int64_t i = 0; i < q; i++) {
int64_t op;
cin >> op;
if (op == 1) { // 区间乘法
int64_t l, r, x;
cin >> l >> r >> x;
seg.range_mul(l, r, x, 1, 1, n);
} else if (op == 2) { // 区间加法
int64_t l, r, x;
cin >> l >> r >> x;
seg.range_add(l, r, x, 1, 1, n);
} else if (op == 3) { // 区间求和
int64_t l, r;
cin >> l >> r;
cout << seg.range_sum(l, r, 1, 1, n) << '\n';
}
}
return 0;
}
开方⚓︎
区间开方
整数范围内不断开方的数列最终会收敛到 1,因此可以用区间最大值来判断是否需要继续递归。如果区间最大值已经是 1,则该区间内所有数均为 1,无需继续开方。否则继续递归处理子节点。
上帝造题的七分钟 2 / 花神游历各国(区间开方)
给定一个长度为 n 的数组,支持以下两种操作:
- 对区间 [l,r] 内的每个数开平方(向下取整)。(op=0)
- 查询区间 [l,r] 的和。(op=1)
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> sum; // 区间和
vector<int64_t> maximum; // 区间最大值
explicit segment_tree(int64_t n) : sum(n * 4), maximum(n * 4) {}
void push_up(int64_t i) {
sum[i] = sum[2 * i] + sum[2 * i + 1];
maximum[i] = max(maximum[2 * i], maximum[2 * i + 1]);
}
// 构建线段树
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) { // 叶子节点,进行初始化
sum[i] = maximum[i] = nums[left];
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
push_up(i);
}
// 区间开方: range_sqrt(x, y, 1, 1, n) 将区间 [x,y] 的值开方
void range_sqrt(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (l == r) { // 叶子节点,进行开方
auto val = static_cast<int64_t>(sqrt(sum[i]));
sum[i] = val;
maximum[i] = val;
return;
}
int64_t mid = l + ((r - l) / 2);
// 只对有可能变化的子节点进行递归
if (ql <= mid && maximum[2 * i] > 1) { range_sqrt(ql, qr, 2 * i, l, mid); }
if (qr > mid && maximum[2 * i + 1] > 1) { range_sqrt(ql, qr, 2 * i + 1, mid + 1, r); }
push_up(i);
}
// 区间求和: range_sum(x, y, 1, 1, n) 查询区间 [x,y] 的和
int64_t range_sum(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return sum[i]; } // 区间覆盖,直接返回
int64_t mid = l + ((r - l) / 2);
// 汇总结果
int64_t res = 0;
if (ql <= mid) { res += range_sum(ql, qr, 2 * i, l, mid); }
if (qr > mid) { res += range_sum(ql, qr, 2 * i + 1, mid + 1, r); }
return res;
}
};
int main() {
int n;
cin >> n;
vector<int64_t> nums(n + 1);
for (int i = 1; i <= n; ++i) { cin >> nums[i]; }
segment_tree seg(n);
seg.build(1, 1, n, nums);
int m;
cin >> m;
for (int i = 0; i < m; ++i) {
int op, l, r;
cin >> op >> l >> r;
if (l > r) { swap(l, r); }
if (op == 0) {
seg.range_sqrt(l, r, 1, 1, n);
} else if (op == 1) {
cout << seg.range_sum(l, r, 1, 1, n) << '\n';
}
}
}
GCD⚓︎
区间加法与 \gcd
n 个数的 \gcd 可以通过差分数组来维护。由于 \gcd(x, y) = \gcd(x, y - x),因此有以下等式成立:
令 d_i = a_i - a_{i-1},则上式可改写为:
因此,可以使用线段树维护差分数组 d 的区间 gcd,以及差分数组的前缀和来计算 a_l 的值(a_l = \sum_{i=1}^{l} d_i,也可以通过树状数组维护)。
更新时,区间加法操作会影响差分数组的两个位置:d_l 增加 val,d_{r+1} 减少 val(如果 r + 1 \leq n)。
Interval GCD(区间\gcd)
给定一个长度为 n 的数组,支持以下两种操作:
- 对区间 [l,r] 加上一个整数 val。(op=C)
- 查询区间 [l,r] 的 \gcd。(op=Q)
#include <cmath>
#include <cstdint>
#include <iostream>
#include <numeric>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> sum; // 区间和信息
vector<int64_t> gcd; // 区间gcd信息
explicit segment_tree(int64_t n) : sum(n * 4), gcd(n * 4) {}
void push_up(int64_t i) {
sum[i] = sum[2 * i] + sum[2 * i + 1];
gcd[i] = std::gcd(gcd[2 * i], gcd[2 * i + 1]);
}
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) { // 叶子节点,差分数组的值
sum[i] = gcd[i] = nums[left] - nums[left - 1];
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
push_up(i);
}
void point_add(int64_t index, int64_t val, int64_t i, int64_t left, int64_t right) {
if (left == index && right == index) { // 到叶子,直接修改数组中的值
sum[i] += val;
gcd[i] += val;
return;
}
int64_t mid = left + ((right - left) / 2);
if (index <= mid) { // 继续往下找
point_add(index, val, 2 * i, left, mid);
} else {
point_add(index, val, 2 * i + 1, mid + 1, right);
}
push_up(i); // 更新当前节点的值
}
int64_t range_sum(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return sum[i]; }
int64_t mid = l + ((r - l) / 2);
int64_t res = 0;
if (ql <= mid) { res += range_sum(ql, qr, 2 * i, l, mid); }
if (qr > mid) { res += range_sum(ql, qr, 2 * i + 1, mid + 1, r); }
return res;
}
int64_t range_gcd(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return gcd[i]; }
int64_t mid = l + ((r - l) / 2);
int64_t res = 0;
if (ql <= mid) { res = std::gcd(res, range_gcd(ql, qr, 2 * i, l, mid)); }
if (qr > mid) { res = std::gcd(res, range_gcd(ql, qr, 2 * i + 1, mid + 1, r)); }
return res;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int64_t n, m;
cin >> n >> m;
vector<int64_t> nums(n + 1);
for (int64_t i = 1; i <= n; ++i) { cin >> nums[i]; }
segment_tree seg(n);
seg.build(1, 1, n, nums);
for (int64_t i = 0; i < m; ++i) {
char op;
cin >> op;
if (op == 'Q') {
int64_t left, right;
cin >> left >> right;
int64_t a_left = seg.range_sum(1, left, 1, 1, n);
int64_t range_gcd = seg.range_gcd(left + 1, right, 1, 1, n);
cout << std::abs(std::gcd(a_left, range_gcd)) << '\n';
} else if (op == 'C') {
int64_t left, right, val;
cin >> left >> right >> val;
seg.point_add(left, val, 1, 1, n);
if (right + 1 <= n) { seg.point_add(right + 1, -val, 1, 1, n); }
}
}
return 0;
}
等差数列⚓︎
区间加上等差数列
对于区间 [l,r],将其加上等差数列 s, s+d, s+2d, \ldots, s+(r-l)d,可以维护一个差分数组 D,其中 D_i = a_i - a_{i-1}。则操作可以转化为:
- D_l = D_l + s
- [D_{l+1}, D_r] 增加 d(\text{if } l + 1 \leq r)
- D_{r+1} = D_{r+1} - (s + (r - l)d) = D_{r+1} - e(\text{if }r + 1 \leq n)
也可以使用树状数组维护二阶差分数组,见高阶差分数组中等差数列的性质维护。
无聊的数列
给定一个长度为 n 的数组,支持以下两种操作:
- 对区间 [l,r] 加上等差数列,l 位置加上 k,之后每个位置比前一个位置多加 d。(op=1)
- 查询下标为 index 的位置的值。(op=2)
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> sum;
vector<int64_t> tag_add;
explicit segment_tree(int64_t n) : sum(n * 4), tag_add(n * 4) {}
void push_up(int64_t i) { sum[i] = sum[2 * i] + sum[2 * i + 1]; }
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) {
sum[i] = nums[left];
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
push_up(i);
}
void lazy_add(int64_t i, int64_t val, int64_t count) {
sum[i] += count * val;
tag_add[i] += val;
}
void push_down(int64_t i, int64_t left_count, int64_t right_count) {
if (tag_add[i] != 0) {
lazy_add(2 * i, tag_add[i], left_count);
lazy_add(2 * i + 1, tag_add[i], right_count);
tag_add[i] = 0;
}
}
void range_add(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) {
lazy_add(i, val, r - l + 1);
return;
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_add(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_add(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
int64_t range_sum(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return sum[i]; }
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
int64_t res = 0;
if (ql <= mid) { res += range_sum(ql, qr, 2 * i, l, mid); }
if (qr > mid) { res += range_sum(ql, qr, 2 * i + 1, mid + 1, r); }
return res;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int64_t n, m;
cin >> n >> m;
vector<int64_t> nums(n + 1);
vector<int64_t> diff(n + 1);
for (int64_t i = 1; i <= n; ++i) {
cin >> nums[i];
diff[i] = nums[i] - nums[i - 1];
}
segment_tree seg_tree(n);
seg_tree.build(1, 1, n, diff);
for (int64_t i = 0; i < m; ++i) {
int64_t op;
cin >> op;
if (op == 1) {
int64_t l, r, k, d;
cin >> l >> r >> k >> d;
// 加上等差数列, l位置加k, [l+1,r]区间加d, r+1位置减去k+(r-l)*d
seg_tree.range_add(l, l, k, 1, 1, n);
if (l + 1 <= r) { seg_tree.range_add(l + 1, r, d, 1, 1, n); }
if (r + 1 <= n) {
int64_t val = k + (r - l) * d;
seg_tree.range_add(r + 1, r + 1, -val, 1, 1, n);
}
} else if (op == 2) {
int64_t index;
cin >> index;
int64_t result = seg_tree.range_sum(1, index, 1, 1, n);
cout << result << '\n';
}
}
return 0;
}
线段树二分⚓︎
线段树二分(\text{Segment Tree Binary Search})是一种结合线段树和二分查找的方法,用于在区间内快速定位满足某些条件的元素位置。其基本思路是利用线段树的区间信息,在查询时通过二分查找来缩小搜索范围,从而提高查询效率。
线段树上二分模版
给定一个长度为 n 的数组,每组询问给出 l,r,k,询问在 [l,r] 中从左往右数第一个大于 k 的数的位置(下标)。如果不存在这样的数,输出 −1。
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> maximum; // 区间最大值
explicit segment_tree(int64_t n) : maximum(n * 4) {}
// 构建线段树
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) { // 叶子节点,进行初始化
maximum[i] = nums[left];
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
maximum[i] = max(maximum[2 * i], maximum[2 * i + 1]);
}
int64_t query(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (maximum[i] <= val) { return -1; }
if (l == r) { return l; }
if (ql <= l && r <= qr) { // 当前区间完全包含在查询区间内, 分别查询左右子树
if (maximum[2 * i] > val) { return query(ql, qr, val, 2 * i, l, (l + r) / 2); }
return query(ql, qr, val, 2 * i + 1, (l + r) / 2 + 1, r);
}
int64_t mid = l + ((r - l) / 2);
int64_t res = -1;
// 左子树有交集
if (ql <= mid) {
res = query(ql, qr, val, 2 * i, l, mid);
if (res != -1) { return res; } // 左子树找到答案,直接返回
}
// 右子树有交集
if (qr > mid) { res = query(ql, qr, val, 2 * i + 1, mid + 1, r); }
return res;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int64_t n, m;
cin >> n >> m;
vector<int64_t> nums(n + 1);
for (int64_t i = 1; i <= n; ++i) { cin >> nums[i]; }
segment_tree seg(n);
seg.build(1, 1, n, nums);
while ((m--) != 0) {
int64_t l, r, x;
cin >> l >> r >> x;
cout << seg.query(l, r, x, 1, 1, n) << '\n';
}
return 0;
}
权值线段树⚓︎
权值线段树(\text{Weight Segment Tree})对权值(值域)进行维护,其基本思路是将每个节点的值视为一个权值,维护一个线段树来表示这些权值的分布情况,常常需要对值域进行离散化操作。
逆序对
求数组中的逆序对数。
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> sum; // 区间和
explicit segment_tree(int64_t n) : sum(n * 4) {}
void push_up(int64_t i) { sum[i] = sum[2 * i] + sum[2 * i + 1]; }
// 单点修改: point_add(x, val, 1, 1, n) 将下标 x 的值加上 val
void point_add(int64_t index, int64_t val, int64_t i, int64_t left, int64_t right) {
if (left == index && right == index) { // 到叶子,直接修改数组中的值
sum[i] += val;
return;
}
int64_t mid = left + ((right - left) / 2);
if (index <= mid) { // 继续往下找
point_add(index, val, 2 * i, left, mid);
} else {
point_add(index, val, 2 * i + 1, mid + 1, right);
}
push_up(i); // 更新当前节点的值
}
// 区间求和: range_sum(x, y, 1, 1, n) 查询区间 [x,y] 的和
int64_t range_sum(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql > r || qr < l) { return 0; } // 无覆盖
if (ql <= l && r <= qr) { return sum[i]; } // 区间覆盖,直接返回
int64_t mid = l + ((r - l) / 2);
// 汇总结果
int64_t res = 0;
if (ql <= mid) { res += range_sum(ql, qr, 2 * i, l, mid); }
if (qr > mid) { res += range_sum(ql, qr, 2 * i + 1, mid + 1, r); }
return res;
}
};
int main() {
int n;
cin >> n;
vector<int> nums(n);
for (int i = 0; i < n; ++i) { cin >> nums[i]; }
vector<int> sorted_nums = nums;
sort(sorted_nums.begin(), sorted_nums.end());
auto get_rank = [&](int num) {
return lower_bound(sorted_nums.begin(), sorted_nums.end(), num) - sorted_nums.begin() + 1;
};
int64_t ans = 0;
segment_tree seg(n + 1);
for (int64_t i = n - 1; i >= 0; --i) {
int rank = get_rank(nums[i]);
ans += seg.range_sum(1, rank - 1, 1, 1, n);
seg.point_add(rank, 1, 1, 1, n);
}
cout << ans << '\n';
return 0;
}
非对称信息维护⚓︎
在区间查询时,有时需要维护一些非对称的信息,比如区间的最大子段和、最大前缀和、最大后缀和等。这些信息不能简单地通过左右子区间的信息合并得到,需要额外的维护。
区间最大子段和
求数组的区间最大子段和。
最大子段和
一个区间的最大子段和可能出现在左子区间、右子区间,或者跨越左右子区间的中点。 因此,最大子段和需要维护四个属性:区间和、最大子段和、最大前缀和、最大后缀和。 其中,最大前缀和是指从区间左端点开始的连续子段的最大和,最大后缀和是指从区间右端点开始的连续子段的最大和。
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> sum; // 区间和
vector<int64_t> max_seq; // 区间最大子序列和
vector<int64_t> max_prefix; // 区间最大前缀和
vector<int64_t> max_suffix; // 区间最大后缀和
vector<int64_t> tag_set; // 区间赋值懒标记
vector<int64_t> tag_set_val; // 区间赋值懒标记值, 只有tag_set为true时该值才有意义
explicit segment_tree(int64_t n)
: sum(n * 4),
max_seq(n * 4),
max_prefix(n * 4),
max_suffix(n * 4),
tag_set(n * 4),
tag_set_val(n * 4) {}
void push_up(int64_t i) {
sum[i] = sum[2 * i] + sum[2 * i + 1];
// 最大子序列和相关, 需要考虑跨越左右子树的情况
max_prefix[i] = max(max_prefix[2 * i], sum[2 * i] + max_prefix[2 * i + 1]);
max_suffix[i] = max(max_suffix[2 * i + 1], sum[2 * i + 1] + max_suffix[2 * i]);
max_seq[i]
= max({max_seq[2 * i], max_seq[2 * i + 1], max_suffix[2 * i] + max_prefix[2 * i + 1]});
}
// 构建线段树
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) { // 叶子节点,进行初始化
// 包含自身初始化为val, 否则初始化为max(0, val)
sum[i] = max_seq[i] = max_prefix[i] = max_suffix[i] = nums[left];
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
push_up(i);
}
void lazy_set(int64_t i, int64_t val, int64_t count) {
sum[i] = count * val;
// 若val为负数, 则最大子序列和, 前缀和, 后缀和均取val, 否则取 val*count
max_seq[i] = max_prefix[i] = max_suffix[i] = max(val, val * count);
tag_set[i] = 1;
tag_set_val[i] = val;
}
// 向下传递懒标记
void push_down(int64_t i, int64_t left_count, int64_t right_count) {
if (tag_set[i] != 0) { // 处理赋值
lazy_set(2 * i, tag_set_val[i], left_count);
lazy_set(2 * i + 1, tag_set_val[i], right_count);
tag_set[i] = 0; // 清空根节点赋值标记
}
}
// 区间赋值: range_set(x, y, val, 1, 1, n) 将区间 [x,y] 的值修改为 val
void range_set(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { // 区间覆盖, 直接更新
lazy_set(i, val, r - l + 1);
return;
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_set(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_set(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
struct Node {
int64_t sum, max_seq, max_prefix, max_suffix;
};
// 区间最大子序列和: range_maxseq(x, y, 1, 1, n) 查询区间 [x,y] 的最大子序列和
Node range_maxseq(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { // 区间覆盖,直接返回
return {.sum = sum[i],
.max_seq = max_seq[i],
.max_prefix = max_prefix[i],
.max_suffix = max_suffix[i]};
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
// 查询区间完全在左子树或右子树
if (qr <= mid) { return range_maxseq(ql, qr, 2 * i, l, mid); }
if (ql > mid) { return range_maxseq(ql, qr, 2 * i + 1, mid + 1, r); }
// 查询区间跨越左右子树, 需要合并结果
Node left_res = range_maxseq(ql, qr, 2 * i, l, mid);
Node right_res = range_maxseq(ql, qr, 2 * i + 1, mid + 1, r);
Node res;
res.sum = left_res.sum + right_res.sum; // 用于更新父节点的前缀和后缀和
res.max_prefix = max(left_res.max_prefix, left_res.sum + right_res.max_prefix);
res.max_suffix = max(right_res.max_suffix, right_res.sum + left_res.max_suffix);
res.max_seq
= max({left_res.max_seq, right_res.max_seq, left_res.max_suffix + right_res.max_prefix});
return res;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int64_t n;
cin >> n;
vector<int64_t> nums(n + 1);
for (int64_t i = 1; i <= n; i++) { cin >> nums[i]; }
segment_tree seg(n);
seg.build(1, 1, n, nums);
int64_t m;
cin >> m;
for (int64_t i = 0; i < m; i++) {
int64_t op;
cin >> op;
if (op == 0) { // 区间赋值
int64_t x, y;
cin >> x >> y;
seg.range_set(x, x, y, 1, 1, n);
} else { // 区间最大子序列和
int64_t l, r;
cin >> l >> r;
cout << seg.range_maxseq(l, r, 1, 1, n).max_seq << '\n';
}
}
return 0;
}
区间最值和历史最值⚓︎
在区间查询时,返回的不仅是当前的最值,还需要返回历史最值,即在所有的更新操作中,该区间曾经出现过的最值。
【模板】线段树 3(区间最值操作、区间历史最值)
区间设置最值、区间历史最值
Hint
- 区间设置最最小值操作只能将区间内的值变小,不能变大,因此如果区间的最大值已经小于等于设置值,则不需要进行任何操作。(设置最大值同理,维护区间最小值即可)
- 区间历史最值需要在每次更新时,比较当前最大值与历史最大值,取较大者。
- 区间加法操作分为对最大值的加法和对非最大值的加法两种情况,需要分别维护两个懒标记。
- 区间加法操作会影响历史最值,因此需要分别维护最大值和非最大值达到过的最大涨幅。
#include <algorithm>
#include <cstdint>
#include <ios>
#include <iostream>
#include <vector>
using namespace std;
struct segment_tree {
vector<int64_t> sum; // 区间和
vector<int64_t> maximum; // 区间最大值
vector<int64_t> max_count; // 区间内最大值的个数
vector<int64_t> second_max; // 区间内严格第二大值
vector<int64_t> history_max; // 区间内历史最大值
vector<int64_t> max_add; // 最大值的加法懒惰标记
vector<int64_t> other_add; // 非最大值的加法
vector<int64_t> max_up; // 最大值达到过的最大涨幅
vector<int64_t> other_up; // 非最大值达到过的最大涨幅
explicit segment_tree(int64_t n)
: sum(n * 4),
maximum(n * 4),
max_count(n * 4),
second_max(n * 4),
history_max(n * 4),
max_add(n * 4),
other_add(n * 4),
max_up(n * 4),
other_up(n * 4) {}
void push_up(int64_t i) {
sum[i] = sum[2 * i] + sum[2 * i + 1];
maximum[i] = max(maximum[2 * i], maximum[2 * i + 1]);
history_max[i] = max(history_max[2 * i], history_max[2 * i + 1]);
if (maximum[2 * i] > maximum[2 * i + 1]) { // 左子树最大值更大
max_count[i] = max_count[2 * i];
second_max[i] = max(second_max[2 * i], maximum[2 * i + 1]);
} else if (maximum[2 * i] < maximum[2 * i + 1]) { // 右子树最大值更大
max_count[i] = max_count[2 * i + 1];
second_max[i] = max(second_max[2 * i + 1], maximum[2 * i]);
} else { // 两个子树最大值相等
max_count[i] = max_count[2 * i] + max_count[2 * i + 1];
second_max[i] = max(second_max[2 * i], second_max[2 * i + 1]);
}
}
void build(int64_t i, int64_t left, int64_t right, const vector<int64_t> &nums) {
if (left == right) {
sum[i] = maximum[i] = history_max[i] = nums[left];
max_count[i] = 1;
second_max[i] = INT64_MIN; // 只有一个元素没有第二大值
return;
}
int64_t mid = left + ((right - left) / 2);
build(2 * i, left, mid, nums);
build(2 * i + 1, mid + 1, right, nums);
push_up(i);
}
// max_add_val: 最大值的加法更新值, other_add_val: 非最大值的加法更新值
// max_up: 最大值的涨幅, other_up: 非最大值的涨幅
void lazy_add(int64_t i, int64_t count, int64_t max_add_val, int64_t other_add_val,
int64_t max_up_val, int64_t other_up_val) {
// 首先维护区间历史最大值
history_max[i] = max(history_max[i], maximum[i] + max_up_val);
max_up[i] = max(max_up[i], max_add[i] + max_up_val);
other_up[i] = max(other_up[i], other_add[i] + other_up_val);
// 维护节点值
sum[i] += max_add_val * max_count[i] + other_add_val * (count - max_count[i]);
maximum[i] += max_add_val;
second_max[i] += second_max[i] == INT64_MIN ? 0 : other_add_val;
// 维护懒惰标记
max_add[i] += max_add_val;
other_add[i] += other_add_val;
}
void push_down(int64_t i, int64_t left_count, int64_t right_count) {
int64_t max_val = max(maximum[2 * i], maximum[2 * i + 1]);
if (maximum[2 * i] == max_val) { // 左子树最大值等于当前节点最大值
lazy_add(2 * i, left_count, max_add[i], other_add[i], max_up[i], other_up[i]);
} else { // 左子树最大值小于当前节点最大值
lazy_add(2 * i, left_count, other_add[i], other_add[i], other_up[i], other_up[i]);
}
if (maximum[2 * i + 1] == max_val) { // 右子树最大值等于当前节点最大值
lazy_add(2 * i + 1, right_count, max_add[i], other_add[i], max_up[i], other_up[i]);
} else { // 右子树最大值小于当前节点最大值
lazy_add(2 * i + 1, right_count, other_add[i], other_add[i], other_up[i], other_up[i]);
}
// 清空根节点的加法标记
max_add[i] = other_add[i] = 0;
max_up[i] = other_up[i] = 0;
}
void range_add(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) {
lazy_add(i, r - l + 1, val, val, val, val);
return;
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_add(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_add(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
// 区间取min操作, 如果当前节点值大于等于val则更新为val, 否则不更新
void range_set_min(int64_t ql, int64_t qr, int64_t val, int64_t i, int64_t l, int64_t r) {
if (val >= maximum[i]) { return; } // 当前节点值已经小于等于val, 不需要更新
if (ql <= l && r <= qr) {
if (val > second_max[i]) { // 只会影响最大值, 不会影响第二大值, 直接更新返回
lazy_add(i, r - l + 1, val - maximum[i], 0, val - maximum[i], 0);
return;
}
// 需要影响非最大值, 继续向下传递
}
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
if (ql <= mid) { range_set_min(ql, qr, val, 2 * i, l, mid); }
if (qr > mid) { range_set_min(ql, qr, val, 2 * i + 1, mid + 1, r); }
push_up(i);
}
// 查询区间和: range_sum(x, y, 1, 1, n) 表示查询区间 [x,y] 的和
int64_t range_sum(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return sum[i]; } // 区间覆盖,直接返回
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
// 汇总结果
int64_t res = 0;
if (ql <= mid) { res += range_sum(ql, qr, 2 * i, l, mid); }
if (qr > mid) { res += range_sum(ql, qr, 2 * i + 1, mid + 1, r); }
return res;
}
// 查询区间最大值: range_max(x, y, 1, 1, n) 查询区间 [x,y] 的最大值
int64_t range_max(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return maximum[i]; } // 区间覆盖,直接返回
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
// 汇总结果
int64_t res = INT64_MIN;
if (ql <= mid) { res = max(res, range_max(ql, qr, 2 * i, l, mid)); }
if (qr > mid) { res = max(res, range_max(ql, qr, 2 * i + 1, mid + 1, r)); }
return res;
}
// 查询区间历史最大值: range_history_max(x, y, 1, 1, n) 查询区间 [x,y]
// 的历史最大值
int64_t range_history_max(int64_t ql, int64_t qr, int64_t i, int64_t l, int64_t r) {
if (ql <= l && r <= qr) { return history_max[i]; }
int64_t mid = l + ((r - l) / 2);
push_down(i, mid - l + 1, r - mid);
// 分割区间
int64_t res = INT64_MIN;
if (ql <= mid) { res = max(res, range_history_max(ql, qr, 2 * i, l, mid)); }
if (qr > mid) { res = max(res, range_history_max(ql, qr, 2 * i + 1, mid + 1, r)); }
return res;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, m;
cin >> n >> m;
vector<int64_t> nums(n + 1);
for (int i = 1; i <= n; i++) { cin >> nums[i]; }
segment_tree seg(n);
seg.build(1, 1, n, nums);
for (int i = 0; i < m; i++) {
int op;
cin >> op;
if (op == 1) {
int l, r, k;
cin >> l >> r >> k;
seg.range_add(l, r, k, 1, 1, n);
} else if (op == 2) {
int l, r, v;
cin >> l >> r >> v;
seg.range_set_min(l, r, v, 1, 1, n);
} else if (op == 3) {
int l, r;
cin >> l >> r;
cout << seg.range_sum(l, r, 1, 1, n) << '\n';
} else if (op == 4) {
int l, r;
cin >> l >> r;
cout << seg.range_max(l, r, 1, 1, n) << '\n';
} else if (op == 5) {
int l, r;
cin >> l >> r;
cout << seg.range_history_max(l, r, 1, 1, n) << '\n';
}
}
return 0;
}
动态开点线段树⚓︎
动态开点线段树(\text{Dynamic Segment Tree})适用于处理大范围的区间查询和更新操作,尤其是在值域较大但实际使用的点较少的情况下。动态开点线段树通过按需创建节点来节省空间。