复习一下。

介绍

对于一个 $k$ 次多项式,已知 $k+1$ 个点值就能确定这个多项式,使用待定系数法,高斯消元求解多项式的时间复杂度是 $O(k^3)$ 而拉格朗日插值允许我们在 $O(k^2)$ 的时间内求出另一个点值、

考虑构造 $n$ 个函数 $f_1(x), f_2(x), \cdots, f_n(x)$,使得对于第 $i$ 个函数 $f_i(x)$,其图像过 $\begin{cases}(x_j,0),j\neq i\\(x_i,y_i),j=i\end{cases}$,则可知所求的函数 $f(x)=\sum\limits_{i=1}^nf_i(x)$。

这是很显然的,可以发现我们构造出来的函数可以恰好过所有给出的点。

下面给出每个函数的构造:

$$ f_i(x)=y_i\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j} $$

正确性很显然,首先这是个 $k$ 次多项式,而且当 $x\neq x_i$ 时分子为 $0$,当 $x=x_i$ 时分子分母相等。

$$ f(x)=\sum_{i=1}^ny_i\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j} $$

代码没有,因为没啥东西,代入式子算就行。别忘了乘上 $y_i$。

横坐标连续时线性做法

假设我们已知 $(1,y_1),\dots,(k+1,y_{k+1})$

$$ \begin{aligned} f(x)&=\sum\limits_{i=1}^{k+1}y_i\prod\limits_{j\ne i}\frac{x-x_j}{x_i-x_j}\\ &=\sum\limits_{i=1}^{k+1}y_i\prod\limits_{j\ne i}\frac{x-j}{i-j} \end{aligned} $$

分子

$$ \prod_{j=1}^{i-1}(x-j)\prod_{j=i+1}^{k+1}(x-j) $$

分母

$$ (-1)^{k+1-i}(i-1)!(k+1-i)! $$

预处理阶乘、阶乘逆元、$(x-j)$ 的前缀积、后缀积即可。别忘了乘上 $y_i$。

自然数幂和

$$ \sum_{i=0}^ni^k $$

这是一个 $k+1$ 次多项式(这里没有证明)。

因为 $i^k$ 是完全积性函数,所以用线性筛可以 $O(k)$ 求出点值。再使用上面的线性插值,即可做到 $O(k)$ 的时间复杂度。

例题:CF622F

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
using namespace std;

namespace solve
{
    const int maxn = 1e6 + 10;
    typedef long long ll;
    const int mod = 1e9 + 7;

    ll qpow(ll a, ll x, ll p)
    {
        ll res = 1;
        for (; x; x >>= 1, a = a * a % p)
            if (x & 1)
                res = res * a % p;
        return res;
    }

    ll calc(int n, int k)
    {
        static ll y[maxn];
        static ll fac[maxn], ifac[maxn], pre[maxn], suf[maxn];
        static int pri[maxn], vis[maxn], tot;
        y[1] = 1;
        for (int i = 2; i <= k + 2; i++)
        {
            if (!vis[i])
                pri[++tot] = i, y[i] = qpow(i, k, mod);
            for (int j = 1; j <= tot && i * pri[j] <= k + 2; j++)
            {
                vis[i * pri[j]] = 1;
                y[i * pri[j]] = y[i] * y[pri[j]] % mod;
                if (i % pri[j] == 0)
                    break;
            }
        }
        for (int i = 1; i <= k + 2; i++)
            y[i] = (y[i] + y[i - 1]) % mod;
        if (n <= k + 2)
            return y[n];
        fac[0] = 1;
        for (int i = 1; i <= k + 2; i++)
            fac[i] = fac[i - 1] * i % mod;
        ifac[k + 2] = qpow(fac[k + 2], mod - 2, mod);
        for (int i = k + 1; i >= 0; i--)
            ifac[i] = ifac[i + 1] * (i + 1) % mod;
        pre[0] = 1, suf[k + 3] = 1;
        for (int i = 1; i <= k + 2; i++)
            pre[i] = pre[i - 1] * (n - i) % mod;
        suf[k + 2] = n - (k + 2);
        for (int i = k + 1; i >= 1; i--)
            suf[i] = suf[i + 1] * (n - i) % mod;
        ll res = 0;
        k++;
        for (int i = 1; i <= k + 1; i++)
            res += ((k - i + 1) % 2 ? -1 : 1) * ifac[i - 1] * ifac[k - i + 1] % mod * pre[i - 1] % mod * suf[i + 1] % mod * y[i] % mod;
        return (res % mod + mod) % mod;
    }

    void main()
    {
        int n, k;
        cin >> n >> k;
        cout << calc(n, k) << endl;
    }
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int T = 1;
    // cin >> T;
    while (T--)
        solve::main();
}