复习一下。

介绍

对于一个 kk 次多项式,已知 k+1k+1 个点值就能确定这个多项式,使用待定系数法,高斯消元求解多项式的时间复杂度是 O(k3)O(k^3) 而拉格朗日插值允许我们在 O(k2)O(k^2) 的时间内求出另一个点值、

考虑构造 nn 个函数 f1(x),f2(x),,fn(x)f_1(x), f_2(x), \cdots, f_n(x),使得对于第 ii 个函数 fi(x)f_i(x),其图像过 {(xj,0),ji(xi,yi),j=i\begin{cases}(x_j,0),j\neq i\\(x_i,y_i),j=i\end{cases},则可知所求的函数 f(x)=i=1nfi(x)f(x)=\sum\limits_{i=1}^nf_i(x)

这是很显然的,可以发现我们构造出来的函数可以恰好过所有给出的点。

下面给出每个函数的构造:

fi(x)=yijixxjxixj f_i(x)=y_i\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j}

正确性很显然,首先这是个 kk 次多项式,而且当 xxix\neq x_i 时分子为 00,当 x=xix=x_i 时分子分母相等。

f(x)=i=1nyijixxjxixj f(x)=\sum_{i=1}^ny_i\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j}

代码没有,因为没啥东西,代入式子算就行。别忘了乘上 yiy_i

横坐标连续时线性做法

假设我们已知 (1,y1),,(k+1,yk+1)(1,y_1),\dots,(k+1,y_{k+1})

f(x)=i=1k+1yijixxjxixj=i=1k+1yijixjij \begin{aligned} f(x)&=\sum\limits_{i=1}^{k+1}y_i\prod\limits_{j\ne i}\frac{x-x_j}{x_i-x_j}\\ &=\sum\limits_{i=1}^{k+1}y_i\prod\limits_{j\ne i}\frac{x-j}{i-j} \end{aligned}

分子

j=1i1(xj)j=i+1k+1(xj) \prod_{j=1}^{i-1}(x-j)\prod_{j=i+1}^{k+1}(x-j)

分母

(1)k+1i(i1)!(k+1i)! (-1)^{k+1-i}(i-1)!(k+1-i)!

预处理阶乘、阶乘逆元、(xj)(x-j) 的前缀积、后缀积即可。别忘了乘上 yiy_i

自然数幂和

i=0nik \sum_{i=0}^ni^k

这是一个 k+1k+1 次多项式(这里没有证明)。

因为 iki^k 是完全积性函数,所以用线性筛可以 O(k)O(k) 求出点值。再使用上面的线性插值,即可做到 O(k)O(k) 的时间复杂度。

例题:CF622F

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
using namespace std;

namespace solve
{
    const int maxn = 1e6 + 10;
    typedef long long ll;
    const int mod = 1e9 + 7;

    ll qpow(ll a, ll x, ll p)
    {
        ll res = 1;
        for (; x; x >>= 1, a = a * a % p)
            if (x & 1)
                res = res * a % p;
        return res;
    }

    ll calc(int n, int k)
    {
        static ll y[maxn];
        static ll fac[maxn], ifac[maxn], pre[maxn], suf[maxn];
        static int pri[maxn], vis[maxn], tot;
        y[1] = 1;
        for (int i = 2; i <= k + 2; i++)
        {
            if (!vis[i])
                pri[++tot] = i, y[i] = qpow(i, k, mod);
            for (int j = 1; j <= tot && i * pri[j] <= k + 2; j++)
            {
                vis[i * pri[j]] = 1;
                y[i * pri[j]] = y[i] * y[pri[j]] % mod;
                if (i % pri[j] == 0)
                    break;
            }
        }
        for (int i = 1; i <= k + 2; i++)
            y[i] = (y[i] + y[i - 1]) % mod;
        if (n <= k + 2)
            return y[n];
        fac[0] = 1;
        for (int i = 1; i <= k + 2; i++)
            fac[i] = fac[i - 1] * i % mod;
        ifac[k + 2] = qpow(fac[k + 2], mod - 2, mod);
        for (int i = k + 1; i >= 0; i--)
            ifac[i] = ifac[i + 1] * (i + 1) % mod;
        pre[0] = 1, suf[k + 3] = 1;
        for (int i = 1; i <= k + 2; i++)
            pre[i] = pre[i - 1] * (n - i) % mod;
        suf[k + 2] = n - (k + 2);
        for (int i = k + 1; i >= 1; i--)
            suf[i] = suf[i + 1] * (n - i) % mod;
        ll res = 0;
        k++;
        for (int i = 1; i <= k + 1; i++)
            res += ((k - i + 1) % 2 ? -1 : 1) * ifac[i - 1] * ifac[k - i + 1] % mod * pre[i - 1] % mod * suf[i + 1] % mod * y[i] % mod;
        return (res % mod + mod) % mod;
    }

    void main()
    {
        int n, k;
        cin >> n >> k;
        cout << calc(n, k) << endl;
    }
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int T = 1;
    // cin >> T;
    while (T--)
        solve::main();
}