树状数组 学习笔记

模板及讲解

维护区间$[1, n]$的数据结构,可以用$[1, b]-[1, a]$来求$[a,b]$
可以$add(a, x), add(b+1, -x)$实现区间修改单点查询(差分序列求前缀和优化到$logn$)

二维树状数组(容斥原理)

区间修改区间查询
设原数组为$a_i$,原数组差分序列为$d_i$,$x$为查询区间$[1,x]$,则
$$a_x=\sum_{i=1}^x d_i$$

$$\sum_{i=1}^x a_i= \sum_{i=1}^x \sum_{j=1}^i d_j =\sum_{i=1}^x(x-i+1)d_i$$
那么
$$\sum_{i=1}^x a_i=(x+1)\sum_{i=1}^x d_i-\sum_{i=1}^x d_i \times i$$
这样我们维护两个树状数组,一个维护$d_i$,一个维护$d_i \times i$,每次查询修改对两个树状数组进行操作即可。(常数比线段树小)

单点修改区间最大值
Bzoj 1012
$c[i]$维护$[i-\operatorname{lowbit}(i)+1,i]$的最大值,$a[i]$为原数组

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<map>
#include<string>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db long double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    LL m, D, x, cnt, lastans = 0;
    LL a[200000 + 5], c[200000 + 5]; // c[i]维护[i-lowbit(i)+1,i]的最大值,a[i]为原数组
    char s[5];

    LL lowbit(LL x) {return x & (-x);}
    LL query(LL l, LL r) {
        LL ret = a[r];
        while (l <= r) {
            ret = max(ret, a[r]);
            for (--r; r - l >= lowbit(r); r -= lowbit(r)) 
                ret = max(ret, c[r]);
        }
        return ret;
    }

    void clean() {
        cnt = 0, ms(a, 0), ms(c, 0);
    }
    int solve() {

        clean();
        cin >> m >> D;
        for (LL i = 1; i <= m; ++i) {
            scanf("%s%lld", s, &x);
            if (s[0] == 'A') {
                a[++cnt] = (x + lastans) % D;
                c[cnt] = max(a[cnt], query(cnt - lowbit(cnt) + 1, cnt));
            } else {
                printf("%lld\n", lastans = query(cnt - x + 1, cnt));
            }
        }

        return 0;
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}

树状数组求$k$大

LL C[MAXN];
void update(LL x) {while (x <= m) C[x]++, x += x & (-x);}
LL query(LL x) {
    LL p = 0;
    for (LL i = 20; ~i; --i) {
        if (p + (1 << i) <= m && C[p + (1 << i)] <= x) x -= C[p + (1 << i)], p += (1 << i);
    }
    return p;
}

常见题型

1、单点修改区间查询/区间修改单点查询
解:直接套用模板即可,见下面的相关代码
2、开多棵树状数组解决问题
Q:一个区间(矩阵)有多种颜色,每个点有一个权值,每次修改(查询)指定颜色上的权值
解:对于每一种颜色(类型)都开一棵树状数组。
例题:BZOJ 1452
3、二维树状数组
Q:在矩阵上查询某个子矩阵的值。
解:建二维树状数组,见模板讲解。
例题:BZOJ 1452
4、求逆序对
Q:求逆序对。
解:类似权值线段树,每次使$[1, i] + 1$, 然后$i$对答案的贡献为$[1, i]$的值(即小于i的元素个数)
例题:NOIP2013 D1T2
5、区间修改区间查询
解:推公式,开两个树状数组求值,见模板讲解。
例题:BZOJ 2017-07-20集训-t2
5、树状数组离线排序右端点
解:离线,删点加点
例题:spoj DQUERY

相关代码

1 点修改,求x~y区间值

#include<cstdio>      
#include<algorithm>      
#include<cstring>      
#include<queue>      
#define ms(i,j) memset(i,j, sizeof i);      
using namespace std;    
int a[500005];   
int n,m;  
int abss(int x){return x>=0 ? x : -x;}  
int lowbit(int x)  
{  
    return x&(-x);  
}  
int getsum(int x)//求1~x的和   
{  
    int ret = 0;  
    for (int i=x;i>0;i-=lowbit(i))  
    {  
        ret += a[i];  
    }  
    return ret;  
}  
void addsum(int x, int y)//1~x加y   
{  
    for (int i=x;i<=n;i+=lowbit(i))  
    {  
        a[i] += y;  
    }  
}  
int main()      
{      
    a[0] = 0;  
    scanf("%d%d", &n, &m);  
    for (int i=1;i<=n;i++)  
    {  
        int x;  
        scanf("%d", &x);  
        addsum(i,x);  
    }  
    for (int i=1;i<=m;i++)  
    {  
        int ty;  
        scanf("%d", &ty);  
        if (ty==1)  
        {  
            int x,k;  
            scanf("%d%d", &x, &k);  
            addsum(x,k);  
        } else  
        {  
            int x,y;  
            scanf("%d%d", &x, &y);  
            printf("%d\n", abss(getsum(y)-getsum(x-1)));  
        }  
    }  
    return 0;      
}

2 区间修改,求某一点值(差分序列)

#include<cstdio>  
#include<cstring>  
#include<algorithm>  
#include<vector>  
using namespace std;  
#define ms(i,j) memset(i,j,sizeof i);  
int n,m;  
const int maxn = 500005;  
int a[maxn];//a记录的是比i-lowbit(i)多的值  
int lowbit(int x)  
{  
    return x&(-x);  
}  
int add(int x, int v)  
{  
    for (int i=x;i<=n;i+=lowbit(i))  
    {  
        a[i] += v;  
    }  
}  
int sub(int x)  
{  
    int ret = 0;  
    for (int i=x;i>0;i-=lowbit(i))  
    {  
        ret += a[i];  
    }  
    return ret;  
}  
int main()  
{  
    scanf("%d%d", &n ,&m);  
    ms(a,0);  
    for (int i=1;i<=n;i++)  
    {  
        int x;  
        scanf("%d", &x);  
        add(i,x);  
        add(i+1,-x);  
    }  
    for (int i=1;i<=m;i++)  
    {  
        int ty;  
        scanf("%d", &ty);  
        if(ty==1)  
        {  
            int x,y,k;  
            scanf("%d%d%d", &x,&y,&k);  
            add(x,k); add(y+1,-k);  
        } else   
        {  
            int x;  
            scanf("%d", &x);  
            printf("%d\n", sub(x));  
        }  
    }  
    system("pause");  
    return 0;  
}
------ 本文结束 ------