模板及讲解
点分治
点分治就是处理树上路径的一个树上分治算法。
其基本思路与分治类似。
模板题
给出一颗无根树,求树上距离小于等于 $k$ 的点对(路径)有多少个。
$n \leq 10000$
基本思路
点分治基本步骤:
1、DFS 每一棵子树 (以重心为根保证复杂度)
2、对当前 DFS 的子树进行求值 (calc
)
3、求值完删除当前重心继续往下 DFS 子树,顺便对子树进行求值后容斥
求值方法 (因题而异,用 Poj 1741 为例):
1、DFS 求出所有连通点到当前子树根的距离
2、将距离排序,用两点法等对信息进行处理
点分治核心函数是 calc
代码实现
函数
void dfs(int u)
:DFS 求解子树,根为 $u$int calc(int u, int D)
:计算以$u$为根子树的答案,$(fa[u], u)$边权为$D$void dfsD(int u, int D, int fa)
:计算以$u$为根子树节点到 $u$ (根)的距离数组,$D,fa$为临时局部变量void getRt(int u, int fa)
:求解以$u$为根的子树的重心代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<set>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace flyinthesky {
const int MAXN = 10000 + 5;
const int INF = 2147483647;
struct edge {
int v, w, nxt;
}ed[MAXN * 2];
int n, k, ans, en, hd[MAXN];
int vis[MAXN], siz[MAXN], rt, wt[MAXN], Tsz; // 访问数组 (点分治中删除的点设为 1),子树大小,当前根,以 i 为重心的最大深度(用作求重心),当前树大小
int arr[MAXN], cnt; // 距离数组 arr ,在 dfsD() 和 calc() 中使用
void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}
void getRt(int u, int fa) { // 找 u 子树的重心
siz[u] = 1; wt[u] = 0;
for (int i = hd[u]; i > 0; i = ed[i].nxt) {
edge &e = ed[i];
if (e.v != fa && !vis[e.v]) getRt(e.v, u), siz[u] += siz[e.v], wt[u] = max(wt[u], siz[e.v]);
}
wt[u] = max(wt[u], Tsz - siz[u]);
if(wt[rt] > wt[u]) rt = u;
}
void dfsD(int u, int D, int fa) { // DFS 求子树到 u (根)的距离数组 arr
arr[++cnt] = D;
for (int i = hd[u]; i > 0; i = ed[i].nxt) {
edge &e = ed[i];
if (e.v != fa && !vis[e.v]) dfsD(e.v, D + e.w, u);
}
}
int calc(int u, int D) { // 计算 u 子树的答案,(fa[u], u) 边权为 D
cnt = 0; dfsD(u, D, 0); // 计算 子树到 u (根)的距离数组 arr
int l = 1, r = cnt, sum = 0;
sort(arr + 1, arr + cnt + 1);
for( ; ; ++l) {
while (r && arr[l] + arr[r] > k) --r;
if(r < l) break;
sum += r - l + 1;
} // 两点法求答案
return sum;
}
void dfs(int u) { // DFS 求解
ans += calc(u, 0); // 求解当前树
vis[u] = 1; // 记录当前删除标记
for (int i = hd[u]; i > 0; i = ed[i].nxt) {
edge &e = ed[i];
if (!vis[e.v]) {
ans -= calc(e.v, e.w); // 容斥,减去无用答案
rt = 0, Tsz = siz[e.v], getRt(e.v, 0); // 找子树的重心作为子树根来进行下一步求解
dfs(rt);
}
}
}
void clean() {
ans = en = 0, ms(hd, -1), ms(vis, 0), ms(wt, 0);
}
int solve() {
clean();
for (int u, v, w, i = 1; i < n; ++i) scanf("%d%d%d", &u, &v, &w), ins(u, v, w), ins(v, u, w);
wt[0] = INF, Tsz = n, getRt(1, 0); // 第一次找重心作为根
dfs(rt);
printf("%d\n", ans - n); // 减去一个点单独成路径
return 0;
}
}
int main() {
while(scanf("%d%d", &flyinthesky::n, &flyinthesky::k) == 2 && (flyinthesky::n || flyinthesky::k)) flyinthesky::solve();
return 0;
}
例题
树上距离为 $ k $ 的点对
询问树上距离为 $ k $ 的点对是否存在。
解:直接上点分治,算答案的时候用桶存答案,然后乘法原理即可。
每次询问做一次点分治或者打表。注意边权为 0 的情况
Ps:CF 161D 数据更强
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<set>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace flyinthesky {
const int MAXN = 10000 + 5;
const int INF = 2147483647;
struct edge {
int v, w, nxt;
}ed[MAXN * 2];
int n, q, k, en, hd[MAXN];
int vis[MAXN], wt[MAXN], Tsz, rt, siz[MAXN];
int cnt, arr[MAXN], tax[10000005], whw[10000005];
LL ans;
void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}
void getRt(int u, int fa) {
wt[u] = 0, siz[u] = 1;
for (int i = hd[u]; i > 0; i = ed[i].nxt) {
edge &e = ed[i];
if (e.v != fa && !vis[e.v]) getRt(e.v, u), siz[u] += siz[e.v], wt[u] = max(wt[u], siz[e.v]);
}
wt[u] = max(wt[u], Tsz - siz[u]);
if (wt[rt] > wt[u]) rt = u;
}
void dfsD(int u, int fa, int D) {
arr[++cnt] = D;
for (int i = hd[u]; i > 0; i = ed[i].nxt) {
edge &e = ed[i];
if (e.v != fa && !vis[e.v]) dfsD(e.v, u, D + e.w);
}
}
LL calc(int u, int D) {
cnt = 0; dfsD(u, 0, D);
LL sum = 0;
for (int i = 1; i <= cnt; ++i) ++tax[arr[i]];
for (int i = 1; i <= cnt; ++i) {
if (!whw[arr[i]] && k >= arr[i]) {
if (arr[i] == k - arr[i]) sum += tax[arr[i]] * (tax[arr[i]] - 1) / 2; // 写对两两配对不重复的公式
else sum += tax[k - arr[i]] * tax[arr[i]];
whw[arr[i]] = 1, whw[k - arr[i]] = 1;
}
}
for (int i = 1; i <= cnt; ++i) --tax[arr[i]];
for (int i = 1; i <= cnt; ++i) {
whw[arr[i]] = 0;
if (k >= arr[i]) whw[k - arr[i]] = 0;
}
return sum;
}
void dfs(int u) {
ans += calc(u, 0);
vis[u] = 1;
for (int i = hd[u]; i > 0; i = ed[i].nxt) {
edge &e = ed[i];
if (!vis[e.v]) {
ans -= calc(e.v, e.w);
rt = 0, Tsz = siz[e.v], getRt(e.v, u);
dfs(rt);
}
}
}
void clean() {
en = 0, ms(hd, -1), ms(tax, 0);
}
int solve() {
scanf("%d%d", &n, &q);
clean();
for (int u, v, w, i = 1; i < n; ++i) scanf("%d%d%d", &u, &v, &w), ins(u, v, w), ins(v, u, w);
while (q--) {
ans = 0, ms(vis, 0);
scanf("%d", &k);
rt = 0, Tsz = n, wt[0] = INF, getRt(1, 0);
dfs(1);
printf("%s\n", ans > 0 ? "AYE" : "NAY");
}
return 0;
}
}
int main() {
flyinthesky::solve();
return 0;
}
/*
5 50
1 2 3
1 3 1
1 4 2
3 5 1
4
2 100
1 2 0
0
*/
树上距离模 $ k $ / 是 $ k $ 的倍数的点对
询问树上距离是 $ k $ 的倍数的点对。
解:直接上点分治,算答案的时候用桶存答案,然后乘法原理即可。
是 $ k $ 的倍数相当于模$ k $为 $ 0 $
不用容斥实现的点分治
给一棵树,每条边有权。求一条简单路径,权值和等于 $K$,且边的数量最小。
点分治模板题,但是这里维护的是不可加信息,不能用容斥的方法,我们求每个点的值的时候按一定顺序遍历他的儿子,然后将前面的存起来供后面的合并(类似树直径DP),然后就不用考虑容斥了,都是可行的合并。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<queue>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db long double
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace flyinthesky {
const int MAXN = 200000 + 5, INF = 0x3f3f3f3f;
struct edge {
int v, w, nxt;
} ed[MAXN * 2];
int n, k, en, hd[MAXN];
int wt[MAXN], siz[MAXN], Tsz, rt;
int vis[MAXN];
int ans;
int tot, dis[MAXN], bs[MAXN];
void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}
void getRt(int u, int fa) { // 找 u 子树的重心
siz[u] = 1, wt[u] = 0;
for (int i = hd[u]; i > 0; i = ed[i].nxt) {
edge &e = ed[i];
if (e.v != fa && !vis[e.v]) getRt(e.v, u), siz[u] += siz[e.v], wt[u] = max(wt[u], siz[e.v]);
}
wt[u] = max(wt[u], Tsz - siz[u]);
if(wt[rt] > wt[u]) rt = u;
}
void getD(int u, int fa, int D, int b) {
if (D > k) return ;
dis[++tot] = D, bs[tot] = b;
for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
edge &e = ed[i];
if (e.v != fa && !vis[e.v]) getD(e.v, u, D + e.w, b + 1);
}
}
int tax[1000000 + 5];
void calc(int u) {
tot = 0, tax[0] = 0;
for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
edge &e = ed[i];
if (!vis[e.v]) {
int whw = tot;
getD(e.v, u, e.w, 1);
for (int j = whw + 1; j <= tot; ++j) ans = min(ans, tax[k - dis[j]] + bs[j]);
for (int j = whw + 1; j <= tot; ++j) tax[dis[j]] = min(tax[dis[j]], bs[j]);
}
}
for (int i = 1; i <= tot; ++i) tax[dis[i]] = INF;
}
void dfs(int u) {
vis[u] = 1;
calc(u);
for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
edge &e = ed[i];
if (!vis[e.v]) {
rt = 0, Tsz = siz[e.v], wt[0] = INF, getRt(e.v, 0);
dfs(rt); // rt
}
}
}
void clean() {
en = -1, ms(hd, -1);
ms(wt, 0), ms(siz, 0), ms(vis, 0);
ms(tax, 0x3f), ans = INF;
}
int solve() {
clean();
cin >> n >> k;
if (k == 0) return printf("0\n"), 0;
for (int x, y, w, i = 1; i < n; ++i) {
scanf("%d%d%d", &x, &y, &w);
++x, ++y;
ins(x, y, w), ins(y, x, w);
}
rt = 0, Tsz = n, wt[0] = INF, getRt(1, 0);
dfs(rt);
if (ans >= n) printf("-1\n"); else printf("%d\n", ans);
return 0;
}
}
int main() {
flyinthesky::solve();
return 0;
}
写时注意:
1、找完根后是dfs(rt)
而不是dfs(v)