《信息学奥赛一本通·提高篇》 数据结构第3节—线段树
【 例 1】区间和
单点修改模板题
#include <iostream>
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
struct Node
{
int l, r;
LL sum;
}tr[N * 4];
int n, m;
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r)
{
if(l == r) tr[u] = {l, r};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, int v)
{
if(tr[u].l == x && tr[u].r == x) tr[u].sum += v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
LL query(int u, int l, int r)
{
if(l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
else
{
LL sum = 0;
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) sum += query(u << 1, l, r);
if(r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
}
int main()
{
cin >> n >> m;
build(1, 1, n);
int t, a, b;
while(m --)
{
scanf("%d%d%d", &t, &a, &b);
if(t == 0) modify(1, a, b);
else printf("%lld\n", query(1, a, b));
}
return 0;
}
【例 2】A Simple Problem with Integers
区间修改模板题
#include <iostream>
using namespace std;
const int N = 1e6 + 10;
typedef long long LL;
struct Node
{
int l, r;
LL sum, add;
}tr[N * 4];
int n, m, a[N];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r)
{
if(l == r) tr[u] = {l, r, a[l], 0};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void change(int u, int v)
{
tr[u].sum += (LL)v * (tr[u].r - tr[u].l + 1);
tr[u].add += v;
}
void pushdown(int u)
{
if(tr[u].add)
{
change(u << 1, tr[u].add);
change(u << 1 | 1, tr[u].add);
tr[u].add = 0;
}
}
void modify(int u, int l, int r, int v)
{
if(l <= tr[u].l && tr[u].r <= r) change(u, v);
else
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(u << 1, l, r, v);
if(r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
}
LL query(int u, int l, int r)
{
if(l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
else
{
pushdown(u);
LL sum = 0;
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) sum += query(u << 1, l, r);
if(r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i ++) scanf("%d", &a[i]);
build(1, 1, n);
int op, l, r, x;
while(m --)
{
scanf("%d%d%d", &op, &l, &r);
if(op == 1)
{
scanf("%d", &x);
modify(1, l, r, x);
}
else printf("%lld\n", query(1, l, r));
}
return 0;
}
最大数
又是模板题
#include <iostream>
using namespace std;
typedef long long LL;
const int N = 2e5 + 10;
struct Node
{
int l, r;
int ma;
}tr[N * 4];
int n, m, a[N];
void pushup(int u)
{
tr[u].ma = max(tr[u << 1].ma, tr[u << 1 | 1].ma);
}
void build(int u, int l, int r)
{
if(l == r) tr[u] = {l, r};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, int v)
{
if(tr[u].l == x && tr[u].r == x) tr[u].ma = v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
int query(int u, int l, int r)
{
if(l <= tr[u].l && tr[u].r <= r) return tr[u].ma;
else
{
int res = 0;
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) res = max(res, query(u << 1, l, r));
if(r > mid) res = max(res, query(u << 1 | 1, l, r));
return res;
}
}
int main()
{
cin >> n >> m;
build(1, 1, n);//n个操作对应n个节点,直接一次性把n个节点先建出来
char op[2];
int cnt = 0, last = 0, x;//last表示最后一次询问操作的答案
while(n --)
{
scanf("%s%d", op, &x);
if(*op == 'Q')
{
last = query(1, cnt - x + 1, cnt);
printf("%d\n", last);
}
else modify(1, ++ cnt, ((LL)last + x) % m);
//注意这里两个int加起来可能爆int
}
return 0;
}
花神游历各国
思路:
对于区间开根号,无法直接像加减一个数那样进行区间修改,所以依靠单点修改来实现区间修改
对于 1 1 1 和 0 0 0 无论如何开方都为它本身,并且数据的最大值为 1 0 9 10^9 109,所以每个数最多进行 5 5 5 次开方的操作就会变成 1 1 1
根据这一条件每个数值最多只需要修改 5 5 5 次,而区间长度最大为 1 0 5 10^5 105 所以修改操作最多执行 5 ∗ 1 0 5 5∗10^5 5∗105 次。
用线段树维护区间和,单次查询的时间复杂度为 O ( l o g n ) O(logn) O(logn),查询最差时间复杂度为 O ( m ∗ l o g n ) O(m∗logn) O(m∗logn),再维护一个值 a d d add add (你可以认为是懒标记)表示该区间是否需要进行开方的操作(当区间中的数值都为 1 1 1 和 0 0 0 时,该区间就不需要进行修改),这样修改最差的时间复杂度为 O ( n l o g n ) O(nlogn) O(nlogn)
时间复杂度 O ( m l o g n + n l o g n ) O(mlogn+nlogn) O(mlogn+nlogn)
#include <iostream>
#include <cstdio>
#include <cmath>
#define lson u << 1
#define rson u << 1 | 1
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
int n, m;
int a[N];
struct Node
{
int l, r;
LL sum;
int add;
}tr[N << 4];
template<class T>
inline void read(T &res)
{
char ch; bool flag = false;
while ((ch = getchar()) < '0' || ch > '9')
if (ch == '-') flag = true;
res = (ch ^ 48);
while ((ch = getchar()) >= '0' && ch <= '9')
res = (res << 3) + (res << 1) + (ch ^ 48);
if (flag) res = ~res + 1;
}
void pushup(int u)
{
tr[u].sum = tr[lson].sum + tr[rson].sum;
tr[u].add = tr[lson].add && tr[rson].add;
}
void build(int u, int l, int r)
{
tr[u].l = l, tr[u].r = r;
if (l == r)
{
tr[u].sum = a[l];
tr[u].add = (a[l] <= 1);
return ;
}
int mid = l + r >> 1;
build(lson, l, mid), build(rson, mid + 1, r);
pushup(u);
}
//虚假的区间修改
void modify(int u, int x, int y)
{
if (tr[u].add) return ;//剪枝
if (tr[u].l == tr[u].r)//靠单点修改实现区间修改
{
tr[u].sum = sqrt(tr[u].sum);
tr[u].add = (tr[u].sum <= 1);
return ;
}
int mid = tr[u].l + tr[u].r >> 1;
if (y <= mid) modify(lson, x, y);
else if (x > mid) modify(rson, x, y);
else
{
modify(lson, x, mid);
modify(rson, mid + 1, y);
}
pushup(u);
}
LL query(int u, int x, int y)
{
if (x <= tr[u].l && tr[u].r <= y) return tr[u].sum;
int mid = tr[u].l + tr[u].r >> 1;
LL res = 0;
if (x <= mid) res += query(lson, x, y);
if (y > mid) res += query(rson, x, y);
return res;
}
int main()
{
read(n);
for (int i = 1; i <= n; i ++) read(a[i]);
build(1, 1, n);
read(m);
for (int i = 1; i <= m; i ++)
{
int x, l, r;
read(x), read(l), read(r);
if (x == 1) printf("%lld\n", query(1, l, r));
else modify(1, l, r);
}
return 0;
}
维护序列
话不多说直接上代码
#include <iostream>
#include <cstdio>
#define lson u << 1
#define rson u << 1 | 1
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
int n, m, p;
int a[N];
struct Node
{
int l, r;
LL sum;
LL add, mul;//加法和乘法的懒标记
}tr[N << 2];
template<class T>
inline void read(T &res)
{
char ch; bool flag = false;
while ((ch = getchar()) < '0' || ch > '9')
if (ch == '-') flag = true;
res = (ch ^ 48);
while ((ch = getchar()) >= '0' && ch <= '9')
res = (res << 3) + (res << 1) + (ch ^ 48);
if (flag) res = ~res + 1;
}
void pushup(int u) {tr[u].sum = (tr[lson].sum + tr[rson].sum) % p;}
void change(int u, LL mul, LL add)//别忘了LL
{
tr[u].sum = (tr[u].sum * mul + add * (tr[u].r - tr[u].l + 1)) % p;
tr[u].add = (tr[u].add * mul + add) % p;
tr[u].mul = tr[u].mul * mul % p;
}
void pushdown(int u)
{
change(lson, tr[u].mul, tr[u].add);
change(rson, tr[u].mul, tr[u].add);
tr[u].add = 0, tr[u].mul = 1;//加法懒标记清零,乘法懒标记记为一
}
void build(int u, int l, int r)
{
tr[u].l = l, tr[u].r = r, tr[u].mul = 1;//乘法懒标记初始为1
if (l == r)
{
tr[u].sum = a[l];
return ;
}
int mid = l + r >> 1;
build(lson, l, mid), build (rson, mid + 1, r);
pushup(u);
}
void modify(int u, int x, int y, int mul, int add)
{
if (x <= tr[u].l && tr[u].r <= y)
{
change(u, mul, add);
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(lson, x, y, mul, add);
if (y > mid) modify(rson, x, y, mul, add);
pushup(u);
}
int query(int u, int x, int y)
{
if (x <= tr[u].l && tr[u].r <= y) return tr[u].sum % p;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL res = 0;
if (x <= mid) res += query(lson, x, y);
if (y > mid) res += query(rson, x, y);
return res % p;
}
int main()
{
read(n), read(p);
for (int i = 1; i <= n; i ++) read(a[i]);
build(1, 1, n);
read(m);
for (int i = 1; i <= m; i ++)
{
int op, l, r, v;
read(op);
if (op == 1)
{
read(l), read(r), read(v);
modify(1, l, r, v, 0);
}
else if (op == 2)
{
read(l), read(r), read(v);
modify(1, l, r, 1, v);//加法相当于乘1
}
else
{
read(l), read(r);
printf("%d\n", query(1, l, r));
}
}
return 0;
}
呃呃,剩下的三道题题网站上不让看,所以就没写了
本文含有隐藏内容,请 开通VIP 后查看