素数求和的动态规划方法

@jameslao / www.jlao.net

怎么求出从 $2$ 到 $N$ 之间的所有素数之和?

素数求和是数论中一个很典型的问题。这事情乍听起来并不难,只要把素数都求出来再加起来就好了嘛。可能你已经了解了一些求素数的方法,比如最简单的试除。写过这个程序的同学都知道这样搞有多慢。或者常用的是埃氏筛法,简单来说,就是列好从 $2$ 到 $N$ 的范围,先用 $2$ 去筛,把从 $2 \times 2$ 开始的 2 的倍数剔掉;再用下一个素数 3 筛,把 $3 \times 3$ 开始的 3 的倍数剔除掉…… 一直筛完不大于 $\sqrt{N}$ 所有素数为止。我直接从 wiki 上借了一个动画来演示这个过程,有兴趣的同学可以自己写一个试试看。

本文中的方法来自于解决了 ProjectEuler 全部问题的大神 Lucy_Hedgehog。

这个巧妙的思路是这样的:按照筛法,用从 $2$ 到 $N$ 之间所有小于等于 $m$ 的素数筛过后,设剩下的所有数字的和是 $S_{N, m}$,也就是说 $S_{N, m}$ 是 (1) 要么是素数 (2) 要么是素因子全部大于 $m$ 的所有数之和。

怎么算 $S_{N, m}$ 呢?注意筛法是从各个素数的平方 $p^2$ 开始筛的(原因很显然哈),所以如果 $m$ 是合数,或者 $N < m^2$,那么从 $S_{N, m-1}$ 到 $S_{N, m}$ 并没有多筛掉数字,于是

$$S_{N, m} = S_{N, m-1} \qquad m \notin \mathbb{P} \quad \text{or} \quad N < m^2$$

如果 $m$ 恰好是个素数呢?看看在这一轮里面都筛掉了谁……筛掉的是最小素因子是 $m$ 的所有合数,也就是 $m$ 乘上一个素因子不比 $m$ 小的数。注意到这个乘积不能超过 $N$,那么 $S_{N/m, m-1}$ 就是 $N/m$ 之内的 (1)要么是素数 (2)要么素因子大于等于 $m$ 的数之和。我们再去掉里面比 $m$ 小的素数的和 $S_{m-1, m-1}$,于是就有

$$S_{N, m} = S_{N, m-1} – m \cdot (S_{N/m, m-1} – S_{m-1,m-1}), \quad m \in \mathbb{P}, N \leq m^2$$

从 $2$ 到 $100$ 中间筛掉 $7$ 的时候,实际上筛掉的是 $49$、$77$、$91$,我们可以找 $S_{100/7, 6} = 2 + 3 + 5 + 7 + 11 + 13$,再去掉里面的 $S_{6,6} = 2 + 3 + 5$,再乘上 $7$ 就是这个筛选所造成的变化。 递推关系有了,我们就写程序来算吧。由于我们只关心 $N$、$N/m$ 和 $m-1$,所以其实并不是所有的 $N$ 都需要计算:


In [4]:
def prime_sum(num):
    r = int(num**0.5)
    N = [num // i for i in range(1, r + 1)] + list(range(num // r - 1, 0, -1))
    S = {i: i * (i + 1) // 2 - 1 for i in N}
    for m in range(2, r + 1):
        if S[m] > S[m - 1]:  # m is prime
            for n in N:
                if n < m * m:
                    break
                S[n] -= m * (S[n // m] - S[m - 1])
    return S[num]

%time print(prime_sum(10**10))


2220822432581729238
CPU times: user 6.8 s, sys: 24 ms, total: 6.82 s
Wall time: 6.82 s

在 cPython 下计算也只需要几秒钟,而如果要把每个的素数都算出来的话就远远不止这点时间了呢。 同样的思路,稍加变形就可以计算素数的个数。当然性能和专门的计数函数还是不能比,这个可以下次再写咯~


In [3]:
def prime_count(num):
    r = int(num**0.5)
    N = [num // i for i in range(1, r + 1)] + list(range(num // r - 1, 0, -1))
    S = {i: i - 1 for i in N}
    for m in range(2, r + 1):
        if S[m] > S[m - 1]:  # m is prime
            for n in N:
                if n < m * m:
                    break
                S[n] -= S[n // m] - S[m - 1]
    return S[num]

print(prime_count(10**10))


455052511