Bzoj 3531(树链剖分+动态开点线段树)

BZOJ 3531

树剖以后每个宗教建立一棵线段树,节点太多用传统方法开数组肯定不行,这里进行改进,使用了动态开点线段树,即需要这个点再开这个点。

知识点:
1、树剖后的编号注意
2、动态开点没有开到的点赋值不要赋-1而是赋值0,否则要很多细节处理

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define ll long long
using namespace std;

const int MAXN = 1e5 + 5, MAXTN = 10000000 + 5;

vector<int> G[MAXN];
int n, q, wi[MAXN], ci[MAXN];
int top[MAXN], fa[MAXN], son[MAXN], siz[MAXN], p[MAXN], dep[MAXN], pre;
int lt[MAXTN], rt[MAXTN], root[MAXN], nd;
ll sumv[MAXTN], maxv[MAXTN];
//树剖部分 
void dfs1(int u, int pa) {
    dep[u] = dep[pa] + 1, fa[u] = pa, siz[u] = 1;
    for (int i=0;i<G[u].size();i++) {
        int v = G[u][i];
        if (v!=pa) {
            dfs1(v, u);
            siz[u] += siz[v];
            if (son[u]==-1||siz[v]>siz[son[u]]) son[u] = v;
        }
    }
}
void dfs2(int u, int chain) {
    p[u] = ++pre, top[u] = chain;
    if (son[u]!=-1) {
        dfs2(son[u], chain);
        for (int i=0;i<G[u].size();i++) {
            int v = G[u][i];
            if (v!=fa[u]&&v!=son[u]) {
                dfs2(v, v);
            }
        }
    }
}
//线段树部分 
#define M ((l+r)>>1)
void pushup(int o) {
    sumv[o] = sumv[lt[o]] + sumv[rt[o]];
    maxv[o] = max(maxv[lt[o]], maxv[rt[o]]);
}
void update(int o, int l, int r, int p, ll w) {
    if (l==r) {
        sumv[o] = maxv[o] = w;
        return ;
    }
    if (p<=M) {
        if (!lt[o]) lt[o] = ++nd;
        update(lt[o], l, M, p, w);
    } else if (M<p) {
        if (!rt[o]) rt[o] = ++nd;
        update(rt[o], M+1, r, p, w);
    }
    pushup(o);
}
ll queryMax(int o, int l, int r, int x, int y) {
    ll ret = 0;
    if (x<=l&&r<=y) {
        return maxv[o];
    }
    if (x<=M) {
        if (lt[o]) ret = max(ret, queryMax(lt[o], l, M, x, y));
    } 
    if (M<y) {
        if (rt[o]) ret = max(ret, queryMax(rt[o], M+1, r, x, y));
    }
    return ret;
}
ll querySum(int o, int l, int r, int x, int y) {
    ll ret = 0;
    if (x<=l&&r<=y) {
        return sumv[o];
    }
    if (x<=M) {
        if (lt[o]) ret += querySum(lt[o], l, M, x, y);
    } 
    if (M<y) {
        if (rt[o]) ret += querySum(rt[o], M+1, r, x, y);
    }
    return ret;
}
//树剖找值
ll findMax(int u, int v, int rt) {
    ll ret = 0;int f1 = top[u], f2 = top[v];
    while (f1!=f2) {
        if (dep[f1]<dep[f2]) swap(f1, f2), swap(u, v);
        ret = max(ret, queryMax(rt, 1, n, p[f1], p[u]));
        u = fa[f1];
        f1 = top[u];
    }
    if (dep[u]<dep[v]) swap(u, v);
    return max(ret, queryMax(rt, 1, n, p[v], p[u]));
}
ll findSum(int u, int v, int rt) {
    ll ret = 0;int f1 = top[u], f2 = top[v];
    while (f1!=f2) {
        if (dep[f1]<dep[f2]) swap(f1, f2), swap(u, v);
        ret += querySum(rt, 1, n, p[f1], p[u]);
        u = fa[f1];
        f1 = top[u];
    }
    if (dep[u]<dep[v]) swap(u, v);
    return ret + querySum(rt, 1, n, p[v], p[u]);
}
//主程序 
void clear() {
    pre = nd = 0;
    ms(sumv, 0), ms(maxv, 0), ms(lt, 0), ms(rt, 0), ms(root, 0);
    for (int i=0;i<=n;i++) {
        G[i].clear();
        top[i] = fa[i] = siz[i] = p[i] = dep[i] = 0;
        son[i] = -1;
    }
}
void init() {
    clear();
    for (int i=1;i<=n;i++) scanf("%d%d", &wi[i], &ci[i]);
    for (int i=1;i<n;i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        G[x].push_back(y), G[y].push_back(x);
    }
}
void solve() {
    dfs1(1, 0), dfs2(1, 1);
    for (int i=1;i<=n;i++) {
        if (!root[ci[i]]) root[ci[i]] = ++nd;
        update(root[ci[i]], 1, n, p[i], wi[i]);
    }
    char ch[10];
    for (int i=1;i<=q;i++) {
        scanf("%s", ch);
        if (ch[0]=='C') {
            if (ch[1]=='C') {//CC
                int x, c;
                scanf("%d%d", &x, &c);
                update(root[ci[x]], 1, n, p[x], 0);
                ci[x] = c;
                update(root[c], 1, n, p[x], wi[x]);
            } else {//CW
                int x, w;
                scanf("%d%d", &x, &w);
                wi[x] = w;
                update(root[ci[x]], 1, n, p[x], w);
            }
        } else if (ch[0]=='Q') {
            int x, y;
            scanf("%d%d", &x, &y);
            if (ch[1]=='S') {//QS
                printf("%lld\n", findSum(x,y,root[ci[x]]));
            } else {//QM
                printf("%lld\n", findMax(x,y,root[ci[x]]));
            }
        }
    }
}
int main() {
    #ifndef ONLINE_JUDGE
    freopen("1.in", "r", stdin);freopen("1.out", "w", stdout);
    #endif
    while (scanf("%d%d", &n, &q)==2) init(), solve();
    return 0;
}
------ 本文结束 ------