BZOJ 1858
题意:一个$01$序列,序列里面包含了$n$个数,这些数要么是$0$,要么是$1$,现在对于这个序列有五种变换操作和询问操作:
- $0 a b$ 把$[a, b]$区间内的所有数全变成$0$
- $1 a b$ 把$[a, b]$区间内的所有数全变成$1$
- $2 a b$ 把$[a, b]$区间内的所有数全部取反,也就是说把所有的$0$变成$1$,把所有的$1$变成$0$
- $3 a b$ 询问$[a, b]$区间内总共有多少个$1$
- $4 a b$ 询问$[a, b]$区间内最多有多少个连续的$1$
类似CF 817F,都是要维护$01$序列的区间修改和翻转。
这里更为复杂。
维护
$sum$:区间和
$lsum[0/1]$:左起区间最长$0$串,最长$1$串
$rsum[0/1]$:右起区间最长$0$串,最长$1$串
$maxsum[0/1]$:区间最长$0$串,最长$1$串
$upd$:区间修改$lazy$标记
$rev$:区间翻转$lazy$标记
翻转区间时注意$lsum[0]$和$lsum[1]$互换,其他同理。显然正确。
对于询问$4$的处理,我们分成三种情况讨论
1、在左子树
2、在右子树
3、在中间
注意第三种情况要满足限制$[x,y]$,所以在求值时加上限制。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<queue>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db long double
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace flyinthesky {
const int MAXN = 100000 + 5;
int n, q, a[MAXN];
#define M ((l + r) >> 1)
#define lc (o << 1)
#define rc (o << 1 | 1)
int sumv[MAXN * 4], upd[MAXN * 4], rev[MAXN * 4];
int lsum[MAXN * 4][2], rsum[MAXN * 4][2];
int maxsum[MAXN * 4][2];
void pushup(int o, int len) {
sumv[o] = sumv[lc] + sumv[rc];
for (int i = 0; i <= 1; ++i) {
lsum[o][i] = lsum[lc][i], rsum[o][i] = rsum[rc][i];
if (lsum[lc][i] == len - len / 2) lsum[o][i] += lsum[rc][i];
if (rsum[rc][i] == len / 2) rsum[o][i] += rsum[lc][i];
maxsum[o][i] = max(maxsum[lc][i], maxsum[rc][i]);
maxsum[o][i] = max(maxsum[o][i], rsum[lc][i] + lsum[rc][i]);
}
}
void pushdown(int o, int len) {
if (len == 1) return ;
if (upd[o] != -1) {
upd[lc] = upd[rc] = upd[o];
sumv[lc] = (len - len / 2) * upd[lc], sumv[rc] = (len / 2) * upd[rc];
if (upd[o] == 0) {
maxsum[lc][0] = lsum[lc][0] = rsum[lc][0] = (len - len / 2);
maxsum[lc][1] = lsum[lc][1] = rsum[lc][1] = 0;
} else {
maxsum[lc][0] = lsum[lc][0] = rsum[lc][0] = 0;
maxsum[lc][1] = lsum[lc][1] = rsum[lc][1] = (len - len / 2);
}
if (upd[o] == 0) {
maxsum[rc][0] = lsum[rc][0] = rsum[rc][0] = len / 2;
maxsum[rc][1] = lsum[rc][1] = rsum[rc][1] = 0;
} else {
maxsum[rc][0] = lsum[rc][0] = rsum[rc][0] = 0;
maxsum[rc][1] = lsum[rc][1] = rsum[rc][1] = len / 2;
}
upd[o] = -1, rev[o] = 0;
}
if (rev[o]) {
if (upd[lc] != -1) {
upd[lc] ^= 1;
sumv[lc] = upd[lc] * (len - len / 2);
if (upd[lc] == 0) {
maxsum[lc][0] = lsum[lc][0] = rsum[lc][0] = (len - len / 2);
maxsum[lc][1] = lsum[lc][1] = rsum[lc][1] = 0;
} else {
maxsum[lc][0] = lsum[lc][0] = rsum[lc][0] = 0;
maxsum[lc][1] = lsum[lc][1] = rsum[lc][1] = (len - len / 2);
}
} else {
rev[lc] ^= 1;
sumv[lc] = (len - len / 2) - sumv[lc];
swap(lsum[lc][0], lsum[lc][1]);
swap(rsum[lc][0], rsum[lc][1]);
swap(maxsum[lc][0], maxsum[lc][1]);
}
if (upd[rc] != -1) {
upd[rc] ^= 1;
sumv[rc] = upd[rc] * (len / 2);
if (upd[rc] == 0) {
maxsum[rc][0] = lsum[rc][0] = rsum[rc][0] = (len / 2);
maxsum[rc][1] = lsum[rc][1] = rsum[rc][1] = 0;
} else {
maxsum[rc][0] = lsum[rc][0] = rsum[rc][0] = 0;
maxsum[rc][1] = lsum[rc][1] = rsum[rc][1] = (len / 2);
}
} else {
rev[rc] ^= 1;
sumv[rc] = (len / 2) - sumv[rc];
swap(lsum[rc][0], lsum[rc][1]);
swap(rsum[rc][0], rsum[rc][1]);
swap(maxsum[rc][0], maxsum[rc][1]);
}
rev[o] = 0;
}
}
void build(int o, int l, int r) {
upd[o] = -1, rev[o] = 0;
if (l == r) {
sumv[o] = a[l];
if (a[l] == 0) {
maxsum[o][0] = lsum[o][0] = rsum[o][0] = 1;
maxsum[o][1] = lsum[o][1] = rsum[o][1] = 0;
} else {
maxsum[o][0] = lsum[o][0] = rsum[o][0] = 0;
maxsum[o][1] = lsum[o][1] = rsum[o][1] = 1;
}
} else {
build(lc, l, M), build(rc, M + 1, r);
pushup(o, r - l + 1);
}
}
void update(int o, int l, int r, int x, int y, int v) {
pushdown(o, r - l + 1);
if (x <= l && r <= y) {
upd[o] = v, rev[o] = 0;
sumv[o] = (r - l + 1) * upd[o];
if (v == 0) {
maxsum[o][0] = lsum[o][0] = rsum[o][0] = r - l + 1;
maxsum[o][1] = lsum[o][1] = rsum[o][1] = 0;
} else {
maxsum[o][0] = lsum[o][0] = rsum[o][0] = 0;
maxsum[o][1] = lsum[o][1] = rsum[o][1] = r - l + 1;
}
return ;
}
if (x <= M) update(lc, l, M, x, y, v);
if (M < y) update(rc, M + 1, r, x, y, v);
pushup(o, r - l + 1);
}
void reverse(int o, int l, int r, int x, int y) {
pushdown(o, r - l + 1);
if (x <= l && r <= y) {
if (upd[o] != -1) {
upd[o] ^= 1;
sumv[o] = upd[o] * (r - l + 1);
if (upd[o] == 0) {
maxsum[o][0] = lsum[o][0] = rsum[o][0] = r - l + 1;
maxsum[o][1] = lsum[o][1] = rsum[o][1] = 0;
} else {
maxsum[o][0] = lsum[o][0] = rsum[o][0] = 0;
maxsum[o][1] = lsum[o][1] = rsum[o][1] = r - l + 1;
}
} else {
rev[o] ^= 1;
sumv[o] = r - l + 1 - sumv[o];
swap(lsum[o][0], lsum[o][1]);
swap(rsum[o][0], rsum[o][1]);
swap(maxsum[o][0], maxsum[o][1]);
}
return ;
}
if (x <= M) reverse(lc, l, M, x, y);
if (M < y) reverse(rc, M + 1, r, x, y);
pushup(o, r - l + 1);
}
int query_sum(int o, int l, int r, int x, int y) {
pushdown(o, r - l + 1);
if (x <= l && r <= y) {
return sumv[o];
}
int ret = 0;
if (x <= M) ret += query_sum(lc, l, M, x, y);
if (M < y) ret += query_sum(rc, M + 1, r, x, y);
return ret;
}
int query_max(int o, int l, int r, int x, int y) {
pushdown(o, r - l + 1);
if (x <= l && r <= y) return maxsum[o][1];
int ret = min(M - x + 1, rsum[lc][1]) + min(y - M, lsum[rc][1]);
if (x <= M) ret = max(ret, query_max(lc, l, M, x, y));
if (M < y) ret = max(ret, query_max(rc, M + 1, r, x, y));
return ret;
}
void clean() {
}
int solve() {
clean();
cin >> n >> q;
for (int i = 0; i < n; ++i) scanf("%d", &a[i]);
build(1, 0, n - 1);
while (q--) {
int tp, x, y;
scanf("%d%d%d", &tp, &x, &y);
if (tp == 0)
update(1, 0, n - 1, x, y, 0);
if (tp == 1)
update(1, 0, n - 1, x, y, 1);
if (tp == 2)
reverse(1, 0, n - 1, x, y);
if (tp == 3)
printf("%d\n", query_sum(1, 0, n - 1, x, y));
if (tp == 4)
printf("%d\n", query_max(1, 0, n - 1, x, y));
}
return 0;
}
}
int main() {
flyinthesky::solve();
return 0;
}