Python で二項係数 nCr を高速に計算したい

二項係数nCrをPythonで計算する

AtCoder の問題を解いていると、 mod p の条件下で高速に二項係数 nCr を求める場面に多く遭遇するのでそのあたりの知識をまとめます。

二項係数 nCr を計算する

まずは普通に計算してみましょう。これは二項係数を階乗を用いて表示した

$$ _n \mathrm{C} _r = \frac{n!}{r! \times (n-r)!} $$

から、次のように求めることができます。

from math import factorial
print(factorial(n) // factorial(r) // factorial(n - r))

mod p の条件下で nCr を計算する

競プロの世界では \(_n \mathrm{C} _r\) を計算する問題に、 \(10^9 + 7\) で割ったあまりを求めさせるという制約が付いたものが多く存在します。

今回はこのような条件下で高速に \(_n \mathrm{C} _r\) を計算するにはどうすればよいかを考えます。(これが本題です)

サトゥー

この条件がつく時点で大きな n, r を扱うことが多く、高速に計算できるアルゴリズムを考えないと TLE 祭りになってしまうわけですね

結論

結論から言うと、 \(_n \mathrm{C} _r = (n!) \times (r!)^{-1} \times ((n-r)!)^{-1}\) となることを利用して、次のように計算することで高速に処理を行うことが可能です。

def cmb(n, r, p):
    if (r < 0) or (n < r):
        return 0
    r = min(r, n - r)
    return fact[n] * factinv[r] * factinv[n-r] % p

p = 10 ** 9 + 7
N = 10 ** 6  # N は必要分だけ用意する
fact = [1, 1]  # fact[n] = (n! mod p)
factinv = [1, 1]  # factinv[n] = ((n!)^(-1) mod p)
inv = [0, 1]  # factinv 計算用
 
for i in range(2, N + 1):
    fact.append((fact[-1] * i) % p)
    inv.append((-inv[p % i] * (p // i)) % p)
    factinv.append((factinv[-1] * inv[-1]) % p)

print(cmb(n, r, p))

cmb は二項係数 \(_n \mathrm{C} _r\) を求める関数を指すとします。(パラメータは場合によって変わります)組合せを意味する Combination の略です。

やっていることをカンタンに説明すると次のようになります。

  • \(n!, (n!)^{-1}\) について、 \(p\) で割ったあまりを配列にまとめておく(下準備)
  • \begin{align*}
    _n \mathrm{C} _r = (n!) \times (r!)^{-1} \times ((n-r)!)^{-1}
    \end{align*}

    に基づいて計算する

下準備について、 \(n!\) を \(p\) で割ったあまりを fact に、 \((n!)^{-1}\) を \(p\) で割ったあまりを factinv にそれぞれ格納しています。

こうすることで、掛け算自体は \(O(1)\) でできるので、実質階乗を求めるのに必要な \(O(n)\) の計算量で処理をすることができます。

mod p における n! の計算

まずは fact[n] = n! を p で割ったあまり の計算について見ていきましょう。

これは \(0! =1\) であることに注意して、次のように作成することができそうです。

fact = [1]  # 0!

for i in range(1, N + 1):
    fact.append((fact[-1] * i) % p)

mod p における n! の逆元の計算

次に factinv[n] = mod p における n! の逆元 (n!)^(-1) の計算を見てみます。これは、次のように分解して考えることで、先程やった \(n!\) の計算と同じように処理できそうです。

$$ (n!)^{-1} = \prod_{i = 1}^{n} i^{-1} $$

つまり、mod p における \(a\) の逆元 \(a^{-1}\) を求めてやれば、それを掛け算していくことで mod p における \(n!\) の逆元 \((n!)^{-1}\) も求められるわけです。

モジュラ逆数

この逆元というやつは モジュラ逆数 と呼ばれるやつで、次のような数になるようです。

モジュラ逆数は、与えられた整数 \(a\) と法 \(p\) に関して

$$ a^{-1} \equiv x {\pmod {p}}$$

という関係にある整数 \(x\) をいう。

具体例を考えましょう。

ex. 整数 3 の法 11 に関するモジュラ逆数 x を求めたい

つまり、 \( 3^{-1} \equiv x {\pmod {11}}\) なる \(x\) を計算するということになるが、これは次式を満たす \(x\) を計算することである。

$$ 3 x \equiv 1 \pmod{11} $$

これは、\( 3 \times 4 = 12 \equiv 1 \pmod{11} \) から \(x = 4\) と求まる。

(一般には \(x = 4 + 11 k\) の形をしている)

このようにモジュラ逆数という考え方を用いて \(n^{-1}\) にも \(p\) で割ったあまりのようなものを定義してあげることで、結局 \(_n \mathrm{C} _r = (n!) \times (r!)^{-1} \times ((n-r)!)^{-1}\) の mod p における計算は、それぞれの mod p における値を掛け算した値になることがわかります。

モジュラ逆数が満たす性質

具体的にモジュラ逆数を計算するために、次のような性質を利用します。

\(s\) を \(t\) で割った商を \(s // t\) 、余りを \(s \% t\) と表記する。

このとき mod p における \(a\) の逆元 \(a^{-1}\) について、次の関係が成り立つ。

\begin{align*}
a^{-1} \equiv -(p \% a)^{-1} \times (p // a) \pmod{p}
\end{align*}

証明は次のようになります。

証明

\(p\) を \(a\) で割ると

$$ p = (p // a) \times a + (p \% a) $$

が成立。両辺の mod p を取って、

\( (p // a) \times a + (p \% a) \equiv 0 \\
\Leftrightarrow (p // a) + (p \% a) \times a^{-1} \equiv 0 \\
\Leftrightarrow (p \% a) \times a^{-1} \equiv -(p // a) \\
\Leftrightarrow a^{-1} \equiv -(p \% a)^{-1} \times (p // a) \\
\)

と変形できる。(証明終)

mod p における n の逆元の計算

さて、準備が整ったので inv[n] = mod p における n の逆元 (n)^(-1) から計算します。これは上の式から次のようにできます。

inv = [0, 1]  # 便宜上 inv[0] = 0 とした

for i in range(2, N + 1):
    inv.append((-inv[p % i] * (p // i)) % p)

再び、 mod p における n! の逆元の計算

本題に戻ります。factinv[n] = mod p における n! の逆元 (n!)^(-1) の話をしているんでしたね、ここまできたら inv を掛け合わせていくだけなので次のようにできます。

factinv = [1, 1]

for i in range(2, N + 1):
    factinv.append((factinv[-1] * inv[i]) % p)

改めて、結論

以上を踏まえてもう一度コードを見てみましょう。余計なコメントは消してあります。

def cmb(n, r, p):
    if (r < 0) or (n < r):
        return 0
    r = min(r, n - r)
    return fact[n] * factinv[r] * factinv[n-r] % p

p = 10 ** 9 + 7
N = 10 ** 6  # N は必要分だけ用意する
fact = [1, 1]
factinv = [1, 1]
inv = [0, 1]
 
for i in range(2, N + 1):
    fact.append((fact[-1] * i) % p)
    inv.append((-inv[p % i] * (p // i)) % p)
    factinv.append((factinv[-1] * inv[-1]) % p)

print(cmb(n, r, p))

参考記事

さらに発展的な内容や今回参考にした内容を最後に紹介します。

興味がある方はこちらにリンクをまとめておいたので御覧ください。

【Python】組み合わせ(nCr) 計算の高速化 – Qiita

よくやる二項係数 (nCk mod. p)、逆元 (a^-1 mod. p) の求め方 – けんちょんの競プロ精進記録

二項係数 (nCr) の計算方法 – izumo’s diary

この記事を書いた人

サトゥー

東大学際情報学府M1。情報科学と教養の海に溺れています。面白いことをやるのがすきです。