spoj Count on a tree II(树上莫队)

SPOJ-COT2
题意:给一棵树,每个点有一个权值。多次询问路径$(u, v)$上有多少个权值不同的点。
树上莫队的资料
考虑树上莫队。难点在于怎么将树搬到序列中。将树的括号序求出来,树的括号序就是维护一个序列,这个序列是dfs这棵树时,每个节点在dfs到的时候将本身加入序列,然后离开该节点(返回父节点)也将本身加入序列的一个序列(具体可以看资料,有图解)。
我们设$st,ed$分别为某个点最开始在序列出现位置和最后出现在序列的位置。对于一条路径$(u,v)$(这里$st_u \leq st_v$):
如果$LCA=u$,那么路径就是$st_u$到$st_v$之间一段中出现次数为奇数次的点。
否则,路径就是$ed_u$到$st_v$之间一段中出现次数为奇数次的点。注意这一段并不包含LCA,要手动加上。
然后就是注意莫队转移的时候的方法,用奇偶性判断这个节点出现几次,只考虑奇偶性。奇数就可以让记录权值出现次数的数组加一,反之减一。

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 40000 + 5, logs = 18;
int n, m, whw, blolen, bl[MAXN * 2], ai[MAXN], st[MAXN], ed[MAXN], T[MAXN * 2], tax[MAXN], sz;
int dep[MAXN], vis[MAXN], nl, nr, nans, pre[MAXN][22], ans[100000 + 5];
vector<int> G[MAXN];
struct data {
    int l, r, u, v, id, lca;
    bool operator < (const data &b) const {
        if (bl[l] == bl[b.l]) return r < b.r;
        return bl[l] < bl[b.l];
    }
}xw[100000 + 5];
void dfs(int u, int pa) {
    dep[u] = dep[pa] + 1, T[++sz] = u, st[u] = sz, pre[u][0] = pa;
    for (int i = 1; i <= logs; i++) pre[u][i] = pre[pre[u][i - 1]][i - 1];
    for (int i = 0; i < (int)G[u].size(); i++) {
        int v = G[u][i];
        if (v != pa) dfs(v, u);
    }
    T[++sz] = u, ed[u] = sz;
}
void ins(int a, int b) {G[a].push_back(b), G[b].push_back(a);}
int LCA(int a, int b) {
    if (dep[a] < dep[b]) swap(a, b);
    for (int i = logs; i >= 0; i--) if (dep[pre[a][i]] >= dep[b]) a = pre[a][i];
    if (a == b) return a;
    for (int i = logs; i >= 0; i--) if (pre[a][i] != pre[b][i]) a = pre[a][i], b = pre[b][i];
    return pre[a][0];
}
void adjust(int x) {
    tax[T[x]] = 1 - tax[T[x]];
    if (tax[T[x]] % 2) {
        vis[ai[T[x]]]++;
        if (vis[ai[T[x]]] == 1) nans++;
    } else {
        vis[ai[T[x]]]--;
        if (vis[ai[T[x]]] == 0) nans--;
    }
}
void clean() {
    ms(dep, 0), sz = 0;
}
int solve() {
    clean(); 
    blolen = (int)sqrt(2 * n);
    for (int i = 1; i <= n; i++) scanf("%d", &ai[i]), tax[i] = ai[i];
    for (int i = 1; i <= 2 * n; i++) bl[i] = (i - 1) / blolen + 1;
    sort(tax + 1, tax + 1 + n), whw = unique(tax + 1, tax + 1 + n) - tax - 1;
    for (int i = 1; i <= n; i++) ai[i] = lower_bound(tax + 1, tax + 1 + whw, ai[i]) - tax;
    for (int u, v, i = 1; i < n; i++) scanf("%d%d", &u, &v), ins(u, v);
    dfs(1, 0);
    for (int i = 1; i <= m; i++) scanf("%d%d", &xw[i].u, &xw[i].v), xw[i].id = i, xw[i].lca = LCA(xw[i].u, xw[i].v); 
    for (int i = 1; i <= m; i++) {
        if (st[xw[i].u] > st[xw[i].v]) swap(xw[i].u, xw[i].v);
        if (xw[i].lca != xw[i].u) xw[i].l = ed[xw[i].u], xw[i].r = st[xw[i].v];
        else xw[i].l = st[xw[i].u], xw[i].r = st[xw[i].v];
    }
    sort(xw + 1, xw + 1 + m);
    ms(tax, 0), ms(vis, 0), nl = 1, nr = 0, nans = 0;
    for (int i = 1; i <= m; i++) {
        while (nl < xw[i].l) adjust(nl), nl++;
        while (nl > xw[i].l) adjust(nl - 1), nl--;
        while (nr < xw[i].r) adjust(nr + 1), nr++;
        while (nr > xw[i].r) adjust(nr), nr--;
        ans[xw[i].id] = nans;
        if (xw[i].lca != xw[i].u) {
            if (vis[ai[xw[i].lca]] == 0) ans[xw[i].id]++;
        }
    }
    for (int i = 1; i <= m; i++) printf("%d\n", ans[i]);
    return 0; 
}
int main() {
    scanf("%d%d", &n, &m), solve();
    return 0;
}
/*
8 3
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
1 8
6 8
5 4
*/
------ 本文结束 ------