Loj 2478
题意:给定一棵 $n$ 个点的树,边权有正有负,要求在树上选出 $k+1$ 条链,使得其权值之和最大。
考虑DP。设$dp(i,j,0/1/2)$分别为$i$点子树$j$条完整链,当前 $i$ 节点的度数为 $0/1/2$ 的最大价值。度数为 $0$ 时,这个点没有链的连边。度数为 $1$ 时,这个点拖着一条未完成的链,而这条链不计入 $j$ 。度数为 $2$ 时,这个点被一条连接两个不同子树的链穿过,计入$j$。
转移见代码,状态非常经典重要。
考虑带权二分优化DP,即我们发现这个答案是关于$k, ans$的上凸函数,所以我们可以二分斜率切这个上凸函数,然后计算$k,ans$,每个物品都要减去二分的斜率值,然后直到找到极值才输出,注意上凸函数切点越左斜率越大
知识点:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#include<queue>
#include<string>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace flyinthesky {
const int MAXN = 300000 + 5;
struct data {
LL v;
int k;
data(LL x = 0, LL y = 0) : v(x), k(y) {}
data operator + (const data &rhs) const {return data(v + rhs.v, k + rhs.k);}
data operator += (const data &rhs) {return *this = data(v + rhs.v, k + rhs.k);}
data operator + (const int &rhs) const {return data(v + rhs, k);}
data operator += (const int &rhs) {return *this = data(v + rhs, k);}
bool operator < (const data &rhs) const {return v == rhs.v ? k > rhs.k : v < rhs.v;}
} g;
int n, k;
data dp[MAXN][3];
vector<int > G[MAXN], c[MAXN];
void dfs(int u, int fa) {
dp[u][0] = dp[u][1] = data();
dp[u][2] = g;
for (int i = 0; i < (int)G[u].size(); ++i) {
int v = G[u][i], ci = c[u][i];
if (v != fa) {
dfs(v, u);
dp[u][2] = max(dp[u][2], max(dp[u][2] + dp[v][0], dp[u][1] + dp[v][1] + g + ci));
dp[u][1] = max(dp[u][1], max(dp[u][1] + dp[v][0], dp[u][0] + dp[v][1] + ci));
dp[u][0] = max(dp[u][0], dp[u][0] + dp[v][0]);
}
}
dp[u][0] = max(dp[u][0], max(dp[u][1] + g, dp[u][2]));
}
void clean() {
}
int solve() {
clean();
cin >> n >> k; ++k;
for (int x, y, w, i = 1; i < n; ++i) {
scanf("%d%d%d", &x, &y, &w);
G[x].push_back(y), G[y].push_back(x);
c[x].push_back(w), c[y].push_back(w);
}
LL l = -1e12, r = 1e12;
while (l < r) {
LL mid = (db)(l + r) / 2.0 - 0.5;
g = data(-mid, 1);
dfs(1, 0);
if (dp[1][0].k == k) {
return printf("%lld\n", dp[1][0].v + mid * k), 0;
}
if (dp[1][0].k < k) r = mid;
else l = mid + 1;
}
g = data(-l, 1), dfs(1, 0);
printf("%lld\n", dp[1][0].v + l * k);
return 0;
}
}
int main() {
flyinthesky::solve();
return 0;
}