点分治 学习笔记

模板及讲解

点分治

点分治就是处理树上路径的一个树上分治算法。
其基本思路与分治类似。

模板题

Poj 1741

给出一颗无根树,求树上距离小于等于 $k$ 的点对(路径)有多少个。
$n \leq 10000$

基本思路

点分治基本步骤:
1、DFS 每一棵子树 (以重心为根保证复杂度)
2、对当前 DFS 的子树进行求值 (calc)
3、求值完删除当前重心继续往下 DFS 子树,顺便对子树进行求值后容斥

求值方法 (因题而异,用 Poj 1741 为例):
1、DFS 求出所有连通点到当前子树根的距离
2、将距离排序,用两点法等对信息进行处理

点分治核心函数是 calc

代码实现

函数

  • void dfs(int u):DFS 求解子树,根为 $u$
  • int calc(int u, int D):计算以$u$为根子树的答案,$(fa[u], u)$边权为$D$
  • void dfsD(int u, int D, int fa):计算以$u$为根子树节点到 $u$ (根)的距离数组,$D,fa$为临时局部变量
  • void getRt(int u, int fa):求解以$u$为根的子树的重心

    代码

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<set>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 10000 + 5;
    const int INF = 2147483647;

    struct edge {
        int v, w, nxt;
    }ed[MAXN * 2];

    int n, k, ans, en, hd[MAXN]; 
    int vis[MAXN], siz[MAXN], rt, wt[MAXN], Tsz; // 访问数组 (点分治中删除的点设为 1),子树大小,当前根,以 i 为重心的最大深度(用作求重心),当前树大小 
    int arr[MAXN], cnt; // 距离数组 arr ,在 dfsD() 和 calc() 中使用 

    void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}

    void getRt(int u, int fa) { // 找 u 子树的重心 
        siz[u] = 1; wt[u] = 0;
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa && !vis[e.v]) getRt(e.v, u), siz[u] += siz[e.v], wt[u] = max(wt[u], siz[e.v]);
        }
        wt[u] = max(wt[u], Tsz - siz[u]);
        if(wt[rt] > wt[u]) rt = u;
    }
    void dfsD(int u, int D, int fa) { // DFS 求子树到 u (根)的距离数组 arr 
        arr[++cnt] = D;
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa && !vis[e.v]) dfsD(e.v, D + e.w, u);
        }
    }
    int calc(int u, int D) { // 计算 u 子树的答案,(fa[u], u) 边权为 D 
        cnt = 0; dfsD(u, D, 0); // 计算 子树到 u (根)的距离数组 arr 
        int l = 1, r = cnt, sum = 0;
        sort(arr + 1, arr + cnt + 1);
        for( ; ; ++l) {
            while (r && arr[l] + arr[r] > k) --r;
            if(r < l) break;
            sum += r - l + 1;
        } // 两点法求答案 
        return sum;
    }
    void dfs(int u) { // DFS 求解 
        ans += calc(u, 0); // 求解当前树 
        vis[u] = 1; // 记录当前删除标记 
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (!vis[e.v]) {
                ans -= calc(e.v, e.w); // 容斥,减去无用答案 
                rt = 0, Tsz = siz[e.v], getRt(e.v, 0); // 找子树的重心作为子树根来进行下一步求解 
                dfs(rt);
            }
        }
    }

    void clean() {
        ans = en = 0, ms(hd, -1), ms(vis, 0), ms(wt, 0);
    }
    int solve() {
        clean();
        for (int u, v, w, i = 1; i < n; ++i) scanf("%d%d%d", &u, &v, &w), ins(u, v, w), ins(v, u, w);
        wt[0] = INF, Tsz = n, getRt(1, 0); // 第一次找重心作为根 
        dfs(rt);
        printf("%d\n", ans - n); // 减去一个点单独成路径 
        return 0; 
    }
}
int main() {
    while(scanf("%d%d", &flyinthesky::n, &flyinthesky::k) == 2 && (flyinthesky::n || flyinthesky::k)) flyinthesky::solve();
    return 0;
}

例题

树上距离为 $ k $ 的点对

Luogu 3806

询问树上距离为 $ k $ 的点对是否存在。

解:直接上点分治,算答案的时候用桶存答案,然后乘法原理即可。
每次询问做一次点分治或者打表。注意边权为 0 的情况

Ps:CF 161D 数据更强

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<set>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 10000 + 5;
    const int INF = 2147483647; 

    struct edge {
        int v, w, nxt;
    }ed[MAXN * 2];

    int n, q, k, en, hd[MAXN];
    int vis[MAXN], wt[MAXN], Tsz, rt, siz[MAXN];
    int cnt, arr[MAXN], tax[10000005], whw[10000005];
    LL ans;

    void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}

    void getRt(int u, int fa) {
        wt[u] = 0, siz[u] = 1;
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa && !vis[e.v]) getRt(e.v, u), siz[u] += siz[e.v], wt[u] = max(wt[u], siz[e.v]);
        }
        wt[u] = max(wt[u], Tsz - siz[u]);
        if (wt[rt] > wt[u]) rt = u;
    }
    void dfsD(int u, int fa, int D) {
        arr[++cnt] = D;
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa && !vis[e.v]) dfsD(e.v, u, D + e.w);
        }
    }
    LL calc(int u, int D) {
        cnt = 0; dfsD(u, 0, D);
        LL sum = 0;
        for (int i = 1; i <= cnt; ++i) ++tax[arr[i]];
        for (int i = 1; i <= cnt; ++i) {
            if (!whw[arr[i]] && k >= arr[i]) {
                if (arr[i] == k - arr[i]) sum += tax[arr[i]] * (tax[arr[i]] - 1) / 2; // 写对两两配对不重复的公式
                else sum += tax[k - arr[i]] * tax[arr[i]];
                whw[arr[i]] = 1, whw[k - arr[i]] = 1;
            }
        }
        for (int i = 1; i <= cnt; ++i) --tax[arr[i]];
        for (int i = 1; i <= cnt; ++i) {
            whw[arr[i]] = 0;
            if (k >= arr[i]) whw[k - arr[i]] = 0;
        }
        return sum;
    }
    void dfs(int u) {
        ans += calc(u, 0);
        vis[u] = 1;
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (!vis[e.v]) {
                ans -= calc(e.v, e.w);
                rt = 0, Tsz = siz[e.v], getRt(e.v, u);
                dfs(rt);
            }
        }
    }

    void clean() {
        en = 0, ms(hd, -1), ms(tax, 0);
    }
    int solve() {
        scanf("%d%d", &n, &q);
        clean();
        for (int u, v, w, i = 1; i < n; ++i) scanf("%d%d%d", &u, &v, &w), ins(u, v, w), ins(v, u, w);
        while (q--) {
            ans = 0, ms(vis, 0);
            scanf("%d", &k);
            rt = 0, Tsz = n, wt[0] = INF, getRt(1, 0);
            dfs(1);
            printf("%s\n", ans > 0 ? "AYE" : "NAY");
        }
        return 0; 
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
/*
5 50
1 2 3
1 3 1
1 4 2
3 5 1
4

2 100
1 2 0
0
*/

树上距离模 $ k $ / 是 $ k $ 的倍数的点对

Bzoj 2152

询问树上距离是 $ k $ 的倍数的点对。

解:直接上点分治,算答案的时候用桶存答案,然后乘法原理即可。
是 $ k $ 的倍数相当于模$ k $为 $ 0 $

不用容斥实现的点分治

BZOJ 2559

给一棵树,每条边有权。求一条简单路径,权值和等于 $K$,且边的数量最小。

点分治模板题,但是这里维护的是不可加信息,不能用容斥的方法,我们求每个点的值的时候按一定顺序遍历他的儿子,然后将前面的存起来供后面的合并(类似树直径DP),然后就不用考虑容斥了,都是可行的合并。

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<queue>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db long double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 200000 + 5, INF = 0x3f3f3f3f;

    struct edge {
        int v, w, nxt;
    } ed[MAXN * 2];

    int n, k, en, hd[MAXN];
    int wt[MAXN], siz[MAXN], Tsz, rt;
    int vis[MAXN];
    int ans;
    int tot, dis[MAXN], bs[MAXN];

    void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}

    void getRt(int u, int fa) { // 找 u 子树的重心 
        siz[u] = 1, wt[u] = 0;
        for (int i = hd[u]; i > 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa && !vis[e.v]) getRt(e.v, u), siz[u] += siz[e.v], wt[u] = max(wt[u], siz[e.v]);
        }
        wt[u] = max(wt[u], Tsz - siz[u]);
        if(wt[rt] > wt[u]) rt = u;
    }
    void getD(int u, int fa, int D, int b) {
        if (D > k) return ;
        dis[++tot] = D, bs[tot] = b;
        for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa && !vis[e.v]) getD(e.v, u, D + e.w, b + 1);
        }
    }
    int tax[1000000 + 5];
    void calc(int u) {
        tot = 0, tax[0] = 0;
        for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (!vis[e.v]) {
                int whw = tot;
                getD(e.v, u, e.w, 1);
                for (int j = whw + 1; j <= tot; ++j) ans = min(ans, tax[k - dis[j]] + bs[j]);
                for (int j = whw + 1; j <= tot; ++j) tax[dis[j]] = min(tax[dis[j]], bs[j]);
            }
        }
        for (int i = 1; i <= tot; ++i) tax[dis[i]] = INF;
    }
    void dfs(int u) {
        vis[u] = 1;
        calc(u);
        for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (!vis[e.v]) {
                rt = 0, Tsz = siz[e.v], wt[0] = INF, getRt(e.v, 0);
                dfs(rt); // rt
            }
        }
    }

    void clean() {
        en = -1, ms(hd, -1);
        ms(wt, 0), ms(siz, 0), ms(vis, 0);
        ms(tax, 0x3f), ans = INF;
    }
    int solve() {

        clean();
        cin >> n >> k;
        if (k == 0) return printf("0\n"), 0;
        for (int x, y, w, i = 1; i < n; ++i) {
            scanf("%d%d%d", &x, &y, &w);
            ++x, ++y;
            ins(x, y, w), ins(y, x, w);
        }
        rt = 0, Tsz = n, wt[0] = INF, getRt(1, 0);
        dfs(rt);

        if (ans >= n) printf("-1\n"); else printf("%d\n", ans);

        return 0;
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}

写时注意:
1、找完根后是dfs(rt)而不是dfs(v)

常见题型

------ 本文结束 ------