「Bzoj 1036」「ZJOI2008」树的统计Count (树链剖分)

bzoj 1036

树剖+线段树。

第一个树剖题,终于AC

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define fo(i, j, k) for (i=(j);i<=(k);i++)
#define fd(i, k, j) for (i=(k);i>=(j);i--)
#define rd(a) scanf("%d", &a)
#define rd2(a, b) scanf("%d%d", &a, &b)
#define rd3(a, b, c) scanf("%d%d%d", &a, &b, &c)
#define ms(i, j) memset(i, j, sizeof i)
#define FN2 "bzoj1036" 
using namespace std;

const int MAXN = 30000 + 5;

int dep[MAXN], son[MAXN], fa[MAXN], siz[MAXN]; //深度,重儿子,父亲,子树大小 
int p[MAXN], top[MAXN], pre; //在线段树中的位置,所在重链顶部,线段树当前标号 
int n, wi[MAXN];
vector<int> G[MAXN];

void dfs1(int u, int f)//第一次dfs记录值 
{
    int i;
    dep[u] = dep[f] + 1, fa[u] = f, siz[u] = 1;
    fo (i, 0, G[u].size()-1) {
        int v = G[u][i];
        if (v!=f) {
            dfs1(v, u);
            siz[u] += siz[v];
            if (son[u]==-1||siz[son[u]]<siz[v]) son[u] = v;
        }
    }
}
void dfs2(int u, int chain) {//第二次dfs连重儿子成重链 
    int i; 
    p[u] = ++pre, top[u] = chain;
    if (son[u]!=-1) {
        dfs2(son[u], chain);
        fo (i, 0, G[u].size()-1) {
            int v = G[u][i];
            if (v!=son[u]&&v!=fa[u]) dfs2(v, v);
        }
    }
}

int maxv[MAXN*4], sumv[MAXN*4];
void pushup(int o) {
    int lc = o*2, rc = o*2+1;
    maxv[o] = max(maxv[lc], maxv[rc]);
    sumv[o] = sumv[lc] + sumv[rc];
}
void update(int o, int l, int r, int p, int v) {
    int lc = o*2, rc = o*2+1, M = (l+r)/2;
    if (l==r) {
        sumv[o] = maxv[o] = v; return ;
    }
    if (p<=M) update(lc, l, M, p, v); else if (M<p) update(rc, M+1, r, p, v);
    pushup(o);
}
int getMax(int o, int l, int r, int x, int y) {
    int lc = o*2, rc = o*2+1, M = (l+r)/2, ret = -200000000;
    if (x<=l&&r<=y) {
        return maxv[o];
    }
    if (x<=M) ret = max(ret, getMax(lc, l, M, x, y));
    if (M<y)  ret = max(ret, getMax(rc, M+1, r, x, y));    
    return ret;
}
int getSum(int o, int l, int r, int x, int y) {
    int lc = o*2, rc = o*2+1, M = (l+r)/2, ret = 0;
    if (x<=l&&r<=y) {
        return sumv[o];
    }
    if (x<=M) ret += getSum(lc, l, M, x, y);
    if (M<y)  ret += getSum(rc, M+1, r, x, y);    
    return ret;
}
int findMax(int u, int v)
{
    int f1 = top[u], f2 = top[v];
    int ret = -200000000;
    while (f1!=f2) {
        if (dep[f1]<dep[f2]) swap(f1, f2), swap(u, v);
        ret = max(ret, getMax(1, 1, n, p[f1], p[u]));
        u = fa[f1], f1 = top[u];
    }    
    if (dep[u]<dep[v]) swap(u, v);
    return max(ret, getMax(1, 1, n, p[v], p[u]));
}
int findSum(int u, int v)
{
    int f1 = top[u], f2 = top[v];
    int ret = 0;
    while (f1!=f2) {
        if (dep[f1]<dep[f2]) swap(f1, f2), swap(u, v);
        ret += getSum(1, 1, n, p[f1], p[u]);
        u = fa[f1], f1 = top[u];
    }    
    if (dep[u]<dep[v]) swap(u, v);
    return ret+getSum(1, 1, n, p[v], p[u]);
}
void init() {
    int i; pre = 0;
    fo (i, 1, n) dep[i] = fa[i] = siz[i] = p[i] = top[i] = 0, son[i] = -1, G[i].clear();
    fo (i, 1, n*4) maxv[i] = -200000000, sumv[i] = 0;
    fo (i, 1, n-1) {
        int a, b; rd2(a, b);
        G[a].push_back(b), G[b].push_back(a);
    }
}
void solve() {
    int q, i; 
    dfs1(1, 0);
    dfs2(1, 1);
    fo (i, 1, n) rd(wi[i]), update(1, 1, n, p[i], wi[i]); 
    rd(q);
    fo (i, 1, q) {
        char ch[10]; scanf("%s", ch);
        if (ch[0]=='C') {
            int u, t; rd2(u, t), update(1, 1, n, p[u], t);
        } else if (ch[1]=='M') {
            int u, v; rd2(u, v), printf("%d\n", findMax(u, v));
        } else if (ch[1]=='S') {
            int u, v; rd2(u, v), printf("%d\n", findSum(u, v));
        }
    }
}
int main() {
    #ifndef ONLINE_JUDGE
    freopen(FN2".in","r",stdin);freopen(FN2".out","w",stdout);
    #endif
    while (rd(n)==1) init(), solve();
    return 0;
}
------ 本文结束 ------