「Loj 2478」「九省联考 2018」林克卡特树 (树形DP + 带权二分)

Loj 2478
题意:给定一棵 $n$ 个点的树,边权有正有负,要求在树上选出 $k+1$ 条链,使得其权值之和最大。

考虑DP。设$dp(i,j,0/1/2)$分别为$i$点子树$j$条完整链,当前 $i$ 节点的度数为 $0/1/2$ 的最大价值。度数为 $0$ 时,这个点没有链的连边。度数为 $1$ 时,这个点拖着一条未完成的链,而这条链不计入 $j$ 。度数为 $2$ 时,这个点被一条连接两个不同子树的链穿过,计入$j$。

转移见代码,状态非常经典重要。

考虑带权二分优化DP,即我们发现这个答案是关于$k, ans$的上凸函数,所以我们可以二分斜率切这个上凸函数,然后计算$k,ans$,每个物品都要减去二分的斜率值,然后直到找到极值才输出,注意上凸函数切点越左斜率越大

知识点:

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#include<queue>
#include<string>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 300000 + 5;

    struct data {
        LL v;
        int k;
        data(LL x = 0, LL y = 0) : v(x), k(y) {}
        data operator + (const data &rhs) const {return data(v + rhs.v, k + rhs.k);}
        data operator += (const data &rhs) {return *this = data(v + rhs.v, k + rhs.k);}
        data operator + (const int &rhs) const {return data(v + rhs, k);}
        data operator += (const int &rhs) {return *this = data(v + rhs, k);}
        bool operator < (const data &rhs) const {return v == rhs.v ? k > rhs.k : v < rhs.v;}
    } g;

    int n, k;
    data dp[MAXN][3];
    vector<int > G[MAXN], c[MAXN];

    void dfs(int u, int fa) {
        dp[u][0] = dp[u][1] = data();
        dp[u][2] = g;
        for (int i = 0; i < (int)G[u].size(); ++i) {
            int v = G[u][i], ci = c[u][i];
            if (v != fa) {
                dfs(v, u);
                dp[u][2] = max(dp[u][2], max(dp[u][2] + dp[v][0], dp[u][1] + dp[v][1] + g + ci));
                dp[u][1] = max(dp[u][1], max(dp[u][1] + dp[v][0], dp[u][0] + dp[v][1] + ci));
                dp[u][0] = max(dp[u][0], dp[u][0] + dp[v][0]);
            }
        }
        dp[u][0] = max(dp[u][0], max(dp[u][1] + g, dp[u][2]));
    }

    void clean() {
    }
    int solve() {

        clean();
        cin >> n >> k; ++k;
        for (int x, y, w, i = 1; i < n; ++i) {
            scanf("%d%d%d", &x, &y, &w);
            G[x].push_back(y), G[y].push_back(x);
            c[x].push_back(w), c[y].push_back(w);
        }
        LL l = -1e12, r = 1e12;
        while (l < r) {
            LL mid = (db)(l + r) / 2.0 - 0.5;
            g = data(-mid, 1);
            dfs(1, 0);
            if (dp[1][0].k == k) {
                return printf("%lld\n", dp[1][0].v + mid * k), 0;
            }
            if (dp[1][0].k < k) r = mid;
            else l = mid + 1;
        }
        g = data(-l, 1), dfs(1, 0);
        printf("%lld\n", dp[1][0].v + l * k);

        return 0;
    }  
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------