spoj QTREE2
把一条路径根据$LCA$拆成两条链,然后判断第$k$个点在哪条链上即可
注意求LCA的时候要判断深度大小,以及树上的路径权值和不要误认为是距离了
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define pb push_back
using namespace std;
const int MAXN = 20000 + 5, logs = 20;
struct data {LL v, wi;}ed[MAXN * 2];
LL en, n, fa[MAXN], dep[MAXN], pre[MAXN][logs + 1], far[MAXN];
vector<LL> G[MAXN];
void clean() {
en = 0;
for (LL i = 0; i <= n; i++) {
far[i] = fa[i] = dep[i] = 0, G[i].clear();
for (LL j = 0; j <= logs; j++) pre[i][j] = 0;
}
}
void ins(LL u, LL v, LL c) {
en++, ed[en] = (data){v, c}, G[u].pb(en);
en++, ed[en] = (data){u, c}, G[v].pb(en);
}
void dfs(LL u, LL pa) {
dep[u] = dep[pa] + 1, fa[u] = pa, pre[u][0] = pa;
for (LL i = 1; i <= logs; i++) pre[u][i] = pre[pre[u][i - 1]][i - 1];
for (LL i = 0; i < (LL)G[u].size(); i++) {
LL v = ed[G[u][i]].v, wi = ed[G[u][i]].wi;
if (v != pa) {
far[v] = far[u] + wi;
dfs(v, u);
}
}
}
LL getlca(LL a, LL b) {
if (dep[a] < dep[b]) swap(a, b);
for (LL i = logs; i >= 0; i--) if (dep[pre[a][i]] >= dep[b]) a = pre[a][i];
if (a == b) return a;
for (LL i = logs; i >= 0; i--) if (pre[a][i] != pre[b][i]) a = pre[a][i], b = pre[b][i];
return pre[a][0];
}
LL dist(LL u, LL v) {
LL lca = getlca(u, v);
return far[u] + far[v] - 2 * far[lca];
}
LL pc(LL u, LL k) {
LL dt = dep[u] - k + 1;
for (LL i = logs; i >= 0; i--) if (dep[pre[u][i]] >= dt) u = pre[u][i];
return u;
}
LL kth(LL u, LL v, LL w) {
LL lca = getlca(u, v);
LL dt = dep[u] - dep[lca];
if (dt + 1 >= w) {
return pc(u, w);
} else return pc(v, dep[u] + dep[v] - 2 * dep[lca] - w + 2);
}
void solve() {
scanf("%lld", &n);
clean();
for (LL x, y, c, i = 1; i < n; i++) {
scanf("%lld%lld%lld", &x, &y, &c);
ins(x, y, c);
}
dfs(1, 0);
char s[10];
while (true) {
scanf("%s", s);
if (s[1] == 'O') break;
if (s[1] == 'I') {
LL u, v;
scanf("%lld%lld", &u, &v);
printf("%lld\n", dist(u, v));
} else {
if (s[1] == 'T') {
LL u, v, k;
scanf("%lld%lld%lld", &u, &v, &k);
printf("%lld\n", kth(u, v, k));
}
}
}
}
int main() {
LL T; scanf("%lld", &T);
while (T--) solve();
return 0;
}