Segment Tree

Before:Long time no see!最近强度激增,B+树de了一整天,还有MLE和TLE,感觉是某个地方写爆栈了😢数理逻辑作业还没写完,下周还要复习数分。

upd:昨天晚上过了,打算51把管理系统速通了,挑点bonus写一下(其实想全写但显然没空),弥补一下bookstore啥bonus都没写的遗憾。希望周五数分顺利!

Algorithm Of Data Structure —— Segment Tree

Definition

  • 二叉树
  • 每一个叶子节点维护原序列的信息
  • 每个中间结点维护一段区间信息
  • 通过子节点的高效信息合并得到中间节点的区间信息

An example of Segment Tree:
Simple Segment Tree

Basic Rules

  • Apparently,线段树是一棵完全二叉树,故树高 h = log2nlog_2 n
  • 结点总数约为 2n,故O(n) = 2n
  • 每个结点代表一个区间[l,r],并维护该区间的信息,如区间内的数字和、最大(小)值等。该区间信息由两个分别代表[l, mid],[mid+1, r]的子结点合并而来。

Storage

1
2
3
4
5
6
7
class SegmentTree{
struct TreeNode{
int l,r; //左右端点
int lson,rson; //左右儿子编号
int data; //关键数据
}nodes[maxN << 2];
}

问题来了,存树的数组应该开多大?
通过一些数学证明,我们可以得到,数组应该开成 4n

Build Tree

Build Segment Tree

Example:求区间最大值

1
2
3
4
void push_up(int node){
auto &cur = nodes[node];
cur.data = max(nodes[cur.lson],nodes[cur.rson]);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void build(int node,int l,int r){
auto &cur = nodes[node];
cur.l = l;
cur.r = r;
if(l == r){
cur.data = data[l];
return;
}
auto mid = (l + r) / 2;
cur.lson = 2 * node;
cur.rson = 2 * node + 1;
build(cur.lson,l,mid);
build(cur.rson,mid + 1,r);
push_up(node);
}

Query

A Good Picture!
Query

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int query(int node,int l,int r){
auto &cur = nodes[node];
if(l <= cur.l && cur.r <= r){
return cur.data;
}
int res = 0;
int mid = (l + r)/2;
if(l <= mid){
res = max(res,query(cur.lson,l,r));
}
if(r > mid){
res = max(res,query(cur.rson,l,r));
}
}

Single Point Update

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void update(int node,int pos,int val){
auto &cur = nodes[node];
if(cur.lson == cur.rson){
cur.data = val;
return;
}
int mid = (cur.l + cur.r)/2;
if(pos <= mid){
update(cur.lson,pos,val);
}
if(pos > mid){
update(cur.rson,pos,val);
}
push_up(node);
}

Range Update

区间查询的核心思想:区间信息上放
复杂度:O(logN)

区间修改能不能也使用相同的思想?
若将区间 [l, r] 内元素都加上某个值:

  • 若维护最值,直接 cur.data+=val
  • 若维护区间和,则 cur.data+=(r −l+1)∗ val 
    发现:有时我们会对整棵子树做同样的操作

Solution: Lazy Tag!

1
2
3
4
5
6
7
8
class SegmentTree{
struct TreeNode{
int l,r; //左右端点
int lson,rson; //左右儿子编号
int lazy_tag;
int data; //关键数据
}nodes[maxN << 2];
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void modify(int node,int l,int r,int val){
auto &cur = nodes[node];
if(l <= cur.l && cur.r <= r){
cur.data += val;
cur.lazytag += val;
return;
}
push_down(node);
int mid = (cur.l + cur.r)/2;
if(l <= mid){
modify(cur.lson,l,r,val);
}
if(r > mid){
modify(cur.rson,l,r,val);
}
push_up(node);
}
1
2
3
4
5
6
7
8
9
10
void push_down(int node){
auto &cur = nodes[node];
if(cur.lazytag != 0){
nodes[cur.lson].data += cur.lazytag;
nodes[cur.rson].data += cur.lazytag;
nodes[cur.lson].lazytag = cur.lazytag;
nodes[cur.rson].lazytag = cur.lazytag;
cur.lazytag = 0;
}
}

Template Problem

已知一个数列,你需要进行下面三种操作:

  • 将某区间每一个数乘上 x;
  • 将某区间每一个数加上 x;
  • 求出某区间每一个数的和。

(add_tag,mul_tag)------(* k)----->(add_tag * k,mul_tag * k)
(add_tag,mul_tag)------(+ k)----->(add_tag + k,mul_tag)

AC代码捏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <iostream>
using namespace std;
long long int n,m,mod;
long long int num[100005];
struct SegmentTree {
long long int l,r;
long long int lson,rson;
long long int lazytag;
long long int multiTag = 1;
long long int data;
}nodes[400005];
void push_up(long long int node) {
auto &cur = nodes[node];
cur.data = nodes[cur.lson].data + nodes[cur.rson].data;
cur.data = (cur.data + mod)%mod;
}
void push_down(long long int node) {
auto &cur = nodes[node];
if(cur.multiTag != 1) {
nodes[cur.lson].multiTag *= cur.multiTag;
nodes[cur.lson].multiTag = (nodes[cur.lson].multiTag + mod)%mod;
nodes[cur.rson].multiTag *= cur.multiTag;
nodes[cur.rson].multiTag = (nodes[cur.rson].multiTag + mod)%mod;
nodes[cur.lson].lazytag *= cur.multiTag;
nodes[cur.lson].lazytag = (nodes[cur.lson].lazytag + mod)%mod;
nodes[cur.rson].lazytag *= cur.multiTag;
nodes[cur.rson].lazytag = (nodes[cur.rson].lazytag + mod)%mod;
nodes[cur.lson].data *= cur.multiTag;
nodes[cur.lson].data = (nodes[cur.lson].data + mod)%mod;
nodes[cur.rson].data *= cur.multiTag;
nodes[cur.rson].data = (nodes[cur.rson].data + mod)%mod;
cur.multiTag = 1;
}
if(cur.lazytag != 0) {
nodes[cur.lson].lazytag += cur.lazytag;
nodes[cur.lson].lazytag = (nodes[cur.lson].lazytag + mod)%mod;
nodes[cur.rson].lazytag += cur.lazytag;
nodes[cur.rson].lazytag = (nodes[cur.rson].lazytag + mod)%mod;
nodes[cur.lson].data += (nodes[cur.lson].r - nodes[cur.lson].l + 1)*cur.lazytag;
nodes[cur.lson].data = (nodes[cur.lson].data + mod)%mod;
nodes[cur.rson].data += (nodes[cur.rson].r - nodes[cur.rson].l + 1)*cur.lazytag;
nodes[cur.rson].data = (nodes[cur.rson].data + mod)%mod;
cur.lazytag = 0;
}
}
void buildTree(long long int node,long long int l,long long int r) {
auto &cur = nodes[node];
cur.multiTag = 1;
cur.l = l;
cur.r = r;
if(l == r) {
cur.data = num[l];
return;
}
long long int mid = (l + r)/2;
cur.lson = 2 * node;
cur.rson = 2 * node + 1;
buildTree(cur.lson,l,mid);
buildTree(cur.rson,mid + 1,r);
push_up(node);
}
void update(long long int node,long long int l,long long int r,long long int val) {
auto &cur = nodes[node];
if(cur.l >= l && cur.r <= r) {
cur.data += (cur.r - cur.l + 1)*val;
cur.data = (cur.data + mod)%mod;
cur.lazytag += val;
cur.lazytag = (cur.lazytag + mod)%mod;
return;
}
push_down(node);
long long int mid = (cur.l + cur.r)/2;
if(l <= mid) {
update(cur.lson,l,r,val);
}
if(r > mid) {
update(cur.rson,l,r,val);
}
push_up(node);
}
void update2(long long int node,long long int l,long long int r,long long int val) {
auto &cur = nodes[node];
if(cur.l >= l && cur.r <= r) {
cur.data *= val;
cur.data = (cur.data + mod)%mod;
cur.lazytag *= val;
cur.lazytag = (cur.lazytag + mod)%mod;
cur.multiTag *= val;
cur.multiTag = (cur.multiTag + mod)%mod;
return;
}
push_down(node);
long long int mid = (cur.l + cur.r)/2;
if(l <= mid) {
update2(cur.lson,l,r,val);
}
if(r > mid) {
update2(cur.rson,l,r,val);
}
push_up(node);
}
long long int query(long long int node,long long int l,long long int r) {
auto &cur = nodes[node];
if(l <= cur.l && cur.r <= r) {
return cur.data;
}
push_down(node);
long long int mid = (cur.l + cur.r)/2;
long long int res = 0;
if(l <= mid) {
res += query(cur.lson,l,r);
}
res = (res + mod)%mod;
if(r > mid) {
res += query(cur.rson,l,r);
}
res = (res + mod)%mod;
return res;
}
int main() {
cin >> n >> m >> mod;
for(int i = 1;i <= n;i ++) {
cin >> num[i];
}
buildTree(1,1,n);
long long int x,y,k;
for(int i = 1;i <= m;i ++) {
int op;
cin >> op;
if(op == 2) {
cin >> x >> y >> k;
update(1,x,y,k);
}else if(op == 3) {
cin >> x >> y;
cout << query(1,x,y) << '\n';
}else if(op == 1) {
cin >> x >> y >> k;
update2(1,x,y,k);
}
}
return 0;
}

特别值得注意的是,push_down时应该先更新multiTag,再更新addTag(先乘后加的原则)

Classic Problem

POJ 3667 - Hotel

现有一排房间

  • 询问:是否有连续x个空房间;如果有,就将最靠前的连续x个房间填满
  • 修改:将任意一段房间清空(可能本来就是空的)

每个结点维护子树区间内左端极长、右端极长和最长空房间(l_max,r_max,max)
Tag:full/null/empty


Segment Tree
https://janezair.site/2025/04/20/AlgorithmOfDS3/
Author
Yihan Zhu
Posted on
April 20, 2025
Updated on
April 21, 2025
Licensed under