Codeforces 834D
题意:给你长度为$n$的一个序列,让你将其分成连续的$k$段,每段的价值为其中数字种类的个数,求最大价值总和。 (This solution has been updated in August 14th, 2018.)
首先可以设$dp(k,i)$为前$i$个数分了$k$个箱子的最优解
显然$dp(k,i)=max(dp(k - 1, j-1) + c_{j, i})$,其中$c_{i, j}$为$[i,j]$中不同颜色的个数
我们可以用线段树维护$c$, 但是这样仍然是$O(kn^2logn)$,过不了
那么我们用线段树(区间下标$[1, j]$)维护$dp(k - 1, j-1) + c_{j, i}$的最大值,这样的话原方程就是$dp(k,i)=query(1, i)$。先用$dp(k - 1)$来建线段树,然后考虑$c_{j, i}$的计算。
我们设$lst_i$为第$i$个数前面第一个与这个数相同的数的位置(没有为$0$)。
我们一个个让数字加入,对于每一个加入的$i$而言$[lst_i + 1, i]$都需要加一,因为这一部分都被新加进来的$i$所影响了。
注意要边$update$边$query$,不然是错的,因为每次加进来的$i$只能影响$[1,i]$的区间值,$[i, n]$中的数不能影响。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 35000 + 5;
int n, k, ai[MAXN], lst[MAXN], pre[MAXN], dp[55][MAXN];
#define lc (o << 1)
#define rc (o << 1 | 1)
#define M ((l + r) >> 1)
int maxv[MAXN * 4], lazy[MAXN * 4];
void pushup(int o, int l, int r) {
if (l == r) return ;
maxv[o] = max(maxv[lc], maxv[rc]);
}
void pushdown(int o, int l, int r) {
if (l == r) return ;
if (lazy[o] > 0) {
lazy[lc] += lazy[o], lazy[rc] += lazy[o];
maxv[lc] += lazy[o], maxv[rc] += lazy[o];
lazy[o] = 0;
}
}
void build(int o, int l, int r) {
if (l == r) lazy[o] = 0, maxv[o] = ai[l]; else {
lazy[o] = maxv[o] = 0;
build(lc, l, M), build(rc, M + 1, r);
pushup(o, l, r);
}
}
void update(int o, int l, int r, int x, int y, int v) {
pushdown(o, l, r);
if (x <= l && r <= y) {
lazy[o] += v, maxv[o] += v;
return ;
}
if (x <= M) update(lc, l, M, x, y, v);
if (M < y) update(rc, M + 1, r, x, y, v);
pushup(o, l, r);
}
int query(int o, int l, int r, int x, int y) {
int ret = -1;
pushdown(o, l, r);
if (x <= l && r <= y) {
return maxv[o];
}
if (x <= M) ret = max(ret, query(lc, l, M, x, y));
if (M < y) ret = max(ret, query(rc, M + 1, r, x, y));
return ret;
}
void clean() {
ms(lst, 0), ms(pre, 0), ms(dp, 0);
}
void solve() {
clean();
for (int i = 1; i <= n; i++) scanf("%d", &ai[i]);
for (int i = 1; i <= n; i++) lst[i] = pre[ai[i]], pre[ai[i]] = i;
for (int i = 1; i <= n; i++) dp[1][i] = dp[1][i - 1] + (lst[i] == 0);
for (int j = 2; j <= k; j++) {
for (int i = 1; i <= n; i++) ai[i] = dp[j - 1][i - 1];
build(1, 1, n);
for (int i = 1; i <= n; i++) {
update(1, 1, n, lst[i] + 1, i, 1);
dp[j][i] += query(1, 1, n, 1, i);
}
}
printf("%d\n", dp[k][n]);
}
int main() {
scanf("%d%d", &n, &k), solve();
return 0;
}