654 字
3 分钟
BZOJ3509: [CodeChef] COUNTARI
Description
给定一个长度为N的数组A[],求有多少对i, j, k(1<=i<j<k<=N)满足A[k]-A[j]=A[j]-A[i]。
Input
第一行一个整数N(N<=10^5)。 接下来一行N个数A[i](A[i]<=30000)。
Output
一行一个整数。
Sample Input
103 5 3 6 3 4 10 4 5 2
Sample Output
9
分块FFT
先分块, 暴力求出有三个在同一块,和两个在同一块的答案
三个都不在一块的FFT
#include <cmath>#include <cstdio>#include <cstring>#include <algorithm>using namespace std;inline int read(){ int x=0,f=1;char ch=getchar(); while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f;}const int MAXN = 1e6;const int MOD = 998244353;long long pow_mod(long long a, long long b, long long P){ long long ans = 1; while (b) { if (b & 1) ans = ans * a % P; b >>= 1; a = a * a % P; } return ans;}long long Inv, N;int rev[MAXN];void FFt(long long *a, int op){ long long wn, w, t; for (int i = 0; i < N; i++) if (i < rev[i]) swap(a[i], a[rev[i]]); for (int k = 2; k <= N; k <<= 1) { wn = pow_mod(3, op == 1 ? (MOD - 1) / k : MOD - 1 - (MOD - 1) / k, MOD); for (int j = 0; j < N; j += k) { w = 1; for (int i = 0; i < (k >> 1); i++, w = w * wn % MOD) { t = a[i + j + (k >> 1)] * w % MOD; a[i + j + (k >> 1)] = (a[i + j] - t + MOD) % MOD; a[i + j] = (a[i + j] + t) % MOD; } } } if (op == -1) for (int i = 0; i < N; i++) a[i] = a[i] * Inv % MOD;}int Sum1[MAXN], Sum2[MAXN];int W[MAXN];int l[MAXN], r[MAXN];long long A[65537], B[65537];int main(){ int n = read(); for (int i = 0; i < n; i++) W[i] = read(); int len = min((int)sqrt(n) * 6, n); int num = n / len; if (n % len) num++; long long ans = 0; for (int i = 1; i <= num; i++) { l[i] = (i - 1) * len; r[i] = min(l[i] + len - 1, n - 1); } for (int i = 1; i <= num; i++) { for (int j = l[i]; j <= r[i]; j++) { for (int k = j + 1; k <= r[i]; k++) if (2 * W[j] - W[k] >= 0) ans += Sum1[2 * W[j] - W[k]]; Sum1[W[j]]++; } } for (int i = num; i >= 1; i--) { for (int j = l[i]; j <= r[i]; j++) { for (int k = j + 1; k <= r[i]; k++) if (2 * W[k] - W[j] >= 0) ans += Sum2[2 * W[k] - W[j]]; } for (int j = l[i]; j <= r[i]; j++) Sum2[W[j]]++; } N = 65536; Inv = pow_mod(N, MOD - 2, MOD); for (int i = 1; i < N; i++) if (i & 1) rev[i] = (rev[i >> 1] >> 1) | (N >> 1); else rev[i] = (rev[i >> 1] >> 1); for (int i = 2; i < num; i++) { memset (A, 0, sizeof (A)); memset (B, 0, sizeof (B)); for (int j = 0; j < l[i]; j++) A[W[j]]++; for (int j = r[i] + 1; j < n; j++) B[W[j]]++; FFt(A, 1), FFt(B, 1); for (int j = 0; j < N; j++) A[j] = A[j] * B[j]; FFt(A, -1); for (int j = l[i]; j <= r[i]; j++) ans += A[2 * W[j]]; } printf ("%lld\n", ans);}
BZOJ3509: [CodeChef] COUNTARI
https://www.nekomio.com/posts/127/