Codeforces 24D
给你一个$n×m$的棋盘,你初始在$(x,y)$,每一步等概率不动、向左、向下或向右走(如果向左走会越界则该步等概率不动、向右、向下走,向右会越界同理, 走到最后一行就结束, 问从开始到结束的期望步数
很容易设出方程$f[i][j]$表示从$(i,j)$走到最后一行的期望步数
$$
\begin{align}
f[i][1]&=\frac{1}{3}(f[i][1]+f[i][2]+f[i+1][1])+1, j=1\\
f[i][m]&=\frac{1}{3}(f[i][m]+f[i][m-1]+f[i+1][m])+1, j=m \\
f[i][j]&=\frac{1}{4}(f[i][j]+f[i][j-1]+f[i][j+1]+f[i+1][j])+1, 2 \leq j \leq n - 1\\
\end{align}
$$
初始化$f[n][j]=0$
那么我们发现这个方程是有后效性的。
其实行与行之间还是有无后效性的,但是列之间不满足。
考虑整理化简。我们将$f[i+1][]$看作常数,那么可以列一个方程组解出所有的$f[i][]$
即化成$a \cdot f[i][j] +b \cdot f[i][j - 1] + c \cdot f[i][j +1]=d+e \cdot f[i + 1][]$的形式
那么我们可以高斯消元解出来。
我们发现这里的每列最多三个有值,那么我们可以用特殊方法来求值,这里用倒三角的矩阵好做点,比直接求出简化阶梯矩阵好求,具体看代码
然后注意特判$m=1$即可,也可以直接输出$2n-1$,具体可以从高斯消元来推导,然后递推式转成封闭形式
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace flyinthesky {
const int MAXN = 1000 + 5;
db dp[MAXN][MAXN], M[MAXN][MAXN];
int n, m, x, y;
void gauss() {
for (int i = 1; i <= m; ++i) { // M[][m+1] 是常数项
db tmp = 1.0 / M[i][i];
M[i][i] *= tmp, M[i][m + 1] *= tmp;
if (i == m) break ;
M[i][i + 1] *= tmp;
tmp = M[i + 1][i] / M[i][i];
M[i + 1][i] -= tmp * M[i][i], M[i + 1][i + 1] -= tmp * M[i][i + 1], M[i + 1][m + 1] -= tmp * M[i][m + 1]; // 倒三角矩阵
}
for (int i = m - 1; i > 0; --i) M[i][m + 1] -= M[i + 1][m + 1] * M[i][i + 1]; // 逐一将后面的回代
}
void clean() {
}
int solve() {
clean();
cin >> n >> m >> x >> y;
for (int i = n - 1; i >= x; --i) {
M[1][1] = -2.0 / 3.0;
M[1][2] = 1.0 / 3.0;
M[1][m + 1] = -1.0 - 1.0 / 3.0 * dp[i + 1][1];
M[m][m] = -2.0 / 3.0;
M[m][m - 1] = 1.0 / 3.0;
M[m][m + 1] = -1.0 - 1.0 / 3.0 * dp[i + 1][m];
for (int j = 2; j < m; ++j) {
M[j][j - 1] = M[j][j + 1] = 1.0 / 4.0;
M[j][j] = -3.0 / 4.0;
M[j][m + 1] = -1.0 / 4.0 * dp[i + 1][j] - 1.0;
}
if (m == 1) {
M[1][1] = -1.0 / 2.0;
M[1][m + 1] = -1.0 / 2.0 * dp[i + 1][1] - 1;
}
gauss();
for (int j = 1; j <= m; ++j) dp[i][j] = M[j][m + 1];
//for (int j = 1; j <= m; ++j, puts(""))
//for (int k = 1; k <= m; ++k) printf("%.2f ", M[j][k]);
}
printf("%.8f\n", dp[x][y]);
return 0;
}
}
int main() {
flyinthesky::solve();
return 0;
}