「Bzoj 4665」小w的喜糖 (容斥DP)

BZOJ 4665
题意:求有重复元素排列的错排。

考虑将所有元素都看成不同的(重复元素也看做),然后最后再除以各自的个数消序即可。

然后我们考虑设$F(i)$为至少有$i$个位置上是原来的数,那么我们可以利用这个来容斥,即
$$
\text{ans}=\sum_{i=0}^n F(i) \cdot (-1)^i
$$

我们考虑怎么求这个$F(i)$,设$dp(i,j)$为前$i$个值全部考虑完,至少有$j$个位置上是原来的数的方案,那么

$$
dp(i,j)=\sum_{k=0}^{min(j,cnt[i])}dp(i-1,j-k) \cdot C_{cnt[i]}^{k} \cdot \prod_{p=cnt[i]-k+1}^{cnt[i]}p
$$

其中$cnt[i]$为值$i$的元素个数。
那么我们可以得到
$$
F(i)=dp(n,i)\cdot (n-i)!
$$
那么就能容斥了。

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#include<queue>
#include<set>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const LL MO = 1000000009, MAXN = 2000 + 5;

    LL n, cnt[MAXN], jc[MAXN], jc_inv[MAXN], dp[MAXN][MAXN], f[MAXN];

    LL ksm(LL a, LL b) {
        LL bs = a, ans = 1;
        while (b) {
            if (b & 1) ans = ans * bs % MO;
            bs = bs * bs % MO;
            b >>= 1;
        }
        return ans;
    }
    LL C(LL n, LL m) {
        if (m > n) return 0;
        return jc[n] * jc_inv[m] % MO * jc_inv[n - m] % MO;
    }

    void clean() {
    }
    int solve() {

        clean();
        cin >> n;
        jc[0] = 1, jc_inv[0] = 1;
        for (LL i = 1; i <= n; ++i) jc[i] = jc[i - 1] * i % MO, jc_inv[i] = ksm(jc[i], MO - 2);
        for (LL x, i = 1; i <= n; ++i) scanf("%lld", &x), ++cnt[x];

        dp[0][0] = 1;
        for (LL i = 1; i <= n; ++i) {
            for (LL j = 0; j <= n; ++j) {
                for (LL k = 0; k <= min(j, cnt[i]); ++k) {
                    dp[i][j] = (dp[i][j] + dp[i - 1][j - k] * C(cnt[i], k) % MO * jc[cnt[i]] % MO * jc_inv[cnt[i] - k] % MO) % MO;
                }
            }
        }

        for (LL i = 0; i <= n; ++i) f[i] = dp[n][i] * jc[n - i] % MO;

        LL ans = 0, hh = 1;
        for (LL i = 0; i <= n; ++i) {
            ans = (ans + hh * f[i] % MO) % MO;
            hh = (hh * -1 + MO) % MO;
        }

        for (LL i = 1; i <= n; ++i) ans = (ans * jc_inv[cnt[i]]) % MO;

        cout << ans;

        return 0;
    } 
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------