「Loj 2652」「Code+1」大吉大利,晚上吃鸡! (最短路+最短路图DP+Bitset)

Loj 2652
题意:见上。

设$g(i,j)$为$i \to j$的路径条数,那么设$F(i)=g(S,i) \cdot g(i,T)$
可以发现第一个条件等价于$F(A)+F(B)=F(T)$的$A,B$
第二个条件即$A,B$不为前驱后缀关系

那么第一个条件用一个map优化,第二个条件我们可以求出最短路图,然后传递闭包,bitset优化

第一个条件和第二个条件的可能点的交集即为答案。

知识点:
1、Bitset空间

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<queue>
#include<bitset>
#include<map>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 50000 + 5;
    const LL INF = 0x3f3f3f3f3f3f3f3f;

    int n, m, S, T;
    int en_o, hd_o[MAXN], vis_o[MAXN];
    int en_z, hd_z[MAXN];
    int en_f, hd_f[MAXN];
    LL dis[MAXN][2], g[MAXN][2], f[MAXN];
    bitset<50001 > bs[50001][2];
    map<LL, bitset<50001 > > ma;

    struct edge {int v, w, nxt;} ed_o[MAXN * 2], ed_z[MAXN * 2], ed_f[MAXN * 2];
    struct data {
        int u;
        LL dis;
        bool operator < (const data &rhs) const {
            return dis > rhs.dis;
        }
    };

    priority_queue<data > q;

    void ins_o(int u, int v, int w) {ed_o[++en_o] = (edge){v, w, hd_o[u]}, hd_o[u] = en_o;}
    void ins_z(int u, int v, int w) {ed_z[++en_z] = (edge){v, w, hd_z[u]}, hd_z[u] = en_z;} 
    void ins_f(int u, int v, int w) {ed_f[++en_f] = (edge){v, w, hd_f[u]}, hd_f[u] = en_f;} 

    void prbs(bitset<50001 > b) {for (int i = 0; i < n; ++i) printf("%d ", (int)b[i]); puts("\n");}

    void dij(int op) {
        ms(vis_o, 0);
        if (op == 0)
            q.push((data){S, 0}), dis[S][op] = 0, g[S][op] = 1;
        else 
            q.push((data){T, 0}), dis[T][op] = 0, g[T][op] = 1; 
        while (!q.empty()) {
            data p = q.top(); q.pop();
            if (vis_o[p.u]) continue ;
            vis_o[p.u] = 1;
            for (int i = hd_o[p.u]; i >= 0; i = ed_o[i].nxt) {
                edge &e = ed_o[i];
                if (dis[e.v][op] > dis[p.u][op] + e.w) {
                    dis[e.v][op] = dis[p.u][op] + e.w;
                    g[e.v][op] = 1;
                    q.push((data){e.v, dis[e.v][op]});
                } else if (dis[e.v][op] == dis[p.u][op] + e.w) g[e.v][op] += g[p.u][op];
            }
        }
    }

    void csdag(int u) {
        for (int i = hd_o[u]; i >= 0; i = ed_o[i].nxt) {
            edge &e = ed_o[i];
            if (dis[e.v][0] == dis[u][0] + e.w) {
                if (dis[e.v][0] + dis[e.v][1] == dis[T][0]) {
                    ins_z(u, e.v, e.w);
                    ins_f(e.v, u, e.w);
                    vis_o[e.v] = 1;
                    csdag(e.v);
                }
            }
        }
    }

    void dp_z(int u) {
        if (vis_o[u]) return ;
        vis_o[u] = 1;
        for (int i = hd_z[u]; i >= 0; i = ed_z[i].nxt) {
            edge &e = ed_z[i];
            dp_z(e.v);
            bs[u][0] |= bs[e.v][0];
        }
    }
    void dp_f(int u) {
        if (vis_o[u]) return ;
        vis_o[u] = 1;
        for (int i = hd_f[u]; i >= 0; i = ed_f[i].nxt) {
            edge &e = ed_f[i];
            dp_f(e.v);
            bs[u][1] |= bs[e.v][1];
        }
    }

    void clean() {
        ms(hd_o, -1), en_o = -1;
        ms(hd_z, -1), en_z = -1;
        ms(hd_f, -1), en_f = -1;
        ms(dis, 0x3f), ms(g, 0);
    }
    int solve() {

        clean();
        cin >> n >> m >> S >> T;
        for (int u, v, w, i = 1; i <= m; ++i) {
            scanf("%d%d%d", &u, &v, &w);
            ins_o(u, v, w), ins_o(v, u, w);
        }
        dij(0), dij(1);
        if (dis[T][0] == INF) {
            return printf("%lld\n", 1ll * n * (n - 1) / 2);
        }

        vis_o[S] = 1, ms(vis_o, 0);
        csdag(S);

        /*for (int u = 1; u <= n; ++u) {printf("u=%d\n", u);for (int i = hd_z[u]; i >= 0; i = ed_z[i].nxt) {printf("%d ", ed_z[i].v);}}*/

        for (int u = 0; u < n; ++u) bs[u + 1][0][u] = bs[u + 1][1][u] = 1;
        ms(vis_o, 0);
        dp_z(S);
        ms(vis_o, 0);
        dp_f(T);

        for (int u = 1; u <= n; ++u) {
            if (dis[u][0] + dis[u][1] == dis[T][0]) f[u] = g[u][0] * g[u][1];
            ma[f[u]].set(u - 1);
        }

        LL ans = 0;
        for (int u = 1; u <= n; ++u) {
            //if (dis[u][0] + dis[u][1] != dis[T][0]) continue ;
            /*printf("u=%d\n", u);
            puts("fir");
            prbs(ma[f[T] - f[u]]);
            puts("sec");
            prbs(~bs[u][0]);
            puts("third");
            prbs(~bs[u][1]);*/
            ans += (ma[f[T] - f[u]] & (~bs[u][0]) & (~bs[u][1])).count();
        }

        cout << ans / 2;

        return 0;
    } 
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------