Codeforces 1095F (贪心 + 并查集)

Codeforces 1095F
题意:$n$点无向图最开始没有边,加一条边的代价为$A_x+A_y$, $A_i$为点权,现在有$m$个特殊方案$(x,y,w)$使得$x,y$连通,花费$w$。请问使图连通最小费用。

最开始我的思路:
考虑没有特殊方案,则就是个合并果子,必须保证每次都将两个连通块,并查集维护。
如果有特殊方案,那么每次操作要将两个连通块并起来。将方案按价值增序,那么堆中取出两个元素后和当前方案对比哪个方案连通最优。

其实更容易的:
考虑没有特殊方案,可以发现所有点都会连在一个最小价值的点上。
如果有,将所有点都会连在一个最小价值的点的操作一起并到特殊方案里排序取最优。

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<queue>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky { 

    const int MAXN = 200000 + 5;

    struct node {
        int u;
        LL w;
        bool operator < (const node &rhs) const {return w > rhs.w;}
    };
    struct data {
        int u, v;
        LL w;
        bool operator < (const data &rhs) const {return w < rhs.w;}
    } xw[MAXN];

    int n, m, f[MAXN];
    LL a[MAXN], ans;

    priority_queue<node > q;

    int find(int x) {return x == f[x] ? x : f[x] = find(f[x]);}

    void clean() {
        ans = 0ll;
    }
    int solve() {
        clean();
        cin >> n >> m;
        if (n == 1) return printf("0\n"), 0;
        for (int i = 1; i <= n; ++i) scanf("%I64d", &a[i]), q.push((node){i, a[i]}), f[i] = i; 
        for (int i = 1; i <= m; ++i) {
            scanf("%d%d%I64d", &xw[i].u, &xw[i].v, &xw[i].w); 
        }
        sort(xw + 1, xw + 1 + m);
        int xwcur = 1;
        for (int i = 1; i <= n - 1 && !q.empty(); ++i) {
            node p1, p2;
            if (!q.empty()) p1 = q.top(), q.pop();
            while (find(p1.u) != p1.u && !q.empty()) 
                p1 = q.top(), q.pop();
            if (!q.empty()) p2 = q.top(), q.pop();
            while ((find(p2.u) != p2.u || find(p2.u) == find(p1.u)) && !q.empty()) 
                p2 = q.top(), q.pop();
            while (xwcur <= m && find(xw[xwcur].u) == find(xw[xwcur].v)) ++xwcur;
            if (xwcur <= m && xw[xwcur].w < p1.w + p2.w) {
                int x = find(xw[xwcur].u), y = find(xw[xwcur].v);
                ans += xw[xwcur].w;
                if (a[x] < a[y]) {
                    f[y] = x;
                    q.push((node){x, a[x]});
                } else {
                    f[x] = y;
                    q.push((node){y, a[y]});
                }
                q.push(p1), q.push(p2);
            } else {
                int x = find(p1.u), y = find(p2.u);
                ans += p1.w + p2.w;
                f[y] = x;
                q.push((node){x, a[x]});
            }
        }
        cout << ans;
        return 0; 
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------