Codeforces 609E(最小生成树+倍增)

Codeforces 609E
题意:询问一个图包含某条边的最小生成树。对每条边进行询问。

先求出最小生成树,然后对于在最小生成树的边直接输出最小生成树的最优值,否则,考虑将这条边$(u, v)$加入最小生成树,则必然会产生环,其中环为最小生成树上$u$到$v$以及边$(u, v)$,我们要在$u$到$v$找权最大的边删除。使用倍增就能够完成。用 LCA 爬即可。

本题思路与次小生成树思路类似。

知识点:
1、记得开LL
2、分段调试后一定要静态查错

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const LL MAXN = 200000 + 5, LOGS = 20;

LL dep[MAXN], pre[MAXN][LOGS + 5], mks[MAXN][LOGS + 5]; 
LL n, m, en2, f[MAXN], ve[MAXN];
vector<LL> G2[MAXN];

struct edge {
    LL u, v, w, no;
    bool operator < (const edge &b) const {
        return w < b.w;
    }
}ed[MAXN];
struct edge2 {
    LL v, w;
}ed2[MAXN * 2];

bool cmp(edge a, edge b) {return a.no < b.no;}
LL find(LL x) {return x == f[x] ? x : f[x] = find(f[x]);}
inline void ins(LL x, LL y, LL w) {
    ed2[++en2] = (edge2){y, w}, G2[x].push_back(en2);
    ed2[++en2] = (edge2){x, w}, G2[y].push_back(en2);
}

void dfs(LL u, LL fa, LL w) {
    pre[u][0] = fa, dep[u] = dep[fa] + 1;
    for (LL i = 1; i <= LOGS; i++) pre[u][i] = pre[pre[u][i - 1]][i - 1];
    mks[u][0] = w;
    for (LL i = 1; i <= LOGS; i++) mks[u][i] = max(mks[u][i - 1], mks[pre[u][i - 1]][i - 1]);
    for (LL i = 0; i < (LL)G2[u].size(); i++) {
        edge2 &e = ed2[G2[u][i]];
        if (e.v != fa) dfs(e.v, u, e.w);
    }
}
LL lca(LL a, LL b) {
    if (dep[a] < dep[b]) swap(a, b);
    for (LL i = LOGS; i >= 0; i--) if (dep[pre[a][i]] >= dep[b]) a = pre[a][i];
    if (a == b) return a;
    for (LL i = LOGS; i >= 0; i--) if (pre[a][i] != pre[b][i]) a = pre[a][i], b = pre[b][i];
    return pre[a][0];
}
LL getMax(LL u, LL v) {
    LL LCA = lca(u, v), ret = 0;
    for (LL i = LOGS; i >= 0; i--) if (dep[pre[u][i]] >= dep[LCA]) ret = max(mks[u][i], ret), u = pre[u][i];
    for (LL i = LOGS; i >= 0; i--) if (dep[pre[v][i]] >= dep[LCA]) ret = max(mks[v][i], ret), v = pre[v][i];
    return ret;
}
void clean() {
    ms(dep, 0), ms(mks, 0), ms(ve, 0), en2 = 0;
}
int solve() {
    clean();
    for (LL i = 1; i <= n; i++) f[i] = i;
    for (LL i = 1; i <= m; i++) scanf("%I64d%I64d%I64d", &ed[i].u, &ed[i].v, &ed[i].w), ed[i].no = i;
    sort(ed + 1, ed + 1 + m);
    LL tot = 0, ans = 0;
    for (LL i = 1; i <= m; i++) {
        LL x = find(ed[i].u), y = find(ed[i].v);
        if (x != y) f[x] = y, tot++, ans += ed[i].w, ins(ed[i].u, ed[i].v, ed[i].w), ve[ed[i].no] = 1;
        if (tot >= n - 1) break;
    }
    dfs(1, 0, 0);
    sort(ed + 1, ed + 1 + m, cmp);
    for (LL i = 1; i <= m; i++) {
        if (ve[i]) printf("%I64d\n", ans); else {
            LL whw = getMax(ed[i].u, ed[i].v);
            printf("%I64d\n", ans + ed[i].w - whw);
        }
    }
    return 0; 
}
int main() {
    scanf("%I64d%I64d", &n, &m), solve();
    return 0;
}
------ 本文结束 ------