主席树 学习笔记

模板及讲解

什么是主席树

主席树也称为函数式线段树、可持久化线段树,主要是利用动态开点每个点建线段树(维护$[1,i]$的区间)、线段树可加可减性($[x,y]=[1,y]-[1,x-1]$)来解决如区间内的某些问题。主席树实际上是树套树,最普通的主席树问题就是前缀和套线段树。

主席树的实现

例题

caioj 1441

给$n$($1 \leq n \leq 100000$)个数字,
$a[1],a[2],……,a[n](0 \leq a[i]<=1000000000),m(1 \leq m \leq 100000)$次询问$l$到$r$之间的第$k$小的值。

由题不需要修改操作,就是最普通的主席树问题。

从全局入手

对于整个区间的$k$小,我们可以开权值线段树记录每个值的大小,然后查询时仿造平衡树的方法可以找到第$k$大值。

线段树可加可减性

那么对于区间$[x,y]$,我们怎么办?
想到每个点$i$开$[1,i]$的线段树(整个线段树维护区间不变,只是每个数值的范围,不然不能满足加减性), 则$[x,y]=[1,y]-[1,x-1]$

这样可以看出我们要研究线段树是否可加可减,看下面的例子(借用了 caioj 的图片)

两棵线段树显然可加,并且对应位置上的和相加(维护区间和)。

主席树实现

首先要对每个点开$[1,i]$的线段树,先开一条只包含$i$点信息的链,再与前面一棵线段树合并(相加)。合并线段树也很方便,只要加上$i$点的信息,合并$[1,i-1]$( $merge$ 操作,代码中的$mge$)

查询的时候类似平衡树的查询,例如求$k$小,因为权值线段树,所以左边点都小于这个点,右边点都大于这个点,判断一下第$k$小在左边还是右边,就可以找到了。

代码

注意要离散化。

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 100000 + 5;
int n, m, whw, tax[MAXN], ai[MAXN], rt[MAXN];
#define M ((l + r) >> 1)
int sumv[MAXN * 20], lc[MAXN * 20], rc[MAXN * 20], sz;
int getPos(int x) {return lower_bound(tax + 1, tax + 1 + whw, x) - tax;}
int mge(int &x, int y) {//合并 
    if (y == 0) return 0;
    if (x == 0) return x = y, 0;//x及其子树与y一致,直接使用
    sumv[x] += sumv[y];//合并信息
    mge(lc[x], lc[y]), mge(rc[x], rc[y]);
    return 0;
}
void build(int l, int r, int &x, int pos) {//建一条链 
    if (x == 0) x = ++sz, sumv[x] = 0, lc[x] = rc[x] = 0;//动态开点
    sumv[x]++; 
    if (l == r) return ;
    if (pos <= M) build(l, M, lc[x], pos); else build(M + 1, r, rc[x], pos);
}
int query(int l, int r, int x, int y, int kth) {//查询 
    if (l == r) return l;
    int dlt = sumv[lc[y]] - sumv[lc[x]];
    if (kth <= dlt) return query(l, M, lc[x], lc[y], kth); 
    else return query(M + 1, r, rc[x], rc[y], kth - dlt);//类似平衡树查询
}
void clean() {
    sz = 0;
    for (int i = 1; i <= 2000001; i++) sumv[i] = lc[i] = rc[i] = 0;
    for (int i = 1; i <= 100001; i++) tax[i] = ai[i] = rt[i] = 0;
}
int solve() {
    clean();
    for (int i = 1; i <= n; i++) scanf("%d", &ai[i]), tax[i] = ai[i];
    sort(tax + 1, tax + 1 + n), whw = unique(tax + 1, tax + 1 + n) - tax - 1;//离散化
    for (int i = 1; i <= n; i++) build(1, whw, rt[i], getPos(ai[i])), mge(rt[i], rt[i - 1]);//建链、合并
    for (int x, y, k, i = 1; i <= m; i++) {
        scanf("%d%d%d", &x, &y, &k);
        printf("%d\n", tax[query(1, whw, rt[x - 1], rt[y], k)]);
    }
    return 0; 
}
int main() {
    scanf("%d%d", &n, &m), solve();
    return 0;
}

树上主席树

caioj 1443

给定一棵$N(1 \leq N \leq 100000)$个节点的树,每个点有一个权值,对于$M(1 \leq M \leq 100000)$个询问$(x,y,k)$,你需要回答$x$和$y$这两个节点间第$k$小的点权。

我们对于每个点建主席树维护$(u,rt)​$路径链,$rt​$为根,合并时与他的父亲节点合并,计算$(u,v)​$信息线段树时使用
$(u,v)=(u, rt)+(v,rt)-(lca,rt)-(fa[lca], rt)$, $lca=LCA(u,v), fa[lca]$为$lca$的父亲节点$rt$为根,画图理解
然后按照普通的主席树做就行了

代码

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 100000 + 5, logs = 18;
int n, q, whw, ai[MAXN], tax[MAXN], rt[MAXN], sz, dep[MAXN], pre[MAXN][25];
vector<int> G[MAXN];
int getPos(int x) {return lower_bound(tax + 1, tax + 1 + whw, x) - tax;}
void ins(int a, int b) {G[a].push_back(b), G[b].push_back(a);}
#define M ((l + r) >> 1)
int sumv[MAXN * 20], lc[MAXN * 20], rc[MAXN * 20];
int mge(int &x, int y) {
    if (y == 0) return 0;
    if (x == 0) return x = y, 0;
    sumv[x] += sumv[y];
    mge(lc[x], lc[y]), mge(rc[x], rc[y]);
    return 0;
}
void build(int l, int r, int &x, int pos) {
    if (x == 0) x = ++sz, sumv[x] = lc[x] = rc[x] = 0;
    sumv[x]++;
    if (l == r) return ;
    if (pos <= M) build(l, M, lc[x], pos); else build(M + 1, r, rc[x], pos);
}
int query(int l, int r, int x, int y, int lca, int flca, int kth) {
    if (l == r) return tax[l];
    int sum = sumv[lc[x]] + sumv[lc[y]] - sumv[lc[lca]] - sumv[lc[flca]];
    if (sum >= kth) return query(l, M, lc[x], lc[y], lc[lca], lc[flca], kth); 
    else return query(M + 1, r, rc[x], rc[y], rc[lca], rc[flca], kth - sum); 
}
void dfs(int u, int pa) {
    dep[u] = dep[pa] + 1, pre[u][0] = pa, mge(rt[u], rt[pa]);
    for (int i = 1; i <= logs; i++) pre[u][i] = pre[pre[u][i - 1]][i - 1];
    for (int i = 0; i < (int)G[u].size(); i++) {
        int v = G[u][i];
        if (v != pa) dfs(v, u);
    }
}
int LCA(int a, int b) {
    if (dep[a] < dep[b]) swap(a, b);
    for (int i = logs; i >= 0; i--) if (dep[pre[a][i]] >= dep[b]) a = pre[a][i];
    if (a == b) return a;
    for (int i = logs; i >= 0; i--) if (pre[a][i] != pre[b][i]) a = pre[a][i], b = pre[b][i];
    return pre[a][0];
}
void clean() {
    sz = 0;
    for (int i = 0; i <= 100001; i++) {
        G[i].clear(), dep[i] = tax[i] = ai[i] = rt[i] = 0;
        for (int j = 0; j <= 19; j++) pre[i][j] = 0;
    }
    for (int i = 0; i <= 2000001; i++) sumv[i] = lc[i] = rc[i] = 0;
}
int solve() {
    clean();
    for (int i = 1; i <= n; i++) scanf("%d", &ai[i]), tax[i] = ai[i];
    sort(tax + 1, tax + 1 + n), whw = unique(tax + 1, tax + 1 + n) - tax - 1;
    for (int x, y, i = 1; i < n; i++) scanf("%d%d", &x, &y), ins(x, y);
    for (int i = 1; i <= n; i++) build(1, whw, rt[i], getPos(ai[i]));
    dfs(1, 0);
    while (q--) {
        int x, y, k, lca;
        scanf("%d%d%d", &x, &y, &k);
        lca = LCA(x, y);
        printf("%d\n", query(1, whw, rt[x], rt[y], rt[lca], rt[pre[lca][0]], k));
    }
    return 0; 
}
int main() {
    scanf("%d%d", &n, &q), solve();
    return 0;
}
/*
13 100
3 4 1 2 3 2 4 5 3 2 1 1 3
1 2
2 3
3 4
4 5
5 6
5 7
2 8
8 9
9 10
10 11
10 12
11 13
7 13 8
*/

带修主席树

caioj 1442

给$n(1 \leq n \leq 50000)$个数字,进行$m(1 \leq m \leq 10000)$次操作,有两种操作:
$Q,l,r,k$:询问$l$到$r$第$k$小的数。
$C,x,k$:改变第$x$个数的值为$k$。

因为普通的主席树是前缀和套线段树,所以不能修改。那么我们想到修改,就发现可以用树状数组/线段树套线段树,由于此题单点修改,所以用树状数组。
对于前缀和套线段树先建主席树,然后再建树状数组套线段树的,修改在树状数组上操作,原数组在前缀和中,综合可以得到修改后的信息,要注意树状数组上的点在线段树上跳动(jump函数调节,存在$ust$数组),查询就用$ust$数组即可

实际上可以不必要建$2n$棵线段树,原数组直接加到树状数组中即可,不过会慢一点,参见此处

代码

建$2n$棵线段树的代码:

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 50000 + 5, MV = 1000000000;
int n, q, sz, rt[MAXN * 2], ai[MAXN];
#define M ((l + r) >> 1)
int sumv[MAXN * 100], lc[MAXN * 100], rc[MAXN * 100], ust[MAXN * 100];
int lowbit(int x) {return x & (-x);}
int mge(int &x, int y) {//线段树 合并 ->将线段树y合并至线段树x 
    if (x == 0) return x = y, 0;
    if (y == 0) return 0;
    sumv[x] += sumv[y];
    mge(lc[x], lc[y]), mge(rc[x], rc[y]);
    return 0;
}
void build(int l, int r, int &x, int pos, int v) {//线段树 建链 -> 维护[l,r]区间,当前线段树上点x,修改位置为pos=v 
    if (x == 0) x = ++sz, lc[x] = rc[x] = sumv[x] = 0;
    sumv[x] += v;
    if (l == r) return ;
    if (pos <= M) build(l, M, lc[x], pos, v); else build(M + 1, r, rc[x], pos, v);
}
void add(int u, int x, int c) {//Bit 加 -> bit上u点,x位置加c 
    for (int i = u; i <= 2 * n; i += lowbit(i)) build(1, MV, rt[i], x, c);
}
void jump(int u, int tp) {//Bit 更新 -> bit上u跳 
    for (int i = u; i > n; i -= lowbit(i)) {
        if (tp == -1) ust[i] = rt[i];
        if (tp == 0)  ust[i] = lc[ust[i]];
        if (tp == 1)  ust[i] = rc[ust[i]];
    }
}
int getBitSum(int u) {//Bit 查询 -> bit上u查询
    int ret = 0;
    for (int i = u; i > n; i -= lowbit(i)) {
        ret += sumv[lc[ust[i]]];
    }
    return ret;
}
int query(int l, int r, int x, int y, int x_2, int y_2, int kth) {//线段树 查询 -> 维护[l,r]区间,当前线段树上点x,y, 位置x_2, y_2, 查询第kth大 
    if (l == r) return l;
    int sum = sumv[lc[y]] - sumv[lc[x]] + getBitSum(y_2 + n) - getBitSum(x_2 + n);
    if (sum >= kth) {
        jump(x_2 + n, 0), jump(y_2 + n, 0);
        return query(l, M, lc[x], lc[y], x_2, y_2, kth);
    } else {
        jump(x_2 + n, 1), jump(y_2 + n, 1);
        return query(M + 1, r, rc[x], rc[y], x_2, y_2, kth - sum);
    }
}
void clean() {
    sz = 0;
    for (int i = 0; i <= 100000 + 5; i++) rt[i] = 0;
    for (int i = 0; i <= 5000000 + 5; i++) sumv[i] = lc[i] = rc[i] = ust[i] = 0;
}
int solve() {
    clean();
    for (int i = 1; i <= n; i++) scanf("%d", &ai[i]);
    for (int i = 1; i <= n; i++) build(1, MV, rt[i], ai[i], 1), mge(rt[i], rt[i - 1]);
    char s[5];
    while (q--) {
        scanf("%s", s);
        if (s[0] == 'C') {
            int x, k; scanf("%d%d", &x, &k);
            add(x + n, ai[x], -1), ai[x] = k, add(x + n, ai[x], 1);
        } else {
            int l, r, k; scanf("%d%d%d", &l, &r, &k);
            jump(r + n, -1), jump(l - 1 + n, -1);
            printf("%d\n", query(1, MV, rt[l - 1], rt[r], l - 1, r, k));
        }
    }
    return 0; 
}
int main() {
    scanf("%d%d", &n, &q), solve();
    return 0;
}

主席树维护区间问题

spoj DQUERY

给出一个$n$个数的序列,求区间$[l,r]$里有多少种不同数字。

与树状数组类似,主席树维护区间,相当于可持久化维护每次加点后的情况。每个点按顺序建树,如果这个点的数之前没有出现过,直接在本棵主席树该位置加$1$。否则就把之前出现这个值的位置减$1$,再重复做没有出现的情况。为的是把数尽可能放到右边,因为记录值中位置不影响答案。这样就方便求解$[l,r]$的信息。
询问直接用右端点的主席树,由于上述操作后答案可减,所以直接把右端点的主席树左端点以右的值求和即可

代码

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 30000 + 5;
int n, q, sz, ai[MAXN], rt[MAXN], lst[1000000 + 5];
#define M ((l + r) >> 1)
int sumv[MAXN * 20], lc[MAXN * 20], rc[MAXN * 20];
int mge(int &x, int y) {//主席树x合并主席树y 
    if (y == 0) return 0;
    if (x == 0) return x = y, 0;
    sumv[x] += sumv[y];
    mge(lc[x], lc[y]), mge(rc[x], rc[y]);
    return 0;
}
void build(int l, int r, int &x, int pos, int v) {//建链维护[l,r], 主席树上x点,修改位置和值 
    if (x == 0) x = ++sz, lc[x] = rc[x] = sumv[x] = 0;
    sumv[x] += v;
    if (l == r) return ;
    if (pos <= M) build(l, M, lc[x], pos, v); else build(M + 1, r, rc[x], pos, v);
}
int query(int l, int r, int x, int u) {//查询[l,r]答案,主席树上x点,左边临界点u 
    if (l == r) return 0;
    if (u <= M) return sumv[rc[x]] + query(l, M, lc[x], u); //加上右边,查询左边
    else return query(M + 1, r, rc[x], u); //不要加左,左边有临界点
}
void clean() {
    sz = 0, ms(lst, -1);
}
int solve() {
    clean();
    for (int i = 1; i <= n; i++) scanf("%d", &ai[i]);
    for (int i = 1; i <= n; i++) {//维护 [0, n] 区间,因为l - 1可能为 0
        if (lst[ai[i]] < 0) build(0, n, rt[i], i, 1), mge(rt[i], rt[i - 1]);//之前没有
        else {
            build(0, n, rt[i], lst[ai[i]], -1), build(0, n, rt[i], i, 1);
            mge(rt[i], rt[i - 1]);
        }//之前有
        lst[ai[i]] = i;
    }
    scanf("%d", &q);
    while (q--) {
        int l, r; scanf("%d%d",&l, &r);
        printf("%d\n", query(0, n, rt[r], l - 1));
    }
    return 0; 
}
int main() {
    scanf("%d", &n), solve();
    return 0;
}

可持久化

caioj 1447

维护区间和,有区间增加,要求可持久化 (回退、查询某个版本)

每个询问开一棵线段树,回退直接删掉中间的线段树即可。由于是主席树不能pushdown,pushup,所以增加的时候直接更新sumv的值,查询时lazy值直接累加乘以查询区间长度即可,具体操作可以看代码

代码

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const LL MAXN = 100000 + 5;
LL n, m, ai[MAXN], qzh[MAXN], rt[MAXN], now, sz;
#define M ((l + r) >> 1)
LL sumv[MAXN * 40], lc[MAXN * 40], rc[MAXN * 40], lazy[MAXN * 40];
int mge(LL &x, LL y) {
    if (y == 0) return 0;
    if (x == 0) return x = y, 0;
    sumv[x] += sumv[y], lazy[x] += lazy[y];
    mge(lc[x], lc[y]), mge(rc[x], rc[y]);
    return 0;
}
void build(LL l, LL r, LL &x, LL cl, LL cr, LL v) {
    if (x == 0) x = ++sz;
    sumv[x] += (cr - cl + 1) * v;//直接加,免去pushup 
    if (l == cl && r == cr) {
        lazy[x] += v;
        return ;
    }
    if (cr <= M) build(l, M, lc[x], cl, cr, v); else if (cl > M) build(M + 1, r, rc[x], cl, cr, v);
    else build(l, M, lc[x], cl, M, v), build(M + 1, r, rc[x], M + 1, cr, v);
    //整个区间在左边、右边、分开两边 
}
LL query(LL l, LL r, LL x, LL cl, LL cr, LL tmp) {
    if (l == cl && r == cr) return (r - l + 1) * tmp + sumv[x];
    if (cr <= M) return query(l, M, lc[x], cl, cr, tmp + lazy[x]); 
    else if (cl > M) return query(M + 1, r, rc[x], cl, cr, tmp + lazy[x]);
    else return query(l, M, lc[x], cl, M, tmp + lazy[x]) + query(M + 1, r, rc[x], M + 1, cr, tmp + lazy[x]);
    //整个查询区间在左边、右边、分开两边,和普通线段树不同,相当于用 M 分离查询区间
    //直接累加lazy最后乘查询区间长度 
}
void clean() {
    now = sz = 0;
    for (LL i = 0; i <= 100000 + 3; i++) rt[i] = qzh[i] = 0;
    for (LL i = 0; i <= 4000000 + 3; i++) sumv[i] = lc[i] = rc[i] = lazy[i] = 0;
}
int solve() {
    clean();
    for (LL i = 1; i <= n; i++) scanf("%lld", &ai[i]), qzh[i] = qzh[i - 1] + ai[i];
    for (LL i = 1; i <= m; i++) {
        LL tp; scanf("%lld", &tp);
        if (tp == 1) {
            LL l, r, k; scanf("%lld%lld%lld", &l, &r, &k);
            build(1, n, rt[++now], l, r, k), mge(rt[now], rt[now - 1]);
        }
        if (tp == 2) {
            LL l, r; scanf("%lld%lld", &l, &r);
            printf("%lld\n", qzh[r] - qzh[l - 1] + query(1, n, rt[now], l, r, 0));
        }
        if (tp == 3) {
            LL l, r, h; scanf("%lld%lld%lld", &l, &r, &h);
            printf("%lld\n", qzh[r] - qzh[l - 1] + query(1, n, rt[h], l, r, 0));
        }
        if (tp == 4) {
            LL h; scanf("%lld", &h);
            for (LL i = h + 1; i <= now; i++) rt[i] = 0;
            now = h;
        }
    }
    return 0; 
}
int main() {
    scanf("%lld%lld", &n, &m), solve();
    return 0;
}

注意事项

1、主席树节点数和操作次数有关,与值域无关
2、主席树边更新边合并和更新完合并两种写法不要写混

常见题型

1、逆序对问题

1、静态区间逆序对 (离线莫队):Bzoj 3289
2、动态逆序对 (每次删除一个数,求序列逆序对个数):Bzoj 3295
3、区间逆序对 (强制在线求区间 $ [l,r] $ 的逆序对):Bzoj 3744

2、众数问题
1、强制在线区间众数:Bzoj 2724
2、摩尔投票法 (序列大于一半数的众数):Luogu 3765

3、区间k大 / 区间小于某数的个数
1、区间$k$大:Bzoj 3932
2、区间小于某数的个数:Bzoj 1926, Bzoj 3295

------ 本文结束 ------