654 字
3 分钟
BZOJ3509: [CodeChef] COUNTARI
Description
给定一个长度为N的数组A[],求有多少对i, j, k(1<=i<j<k<=N)满足A[k]-A[j]=A[j]-A[i]。
Input
第一行一个整数N(N<=10^5)。 接下来一行N个数A[i](A[i]<=30000)。
Output
一行一个整数。
Sample Input
10
3 5 3 6 3 4 10 4 5 2
Sample Output
9
分块FFT
先分块, 暴力求出有三个在同一块,和两个在同一块的答案
三个都不在一块的FFT
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN = 1e6;
const int MOD = 998244353;
long long pow_mod(long long a, long long b, long long P)
{
long long ans = 1;
while (b)
{
if (b & 1) ans = ans * a % P;
b >>= 1;
a = a * a % P;
}
return ans;
}
long long Inv, N;
int rev[MAXN];
void FFt(long long *a, int op)
{
long long wn, w, t;
for (int i = 0; i < N; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int k = 2; k <= N; k <<= 1)
{
wn = pow_mod(3, op == 1 ? (MOD - 1) / k : MOD - 1 - (MOD - 1) / k, MOD);
for (int j = 0; j < N; j += k)
{
w = 1;
for (int i = 0; i < (k >> 1); i++, w = w * wn % MOD)
{
t = a[i + j + (k >> 1)] * w % MOD;
a[i + j + (k >> 1)] = (a[i + j] - t + MOD) % MOD;
a[i + j] = (a[i + j] + t) % MOD;
}
}
}
if (op == -1)
for (int i = 0; i < N; i++)
a[i] = a[i] * Inv % MOD;
}
int Sum1[MAXN], Sum2[MAXN];
int W[MAXN];
int l[MAXN], r[MAXN];
long long A[65537], B[65537];
int main()
{
int n = read();
for (int i = 0; i < n; i++)
W[i] = read();
int len = min((int)sqrt(n) * 6, n);
int num = n / len;
if (n % len) num++;
long long ans = 0;
for (int i = 1; i <= num; i++)
{
l[i] = (i - 1) * len;
r[i] = min(l[i] + len - 1, n - 1);
}
for (int i = 1; i <= num; i++)
{
for (int j = l[i]; j <= r[i]; j++)
{
for (int k = j + 1; k <= r[i]; k++)
if (2 * W[j] - W[k] >= 0)
ans += Sum1[2 * W[j] - W[k]];
Sum1[W[j]]++;
}
}
for (int i = num; i >= 1; i--)
{
for (int j = l[i]; j <= r[i]; j++)
{
for (int k = j + 1; k <= r[i]; k++)
if (2 * W[k] - W[j] >= 0)
ans += Sum2[2 * W[k] - W[j]];
}
for (int j = l[i]; j <= r[i]; j++)
Sum2[W[j]]++;
}
N = 65536;
Inv = pow_mod(N, MOD - 2, MOD);
for (int i = 1; i < N; i++)
if (i & 1)
rev[i] = (rev[i >> 1] >> 1) | (N >> 1);
else
rev[i] = (rev[i >> 1] >> 1);
for (int i = 2; i < num; i++)
{
memset (A, 0, sizeof (A));
memset (B, 0, sizeof (B));
for (int j = 0; j < l[i]; j++) A[W[j]]++;
for (int j = r[i] + 1; j < n; j++) B[W[j]]++;
FFt(A, 1), FFt(B, 1);
for (int j = 0; j < N; j++) A[j] = A[j] * B[j];
FFt(A, -1);
for (int j = l[i]; j <= r[i]; j++) ans += A[2 * W[j]];
}
printf ("%lld\n", ans);
}
BZOJ3509: [CodeChef] COUNTARI
https://www.nekomio.com/posts/127/