「Bzoj 4033」「HAOI2015」树上染色 (贡献型树形DP)

Bzoj 4033
题意:有一棵点数为$N$的树,树边有边权。给你一个在$[0,N]$之内的正整数$K$,你要在这棵树中选择$K$个点,将其染成黑色,并将其他的$N-K$个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。问收益最大值是多少。

本题抛开颜色就是一个树边贡献模板题。
然后这里加入颜色的话就肯定要DP,DP方程受前面的启发设为
$dp(u,j)$为以$u$点为根子树,有$j$个黑点对答案的贡献。
然后转移就是
$$
dp(u,j)=\max(dp(u, j - k)+dp(v, k)+val \times w)
$$

其中$w$为贡献,$val$为
$$
val=k \cdot (m - k) + (siz(v)-k) \cdot (n - m - (sz(v) - k))
$$

即考虑边$(u, v)$的贡献,左边有几个黑点,右边有几个白点,同色点相乘,异色点相加即可。

然后注意细节,$dp$初始化为$-∞$,然后要判
if (dp[u][j - k] >= 0)

并且每个循环的上界都要加上$siz$的限制,以确保复杂度大概为$O(n^2)$

for (int j = min(siz[u], m); j >= 0; --j)
for (int k = 0; k <= min(siz[e.v], j); ++k)
#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<set>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 2000 + 5;

    struct edge {
        int v, w, nxt;
    } ed[MAXN * 2];

    int n, m, en, hd[MAXN], siz[MAXN];
    LL dp[MAXN][MAXN];

    void dfs(int u, int fa) {
        siz[u] = 1, dp[u][0] = dp[u][1] = 0;
        for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa) {
                dfs(e.v, u);
                siz[u] += siz[e.v];
                for (int j = min(siz[u], m); j >= 0; --j) { //
                    for (int k = 0; k <= min(siz[e.v], j); ++k) { //
                        LL val = 1ll * k * (m - k) + 1ll * (siz[e.v] - k) * (n - m - (siz[e.v] - k));
                        if (dp[u][j - k] >= 0) //
                            dp[u][j] = max(dp[u][j], dp[u][j - k] + dp[e.v][k] + val * e.w);
                    }
                }
            }
        }
    }

    void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}

    void clean() {
        ms(hd, -1), en = -1, ms(dp, -1);
    }
    int solve() {

        clean();
        scanf("%d%d", &n, &m);
        for (int u, v, w, i = 1; i < n; ++i) {
               scanf("%d%d%d", &u, &v, &w);
            ins(u, v, w), ins(v, u, w);
        }

        dfs(1, 0);

        cout << dp[1][m];

        return 0;
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------