spoj QTREE2(倍增LCA)

spoj QTREE2
把一条路径根据$LCA$拆成两条链,然后判断第$k$个点在哪条链上即可
注意求LCA的时候要判断深度大小,以及树上的路径权值和不要误认为是距离了

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define pb push_back
using namespace std;
const int MAXN = 20000 + 5, logs = 20;
struct data {LL v, wi;}ed[MAXN * 2];
LL en, n, fa[MAXN], dep[MAXN], pre[MAXN][logs + 1], far[MAXN];
vector<LL> G[MAXN];
void clean() {
    en = 0;
    for (LL i = 0; i <= n; i++) {
        far[i] = fa[i] = dep[i] = 0, G[i].clear();
        for (LL j = 0; j <= logs; j++) pre[i][j] = 0;
    }
}
void ins(LL u, LL v, LL c) {
    en++, ed[en] = (data){v, c}, G[u].pb(en);
    en++, ed[en] = (data){u, c}, G[v].pb(en);
}
void dfs(LL u, LL pa) {
    dep[u] = dep[pa] + 1, fa[u] = pa, pre[u][0] = pa;
    for (LL i = 1; i <= logs; i++) pre[u][i] = pre[pre[u][i - 1]][i - 1];
    for (LL i = 0; i < (LL)G[u].size(); i++) {
        LL v = ed[G[u][i]].v, wi = ed[G[u][i]].wi;
        if (v != pa) {
            far[v] = far[u] + wi;
            dfs(v, u);
        }
    }
}
LL getlca(LL a, LL b) {
    if (dep[a] < dep[b]) swap(a, b);
    for (LL i = logs; i >= 0; i--) if (dep[pre[a][i]] >= dep[b]) a = pre[a][i];
    if (a == b) return a;
    for (LL i = logs; i >= 0; i--) if (pre[a][i] != pre[b][i]) a = pre[a][i], b = pre[b][i];
    return pre[a][0];
}
LL dist(LL u, LL v) {
    LL lca = getlca(u, v);
    return far[u] + far[v] - 2 * far[lca];
}
LL pc(LL u, LL k) {
    LL dt = dep[u] - k + 1;
    for (LL i = logs; i >= 0; i--) if (dep[pre[u][i]] >= dt) u = pre[u][i];
    return u;
}
LL kth(LL u, LL v, LL w) {
    LL lca = getlca(u, v);
    LL dt = dep[u] - dep[lca];
    if (dt + 1 >= w) {
        return pc(u, w);
    } else return pc(v, dep[u] + dep[v] - 2 * dep[lca] - w + 2);
}
void solve() {
    scanf("%lld", &n);
    clean();
    for (LL x, y, c, i = 1; i < n; i++) {
        scanf("%lld%lld%lld", &x, &y, &c);
        ins(x, y, c);
    }
    dfs(1, 0);
    char s[10];
    while (true) {
        scanf("%s", s);
        if (s[1] == 'O') break;
        if (s[1] == 'I') {
            LL u, v;
            scanf("%lld%lld", &u, &v);
            printf("%lld\n", dist(u, v));
        } else {
            if (s[1] == 'T') {
                LL u, v, k;
                scanf("%lld%lld%lld", &u, &v, &k);
                printf("%lld\n", kth(u, v, k));
            }
        }
    }
}
int main() {
    LL T; scanf("%lld", &T);
    while (T--) solve();
    return 0;
}
------ 本文结束 ------