Hdu 2825(AC自动机+状压DP)

hdu 2825
题意:给出$m$个模式串,求至少包含$k$个模式串长为$n$的主串个数。
用模式串建立AC自动机,设$dp(i,j,S)$为主串前$i$个字符,在AC自动机上$j$点,当前存在模式串状态的方案数。
$$dp(i,v,S|val_v)=dp(i-1,j,S)$$
因为有用的只有两层,所以其中可以运用滚动数组节省空间

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MO = 20090717, MAXN = 10 * 10 + 5;
int n, m, k, sz = 10 * 10;
char s[15];
int ch[MAXN][35], val[MAXN], f[MAXN], dp[2][MAXN][(1 << 10) + 100];
void insert(char *s, int ith) {
    int now = 0, len = strlen(s);
    for (int i = 0; i < len; i++) {
        int c = s[i] - 'a';
        if (!ch[now][c]) ch[now][c] = ++sz;
        now = ch[now][c];
        if (i == len - 1) val[now] += (1 << (ith - 1));
    }
}
void getFail() {
    queue<int> q;
    f[0] = 0;
    for (int c = 0; c < 26; c++) {
        int v = ch[0][c];
        if (v) q.push(v), f[v] = 0;
    }
    while (!q.empty()) {
        int u = q.front(); q.pop();
        for (int c = 0; c < 26; c++) {
            int v = ch[u][c];
            if (!v) {ch[u][c] = ch[f[u]][c]; continue;}
            q.push(v);
            int j = f[u]; while (j && !ch[j][c]) j = f[j];
            f[v] = ch[j][c];
            val[v] |= val[f[v]];//注意传递 
        }
    }
}
void cal() {
    dp[0][0][0] = 1;
    int x = 1;//滚动数组当前位置 
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j <= sz; j++)
        for (int S = 0; S < (1 << m); S++) dp[x][j][S] = 0;
        for (int j = 0; j <= sz; j++) {
            for (int S = 0; S < (1 << m); S++) {
                if (!dp[x ^ 1][j][S]) continue;
                for (int c = 0; c < 26; c++) {
                    int v = ch[j][c];
                    dp[x][v][S | val[v]] = (dp[x][v][S | val[v]] + dp[x ^ 1][j][S]) % MO;//方程不要写错了 
                }
            }
        } 
        x ^= 1;
    }
}
void clean() {
    for (int i = 0; i <= sz; i++) {
        for (int j = 0; j < 28; j++) ch[i][j] = 0;
        for (int j = 0; j < 1030; j++) dp[1][i][j] = dp[0][i][j] = 0;
        val[i] = f[i] = 0;
    }
    sz = 0;
}
bool check(int x) {
    int ret = 0, tmp = x;
    do {
        ret += tmp & 1;
        tmp >>= 1;
    } while (tmp != 0);
    return ret >= k;
}
void solve() {
    clean();
    for (int i = 1; i <= m; i++) {
        scanf("%s", s);
        insert(s, i);
    }
    getFail(), cal();
    int taki = 0;
    for (int j = 0; j <= sz; j++) {
        for (int S = 0; S < (1 << m); S++) {
            if (check(S)) taki = (taki + dp[n % 2][j][S]) % MO;
        }
    }
    printf("%d\n", taki); 
}
int main() {
    while (scanf("%d%d%d", &n, &m, &k) == 3 && (n || m || k)) solve();
    return 0;
}
------ 本文结束 ------