「Bzoj 4071」「Apio2015」巴邻旁之桥 (权值线段树维护前缀中位数 / Splay / 主席树)

BZOJ 4071
题意:$[0,10^9]$范围有$n$个区间,要求选一个点$x$或两个点$x,y$,对于任意区间$i$代价为$\min(|l_i-x|+|r_i-x|, |l_i-y|+|r_i-y|)$,求最小代价

对于只选一个点的,设选了$p$点,则
$$
ans=\sum_{i=1}^n |l_i-p|+|r_i-p|
$$
将端点视为同等的,那么答案即为
$$
ans=\sum |x-p|
$$
这是一个中位数的模型,直接取中位数是最优的。

对于选两个点的,考虑每个区间$[l,r]$会过离$\frac{l+r}{2}$最近的桥,所以我们可以将区间按照$l_i+r_i$排序,然后顺序枚举每个区间作分界点,左边右边分开处理,变成一个点的做法。

那么我们需要支持一个动态中位数。显然可以用 Splay / 主席树 求第$k$大($k=\frac{n}{2}$)
然而这里只需要求出一个前缀、后缀的中位数,那么可以考虑不用主席树而是权值线段树代替之。
在权值线段树上找$k$大,然后再分类讨论求所有点到中位数距离即可。

另外这题还满足单峰性质,那么三分套三分枚举桥位置即可。

知识点:
1、中位数的运用:多个点到某个点距离和最近
2、权值线段树 / Splay / 主席树动态中位数的方法

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#include<set>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db long double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 200000 + 10;

    struct data {
        int x, y;
        bool operator < (const data &rhs) const {return (x + y) < (rhs.x + rhs.y);}
    } itv[MAXN];

    int k, n, itv_tot, tax[MAXN], tax_tot, fl;
    LL ans, pre[MAXN];

    #define lc (o << 1)
    #define rc (o << 1 | 1)
    #define M ((l + r) >> 1)
    LL siz[MAXN * 4], sumv[MAXN * 4];
    void pushup(int o) {siz[o] = siz[lc] + siz[rc], sumv[o] = sumv[lc] + sumv[rc];}
    void build(int o, int l, int r) {
        siz[o] = sumv[o] = 0;
        if (l == r) return ; else {
            build(lc, l, M), build(rc, M + 1, r);
            pushup(o);
        } 
    }
    void update(int o, int l, int r, int p, int v) {
        if (l == r) sumv[o] += v, ++siz[o];
        else {
            if (p <= M) update(lc, l, M, p, v);
            else update(rc, M + 1, r, p, v);
            pushup(o);
        }
    }
    int querykth(int o, int l, int r, int k) {
        if (l == r) return l;
        if (k <= siz[lc]) return querykth(lc, l, M, k); // <=
        else return querykth(rc, M + 1, r, k - siz[lc]);
    }
    LL queryitv_sum(int o, int l, int r, int x, int y) {
        if (x <= l && r <= y) {
            return sumv[o];
        }
        LL ret = 0;
        if (x <= M) ret += queryitv_sum(lc, l, M, x, y);
        if (M < y)  ret += queryitv_sum(rc, M + 1, r, x, y);
        return ret;
    }
    LL queryitv_N(int o, int l, int r, int x, int y) {
        if (x <= l && r <= y) {
            return siz[o];
        }
        LL ret = 0;
        if (x <= M) ret += queryitv_N(lc, l, M, x, y);
        if (M < y)  ret += queryitv_N(rc, M + 1, r, x, y);
        return ret;
    }

    void clean() {
        fl = tax_tot = itv_tot = ans = 0;
    }
    int solve() {

        clean();
        scanf("%d%d", &k, &n);
        char sa[5], sb[5];
        for (int i = 1; i <= n; ++i) {
            int s, t;
            scanf("%s%d%s%d", sa, &s, sb, &t);
            if (s > t) swap(s, t);
            if (sa[0] == sb[0]) {
                ans += t - s;
                continue ;
            }
            ++ans, fl = 1;
            itv[++itv_tot] = (data){s, t};
            tax[++tax_tot] = s;
            tax[++tax_tot] = t;
        }

        if (!fl) return printf("%lld\n", ans), 0; // 特判 

        sort(tax + 1, tax + 1 + tax_tot), tax_tot = unique(tax + 1, tax + 1 + tax_tot) - tax - 1;

        sort(itv + 1, itv + 1 + itv_tot);

        for (int i = 1; i <= itv_tot; ++i) 
            itv[i].x = lower_bound(tax + 1, tax + 1 + tax_tot, itv[i].x) - tax, itv[i].y = lower_bound(tax + 1, tax + 1 + tax_tot, itv[i].y) - tax;

        build(1, 1, tax_tot);
        for (int i = 1; i <= itv_tot; ++i) {
            update(1, 1, tax_tot, itv[i].x, tax[itv[i].x]);
            update(1, 1, tax_tot, itv[i].y, tax[itv[i].y]);
            int p = querykth(1, 1, tax_tot, i);
            LL S1 = queryitv_sum(1, 1, tax_tot, 1, p), N1 = queryitv_N(1, 1, tax_tot, 1, p);
            LL S2 = queryitv_sum(1, 1, tax_tot, p + 1, tax_tot), N2 = queryitv_N(1, 1, tax_tot, p + 1, tax_tot);
            pre[i] = tax[p] * N1 - S1 + S2 - tax[p] * N2; // tax[p]
        }

        if (k == 1) return printf("%lld\n", pre[itv_tot] + ans), 0;

        build(1, 1, tax_tot);
        LL whw = pre[itv_tot];
        for (int i = itv_tot; i >= 1; --i) {
            update(1, 1, tax_tot, itv[i].x, tax[itv[i].x]);
            update(1, 1, tax_tot, itv[i].y, tax[itv[i].y]);
            int p = querykth(1, 1, tax_tot, itv_tot - i + 1); // 注意 
            LL S1 = queryitv_sum(1, 1, tax_tot, 1, p), N1 = queryitv_N(1, 1, tax_tot, 1, p);
            LL S2 = queryitv_sum(1, 1, tax_tot, p + 1, tax_tot), N2 = queryitv_N(1, 1, tax_tot, p + 1, tax_tot);
            whw = min(whw, pre[i - 1] + tax[p] * N1 - S1 + S2 - tax[p] * N2);
        }

        return printf("%lld\n", whw + ans), 0;


        return 0;
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------