线段树介绍:
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。对于线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2+1,b] (除法是向下取整)。因此线段树是平衡二叉树,最后的子节点数目为N,即整个线段区间的长度。使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,实际应用时一般还要开4N的数组以免越界,因此有时需要离散化让空间压缩。
4N的原因:
设倒数第二层是n个,那么第1->(n-1)层总共最多(n-1)个,最后一层2n个,那么就是
n + ( n − 1 ) + 2 n = 4 n − 1 n+(n-1)+2n=4n-1 n+(n−1)+2n=4n−1
线段[1,10]的线段树结构图:
线段树建树:
void build(int i,int l,int r){//递归建树
tree[i].l=l;tree[i].r=r;
if(l==r){//如果这个节点是叶子节点
tree[i].sum=input[l];
return ;
}
int mid=(l+r)>>1;
build(i*2,l,mid);//分别构造左子树和右子树
build(i*2+1,mid+1,r);
tree[i].sum=tree[i*2].sum+tree[i*2+1].sum;//线段树性质 return ;
//一颗二叉树,她的左儿子和右儿子编号分别是她*2和她*2+1
}
区间查询:
inline int search(int i,int l,int r){
if(tree[i].l>=l && tree[i].r<=r)//如果这个区间被完全包括在目标区间里面,直接返回这个区间的值
return tree[i].sum;
if(tree[i].r<l || tree[i].l>r) return 0;//如果这个区间和目标区间毫不相干,返回0
int mid=tree[i].l+tree[i].r>>1;
int s=0;
if(l <= mid) s+=search(i*2,l,r);//如果这个区间的左儿子和目标区间有交集,那么搜索左儿子
if(r > mid) s+=search(i*2+1,l,r);//如果这个区间的右儿子和目标区间有交集,那么搜索右儿子
return s;
}
单点修改:
void add(int i,int dis,int k){
if(tree[i].l==tree[i].r){//如果是叶子节点,那么说明找到了
tree[i].sum+=k;
return ;
}
if(dis<=tree[i*2].r) add(i*2,dis,k);//在哪往哪跑
else add(i*2+1,dis,k);
tree[i].sum=tree[i*2].sum+tree[i*2+1].sum;//返回更新
return ;
}
例一:夹娃娃(牛客IOI周赛17-普及组)
代码:
#include<bits/stdc++.h>
#define LL long long
#define fo(i,a,b) for(int i=a;i<b;i++)
#define js ios::sync_with_stdio(false);cin.tie(0); cout.tie(0)
using namespace std;
const int maxn = 1e5 + 7;
struct Tree {
int l,r;
int sum;
}t[maxn<<2];
void build(int i,int l,int r) {
t[i].l = l;t[i].r = r;
if(l == r) {
scanf("%d",&t[i].sum);
return;
}
int m = (l + r) >> 1;
build(i * 2,l,m);
build(i * 2 + 1,m + 1,r);
t[i].sum = t[i * 2].sum + t[i * 2 + 1].sum;
}
int query(int i,int x,int y) {
if(x <= t[i].l && t[i].r <= y) {return t[i].sum;}
if(t[i].r<x || t[i].l>y) return 0;
int mid=t[i].l+t[i].r>>1;
int res = 0;
if(x <= mid) res += query(i<<1,x,y);
if(y > mid) res += query(i<<1|1,x,y);
return res;
}
int main(){
int n,k;
scanf("%d%d",&n,&k);
build(1,1,n);
for(int i = 1;i <= k;i++) {
int x,y;scanf("%d%d",&x,&y);
printf("%d\n",query(1,x,y));
}
return 0;
}
代码:
#include<bits/stdc++.h>
#define LL long long
#define fo(i,a,b) for(int i=a;i<b;i++)
using namespace std;
const int maxn=5e5;
struct tree{
LL l,r,sum;
}t[maxn<<2];
void build(int i,int l,int r){
t[i].l=l;t[i].r=r;
if(l==r){
scanf("%lld",&t[i].sum);
return ;
}
int m=(l+r)>>1;
build(i*2,l,m); build(i*2+1,m+1,r);
t[i].sum=t[i*2].sum+t[i*2+1].sum;
}
void add(int i,int pos,int n){
if(t[i].l==t[i].r){
t[i].sum+=n;
return ;
}
LL mid=t[i].l+t[i].r>>1;
if(pos<=mid)add(i*2,pos,n);
else add(i*2+1,pos,n);
t[i].sum=t[i*2].sum+t[i*2+1].sum;
return;//注意
}
LL search(int i,int l,int r){
if(t[i].l>=l&&t[i].r<=r)return t[i].sum;
if(t[i].r<l||t[i].l>r)return 0;//剪枝
LL sum=0;
LL mid=t[i].l+t[i].r>>1;
if(l<=mid)sum+=search(2*i,l,r);
if(r>mid)sum+=search(2*i+1,l,r);
return sum;//注意
}
int main(){
int n,m;
scanf("%d%d",&n,&m);
build(1,1,n);
int q,x,k;
while(m--){
scanf("%d%d%d",&q,&x,&k);
if(q==1)add(1,x,k);
else {
LL re=search(1,x,k);
printf("%lld\n",re);
}
}
return 0;
}
例三:最大数
代码:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 200010;
int m, p;
struct Node{
int l, r;
int v; // 区间[l, r]中的最大值
}tr[N * 4];
//pushup:用子节点的信息来计算父节点的信息
void pushup(int u){
//父节点最大值=max(左儿子节点最大值,右儿子节点最大值)
tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}
void build(int u, int l, int r){
tr[u] = {l, r};
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
int query(int u, int l, int r){
if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;// 树中节点,已经被完全包含在[l, r]中了
int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
if (l <= mid) v = query(u << 1, l, r);//否则分别考虑再左节点和右节点是否有查询区间
if (r > mid) v = max(v, query(u << 1 | 1, l, r));
return v;
}
void modify(int u, int x, int v){
if (tr[u].l == x && tr[u].r == x) tr[u].v = v;//找到了要添加的位置
else{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
int main(){
int n = 0, last = 0;
scanf("%d%d", &m, &p);
build(1, 1, m);
int x;
char op[2];
while (m -- ){
scanf("%s%d", op, &x);
if (*op == 'Q'){
last = query(1, n - x + 1, n);
printf("%d\n", last);
}
else{
modify(1, n + 1, (last + x) % p);
n ++ ;
}
}
return 0;
}
部分参考: 线段树 从入门到进阶