树状数组

概念

普通数组修改的时间复杂度为 O(1) O(1) ,求和(查询)的时间复杂度为 O(n) O(n)

前缀和数组修改的时间复杂度为 O(n) O(n) ,求和(查询)的时间复杂度为 O(1) O(1)

树状数组修改以及求和(查询)的时间复杂度都为 O(logn) O( \log n)

基本思想

利用二进制。

例如,对于 x=(21)10=(10101)2 x=(21)_{10}=(10101)_2 ,可以将区间 [1,x] [1,x] 拆分为至多 logx \lceil \log x \rceil 个小区间:[1,24],[24+1,24+22],[24+22+1,24+22+20] [1,2^4],[2^4+1,2^4+2^2],[2^4+2^2+1,2^4+2^2+2^0]

树状数组就是局域上述思想的一种数据结构,基本用途是维护区间的前缀和。

定义 lowbit(n) lowbit(n) 为正整数 n n 二进制分解中最小的幂次,对于给定的序列 A A ,建立数组 c c ,使 c[x] c[x] 保存 A A 的子区间 [xlowbit(x)+1,x] [x-lowbit(x)+1,x] 中所有数的和。

这个数组可以视作一个树状结构,满足下述四条性质:

  • c[x] c[x] 保存以它为根的子树中所有叶子结点的和;
  • c[x] c[x] 的子结点个数为 lowbit(x) lowbit(x)
  • 除树根外,每个内部结点 c[x] c[x] 的父节点为 c[x+lowbit(x)] c[x+lowbit(x)]
  • 树的深度为 O(logn) O(\log n)

注意:树状数组的最小下标必须为 1 1

基本算法

  1. lowbit(n) lowbit(n) lowbit(n)=n&(n) lowbit(n)=n\&(-n)
  2. 单点修改
  3. 区间求和

模板

洛谷P3374

,单点修改+查询

# include<iostream>
using namespace std;

const int N=5e5+3;
int n,m,op,x,y;
long long k,c[N];

int lowbit(int x){
	return x&-x;
}

void update(int k,int x){
	for(;x<=n;x+=lowbit(x)) c[x]+=k;
} // 给a[x]加上k 

long long sum(int x){
	long long ans=0;
	for(;x>0;x-=lowbit(x)) ans+=c[x];
	return ans;
} // 查询[1,x]的区间和 

int main(){
	cin>>n>>m;
	for(int i=1;i<=n;++i){
		cin>>k;
		update(k,i);
	}
	for(int i=1;i<=m;++i){
		cin>>op;
		if(op-1){
			cin>>x>>y;
			cout<<sum(y)-sum(x-1)<<endl;
		}
		else{
			cin>>x>>k;
			update(k,x);
		}
	}
	return 0;
}

练习

# include<iostream>
using namespace std;

const int N=15e3+3,X=32e3+3;
int n,x,y,c[X],ans[N];

int lowbit(int x){
	return x&(-x);
}

void update(int x){
	for(;x<=X;x+=lowbit(x)) c[x]++; 
} // 将a[x]加上1 

int sum(int x){
	int ans=0;
	for(;x>0;x-=lowbit(x)) ans+=c[x];
	return ans;
}

int main(){
	cin>>n;
	for(int i=1;i<=n;++i){
		cin>>x>>y;
		ans[sum(x+1)]++;
		update(x+1);
	}
	for(int i=0;i<n;++i) cout<<ans[i]<<endl;
	return 0;
}
# include<iostream>
using namespace std;

const int N=5e5+3;
char op;
int n,k,m,p,c[N];

int lowbit(int x){
	return x&(-x);
}

void update(int x,int p){
	for(;x<=n;x+=lowbit(x)) c[x]+=p; 
} // 将a[x]加上p 

int sum(int x){
	int ans=0;
	for(;x>0;x-=lowbit(x)) ans+=c[x];
	return ans;
}

int main(){
	cin>>n>>k;
	while(k--){
		cin>>op;
		if(op=='A'){
			cin>>m;
			cout<<sum(m)<<endl;
		}
		else{
			cin>>m>>p;
			if(op-'B') update(m,-p);
			else update(m,p);
		}
	}
	return 0;
}