「Bzoj 4566」「Haoi2016」找相同字符 (后缀数组 + 单调栈)

BZOJ 4566
题意:给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。

一开始乱二分挂了……
其实这题就是Bzoj 3238 差异,我们可以想到连接两个字符串,中间用一个没出现过的字符连接,然后单调栈维护所有区间的最小值。但是这里是两个字符串,可能会有答案算到是同串的,那么我们考虑容斥

即求三次后缀数组,第一次求连接串的答案,然后求两个原串的答案,最后用连接串答案减掉后面的两个原串答案即为最终答案,正确性显然。

知识点
1、多次求后缀数组,不要将思维局限
2、不要乱在height上二分

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<queue>
#include<set>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 400000 + 5;

    char ch[2][MAXN];
    int a[MAXN], n, m;
    int SA[MAXN], rk[MAXN], tp[MAXN], tax[MAXN], height[MAXN];

    void build(int n) {//11?¨¬o¨®¡Áo¨ºy¡Á¨¦ 
        for (int i = 1; i <= m; ++i) tax[i] = 0;
        for (int i = 1; i <= n; ++i) tax[rk[i] = a[i]]++;
        for (int i = 1; i <= m; ++i) tax[i] += tax[i - 1];
        for (int i = n; i >= 1; --i) SA[tax[rk[i]]--] = i;//?¨´¨ºy??D¨°??¦Ì¨²¨°???

        for (int k = 1; k <= n; k <<= 1) {
            int p = 0;
            for (int i = n - k + 1; i <= n; i++) tp[++p] = i;//(n-k)~(n-1)?T¦Ì¨²?t1??¨¹¡Á?¡ê??¨´¨°???D¨°¨®|?????¨²?¡ã??
            for (int i = 1; i <= n; i++) if (SA[i] > k) tp[++p] = SA[i] - k;
            //??¨®DSA[i]>=k¦Ì?SA[i]2?¨º?¦Ì¨²?t1??¨¹¡Á?¦Ì????? 
            //¡ä¨®¨ª??D?¨¦¨°??¡ä3?¦Ì¨²¨°?1??¨¹¡Á?o¨ª¦Ì¨²?t1??¨¹¡Á?¦Ì??????¨¤2?k¡ê?1¨ºSA[i] - k 
            for (int i = 1; i <= m; ++i) tax[i] = 0;
            for (int i = 1; i <= n; ++i) tax[rk[tp[i]]]++;//x[tp[i]]?¨¤¦Ì¨¨¨®¨²????¦Ì¨²i¦Ì?¦Ì¨²?t1??¨¹¡Á?¦Ì?¦Ì¨²¨°?1??¨¹¡Á?¦Ì????? 
            for (int i = 1; i <= m; ++i) tax[i] += tax[i - 1];
            for (int i = n; i >= 1; --i) SA[tax[rk[tp[i]]]--] = tp[i];//¡À¡ê?¡è¨¢?¦Ì¨²¨°?1??¨¹¡Á?¦Ì??3D¨°?¨´??¦Ì¨²?t1??¨¹¡Á? 
            //?¨´¨ºy??D¨°¦Ì¨²¨°?1??¨¹¡Á?(rank[i]¦Ì?¨ºy?¦Ì)o¨ª¦Ì¨²?t1??¨¹¡Á?(tp[i]¦Ì???¡À¨º) 
            swap(rk, tp);//¡ä?¨º¡Àtp??¨®?¡ê??Y¡ä?¨¦?¨°???rank¦Ì??¦Ì
            p = 1, rk[SA[1]] = 1;
            for (int i = 2; i <= n; ++i) 
                rk[SA[i]] = (tp[SA[i]] == tp[SA[i - 1]] && tp[SA[i] + k] == tp[SA[i - 1] + k]) ? p : ++p;
            //??????¦Ì¨²i¦Ì?¨ºy¦Ì?rank¡ê?¡ã¡äsa?3D¨°?¨¹1?¡À¡ê?¡èrank¦Ì??y¨¨¡¤D?¡ê?¦Ì?¨º?¨°acmp?D??¨®?¨¦?¨°???¡Á?¡¤?¡ä??¨¤¦Ì¨¨¦Ì??¨¦?? 
            if (p >= n) break;//???|¡ê?¨°??-??¨®D???¡ä?a?? 
            m = p; 
        }
        int k = 0;//k¨º?¡À¨¨i-1?¡ã¨°???¦Ì?o¨®¡Áo
        for (int i = 1; i <= n; ++i) {//H[0], H[1], H[2] ...¦Ì??3D¨°???? 
            if (k) k--;//¡ä¨®k-1?a¨º?¡À¨¨?? ,??¨®??¨¢??H[i]>=H[i-1]-1, ¡Á?3¡è1?12?¡ã¡Áo¦Ì?3¡è?¨¨?¨¢¨¦¨´¨º?k-1(k = H[i-1])
            int j = SA[rk[i] - 1]; //?¡ã¨°???¦Ì?o¨®¡Áo???? 
            while (a[i + k] == a[j + k]) k++; //¨ª¨´o¨®¡À¨¨?? 
            height[rk[i]] = k;  //?¨¹D?¡äe¡ã? 
        }
    }

    LL L[MAXN], R[MAXN], st[MAXN], top = 0, ans = 0;

    void clean() {
        n = 0, m = 200, ms(height, 0), ms(SA, 0), ms(rk, 0), ms(tax, 0), ms(tp, 0), ms(L, 0), ms(R, 0), ms(st, 0);
    }
    int solve() {

        // 1
        clean();
        scanf("%s", ch[0] + 1);
        int tmp = strlen(ch[0] + 1);
        for (int i = 1; i <= tmp; ++i) a[++n] = ch[0][i] - 'a' + 1;

        a[++n] = '$';
        scanf("%s", ch[1] + 1);
        tmp = strlen(ch[1] + 1);
        for (int i = 1; i <= tmp; ++i) a[++n] = ch[1][i] - 'a' + 1;

        build(n);

        st[top = 1] = 1;
        for (int i = 2; i <= n; ++i) {
            while (top && height[st[top]] > height[i]) R[st[top--]] = i;
            L[i] = st[top];
            st[++top] = i;
        }
        while (top) R[st[top--]] = n + 1;
        for (int i = 1; i <= n; ++i) 
            ans += 1ll * (R[i] - i) * (i - L[i]) * height[i];

        // 2
        clean();
        tmp = strlen(ch[0] + 1);
        for (int i = 1; i <= tmp; ++i) a[++n] = ch[0][i] - 'a' + 1;
        build(n);

        st[top = 1] = 1;
        for (int i = 2; i <= n; ++i) {
            while (top && height[st[top]] > height[i]) R[st[top--]] = i;
            L[i] = st[top];
            st[++top] = i;
        }
        while (top) R[st[top--]] = n + 1;
        for (int i = 1; i <= n; ++i) 
            ans -= 1ll * (R[i] - i) * (i - L[i]) * height[i];

        // 3
        clean();
        tmp = strlen(ch[1] + 1);
        for (int i = 1; i <= tmp; ++i) a[++n] = ch[1][i] - 'a' + 1;
        build(n);

        st[top = 1] = 1;
        for (int i = 2; i <= n; ++i) {
            while (top && height[st[top]] > height[i]) R[st[top--]] = i;
            L[i] = st[top];
            st[++top] = i;
        }
        while (top) R[st[top--]] = n + 1;
        for (int i = 1; i <= n; ++i) 
            ans -= 1ll * (R[i] - i) * (i - L[i]) * height[i];

        printf("%lld\n", ans);

        return 0;
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}
------ 本文结束 ------