BZOJ3509: [CodeChef] COUNTARI

Description

给定一个长度为N的数组A[],求有多少对i, j, k(1<=i<j<k<=N)满足A[k]-A[j]=A[j]-A[i]。

Input

第一行一个整数N(N<=10^5)。
接下来一行N个数A[i](A[i]<=30000)。

Output

一行一个整数。

Sample Input

10
3 5 3 6 3 4 10 4 5 2

Sample Output

9

分块FFT
先分块, 暴力求出有三个在同一块,和两个在同一块的答案
三个都不在一块的FFT

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
const int MAXN = 1e6;
const int MOD = 998244353;
long long pow_mod(long long a, long long b, long long P)
{
    long long ans = 1;
    while (b)
    {
        if (b & 1) ans = ans * a % P;
        b >>= 1;
        a = a * a % P;
    }
    return ans;
}
long long Inv, N;
int rev[MAXN];
void FFt(long long *a, int op)
{
    long long wn, w, t;
    for (int i = 0; i < N; i++)
        if (i < rev[i])
            swap(a[i], a[rev[i]]);
    for (int k = 2; k <= N; k <<= 1)
    {
        wn = pow_mod(3, op == 1 ? (MOD - 1) / k : MOD - 1 - (MOD - 1) / k, MOD);
        for (int j = 0; j < N; j += k)
        {
            w = 1;
            for (int i = 0; i < (k >> 1); i++, w = w * wn % MOD)
            {
                t = a[i + j + (k >> 1)] * w % MOD;
                a[i + j + (k >> 1)] = (a[i + j] - t + MOD) % MOD;
                a[i + j] = (a[i + j] + t) % MOD;
            }
        }
    }
    if (op == -1)
        for (int i = 0; i < N; i++)
            a[i] = a[i] * Inv % MOD;
}
int Sum1[MAXN], Sum2[MAXN];
int W[MAXN];
int l[MAXN], r[MAXN];
long long A[65537], B[65537];
int main()
{
    int n = read();
    for (int i = 0; i < n; i++)
        W[i] = read();
    int len = min((int)sqrt(n) * 6, n);
    int num = n / len;
    if (n % len) num++;
    long long ans = 0;
    for (int i = 1; i <= num; i++)
    {
        l[i] = (i - 1) * len;
        r[i] = min(l[i] + len - 1, n - 1);
    }
    for (int i = 1; i <= num; i++)
    {
        for (int j = l[i]; j <= r[i]; j++)
        {
            for (int k = j + 1; k <= r[i]; k++) 
                if (2 * W[j] - W[k] >= 0)
                    ans += Sum1[2 * W[j] - W[k]];
            Sum1[W[j]]++;
        }
    }
    for (int i = num; i >= 1; i--)
    {
        for (int j = l[i]; j <= r[i]; j++)
        {
            for (int k = j + 1; k <= r[i]; k++)
                if (2 * W[k] - W[j] >= 0)
                    ans += Sum2[2 * W[k] - W[j]];
        }
        for (int j = l[i]; j <= r[i]; j++)
            Sum2[W[j]]++;
    }
    N = 65536;
    Inv = pow_mod(N, MOD - 2, MOD);
    for (int i = 1; i < N; i++)
        if (i & 1) 
            rev[i] = (rev[i >> 1] >> 1) | (N >> 1);
        else
            rev[i] = (rev[i >> 1] >> 1);
    for (int i = 2; i < num; i++)
    {
        memset (A, 0, sizeof (A));
        memset (B, 0, sizeof (B));
        for (int j = 0; j < l[i]; j++) A[W[j]]++;
        for (int j = r[i] + 1; j < n; j++) B[W[j]]++;
        FFt(A, 1), FFt(B, 1);
        for (int j = 0; j < N; j++) A[j] = A[j] * B[j];
        FFt(A, -1);
        for (int j = l[i]; j <= r[i]; j++) ans += A[2 * W[j]];
    }
    printf ("%lld\n", ans);
}
本文作者 : NekoMio
知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
本文链接 : https://www.nekomio.com/2018/02/25/127/
上一篇
下一篇