BZOJ 4827
题意:给定数列$a,b, b$可以循环移动,选择整数$c$,求
$$
\min\sum\limits_{i=1}^n(a_i-b_i+c)^2
$$
化简式子,可得
$$
\sum_{i=1}^na_i^2+\sum_{i=1}^nb_i^2-2\sum_{i=1}^na_ib_i+nc^2+2c\sum_{i=1}^n(a_i-b_i)
$$
前面两项是定值,后面两项可以看作一个关于$c$的二次函数,显然有最小值,求最近对称轴$\frac{\sum\limits_{i=1}^n (a_i - b_i)}{n}$的两个点的最小值加入答案。
现在问题变为求$\min\sum\limits_{i=1}^na_ib_i$
我们将环断为链,复制一份在后面
考虑平移的情况
$$
\min\sum\limits_{i=1}^na_{i}b_{i+k}
$$
我们可以转化一下,翻转$a$数组,得
$$
\sum\limits_{i=1}^na_{n-i+1}b_{i+k}
$$
那么两项下标为常数,是卷积的形式,FFT求卷积即可,最后扫描一遍求最大值,答案减去这个最大值。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<complex>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace flyinthesky {
const LL MAXN = 200000 + 5;
const db PI = acos(-1);
LL ans = 0, n, m, a[MAXN], b[MAXN], r[MAXN], ret[MAXN];
complex<db > f[MAXN], g[MAXN];
void FFT(complex<db > *a, int op) {
for (int i = 0; i < n; ++i)
if (i < r[i]) swap(a[i], a[r[i]]);
for (int i = 1; i < n; i *= 2) {
complex<db > Wn(cos(PI / i), sin(PI / i) * op);
for (int j = 0; j < n; j += i * 2) {
complex<db > w(1, 0), *a0 = a + j, *a1 = a0 + i;
for (int k = 0; k < i; ++k) {
complex<db > t = *a1 * w;
*a1 = *a0 - t, *a0 += t;
w *= Wn, ++a0, ++a1;
}
}
}
}
void clean() {
}
int solve() {
clean();
cin >> n >> m;
for (LL i = 1; i <= n; ++i) scanf("%lld", &a[i]), ans += a[i] * a[i];
for (LL i = 1; i <= n; ++i) scanf("%lld", &b[i]), ans += b[i] * b[i];
LL t = 0;
for (LL i = 1; i <= n; ++i) t += a[i] - b[i];
LL c1 = (LL)floor(-(db)t / (db)n), c2 = (LL)ceil(-(db)t / (db)n);
ans += min(n * c1 * c1 + 2ll * c1 * t, n * c2 * c2 + 2ll * c2 * t);
reverse(a + 1, a + 1 + n);
for (LL i = 1; i <= n; ++i) b[n + i] = b[i];
for (LL i = 1; i <= n; ++i) f[i] = (db)a[i];
for (LL i = 1; i <= 2 * n; ++i) g[i] = (db)b[i];
int l = 0;
for (m = n * 2, n = 1; n <= m; ++l, n <<= 1);
for (LL i = 1; i <= n; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
FFT(f, 1), FFT(g, 1);
for (LL i = 0; i < n; ++i) f[i] *= g[i];
FFT(f, -1);
for (LL i = 0; i < n; ++i) ret[i] = (int)(fabs(f[i].real()) / (db)n + 0.5);
LL whw = 0;
for (LL i = 0; i < m / 2; ++i) whw = max(whw, ret[m / 2 + i + 1]);
printf("%lld\n", ans - 2ll * whw);
return 0;
}
}
int main() {
flyinthesky::solve();
return 0;
}