Bzoj 2243(树链剖分+线段树)

BZOJ 2243
这题就是caioj 1102树上版。
所以我们树剖然后线段树维护

  • $lcol$:区间左端点的颜色
  • $rcol$:区间右端点的颜色
  • $cnt$:区间线段的条数

然后合并区间的时候如果中间颜色相同,要$cnt-1$
查询同理,如果左右区间都更新了,则要判断中间颜色

之后树剖的合并就比较麻烦。要分两边来存上一个区间端点颜色是什么,然后最后$u$到$v$的修改也要讨论,这个仔细想想就行,但是容易写错,建议画一下,详情看代码,不太好讲

注意树往深的方向是右方向,浅的方向是左方向,而且$pushdown$记得子节点要传$lazy$

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
using namespace std;
const int MAXN = 100000 + 5;
int n, Q, col[MAXN], dep[MAXN], fa[MAXN], son[MAXN], siz[MAXN], top[MAXN], p[MAXN], np[MAXN], tb;
vector<int> G[MAXN];
void dfs1(int u, int pa) {
    dep[u] = dep[pa] + 1, fa[u] = pa, siz[u] = 1;
    for (int i = 0; i < (int)G[u].size(); i++) {
        int v = G[u][i];
        if (v != pa) {
            dfs1(v, u);
            siz[u] += siz[v];
            if (son[u] == -1 || siz[v] > siz[son[u]]) son[u] = v;
        }
    }
}
void dfs2(int u, int chain) {
    top[u] = chain;
    p[u] = ++tb, np[p[u]] = u;
    if (son[u] != -1) {
        dfs2(son[u], chain);
        for (int i = 0; i < (int)G[u].size(); i++) {
            int v = G[u][i];
            if (v != fa[u] && v != son[u]) {
                dfs2(v, v);
            }
        }
    }
}
#define lc (o << 1)
#define rc (o << 1 | 1)
#define M ((l + r) >> 1)
int lcol[MAXN * 4], rcol[MAXN * 4], cnt[MAXN * 4], lazy[MAXN * 4];
void pushup(int o, int l, int r) {
    if (l == r) return ;
    lcol[o] = lcol[lc], rcol[o] = rcol[rc], cnt[o] = cnt[rc] + cnt[lc];
    if (rcol[lc] == lcol[rc]) cnt[o]--;
}
void pushdown(int o, int l, int r) {
    if (l == r) return ;
    if (lazy[o]) {
        lazy[lc] = lazy[rc] = lazy[o];
        lcol[lc] = lcol[rc] = lazy[o];
        rcol[lc] = rcol[rc] = lazy[o];
        cnt[lc] = cnt[rc] = 1;
        lazy[o] = 0;
    }
}
void build(int o, int l, int r) {
    if (l == r) lcol[o] = rcol[o] = col[np[l]], cnt[o] = 1; else {
        build(lc, l, M), build(rc, M + 1, r);
        pushup(o, l, r);
    }
}
void update(int o, int l, int r, int x, int y, int v) {
    pushdown(o, l, r);
    if (x <= l && r <= y) {
        lazy[o] = v;
        lcol[o] = rcol[o] = v, cnt[o] = 1;
        return ;
    }
    if (x <= M) update(lc, l, M, x, y, v);
    if (M < y) update(rc, M + 1, r, x, y, v);
    pushup(o, l, r);
}
int nowl, nowr;
int query(int o, int l, int r, int x, int y) {
    pushdown(o, l, r);
    int ret = 0, flag = false;
    if (x <= l && r <= y) {
        if (l == x) nowl = lcol[o];
        if (r == y) nowr = rcol[o];
        return cnt[o];
    }
    if (x <= M) ret += query(lc, l, M, x, y), flag = true;
    if (M < y) {
        ret += query(rc, M + 1, r, x, y);
        if (flag && rcol[lc] == lcol[rc]) ret--;
    }
    return ret;
}
void change(int u, int v, int c) {
    int f1 = top[u], f2 = top[v];
    while (f1 != f2) {
        if (dep[f1] < dep[f2]) swap(f1, f2), swap(u, v);
        update(1, 1, n, p[f1], p[u], c);
        u = fa[f1], f1 = top[u];
    }
    if (dep[u] < dep[v]) swap(u, v);
    update(1, 1, n, p[v], p[u], c);
}
int find(int u, int v) {
    int ans = 0, f1 = top[u], f2 = top[v], last1l = -1, last1r = -1, last2l = -1, last2r = -1, isu = 1;
    while (f1 != f2) {
        if (dep[f1] < dep[f2]) swap(f1, f2), swap(u, v), swap(last1l, last2l), swap(last1r, last2r), isu = !isu;

        nowl = nowr = -1;
        int ret = query(1, 1, n, p[f1], p[u]);
        if (last1l != -1 && nowr == last1l) ret--;
        last1l = nowl, last1r = nowr, ans += ret;

        u = fa[f1], f1 = top[u];
    }
    if (dep[u] < dep[v]) swap(u, v), swap(last1l, last2l), swap(last1r, last2r), isu = !isu;

    nowl = nowr = -1;
    int ret = query(1, 1, n, p[v], p[u]);
    if (last1l != -1 && nowr == last1l) ret--;
    if (last2l != -1 && nowl == last2l) ret--;
    ans += ret;

    return ans;
}
void clean() {
    tb = 0;
    for (int i = 0; i <= n * 4; i++) cnt[i] = lazy[i] = lcol[i] = rcol[i] = 0;
    for (int i = 0; i <= n; i++) G[i].clear(), top[i] = p[i] = np[i] = dep[i] = fa[i] = 0, son[i] = -1;
}
void solve() {
    clean();
    for (int i = 1; i <= n; i++) scanf("%d", &col[i]);
    for (int x, y, i = 1; i < n; i++) {
        scanf("%d%d", &x, &y);
        G[x].push_back(y), G[y].push_back(x);
    }
    dfs1(1, 0), dfs2(1, 1);
    build(1, 1, n);
    char s[10];
    while (Q--) {
        scanf("%s", s);
        if (s[0] == 'C') {
            int a, b, c;
            scanf("%d%d%d", &a, &b, &c);
            change(a, b, c);
        } else {
            int a, b;
            scanf("%d%d", &a, &b);
            printf("%d\n", find(a, b));
        }
    }
}
int main() {
    #ifndef ONLINE_JUDGE 
    freopen("1.in", "r", stdin);freopen("1.out", "w", stdout);
    #endif
    scanf("%d%d", &n, &Q), solve();
    return 0;
}
/*
20 3
1 2 2 2 1 1 2 2 1 1 2 2 2 2 2 2 1 2 1 2 
2 1
3 1
4 3
5 3
6 2
7 1
8 7
9 6
10 2
11 8
12 10
13 11
14 7
15 8
16 4
17 13
18 11
19 7
20 11
C 6 13 2
C 7 15 1
Q 12 13
*/
------ 本文结束 ------