【算法笔记】树状数组

发布于:2025-07-22 ⋅ 阅读:(17) ⋅ 点赞:(0)

前置知识:lowbit运算

想理解树状数组的原理,首先一定要理解 l o w b i t lowbit lowbit运算。

模板

int lowbit(int x){ // lowbit运算,返回二进制的最低一位1
    reuturn x & -x;
}

什么叫二进制的最低一位1?

准确的说, l o w b i t lowbit lowbit返回的其实是一个数的二进制表示的最低一位1,以及它后面所有的0所构成的二进制数。

举一个例子:

比如 22 22 22这个数,它的二进制表示是: ( 10110 ) 2 (10110)_2 (10110)2,最低一位 1 1 1也就是从右往左第一位 1 1 1,所以 l o w b i t lowbit lowbit返回的值是 ( 10 ) 2 (10)_2 (10)2,转换成十进制也就是 2 2 2,即 l o w b i t ( 22 ) = 2 lowbit(22) = 2 lowbit(22)=2

再比如 2232 2232 2232这个数,它的二进制表示是 ( 100010111000 ) 2 (100010111000)_2 (100010111000)2 l o w b i t lowbit lowbit返回的也就是 ( 1000 ) 2 (1000)_2 (1000)2,转换成十进制也就是 8 8 8,即 l o w b i t ( 2232 ) = 8 lowbit(2232) = 8 lowbit(2232)=8

为什么模板这么写

首先要知道,计算机存储数据,都是以二进制补码的形式来存的,包含一位符号位和若干数值位,符号位为 0 0 0表示正数,符号位为 1 1 1表示负数,对于正数的补码,就是在正数的二进制表示前面加一个 0 0 0,而对于负数的补码,是将负数的绝对值的二进制表示整体取反再加一后,前面再加一个 1 1 1

举个例子, x = 2232 x = 2232 x=2232

( x ) 2 = (x)_2 = (x)2= ( 100010111000 ) 2 (100010111000)_2 (100010111000)2

( x ˉ ) 2 (\bar x)_2 (xˉ)2 = ( 011101000111 ) 2 (011101000111)_2 (011101000111)2

( x ˉ + 1 ) 2 (\bar x + 1)_2 (xˉ+1)2 = ( 011101001000 ) 2 (011101001000)_2 (011101001000)2

所以 x x x的补码等于: 0 0 0 100010111000 100010111000 100010111000

− x -x x的补码等于: 1 1 1 011101001000 011101001000 011101001000

x & − x x \& -x x&x,是不是恰好等于 1000 1000 1000

这也就是 l o w b i t lowbit lowbit的原理。

通过给 ∣ x ∣ |x| x的二进制位取反,从而保证最低一位 1 1 1前面的部分按位与后都为 0 0 0

取反后最低一位 1 1 1的位置变为了 0 0 0,因为他是最低一位 1 1 1,他后面几位都是 0 0 0,取反后都变成 1 1 1,加一后,恰好把最低一位 1 1 1变回了 1 1 1

xxxxx10000 -取反-> xxxxx01111 -加一-> xxxxx10000 -& xxxxx10000 -> 0000010000 = 10000

树状数组有什么用?

快速( O ( l o g n ) O(logn) O(logn))进行单点修改、区间求和的操作。简单来讲,就是快速求前缀和。

为什么要用树状数组

拿一道题举例子:

n ( 1 ≤ n ≤ 2 × 1 0 5 ) n(1 \le n \le 2 \times 10^5) n(1n2×105)个数, q ( 1 ≤ q ≤ 2 × 1 0 5 ) q(1 \le q \le 2 \times 10^5) q(1q2×105)组查询,每组查询有如下两种:

  • 1 1 1 x x x y y y : 将数组的第 x x x个数的值加上 y y y
  • 2 2 2 x x x y y y:求数组中第 x x x个数到第 y y y个数的和。

对于每个操作 2 2 2,输出一行一个整数。

看到求区间和,第一时间能想到前缀和,但因为有操作1,这题的每个数的值随时都有可能变,也就是每次操作2的时候都要修改一遍前缀和,时间复杂度也就成了 O ( q × n ) O(q \times n) O(q×n) 2 e 5 2e5 2e5的数据范围肯定是过不去的。

对这题而言,

如果用普通数组,修改操作是 O ( 1 ) O(1) O(1)的( a [ x ] = y a[x] = y a[x]=y),但求和操作是 O ( n ) O(n) O(n)的( f o r ( i n t i = x ; i < = y ; i + + ) s u m + = a [ i ] for(int i = x; i <= y; i++) sum += a[i] for(inti=x;i<=y;i++)sum+=a[i]

如果用前缀和数组,求和是 O ( 1 ) O(1) O(1)的( p r e [ y ] − p r e [ x − 1 ] pre[y] - pre[x - 1] pre[y]pre[x1]),但修改操作是 O ( n ) O(n) O(n)的( f o r ( i n t i = x ; i < = n ; i + + ) p r e [ i ] + = y for(int i = x; i <= n; i++) pre[i] += y for(inti=x;i<=n;i++)pre[i]+=y

因为一道题分析时间复杂度指的都是最坏情况下的时间复杂度,所以这两种方法的时间复杂度都是 O ( n ) O(n) O(n)

那有没有一种折中的数据结构,能让两种操作的复杂度都不高不低,来提高整体效率呢?

接下来正式来使介绍树状数组:修改和求和复杂度都是 O ( l o g n ) O(logn) O(logn)的数据结构。

模板

int n;
int tr[N];

int lowbit(int x){ // lowbit运算,返回二进制的最低一位1
    return x & -x;
}

void update(int x, int c){ // 相当于a[x] += c
    for(int i = x; i <= n; i += lowbit(i)){
        tr[i] += c;
    }
}

LL getsum(int x){ // 相当于求数组中a[1 ~ x]的前缀和
    LL res = 0;
    for(int i = x; i; i -= lowbit(i)){
        res += tr[i];
    }
    return res;
}

图示

在这里插入图片描述

模板解释

这块看图会好理解一点。

getsum

g e t s u m ( x ) getsum(x) getsum(x)即求从 a [ 1 ] a[1] a[1]~ a [ x ] a[x] a[x]的前缀和。

常规的前缀和,是一个数一个数的加,所以效率是 O ( n ) O(n) O(n)的,那能不能将每一个从 1 1 1~ n n n的前缀和,都拆成若干个区间,并且这些区间能不重不漏的以某种特定且唯一的规律表示所有的前缀和呢?

可以基于二进制拆分的思想,每个数的二进制表示都是唯一的,每个数都是由若干 2 2 2的整数次幂构成的。

比如 13 = 2 3 + 2 2 + 2 0 13 = 2^3 + 2^2 + 2^0 13=23+22+20 23 = 2 4 + 2 2 + 2 1 + 2 0 23 = 2^4 + 2^2 + 2 ^ 1 + 2 ^ 0 23=24+22+21+20 ,这些二的次幂,对应着数的二进制表示的每一位 1 1 1,也就是几个 l o w b i t lowbit lowbit值。

树状数组的 g e t s u m getsum getsum的过程,便是从 x x x开始,每次先查询 t r [ x ] tr[x] tr[x],再查询 t r [ x − l o w b i t ( x ) ] tr[x - lowbit(x)] tr[xlowbit(x)], 再查询…

t r tr tr数组便是划分的小区间,每个区间都是$tr[x] = sum[x - lowbit(x) + 1 $ ~ $ x]$ ,区间长度为 l o w b i t ( x ) lowbit(x) lowbit(x)

比如对于 g e t s u m ( 23 ) getsum(23) getsum(23)

x = 23 x = 23 x=23,$ lowbit(x) = (1)_2 = 2^0 = 1$ , x = x − l o w b i t ( x ) = 22 x = x - lowbit(x) = 22 x=xlowbit(x)=22 t r [ x ] = s u m [ 23 , 23 ] tr[x] = sum[23, 23] tr[x]=sum[23,23]

x = 22 x = 22 x=22,$ lowbit(x) = (10)_2 = 2^1 = 2$ , x = x − l o w b i t ( x ) = 20 x = x - lowbit(x) = 20 x=xlowbit(x)=20 t r [ x ] = s u m [ 21 , 22 ] tr[x] = sum[21, 22] tr[x]=sum[21,22]

x = 20 x = 20 x=20,$ lowbit(x) = (100)_2 = 2^2 = 4$ , x = x − l o w b i t ( x ) = 16 x = x - lowbit(x) = 16 x=xlowbit(x)=16 t r [ x ] = s u m [ 17 , 20 ] tr[x] = sum[17, 20] tr[x]=sum[17,20]

x = 16 x = 16 x=16,$ lowbit(x) = (10000)_2 = 2^4 = 16$ , x = x − l o w b i t ( x ) = 0 x = x - lowbit(x) = 0 x=xlowbit(x)=0 t r [ x ] = s u m [ 1 , 16 ] tr[x] = sum[1, 16] tr[x]=sum[1,16]

这几个 t r tr tr加到一起,刚好是 s u m [ 1 , 23 ] sum[1, 23] sum[1,23]

因为每个区间长度为 l o w b i t ( x ) lowbit(x) lowbit(x),这几个 l o e w b i t loewbit loewbit加在一起还恰好等于 n n n,所以一定刚好不重不漏的包含所有的数的和。

g e t s u m ( x ) getsum(x) getsum(x)的过程,就是从 t r [ x ] tr[x] tr[x]开始找到所有能拼凑成 x x x的前缀和的区间,而每个区间的前一个区间都是 t r [ x − l o w b i t ( x ) ] tr[x - lowbit(x)] tr[xlowbit(x)],所以不断地往下减 l o w b i t ( x ) lowbit(x) lowbit(x),一直到 0 0 0,即可。

update

如果要给一个小区间的和加上 c c c,对应的包含这个小区间的所有大区间的和也都要加上 c c c

观察图中的数,

t r [ 3 ] tr[3] tr[3]的父节点是 t r [ 4 ] tr[4] tr[4] 4 − 3 = 1 4 - 3 = 1 43=1 l o w b i t ( 3 ) = 1 lowbit(3) = 1 lowbit(3)=1

t r [ 6 ] tr[6] tr[6]的父节点是 t r [ 8 ] tr[8] tr[8] 8 − 6 = 2 8 - 6 = 2 86=2 l o w b i t ( 6 ) = 2 lowbit(6) = 2 lowbit(6)=2

t r [ 12 ] tr[12] tr[12]的父节点是 t r [ 16 ] tr[16] tr[16] 16 − 12 = 4 16 - 12 = 4 1612=4 l o w b i t ( 12 ) = 4 lowbit(12) = 4 lowbit(12)=4

u p d a t e x ( x ) updatex(x) updatex(x)的过程,就是从 t r [ x ] tr[x] tr[x]开始不断地向上找父节点,而每个 t r [ x ] tr[x] tr[x]的父节点,都是 t r [ x + l o w b i t ( x ) ] tr[x + lowbit(x)] tr[x+lowbit(x)],所以不断地往上加 l o w b i t ( x ) lowbit(x) lowbit(x),一直到 > n >n >n,即可。

例题1:树状数组求前缀和

P3374 【模板】树状数组 1

const int N = 5e5 + 10;

int n, m;
int a[N];
int tr[N];

int lowbit(int x){ // lowbit运算,返回二进制的最低一位1
    return x & -x;
}

void update(int x, int c){ // 相当于a[x] += c
    for(int i = x; i <= n; i += lowbit(i)){
        tr[i] += c;
    }
}

LL getsum(int x){ // 相当于求数组中a[1 ~ x]的前缀和
    LL res = 0;
    for(int i = x; i; i -= lowbit(i)){
        res += tr[i];
    }
    return res;
}

int main(){
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= n; i++){
        cin >> a[i];
        update(i, a[i]); // 一开始树状数组都为0, 这句话即给i的位置赋值a[i] 
    }
    while(m--){
        int op, x, y;
        cin >> op >> x >> y;
        if(op == 1){
            update(x, y);
        }
        else cout << getsum(y) - getsum(x - 1) << endl; // 前y个数的和 - 前x-1个数的和 = a[x, y]的和
    }
    return 0;
}

例题2:树状数组求逆序对

用树状数组求逆序对是一个非常经典并常用的 t r i c k trick trick,大体原理如下:

  • a [ 1 ] a[1] a[1] a [ n ] a[n] a[n]遍历,让 a [ i ] a[i] a[i]作为逆序对中的第二个数
  • 对于 a [ i ] a[i] a[i],以 a [ i ] a[i] a[i]作为第二个数的逆序对数即所有满足 j < i & & a [ j ] > a [ i ] j < i \&\& a[j] > a[i] j<i&&a[j]>a[i] a [ j ] a[j] a[j]的数量
  • 简单点说,就是找 a [ i ] a[i] a[i]前面比 a [ i ] a[i] a[i]大的数有几个
  • 可以维护一个出现次数的数组(假想是 c n t cnt cnt c n t [ x ] cnt[x] cnt[x]就表示数字 x x x出现的次数),在遍历的时候,给 a [ i ] a[i] a[i]加一( c n t [ a [ i ] ] + + cnt[a[i]]++ cnt[a[i]]++)
  • 既然是从前往后遍历的,那是不是就能保证遍历到 a [ i ] a[i] a[i]的时候, c n t [ x ] cnt[x] cnt[x]存的就是在 a [ i ] a[i] a[i]前面 x x x出现的次数?
  • 这时候算一下 c n t cnt cnt数组的前缀和, p r e [ a [ i ] ] = c n t [ 1 ] + c n t [ 2 ] + . . . + c n t [ a [ i ] ] pre[a[i]] = cnt[1] + cnt[2] + ... + cnt[a[i]] pre[a[i]]=cnt[1]+cnt[2]+...+cnt[a[i]],那 p r e [ a [ i ] ] pre[a[i]] pre[a[i]]是不是就是:所有在 a [ i ] a[i] a[i]前面,且小于等于 a [ i ] a[i] a[i]的数的个数? p r e [ N ] pre[N] pre[N]是不是就是:所有在 a [ i ] a[i] a[i]前面,且小于等于 N N N的数的个数?
  • p r e [ N ] − p r e [ a [ i ] ] pre[N] - pre[a[i]] pre[N]pre[a[i]]是不是就是:所有在 a [ i ] a[i] a[i]前面,且大于 a [ i ] a[i] a[i]、小于等于 N N N的数的个数?如果 N N N是题目中的最大值, p r e [ N ] − p r e [ a [ i ] ] pre[N] - pre[a[i]] pre[N]pre[a[i]]便是你想要的数,每次累加到答案即可。
  • 按上面的思路,把 c n t cnt cnt数组转化成树状数组。

5910. 求逆序对

const int N = 1e5 + 10;

int n, m;
int a[N];
int tr[N];

int lowbit(int x){ // lowbit运算,返回二进制的最低一位1
    return x & -x;
}

void update(int x, int c){ // 相当于a[x] += c,这里相当于给a[x]出现次数 + c
    for(int i = x; i <= 100000; i += lowbit(i)){
        tr[i] += c;
    }
}

LL getsum(int x){ // 相当于求数组中a[1 ~ x]的前缀和,这里相当于求所有 <= x的数的出现次数
    LL res = 0;
    for(int i = x; i; i -= lowbit(i)){
        res += tr[i];
    }
    return res;
}

int main(){
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin >> n;
    for(int i = 1; i <= n; i++){
        cin >> a[i];
    }
    LL res = 0;
    for(int i = 1; i <= n; i++){
        res += getsum(100000) - getsum(a[i]); // a[i]前面,所有小于等于a[i]的数的出现次数
        update(a[i], 1); // a[i]又多了一个
    }
    cout << res << endl;
    return 0;
}

例题3:树状数组求逆序对 + 离散化

上面那道题的数据范围比较小,但一般的题都是 n ≤ 2 × 1 0 5 , a [ i ] ≤ 1 0 9 n \le 2 \times 10^5,a[i] \le 10^9 n2×105,a[i]109这样的范围,如果你想像上面那样求逆序对,就要开一个 1 0 9 10^9 109的树状数组,肯定是会超内存的,一定要离散化一下,而且不能改变数原本的相对大小(你要求逆序对,你改变他大小了还求什么)

下面讲一下怎么离散化。

离散化

常见的有两种方法,第一步都是先将所有的数都存到一个 v e c t o r vector vector中,然后去重排序。

vector<int> v;
for(int i = 1; i <= n; i++){
    cin >> a[i];
    v.push_back(a[i]);
}
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());

现在,每个 a [ i ] a[i] a[i]已经从小到大的映射到了 v [ 0 ] v[0] v[0]~ v [ v . s i z e ( ) − 1 ] v[v.size() - 1] v[v.size()1]的位置。

接下来,就是要将 v [ i ] v[i] v[i]映射到 i + 1 i + 1 i+1 ,像下面这样:

在这里插入图片描述
这时候如果原数组是 a [ ] = { 1000000 , 1000000000 , − 1 , 0 , 100 , 1 } a[] = \{1000000, 1000000000, -1, 0, 100, 1\} a[]={1000000,1000000000,1,0,100,1}

离散化后就是: a 1 [ ] = 5 , 6 , 1 , 2 , 4 , 3 a^1[] = {5, 6, 1, 2, 4, 3} a1[]=5,6,1,2,4,3

这个时候在新的数组求逆序对,是不是效果是一样的?

离散化下面有两种方法:

u n o r d e r e d _ m a p unordered\_map unordered_map

遍历一遍直接将每个数映射到的数存起来,很好理解,直接看代码

unordered_map<int, int> ma;
for(int i = 0; i < v.size(); i++){
    ma[v[i]] = i + 1;
}

x x x时,就改成用 m a [ x ] ma[x] ma[x]即可。

优点:需要用数 x x x时,直接用 m a [ x ] ma[x] ma[x]就行,复杂度为 O ( 1 ) O(1) O(1)

缺点:需要再遍历一遍数组,需要另开一个哈希表。

用lower_bound()

v v v数组中二分查找 x x x的下标,如果下标是 p o s pos pos,直接用 p o s + 1 pos + 1 pos+1即可。

int get_pos(int x){
    return lower_bound(v.begin(), v.end(), x) - v.begin() + 1;
}

x x x时,就改成用 g e t p o s ( x ) get_pos(x) getpos(x)即可。

优点:不用另开数组和 m a p map map

缺点:每次用都要二分一遍,复杂度 O ( l o g n ) O(logn) O(logn)

例题

P1908 逆序对

const int N = 5e5 + 10;

int n, m;
int a[N];
int tr[N];

int lowbit(int x){ // lowbit运算,返回二进制的最低一位1
    return x & -x;
}

void update(int x, int c){ // 相当于a[x] += c
    for(int i = x; i <= m; i += lowbit(i)){
        tr[i] += c;
    }
}

LL getsum(int x){ // 相当于求数组中a[1 ~ x]的前缀和
    LL res = 0;
    for(int i = x; i; i -= lowbit(i)){
        res += tr[i];
    }
    return res;
}

int main(){
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin >> n;
    vector<int> v;
    for(int i = 1; i <= n; i++){
        cin >> a[i];
        v.push_back(a[i]);
    }
    unordered_map<int, int> ma;
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
    m = v.size();
    for(int i = 0; i < m; i++){
        ma[v[i]] = i + 1;
    }
    LL res = 0;
    for(int i = 1; i <= n; i++){
        res += getsum(m) - getsum(ma[a[i]]);
        update(ma[a[i]], 1);
    }
    cout << res << endl;
    return 0;
}

举一反三1

https://ac.nowcoder.com/acm/contest/109459/I

const int N = 1e5 + 10;

int n, a[N], tr[N];
LL l[N];

int lowbit(int x){
	return x & -x;
}

LL getsum(int x){
    LL res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

void update(int x, int c){
    for (int i = x; i < N; i += lowbit(i)) tr[i] += c;
}

void solve(){
	cin >> n;
	LL res = 0;
	memset(tr, 0, sizeof tr);
	for(int i = 1; i <= n; i++){
		cin >> a[i];
		update(a[i], 1);
		l[i] = getsum(a[i] - 1);
	}
	memset(tr, 0, sizeof tr);
	for(int i = n; i >= 1; i--){
		update(a[i], 1);
		int y = getsum(a[i] - 1);
		res += l[i] * (n - i - y) + y * (i - l[i] - 1);
	}
	cout << res << endl;
}

int main(){
	ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
	int T;
	cin >> T;
	while(T--){
		solve();
	}
	return 0;
}

举一反三2

P1637 三元上升子序列

const int N = 3e4 + 10;

int n;
int tr1[N], tr2[N];

int lowbit(int x){
    return x & -x;
}

void update(int tr[], int x, int d){
    for(int i = x; i <= n; i += lowbit(i)){
        tr[i] += d;
    }
}

int query(int tr[], int x){
    int res = 0;
    for(int i = x; i; i -= lowbit(i)){
        res += tr[i];
    }
    return res;
}

int main(){
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin >> n;
    vector<int> a(n + 1);
    for(int i = 1; i <= n; i++){
        cin >> a[i];
    }
    vector<int> b = a;
    sort(b.begin(), b.end());
    b.erase(unique(b.begin(), b.end()), b.end());
    unordered_map<int, int> ma;
    int m = b.size() - 1;
    for(int i = 1; i <= m; i++){
        ma[b[i]] = i;
    }
    vector<int> pre(n + 1, 0), suf(n + 1, 0);
    for(int i = 1; i <= n; i++){
        update(tr1, ma[a[i]], 1);
        pre[i] = query(tr1, ma[a[i]] - 1);
    }
    for(int i = n; i >= 1; i--){
        update(tr2, ma[a[i]], 1);
        suf[i] = query(tr2, m) - query(tr2, ma[a[i]]);
    }
    LL res = 0;
    for(int i = 2; i < n; i++){
        res += (LL)pre[i] * suf[i];
    }
    cout << res << endl;
    return 0;
}

网站公告

今日签到

点亮在社区的每一天
去签到