容斥、容斥、还是容斥。

description

一个学校有 $n$ 个人,$m$ 门课,每门课有一个最高分 $u_i$(没有 $0$ 分)。

有一个神仙,已知他在所有课的排名 $r_i$(同分同名次)。

定义一个人被神仙碾压当且仅当其所有课的分数都小于等于神仙,并且已知有恰好 $K$ 个人被神仙碾压。

你需要求出所有同学得分情况的方案数。

$n,m\leq100,u_i\leq10^9,r_i\leq n$。

solution

注意到问题可以分为两部分:计数每个人每个科目相对于神仙的偏序关系,和对每门课计数有多少种分数满足条件。两者乘起来就是答案。

第一部分,选出恰好 $K$ 个人被吊打。

选出一部分人被吊打是容易的,他们的名次一定在神仙之后,这是固定的。但其他人发现较难判断。使用经典的二项式反演,$f_i$ 表示选择 $i$ 个人被碾压,其他人无所谓的方案数。被碾压的人已经固定了,其他的人无所谓,可以列出来式子:$f_i=\dbinom{n-1}{i}\prod_{j=1}^m\dbinom{n-1-i}{n-i-r_j}$。然后 $res=\sum_{i=K}^{n-1}(-1)^{i-K}\dbinom{i}{K}f_i$。

第二部分,计数分数的方案数。

考虑枚举神仙每门课的分数。

$$ \begin{aligned} res &=\prod_{i=1}^m\sum_{j=1}^{u_i}j^{n-r_i}(u_i-j)^{r_i-1}\\ \end{aligned} $$

这里好多种做法。

法一

后面用二项式定理展开。

$$ \begin{aligned} val_i &=\sum_{j=1}^{u_i}j^{n-r_i}\sum_{k=0}^{r_i-1}\binom{r_i-1}{k}u_i^k(-1)^{r_i-1-k}\\ &=\sum_{k=0}^{r_i-1}\binom{r_i-1}{k}(-1)^{r_i-1-k}u_i^{k}\sum_{j=1}^{u_i}j^{n-1-k} \end{aligned} $$

变换一下枚举顺序就得到了上面的东西。

用插值算自然数幂求和,就得到一个 $O(n^2m\log n)$ 复杂度的算法。

法二

考虑枚举一门课有几种分数,同样发现恰好不好算,但可以钦定一些可选可不选,其他一定不选,于是 $g_j$ 表示最多 $j$ 种分数的方案数,$g_j=\sum_{k=1}^jk^{n-r_i}(j-k)^{r_i-1}$,同样二项式反演,$h_j=\sum_{k=1}^j(-1)^{j-k}\dbinom{j}{k}g_k$,可以得到答案,然后 $h_j$ 的贡献就是 $\dbinom{u_i}{j}h_j$。

法三

用第二类斯特林数暴力展开普通幂。

$$ \begin{aligned} res &=\sum_{j=1}^{u_i}j^{n-r_i}(u_i-j)^{r_i-1}\\ &=\sum_{j=1}^{u_i}\sum_{x=0}^{n-r_i}\sum_{y=0}^{r_i-1}{n-r_i\brace x}{r-1\brace y}x!y!\binom{j}{x}\binom{u_i-j}{y}\\ &=\sum_{x=0}^{n-r_i}\sum_{y=0}^{r_i-1}{n-r_i\brace x}{r-1\brace y}x!y!\sum_{j=1}^{u_i}\binom{j}{x}\binom{u_i-j}{y}\\ &=\sum_{x=0}^{n-r_i}\sum_{y=0}^{r_i-1}{n-r_i\brace x}{r-1\brace y}x!y!\binom{u_i+1}{x+y+1} \end{aligned} $$

最后一步是组合意义,得到这个式子之后直接算就好了。

code

写的插值做法。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
using namespace std;

namespace solve
{
    const int maxn = 110;
    typedef long long ll;
    const int mod = 1e9 + 7;

    ll qpow(ll a, ll x, ll p)
    {
        ll res = 1;
        for (; x; x >>= 1, a = a * a % p)
            if (x & 1)
                res = res * a % p;
        return res;
    }

    ll fac[maxn], ifac[maxn];
    ll C(int n, int m) { return m > n ? 0 : fac[n] * ifac[m] % mod * ifac[n - m] % mod; }

    void init(int n = 100)
    {
        fac[0] = 1;
        for (int i = 1; i <= n; i++)
            fac[i] = fac[i - 1] * i % mod;
        ifac[n] = qpow(fac[n], mod - 2, mod);
        for (int i = n - 1; i >= 0; i--)
            ifac[i] = ifac[i + 1] * (i + 1) % mod;
    }

    int r[maxn], u[maxn], n, m, K;

    ll calc1()
    {
        ll res = 0;
        static ll f[maxn];
        for (int i = K; i < n; i++)
        {
            f[i] = 1;
            for (int j = 1; j <= m; j++)
                f[i] = f[i] * C(n - 1 - i, r[j] - 1) % mod;
        }
        for (int i = K; i < n; i++)
            (res += ((i - K) % 2 ? -1 : 1) * C(i, K) * C(n - 1, i) % mod * f[i] % mod);
        return (res % mod + mod) % mod;
    }

    ll calc2(int n, int k)
    {
        static ll y[maxn];
        static ll fac[maxn], ifac[maxn], pre[maxn], suf[maxn];
        for (int i = 1; i <= k + 2; i++)
            y[i] = (y[i - 1] + qpow(i, k, mod)) % mod;
        if (n <= k + 2)
            return y[n];
        fac[0] = 1;
        for (int i = 1; i <= k + 2; i++)
            fac[i] = fac[i - 1] * i % mod;
        ifac[k + 2] = qpow(fac[k + 2], mod - 2, mod);
        for (int i = k + 1; i >= 0; i--)
            ifac[i] = ifac[i + 1] * (i + 1) % mod;
        pre[0] = 1, suf[k + 3] = 1;
        for (int i = 1; i <= k + 2; i++)
            pre[i] = pre[i - 1] * (n - i) % mod;
        suf[k + 2] = n - (k + 2);
        for (int i = k + 1; i >= 1; i--)
            suf[i] = suf[i + 1] * (n - i) % mod;
        ll res = 0;
        k++;
        for (int i = 1; i <= k + 1; i++)
            res += ((k - i + 1) % 2 ? -1 : 1) * ifac[i - 1] * ifac[k - i + 1] % mod * pre[i - 1] % mod * suf[i + 1] % mod * y[i] % mod;
        return res % mod;
    }

    void main()
    {
        init();
        cin >> n >> m >> K;
        for (int i = 1; i <= m; i++)
            cin >> u[i];
        for (int i = 1; i <= m; i++)
            cin >> r[i];
        ll res = calc1();
        for (int i = 1; i <= m; i++)
        {
            ll val = 0;
            for (int k = 0; k < r[i]; k++)
            {
                ll tmp = calc2(u[i], n - 1 - k);
                tmp = tmp % mod * C(r[i] - 1, k) % mod * qpow(u[i], k, mod) % mod * ((r[i] - 1 - k) % 2 ? -1 : 1);
                val += tmp;
            }
            val %= mod, val += mod, val %= mod;
            res = res * val % mod;
        }
        cout << res << endl;
    }
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int T = 1;
    // cin >> T;
    while (T--)
        solve::main();
}