Codeforces 544D(BFS最短路+枚举)

Codeforces 544D
题意:给出$n$个点,$m$条边权为$1$的无向边,破坏最多的道路,使得从$s_1$到$t_1,s_2$到$t_2$的距离不超过$l_1,l_2$

对于$s_1$到$t_1$,最优肯定取最短路,2同理。
但是这两条路径可能重复更优,也就是说1取最短路而2可以在$l$范围之内不取最短路而使用1的某些边来减少边数的增加
所以我们枚举路径$(i,j)$,取每个起点终点经过这条路径的长度最短的。4种情况。最后取反答案即可。记得可能两条路径没有重合部分,直接赋值$dis(s_1,t_1)+dis(s_2,t_2)$

知识点:本题与CF 954D类似,都是求完最短路之后枚举边/路径来算答案。
对于权为$1$的图,可以用 BFS $O(n^2)$求多源最短路。而不是用Floyd $O(n^3)$的算法(或者$O(n(n+m)logn)=O(n^2+mnlogn)$)
本题将删除化为选择。

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double

const int MAXN = 3000 + 5, INF = 100000000;

struct node{int u, d;};

int n, m, s1, t1, l1, s2, t2, l2, vis[MAXN], dis[MAXN][MAXN];
std::vector<int > G[MAXN];
std::queue<node > q;

inline void ins(int x, int y) {G[x].push_back(y), G[y].push_back(x);}

void clean() {
    for (int i = 1; i <= n; i++)
    for (int j = 1; j <= n; j++) dis[i][j] = INF;
}
int solve() {
    clean();
    for (int a, b, i = 1; i <= m; i++) scanf("%d%d", &a, &b), ins(a, b);
    scanf("%d%d%d%d%d%d", &s1, &t1, &l1, &s2, &t2, &l2);
    for (int st = 1; st <= n; st++) {
        for (int i = 1; i <= n; i++) vis[i] = 0;
        q.push((node){st, 0}), vis[st] = 1, dis[st][st] = 0;
        while (!q.empty()) {
            node p = q.front(); q.pop();
            for (int i = 0; i < (int)G[p.u].size(); i++) {
                int v = G[p.u][i];
                if (!vis[v]) dis[st][v] = p.d + 1, vis[v] = 1, q.push((node){v, p.d + 1});
            }
        }
    }
    if (dis[s1][t1] > l1 || dis[s2][t2] > l2) return printf("-1\n"), 0;
    int ans = dis[s1][t1] + dis[s2][t2];
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            if (dis[s1][i] + dis[i][j] + dis[j][t1] <= l1 && dis[s2][i] + dis[i][j] + dis[j][t2] <= l2)
            ans = std::min(ans, dis[s1][i] + dis[i][j] + dis[j][t1] + dis[s2][i] + dis[j][t2]);
            if (dis[s1][j] + dis[i][j] + dis[i][t1] <= l1 && dis[s2][i] + dis[i][j] + dis[j][t2] <= l2)
            ans = std::min(ans, dis[s1][j] + dis[i][j] + dis[i][t1] + dis[s2][i] + dis[j][t2]);
            if (dis[s1][i] + dis[i][j] + dis[j][t1] <= l1 && dis[s2][j] + dis[i][j] + dis[i][t2] <= l2)
            ans = std::min(ans, dis[s1][i] + dis[i][j] + dis[j][t1] + dis[s2][j] + dis[i][t2]);
            if (dis[s1][j] + dis[i][j] + dis[i][t1] <= l1 && dis[s2][j] + dis[i][j] + dis[i][t2] <= l2)
            ans = std::min(ans, dis[s1][j] + dis[i][j] + dis[i][t1] + dis[s2][j] + dis[i][t2]);
        }
    }
    printf("%d\n", m - ans);
    return 0;
}
int main() {
    scanf("%d%d", &n, &m), solve();
    return 0;
}
------ 本文结束 ------