poj 1185(状压DP)

poj 1185

本题很像这题,唯一不同的是一个物品会影响周围两个方格,那么我们设$dp(i, st(j), st(k))$为第$i$行用状态$k$,第$i-1$行用状态$j$的最优值。那么转移方程即为:
$$dp(i, st(j), st(k)) = max(dp(i, st(t), st(j)) + num(k))$$
其中$t$是与$k$不冲突的所有状态。初始化要初始化$num[i]$(st[i]中的$1$的个数),我们可以用x&(x-1)来快速消掉$x$最后的$1$,算出$num$。判断是否冲突可以用i&(i<<1), i& (i<<2)来判断,可以视为一个平移的过程,仔细思考可以发现。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
using namespace std;
const int MAXN = 100 + 5, MAXM = 10 + 2;
int n, m, top, st[100], cur[MAXN], dp[MAXN][100][100], num[MAXN];
char map[MAXN][MAXM];
void querynum(int i) {
    int t = st[i];
    while (t) {
        num[i]++;
        t&=(t-1);
    }
}
bool fit(int j, int h) {return (st[j]&cur[h]) ? false : true;}
void clean() {
    top = 0, ms(cur, 0), ms(dp, -1), ms(num, 0);
}
void solve() {
    clean();
    for (int i=0;i<(1<<m);i++) {
        if ( (i& (i<<1) ) || (i& (i<<2) ) ) continue; else st[++top] = i;
    }
    for (int i=1;i<=n;i++) scanf("%s", map[i]+1);
    for (int i=1;i<=n;i++) for (int j=1;j<=m;j++) {
        if (map[i][j]=='H') cur[i] += (1<<(j-1));
    }
    for (int i=1;i<=top;i++) {
        querynum(i);
        if (fit(i, 1)) dp[1][1][i] = num[i];
    }
    for (int hi=2;hi<=n;hi++) {
        for (int i=1;i<=top;i++) {
            if (fit(i, hi)) {
                for (int j=1;j<=top;j++) {
                    if (st[i]&st[j]) continue;
                    for (int k=1;k<=top;k++) {
                        if (st[i]&st[k]) continue;
                        if (dp[hi-1][k][j]==-1) continue;
                        dp[hi][j][i] = max(dp[hi][j][i], dp[hi-1][k][j] + num[i]);
                    }
                }
            }
        }
    }
    int ans = 0;
    for (int hi=1;hi<=n;hi++) {
        for (int i=1;i<=top;i++) {
            for (int j=1;j<=top;j++) {
                ans = max(ans, dp[hi][i][j]);
            }
        }
    }
    printf("%d\n", ans);
}
int main() {
    #ifndef ONLINE_JUDGE 
    freopen("1.in", "r", stdin);freopen("1.out", "w", stdout);
    #endif
    while (scanf("%d%d", &n, &m)==2) solve();
    return 0;
}
------ 本文结束 ------