Description
有两个 $B$ 进制数 $L,R$,求区间 $[L,R]$ 中,将所有 $B$ 进制数看成一个字符串,所有字符串的所有连续子串对应 $B$ 进制数的和(十进制) $\bmod 20130427$。
Solution
这个数位DP实在是太毒瘤了。。。
看了看题解才搞懂,于是我来记录一下这道题。
为方便,下面记 一个数所有连续子串的和 为其权值,记为 $w(n)$;一个数所有后缀的对应数字的和为 $s(n)$;一个数字位数为 $len(n)$。
首先考虑有一个数 $p$,在其后面填一个数 $q$,得到新数 $\overline{pq}$ 的权值 $$ w(\overline{pq})=w(p)+B\times s(p)+len(\overline{pq})\times q $$ 这时我们发现,后半部分实际就是 $s(\overline{pq})$,那么 $w(\overline{pq})=w(p)+s(\overline{pq})$,其中 $s(\overline{pq})=B\times s(p)+len(\overline{pq})\times q$,且 $len(\overline{pq})=len(p)+1$。
上述过程都还算显然。于是我们找到了一个递推的思路。
因为我们要对 $w$ 求和,如果枚举 $0\sim B-1$,直接用以上的方法递推,复杂度会达到 $O(B(n+m))$。考虑优化:
既然我们要求和,那么就直接记录答案,然后大力推式子。
按照数位DP的传统思路,设 $f(i,0/1)$ 表示从高到低 $i$ 位,否/是 紧贴上界的 $\sum{w}$。
你就会发现式子开始变得不那么显然了起来。
记录 $sum(n,0/1)=\sum s(n,0/1)$;$sl(n,0/1)=\sum len(n,0/1)$;$a(i,0/1)$ 为 否/是 紧贴上界时数的个数。
注意,当以上的第二维取 $1$ 时,要注意实际上在该位的取值只有一种(虽然有 $\sum$)。
下面记 $p(i)$ 为当前的位数上的数字。我把低的位存在数组前面,所以第 $n$ 位由 $n+1$ 为转移来。
那么 $f(n,1)=f(n+1,1)+sum(n,1)$,
而 $sum(n,1)=sum(n+1,1)\times B+p(n)\times sl(n,1)$,
其中 $sl(n,1)=sl(n+1,1)+a(n,1)$,$a(n,1)=a(n+1,1)$。
上面这些式子都还算显然,接下来的转移就有点怪了。 $$ f(n,0)=f(n+1,0)\times B+sum(n,0)+f(n+1,1)\times p(n) $$ 前半部分就是我们之前推的式子,解释一下 $f(n+1,1)\times p(n)$:$sum(n,0)$ 已经包括了前 $n$ 位数中所有不紧贴上界的后缀和的和,这也包括了在 $n+1$ 位之前紧贴上界,在第 $n$ 位不紧贴的数。而 $f(n+1,0)$ 是上一位不紧贴的答案。我们发现,上一位紧贴的全部答案都不在现在的后缀和中,也就是没有被计算进去,而这种没有被计算的数恰好有 $p(n)$ 个($0\sim p-1$)。所以有上式。
接下来要求 $sum(n,0)$,这个式子是真的恶心。首先记 $pre(n)=\sum_0^{n-1}$。
当 $n$ 是最高位的时候 $t=0$,否则 $t=B$,这是为了处理前导零,具体作用你把它换成 $B$ 就能明白了吧。 $$ \begin{align*} sum(n,0)=&sum(n+1,1)\times B\times p(n)+pre(p(n))\times sl(n,1)+\\ &sum(n+1,0)\times B\times B+pre(B)\times(sl(n+1,0)+a(n+1,0))+\\ &pre(t) \end{align*} $$ 解释暂时咕了
然后是其中的 $sl(n,0)$: $$ sl(n,0)=t-1+sl(n+1,1)\times p(n)+(sl(n+1,0)+a(n+1,0))\times B $$
对于上一位紧贴上界,现在有 $p(n)$ 个。
解释暂时咕了
对于 $a(n,0)$,有下面一个还算显然的式子: $$ a(n,0)=t-1+a(n+1,0)\times B+a(n+1,1)\times p(n) $$
这个式子不解释了。
于是把以上式子写成代码,这题就做完了。
Code
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
using std::cout;
using std::endl;
const int maxn = 1e5 + 10;
const int mod = 20130427;
typedef long long ll;
ll f[maxn][2], sum[maxn][2], sl[maxn][2], a[maxn][2];
ll B, n, m;
ll S[maxn], l[maxn], r[maxn];
ll solve(ll *p, ll l)
{
memset(f, 0, sizeof(f)), memset(sum, 0, sizeof(sum));
memset(sl, 0, sizeof(sl)), memset(a, 0, sizeof(a));
a[l][1] = 1;
for (int i = l - 1; i >= 0; i--)
{
int t = i == l - 1 ? 0 : B;
a[i][1] = a[i + 1][1];
a[i][0] = (t - 1 + a[i + 1][0] * B % mod + a[i + 1][1] * p[i] % mod) % mod;
sl[i][1] = sl[i + 1][1] + a[i + 1][1];
sl[i][0] = (t - 1 + sl[i][1] * p[i] + (sl[i + 1][0] + a[i + 1][0]) * B % mod) % mod;
sum[i][1] = (sum[i + 1][1] * B % mod + p[i] * sl[i][1]) % mod;
sum[i][0] = (S[t] + sum[i + 1][1] * B * p[i] + S[p[i]] * sl[i][1] +
sum[i + 1][0] * B % mod * B % mod +
S[B] * (sl[i + 1][0] + a[i + 1][0])) %
mod;
f[i][1] = (f[i + 1][1] + sum[i][1]) % mod;
f[i][0] = (f[i + 1][0] * B + sum[i][0] + f[i + 1][1] * p[i]) % mod;
}
return (f[0][1] + f[0][0]) % mod;
}
int main()
{
scanf("%lld", &B);
for (int i = 0; i < B; i++)
S[i + 1] = S[i] + i;
scanf("%lld", &n);
for (int i = 0; i < n; i++)
scanf("%lld", &l[n - i - 1]);
for (int i = 0; i < n; i++)
{
if (l[i] > 0)
{
l[i]--;
break;
}
l[i] = B - 1;
}
if (l[n - 1] == 0)
n--;
scanf("%lld", &m);
for (int i = 0; i < m; i++)
scanf("%lld", &r[m - i - 1]);
printf("%lld\n", (solve(r, m) - solve(l, n) + mod) % mod);
return 0;
}