968 字
5 分钟
BZOJ1129 [POI2008]Per
2018-03-24

Description#

给你一个序列s,你把这个序列的所有不同排列按字典序排列后,求s的排名mod m

Input#

序列的长度n<300000,mn < 300000,m
nn个数,代表序列s

Output#

排名mod m

Sample Input#

4 1000
2 1 10 2

Sample Output#

5

HINT#

All the permutations smaller (with respect to lexicographic order) than the one given are: (1,2,2,10), (1,2,10,2), (1,10,2,2) and (2,1,2,10). 2m1092 \leq m \leq 10^9, 1Si3000001 \leq S_i \leq 300000

题解#

逐位考虑
考虑前i1i-1位与这个数列相等, 第ii位比他小的数量
设有g[i]g[i]个数比他这一位小,那么答案为g[i]F[n1]j=1nF[w[i]]g[i]*\frac{F[n - 1]}{\prod_{j = 1}^{n}{F[w[i]]}}
其中F[i]F[i]ii的阶乘,W[i]W[i]为排除了前面已经确定的数后ii的数量

这样就可以算出答案了
然而我们发现模数可能不是质数,那怎么办呢?
我们可以把他质因数分解,将模数拆成pkp^k相乘的形式
对于每一个pkp^k求模他的结果然后使用CRTCRT合并
可是阶乘如果是pkp^k的倍数就会得00, 我们可以将阶乘拆成apba*p^b的形式
可以实现乘法与除法。
然后解决了。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
const int MAXN = 300005;
long long a[MAXN], b[MAXN], c[MAXN], g[MAXN];
int Sum[MAXN], n, m;
long long B, P, PhiP, K;
long long pow_mod(long long a, int b, long long MP)
{
    long long ans = 1;
    while (b)
    {
        if (b & 1) ans = ans * a % MP;
        b >>= 1;
        a = a * a % MP;
    }
    return ans;
}
long long Inv(long long x)
{
    return pow_mod(x, PhiP - 1, P);
}
struct Num
{
    long long a, b;
    Num(){a = 0, b = 0; }
    Num(int _a)
    {
        a = _a;
        while (a % B == 0)
            a /= B, b++;
    }
    Num(long long _a, long long _b): a(_a), b(_b) {}
    Num operator * (const Num &c) { return Num(a * c.a % P, b + c.b); }
    Num operator / (const Num &c) { return Num(a * Inv(c.a) % P, b - c.b); }
    long long val()
    {
        if (!a || b >= K) return 0;
        long long x = a, k = b, c = B;
        while (k)
        {
            if (k & 1) x = x * c % P;
            k >>= 1;
            c = c * c % P;
        }
        return x;
    }
    void Set(int x)
    {
        a = x, b = 0;
        while (a % B == 0)
            a /= B, b++;
    }
}F[MAXN];
#define lowbit(_) ((_) & (-_))
void A(int x, int c)
{
    for (int i = x; i <= n; i += lowbit(i))
        Sum[i] += c;
}
int Q(int x)
{
    int ans = 0;
    for (int i = x; i >= 1; i -= lowbit(i))
        ans += Sum[i];
    return ans;
}
int cnt;
long long ans[5000], Mp[5000], Phi[5000];
void Solve()
{
    cnt++;
    Mp[cnt] = P, Phi[cnt] = PhiP;
    F[0].Set(1);
    for (int i = 1; i <= n; i++) F[i].Set(i);
    for (int i = 1; i <= n; i++) F[i] = F[i] * F[i - 1];
    for (int i = 1; i <= n; i++) c[i] = 0;
    for (int i = 1; i <= n; i++) c[a[i]]++;
    Num t(1, 0), tmp;
    for (int i = 1; i <= n; i++) t = t * F[c[i]];
    for (int i = 1; i <= n; i++)
    {
        if (g[i]) tmp.Set(g[i]), ans[cnt] = (ans[cnt] + (tmp * F[n - i] / t).val()) % P;
        t = t / F[c[a[i]]] * F[c[a[i]] - 1];
        c[a[i]]--;
    }
}
void Divide(int x)
{
    for (int i = 2; i * i <= x; i++)
        if (x % i == 0)
        {
            B = i, P = 1, PhiP = 1, K = 0;
            while (x % i == 0)
                P *= i, PhiP *= i, K++, x /= i;
            PhiP /= i, PhiP *= i - 1;
            Solve();
        }
    if (x != 1)
    {
        B = x, P = x, PhiP = x - 1, K = 1;
        Solve();
    }
}
long long CRT(long long *a, long long *b, int n)
{
    long long N = 1, Ni, now, ans = 0;
    for (int i = 1; i <= n; i++) N *= a[i];
    for (int i = 1; i <= n; i++)
    {
        Ni = N / a[i];
        now = pow_mod(Ni, Phi[i] - 1, a[i]);
        ans = (ans + (b[i] * now % N) * Ni % N) % N;    
    }
    return ans;
}
int main()
{
    n = read(), m = read();
    for (int i = 1; i <= n; i++)
        b[i] = a[i] = read();
    sort(b + 1, b + n + 1);
    int cnt = unique(b + 1, b + n + 1) - b - 1;
    for (int i = 1; i <= n; i++)
        a[i] = lower_bound(b + 1, b + cnt + 1, a[i]) - b;
    for (int i = 1; i <= n; i++)
        c[a[i]]++;
    for (int i = 1; i <= n; i++) A(i, c[i]);
    for (int i = 1; i <= n; i++)
        g[i] = Q(a[i] - 1), A(a[i], -1);
    Divide(m);
    printf ("%lld\n", (CRT(Mp, ans, ::cnt) + 1) % m);
}

BZOJ1129 [POI2008]Per
https://www.nekomio.com/posts/139/
作者
NekoMio
发布于
2018-03-24
许可协议
CC BY-NC-SA 4.0