Bzoj 2152(点分治)

Bzoj 2152
题意:询问树上距离为$3$的倍数的路径条数。

直接上点分治,算答案的时候用桶存答案,然后乘法原理即可。
是 $ k $ 的倍数相当于模 $ k $ 为 $ 0 $

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<set>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 20000 + 5;
    const int INF = 2147483647; 

    struct edge {
        int v, w, nxt;
    }ed[MAXN * 2];

    int n, en, hd[MAXN];
    int siz[MAXN], vis[MAXN], wt[MAXN], rt, Tsz, tax[5];
    LL ans;

    void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}

    void getRt(int u, int fa) {
        wt[u] = 0, siz[u] = 1;
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa && !vis[e.v]) getRt(e.v, u), siz[u] += siz[e.v], wt[u] = max(wt[u], siz[e.v]);
        }
        wt[u] = max(wt[u], Tsz - siz[u]);
        if (wt[rt] > wt[u]) rt = u;
    }
    void dfsD(int u, int fa, int D) {
        ++tax[D % 3];
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa && !vis[e.v]) dfsD(e.v, u, D + e.w);
        }
    }
    LL calc(int u, int D) {
        ms(tax, 0), dfsD(u, 0, D);
        return (LL)tax[1] * tax[2] * 2 + (LL)tax[0] * tax[0];
    }
    void dfs(int u) {
        ans += calc(u, 0);
        vis[u] = 1;
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (!vis[e.v]) {
                ans -= calc(e.v, e.w);
                rt = 0, Tsz = siz[e.v], getRt(e.v, u);
                dfs(rt);
            }
        }
    }
    LL gcd(LL a, LL b) {return b == 0 ? a : gcd(b, a % b);}

    void clean() {
        ans = en = 0, ms(hd, -1), ms(vis, 0);
    }
    int solve() {
        scanf("%d", &n);
        clean();
        for (int u, v, w, i = 1; i < n; ++i) scanf("%d%d%d", &u, &v, &w), ins(u, v, w), ins(v, u, w);
        rt = 0, wt[0] = INF, Tsz = n, getRt(1, 0);
        dfs(1);
        LL fm = (LL)n * n;
        LL g = gcd(ans, fm);
        fm /= g, ans /= g;
        printf("%lld/%lld\n", ans, fm);
        return 0; 
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
/*
5
1 2 1
1 3 3
2 4 2
2 5 1
*/
------ 本文结束 ------