「Bzoj 1559」「JSOI2009」密码 (AC自动机+状压DP+爆搜结论)

BZOJ 1559
题意:见上。

很容易想到这是一个AC自动机(经典题,多模式串)上的状压DP(数据范围),那么列出方程即为$dp(i,st,u)$为密码前$i$位包含模式串状态$st$,在AC自动机上$u$点的方案。

那么转移很显然,即刷表
$$
dp(i+1,st|zt[v], v)= \sum dp(i,st,u)
$$

然后考虑如何输出方案。
因为这个数字小于等于$42$才输出方案,可以考虑爆搜(方案DP不好记录前驱)

重要结论

字符串都是紧密结合的,不存在自由的可以填$26$种字母的位置

证明很显然,因为如果存在这个,那么会对答案产生$26$的贡献,而字符串可以交换位置,那么一定方案数大于$42$,那么我们只需要预处理两个模式串组合最少的字符长度,$O(n!)$枚举模式串排列即可,注意要舍掉不合法的,即长度大于$L$的。

知识点:
1、发现有阻碍的地方,要从这里下手,在纸上应该多写写

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<queue>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 102;

    int n, m, ch[MAXN][26], f[MAXN], zt[MAXN], sz;
    char s[11][11];
    LL dp[30][1025][MAXN];

    void insert(char *s, int ith) {
        int now = 0, len = strlen(s);
        for (int i = 0; i < len; ++i) {
            int c = s[i] - 'a';
            if (!ch[now][c]) ch[now][c] = ++sz;
            now = ch[now][c];
            if (i == len - 1) zt[now] = (1 << (ith - 1));
        }
    }
    void getFail() {
        queue<int > q;
        f[0] = 0;
        for (int c = 0; c < 26; ++c) {
            int v = ch[0][c];
            if (v) q.push(v), f[v] = 0;
        }
        while (!q.empty()) {
            int u = q.front(); q.pop();
            for (int c = 0; c < 26; ++c) {
                int v = ch[u][c];
                if (!v) {ch[u][c] = ch[f[u]][c]; continue ;}
                q.push(v);
                int j = f[u];
                while (j && !ch[j][c]) j = f[j];
                f[v] = ch[j][c], zt[v] |= zt[f[v]];
            }
        }
    }
    int gg[20], vis[20], orz[20][20], tot = 0;
    string pt[50], tmp;
    void dfs(int a) {
        if (a == m + 1) {
            tmp.clear();
            for (int i = 1; i <= m; ++i) {
                int st = 0, whw = strlen(s[gg[i]]);
                if (i != 1) st = orz[gg[i - 1]][gg[i]];
                for (int j = st; j < whw; ++j) {
                    tmp += s[gg[i]][j];
                }
            }
            if ((LL)tmp.size() == n) pt[++tot] = tmp;
            return ; 
        }
        for (int i = 1; i <= m; ++i) {
            if (!vis[i]) gg[a] = i, vis[i] = 1, dfs(a + 1), vis[i] = 0, gg[a] = 0;
        }
    }

    void clean() {
        ms(ch, 0), ms(f, 0), ms(zt, 0), sz = 0, ms(dp, 0), ms(vis, 0), ms(gg, 0), ms(orz, 0);
    }
    int solve() {

        clean();
        cin >> n >> m;
        for (int i = 1; i <= m; ++i) {
            scanf("%s", s[i]);
            insert(s[i], i);
        }
        getFail();

        dp[0][0][0] = 1;
        for (int i = 0; i <= n; ++i) {
            for (int st = 0; st < (1 << m); ++st) {
                for (int u = 0; u <= sz; ++u) if (dp[i][st][u]) {
                    for (int c = 0; c < 26; ++c) {
                        int v = ch[u][c];
                        dp[i + 1][st | zt[v]][v] += dp[i][st][u];
                    }
                }
            }
        }

        LL ans = 0;
        for (int u = 0; u <= sz; ++u) ans += dp[n][(1 << m) - 1][u];
        printf("%lld\n", ans);

        if (ans <= 42) {
            for (int x = 1; x <= m; ++x) {
                for (int y = 1; y <= m; ++y) if (x != y) {
                    int lx = strlen(s[x]), ly = strlen(s[y]);
                    for (int len = min(lx, ly); len >= 0; --len) {
                        int fl = 0;
                        for (int i = 0; i < len; ++i) {
                            int j = lx - len + i;
                            if (s[x][j] != s[y][i]) {fl = 1; break;} 
                        }
                        if (!fl) {
                            orz[x][y] = len;
                            break;
                        }
                    }
                }
            }
            dfs(1);
            sort(pt + 1, pt + 1 + ans);
            for (int i = 1; i <= ans; ++i) cout << pt[i] << endl;
        }

        return 0;
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
/*
10 2
abcd
bc

2 2
a
b

1 1
a

6 2
good
day

3 2
go
oe
*/
------ 本文结束 ------