「Bzoj 1016」「JSOI2008」最小生成树计数 (最小生成树+乘法原理)

Bzoj 1016
先看定理(转自discuss, 证明看原文):
定理一:如果 $A, B$ 同为 $G$ 的最小生成树,且 $A$ 的边权从小到大为 $w(a_1), w(a_2), w(a_3), \cdots w(a_n)$,$B$ 的边权从小到大为 $w(b_1), w(b_2), w(b_3), \cdots w(b_n)$,则有 $w(a_i) = w(b_i)$。
定理二:如果 $A, B$ 同为 $G$ 的最小生成树,如果 $A, B$ 都从零开始从小到大加边($A$ 加 $A$ 的边,$B$ 加 $B$ 的边)的话,每种权值加完后图的联通性相同。
定理三:如果在最小生成树 $A$ 中权值为 $v$ 的边有 $k$ 条,用任意 $k$ 条权值为 $v$ 的边替换 $A$ 中的权为 $v$ 的边且不产生环的方案都是一棵合法最小生成树。

那么这题我们先做一次最小生成树,然后统计各种权值用的边数(用权值分组),然后每组再dfs找出相应边数的边使得每一条边都可以使图的连通分量减少。然后每组的方案再乘法原理求。一个组的搜完了,还要把这个权值的边都连上,再搜下一次。

注意这时并查集不要路径压缩,因为dfs要回溯,此时并查集复杂度为$O(nlogn)$。
还有图不连通的情况,这种情况不存在最小生成树,然后注意这题还要取模

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 100 + 5, MAXM = 1000 + 5, MO = 31011;
struct data {
    int x, y, w;
    bool operator < (const data &b) const {
        return w < b.w;
    }
}e[MAXM], a[MAXM];
int cnt, n, m, f[MAXN];
LL sum;
int find(int x) {return f[x] == x ? x : find(f[x]);}//不要路径压缩 
void dfs(int tt, int a, int ed, int ys) {
    if (ys == tt) {
        sum++; 
        return ;
    }
    if (a > ed) return ;
    int x = find(e[a].x), y = find(e[a].y);
    if (x != y) {
        f[x] = y;
        dfs(tt, a + 1, ed, ys + 1);
        f[x] = x, f[y] = y;//并查集恢复 
    }
    dfs(tt, a + 1, ed, ys);
}
void clean() {
    cnt = 0;
    for (int i=1;i<=n;i++) f[i] = i;
}
void solve() {
    clean(); 
    for (int i=1;i<=m;i++) scanf("%d%d%d", &e[i].x, &e[i].y, &e[i].w);
    sort(e + 1, e + 1 + m);
    int tot = 0;
    for (int i=1;i<=m;i++) {
        if (e[i].w != e[i - 1].w) a[++cnt].x = i, a[cnt - 1].y = i - 1;
        int x = find(e[i].x), y = find(e[i].y);
        if (x != y) f[x] = y, tot++, a[cnt].w++;
    }
    a[cnt].y = m;
    if (tot != n - 1) {printf("0\n"); return ;}//图不连通,不构成最小生成树
    for (int i=1;i<=n;i++) f[i] = i;
    LL ans = 1;
    for (int i=1;i<=cnt;i++) {
        sum = 0;
        dfs(a[i].w, a[i].x, a[i].y, 0);
        ans = (ans * sum % MO) % MO;
        for (int j=a[i].x;j<=a[i].y;j++) {
            int x = find(e[j].x), y = find(e[j].y);
            if (x != y) f[x] = y;
        }
    }
    printf("%lld\n", ans);
}
int main() {
    #ifndef ONLINE_JUDGE 
    freopen("1.in", "r", stdin);freopen("1.out", "w", stdout);
    #endif
    scanf("%d%d", &n, &m), solve();
    return 0;
}
------ 本文结束 ------