NOIP2016 Day1 T3(概率期望DP)

设$dp(i,j,0)$为前$i$个课程申请$j$次,第$j$次成功的最小体力期望,$dp(i,j,1)$为前$i$个课程申请$j$次,第$j$次不成功的最小体力期望。
转移方程具体看代码,太长了,不在这里重复打
注意double别用memset并且赋值考虑是否会溢出

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 2000 + 5, MAXV = 300 + 5;
int n, m, v, e, ci[MAXN], di[MAXN], G[MAXV][MAXV];
db ki[MAXN], dp[MAXN][MAXN][2];
void clean() {
    for (int i = 1; i <= v; i++) 
    for (int j = 1; j <= v; j++) if (i == j) G[i][j] = 0; else G[i][j] = 1000000000.0;
}
void solve() {
    clean();
    for (int i = 1; i <= n; i++) scanf("%d", &ci[i]);
    for (int i = 1; i <= n; i++) scanf("%d", &di[i]);
    for (int i = 1; i <= n; i++) scanf("%lf", &ki[i]);
    for (int x, y, w, i = 1; i <= e; i++) {
        scanf("%d%d%d", &x, &y, &w);
        G[x][y] = min(G[x][y], w);
        G[y][x] = min(G[y][x], w);//注意邻接矩阵重边处理 
    }
    for (int k = 1; k <= v; k++) 
    for (int i = 1; i <= v; i++) 
    for (int j = 1; j <= v; j++)
    if (i != j && i != k && j != k) G[i][j] = min(G[i][j], G[i][k] + G[k][j]);
    for (int i = 0; i <= n; i++) for (int j = 0; j <= m; j++) dp[i][j][0] = dp[i][j][1] = 1000000000.0;
    db ans = dp[1][0][0];
    dp[1][0][0] = 0, dp[1][1][0] = 0, dp[1][1][1] = 0;
    for (int i = 2; i <= n; i++) {
        for (int j = 0; j <= m; j++) {
            dp[i][j][0] = min(dp[i][j][0], min(dp[i - 1][j][0] + (db)G[ci[i - 1]][ci[i]], dp[i - 1][j][1] + (db)G[di[i - 1]][ci[i]] * ki[i - 1] + G[ci[i - 1]][ci[i]] * (1 - ki[i - 1])));
            if (j - 1 >= 0) 
            dp[i][j][1] = min(dp[i][j][1], min(dp[i - 1][j - 1][0] + (db)G[ci[i - 1]][ci[i]] * (1 - ki[i]) + G[ci[i - 1]][di[i]] * ki[i],dp[i - 1][j - 1][1] + (db)G[ci[i - 1]][ci[i]] * (1 - ki[i]) * (1 - ki[i - 1]) + G[di[i - 1]][ci[i]] * (1 - ki[i]) * ki[i - 1] +G[di[i - 1]][di[i]] * ki[i] * ki[i - 1] + G[ci[i - 1]][di[i]] * ki[i] * (1 - ki[i - 1])));
        }
    }
    for (int i = 0; i <= m; i++) {
        ans = min(ans, min(dp[n][i][1], dp[n][i][0]));
    }
    printf("%.2f\n", ans);
}
int main() {
    scanf("%d%d%d%d", &n, &m, &v, &e), solve();
    return 0;
}
------ 本文结束 ------