复习一下。
介绍
对于一个 $k$ 次多项式,已知 $k+1$ 个点值就能确定这个多项式,使用待定系数法,高斯消元求解多项式的时间复杂度是 $O(k^3)$ 而拉格朗日插值允许我们在 $O(k^2)$ 的时间内求出另一个点值、
考虑构造 $n$ 个函数 $f_1(x), f_2(x), \cdots, f_n(x)$,使得对于第 $i$ 个函数 $f_i(x)$,其图像过 $\begin{cases}(x_j,0),j\neq i\\(x_i,y_i),j=i\end{cases}$,则可知所求的函数 $f(x)=\sum\limits_{i=1}^nf_i(x)$。
这是很显然的,可以发现我们构造出来的函数可以恰好过所有给出的点。
下面给出每个函数的构造:
$$ f_i(x)=y_i\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j} $$
正确性很显然,首先这是个 $k$ 次多项式,而且当 $x\neq x_i$ 时分子为 $0$,当 $x=x_i$ 时分子分母相等。
故
$$ f(x)=\sum_{i=1}^ny_i\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j} $$
代码没有,因为没啥东西,代入式子算就行。别忘了乘上 $y_i$。
横坐标连续时线性做法
假设我们已知 $(1,y_1),\dots,(k+1,y_{k+1})$
$$ \begin{aligned} f(x)&=\sum\limits_{i=1}^{k+1}y_i\prod\limits_{j\ne i}\frac{x-x_j}{x_i-x_j}\\ &=\sum\limits_{i=1}^{k+1}y_i\prod\limits_{j\ne i}\frac{x-j}{i-j} \end{aligned} $$
分子
$$ \prod_{j=1}^{i-1}(x-j)\prod_{j=i+1}^{k+1}(x-j) $$
分母
$$ (-1)^{k+1-i}(i-1)!(k+1-i)! $$
预处理阶乘、阶乘逆元、$(x-j)$ 的前缀积、后缀积即可。别忘了乘上 $y_i$。
自然数幂和
求
$$ \sum_{i=0}^ni^k $$
这是一个 $k+1$ 次多项式(这里没有证明)。
因为 $i^k$ 是完全积性函数,所以用线性筛可以 $O(k)$ 求出点值。再使用上面的线性插值,即可做到 $O(k)$ 的时间复杂度。
例题:CF622F
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
using namespace std;
namespace solve
{
const int maxn = 1e6 + 10;
typedef long long ll;
const int mod = 1e9 + 7;
ll qpow(ll a, ll x, ll p)
{
ll res = 1;
for (; x; x >>= 1, a = a * a % p)
if (x & 1)
res = res * a % p;
return res;
}
ll calc(int n, int k)
{
static ll y[maxn];
static ll fac[maxn], ifac[maxn], pre[maxn], suf[maxn];
static int pri[maxn], vis[maxn], tot;
y[1] = 1;
for (int i = 2; i <= k + 2; i++)
{
if (!vis[i])
pri[++tot] = i, y[i] = qpow(i, k, mod);
for (int j = 1; j <= tot && i * pri[j] <= k + 2; j++)
{
vis[i * pri[j]] = 1;
y[i * pri[j]] = y[i] * y[pri[j]] % mod;
if (i % pri[j] == 0)
break;
}
}
for (int i = 1; i <= k + 2; i++)
y[i] = (y[i] + y[i - 1]) % mod;
if (n <= k + 2)
return y[n];
fac[0] = 1;
for (int i = 1; i <= k + 2; i++)
fac[i] = fac[i - 1] * i % mod;
ifac[k + 2] = qpow(fac[k + 2], mod - 2, mod);
for (int i = k + 1; i >= 0; i--)
ifac[i] = ifac[i + 1] * (i + 1) % mod;
pre[0] = 1, suf[k + 3] = 1;
for (int i = 1; i <= k + 2; i++)
pre[i] = pre[i - 1] * (n - i) % mod;
suf[k + 2] = n - (k + 2);
for (int i = k + 1; i >= 1; i--)
suf[i] = suf[i + 1] * (n - i) % mod;
ll res = 0;
k++;
for (int i = 1; i <= k + 1; i++)
res += ((k - i + 1) % 2 ? -1 : 1) * ifac[i - 1] * ifac[k - i + 1] % mod * pre[i - 1] % mod * suf[i + 1] % mod * y[i] % mod;
return (res % mod + mod) % mod;
}
void main()
{
int n, k;
cin >> n >> k;
cout << calc(n, k) << endl;
}
}
int main()
{
ios::sync_with_stdio(false), cin.tie(0);
int T = 1;
// cin >> T;
while (T--)
solve::main();
}