「NOIP2012 Day2 T3」 疫情控制 (树上倍增+二分+贪心)

Luogu 1084
题意:见上。

可以发现答案具有二分性质,所以我们二分一个答案$mid$来判断。
只要保证每个移动长度不超过$mid$即可。
那么我们可以将一些不能到达根子节点的军队直接能爬上来多少就爬上来多少,在爬上来的位置标记,不能爬到根。
爬完之后,对树进行一次 DFS,更新根节点子节点下的子树是否全部被封锁。
然后将其他军队爬到根节点的子节点,这些军队分成两种:
第一种是剩余距离不足以经过根再回去原来的位置的军队
第二种是除了这些的其他军队

对于第一种军队所在的位置,一定是这个子树中的军队占领。
因为如果是其他子树的军队来占领,不那么优,因为第一种军队去占领其他子树的军队,那个占领这个位置的军队绝对可以占领那个位置,证毕。

所以我们可以留下一个剩余距离最小的第一种军队来占领这颗子树,然后其他军队通过贪心的方法,用剩余距离尽量小的军队取占领距离小的位置。

注意 $chk$ 实现,找了一年BUG…

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<queue>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db long double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 50000 + 5, LOGS = 20;
    const LL INF = 1e18 + 5;

    struct edge {
        int v, w, nxt;
    } ed[MAXN * 2];
    struct data {
        int u, bl, fl1;
        LL rest;
    }am[MAXN], tmpam[MAXN];

    int n, m, en, hd[MAXN];
    int pre[MAXN][LOGS + 5], dep[MAXN], val[MAXN];
    LL d[MAXN][LOGS + 5], dis[MAXN];

    void ins(int u, int v, int w) {ed[++en] = (edge){v, w, hd[u]}, hd[u] = en;}
    bool cmp(data a, data b) {return a.rest < b.rest;}

    void dfs_pre(int u, int fa) { // 预处理 
        dep[u] = dep[fa] + 1, pre[u][0] = fa;
        for (int i = 1; i <= LOGS; ++i) {
            pre[u][i] = pre[pre[u][i - 1]][i - 1];
            d[u][i] = d[u][i - 1] + d[pre[u][i - 1]][i - 1];
        }
        for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa) {
                d[e.v][0] = e.w, dis[e.v] = dis[u] + e.w;
                if (u == 1) val[e.v] = e.w;
                dfs_pre(e.v, u);
            }
        }
    }
    int pa_pre(int u) { // 找每个军队属于哪个 s 
        for (LL i = LOGS; i >= 0; --i) if (pre[u][i] && pre[u][i] != 1) {
            u = pre[u][i]; 
        }
        return u;
    }

    int vis[MAXN], ok[MAXN], lst[MAXN];
    LL whw[MAXN], H[MAXN];
    // vis[n]=n节点是否打标记, ok[n]=n节点是否覆盖, whw[n]=n点的最小rest,lst[n]=n点的最小rest是哪个m 

    void pa(int u, LL mid) { // 能爬多少爬多少 
        LL tot = 0;
        for (LL i = LOGS; i >= 0; --i) if (pre[u][i] && pre[u][i] != 1) {
            if (tot + d[u][i] <= mid) u = pre[u][i], tot += d[u][i]; 
        }
        vis[u] = 1;
    }
    bool dfs_ok(int u, int fa) { // 找没有标记的 
        if (vis[u] && u != 1) return ok[u] = 1, 1;
        int fl = false;
        for (int i = hd[u]; i >= 0; i = ed[i].nxt) {
            edge &e = ed[i];
            if (e.v != fa) {
                if (fl) ok[u] &= dfs_ok(e.v, u); // 
                else fl = true, ok[u] = dfs_ok(e.v, u);
            }
        }
        return ok[u];
    }
    bool chk(LL mid) {
        for (int u = 1; u <= n; ++u) vis[u] = 0, ok[u] = 0, whw[u] = INF, lst[u] = 0;
        for (int i = 1; i <= m; ++i) am[i].fl1 = 0;
        for (int i = 1; i <= m; ++i) if (dis[am[i].u] > mid) pa(am[i].u, mid), am[i].fl1 = 1;// dis[am[i].u] > mid
        dfs_ok(1, 0);
        for (int i = 1; i <= m; ++i) if (dis[am[i].u] <= mid) {
            LL rest = mid - dis[am[i].u];
            int pos = am[i].bl;
            am[i].rest = rest;
            if (val[pos] > rest && !ok[pos]) { // !ok[pos]
                if (whw[pos] <= rest) continue ;
                am[lst[pos]].fl1--, am[i].fl1++;
                lst[pos] = i;
                whw[pos] = rest; // 
            }
        }
        for (int i = 1; i <= n; ++i) ok[am[lst[i]].bl] = 1;
        int tot1 = 0;
        for (int i = 1; i <= m; ++i) if (!am[i].fl1) tmpam[++tot1] = am[i];
        int tot2 = 0;
        for (int i = hd[1]; i >= 0; i = ed[i].nxt) {
            int v = ed[i].v;
            if (!ok[v]) H[++tot2] = val[v];
        }
        sort(tmpam + 1, tmpam + 1 + tot1, cmp);
        sort(H + 1, H + 1 + tot2);
        int hh = 1, i;
        if (tot1 < tot2) return false;
        for (i = 1; i <= tot2; ++i) {
            while (hh <= tot1 && tmpam[hh].rest < H[i]) ++hh;
            if (hh > tot1) return false;
            ++hh;
        }
        return true;
    }

    void clean() {
        en = -1, ms(hd, -1), ms(pre, 0);
    }
    int solve() {

        clean();

        scanf("%d", &n);
        for (int u, v, w, i = 1; i <= n - 1; ++i) {
            scanf("%d%d%d", &u, &v, &w);
            ins(u, v, w), ins(v, u, w);
        }
        scanf("%d", &m);
        for (int i = 1; i <= m; ++i) scanf("%d", &am[i].u);

        dep[0] = 0, dfs_pre(1, 0);
        for (int i = 1; i <= m; ++i) am[i].bl = pa_pre(am[i].u);

        LL l = 0, r = 1e18, ans = -1;
        //LL l = 9, r = 10, ans = -1;
        //LL l = 13, r = 14, ans = -1;
        while (l < r) {
            LL mid = (l + r) >> 1;
            if (chk(mid)) 
                ans = mid, r = mid; 
            else 
                l = mid + 1;
        }

        printf("%lld\n", ans);

        return 0;
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
/*
2
1 2 5
1
2

5
2 1 2
3 2 7
4 3 1
5 3 0
2
4 3 

5
2 1 6
3 2 9
4 3 3
5 3 3
2
4 3 

5
2 1 3
3 2 6
4 1 0
5 4 8
2
2 3 

5
2 1 5
3 2 9
4 3 5
5 2 4
2
4 4 

5
2 1 5
3 1 3
4 2 7
5 4 8
2
4 2 

5
2 1 10
3 2 1
4 1 3
5 2 6
2
2 3 

5
2 1 7
3 2 10
4 1 5
5 4 9
2
4 3 

5
2 1 4
3 1 5
4 3 2
5 4 10
2
4 3 
*/
------ 本文结束 ------