BZOJ 5314
题意:见上。
设$dp(u,j,0/1,0/1)$为$u$子树放了$j$个,$u$放不放,$u$有没有信号的方案。
然后按照树形背包DP的方法转移即可,注意要用$dp[u][i+j][][] \Leftarrow dp[u][i][][], dp[v][j][][]$来保证复杂度。
这种树形背包dp就是形如$dp[x][j+k]=\sum dp[x][j] \times dp[y][k]$的
这里有对这样的树形背包dp复杂度为$O(nk)$的证明。证明用了一个基本性质即在$\text{LCA}$上算贡献,可以将复杂度$O(n^3)$分析到$O(n^2)$
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#include<queue>
#include<set>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace flyinthesky {
const int MAXN = 100000 + 5;
const LL MO = 1000000007;
int n, m, dp[MAXN][102][2][2], gg[102][2][2], siz[MAXN];
vector<int > G[MAXN];
void dfs(int u, int fa) {
siz[u] = 1;
dp[u][0][0][0] = 1, dp[u][1][1][0] = 1;
for (int o = 0; o < (int)G[u].size(); ++o) {
int v = G[u][o];
if (v != fa) {
dfs(v, u);
LL t = 0;
memcpy(gg, dp[u], sizeof dp[u]), ms(dp[u], 0);
for (int i = 0; i <= min(siz[v], m); ++i) {
for (int k = 0; k <= min(siz[u], m) && i + k <= m; ++k) {
int j = i + k;
t = 1ll * gg[k][0][0] * dp[v][j - k][0][1] % MO;
dp[u][j][0][0] = 1ll * (1ll * dp[u][j][0][0] + t) % MO;
t = 1ll * gg[k][0][1] * ((dp[v][j - k][0][1] + dp[v][j - k][1][1]) % MO) % MO;
t = (t + 1ll * gg[k][0][0] * dp[v][j - k][1][1]) % MO;
dp[u][j][0][1] = 1ll * (1ll * dp[u][j][0][1] + t) % MO;
t = 1ll * gg[k][1][0] * ((dp[v][j - k][0][0] + dp[v][j - k][0][1]) % MO) % MO;
dp[u][j][1][0] = 1ll * (1ll * dp[u][j][1][0] + t) % MO;
LL gg1 = (1ll * dp[v][j - k][0][0] + 1ll * dp[v][j - k][0][1] + 1ll * dp[v][j - k][1][0] + 1ll * dp[v][j - k][1][1]) % MO;
LL gg2 = (1ll * dp[v][j - k][1][0] + 1ll * dp[v][j - k][1][1]) % MO;
t = 1ll * gg[k][1][1] * gg1 % MO;
t = (t + 1ll * gg[k][1][0] * gg2 % MO) % MO;
dp[u][j][1][1] = 1ll * (1ll * dp[u][j][1][1] + t) % MO;
}
}
siz[u] += siz[v];
}
}
}
void clean() {
ms(gg, 0);
}
int solve() {
clean();
cin >> n >> m;
for (int x, y, i = 1; i < n; ++i) {
scanf("%d%d", &x, &y);
G[x].push_back(y);
G[y].push_back(x);
}
dfs(1, 0);
LL ans = dp[1][m][0][1] + dp[1][m][1][1];
ans %= MO;
printf("%lld\n", ans);
return 0;
}
}
int main() {
flyinthesky::solve();
return 0;
}