Fast n choose k mod p for large n?
Asked Answered
P

5

50

What I mean by "large n" is something in the millions. p is prime.

I've tried http://apps.topcoder.com/wiki/display/tc/SRM+467 But the function seems to be incorrect (I tested it with 144 choose 6 mod 5 and it gives me 0 when it should give me 2)

I've tried http://online-judge.uva.es/board/viewtopic.php?f=22&t=42690 But I don't understand it fully

I've also made a memoized recursive function that uses the logic (combinations(n-1, k-1, p)%p + combinations(n-1, k, p)%p) but it gives me stack overflow problems because n is large

I've tried Lucas Theorem but it appears to be either slow or inaccurate.

All I'm trying to do is create a fast/accurate n choose k mod p for large n. If anyone could help show me a good implementation for this I'd be very grateful. Thanks.

As requested, the memoized version that hits stack overflows for large n:

std::map<std::pair<long long, long long>, long long> memo;

long long combinations(long long n, long long k, long long p){
   if (n  < k) return 0;
   if (0 == n) return 0;
   if (0 == k) return 1;
   if (n == k) return 1;
   if (1 == k) return n;

   map<std::pair<long long, long long>, long long>::iterator it;

   if((it = memo.find(std::make_pair(n, k))) != memo.end()) {
        return it->second;
   }
   else
   {
        long long value = (combinations(n-1, k-1,p)%p + combinations(n-1, k,p)%p)%p;
        memo.insert(std::make_pair(std::make_pair(n, k), value));
        return value;
   }  
}
Psi answered 12/4, 2012 at 5:57 Comment(8)
do you need to know the exact reminder or is it enough to know whether the number is evenly dividable by p? (n choose k mod p == 0)Dresden
Not sure I understand the question. The answer to n choose k mod p needs to be exact/accurate.Psi
what does the combinations function return(why does it take 3 arguments)Unpolite
combinations function takes three arguments because it's finding (n choose k) mod pPsi
So you need to compute combination(n, k)%p?Unpolite
post the recursive version, someone can perhaps help you turn it into an iterative solution avoiding the stackoverflow exceptionDresden
No, @dbaupp, it is not homework.Psi
The solution on TopCoder works for p > n.Sarsen
U
61

So, here is how you can solve your problem.

Of course you know the formula:

comb(n,k) = n!/(k!*(n-k)!) = (n*(n-1)*...(n-k+1))/k! 

(See http://en.wikipedia.org/wiki/Binomial_coefficient#Computing_the_value_of_binomial_coefficients)

You know how to compute the numerator:

long long res = 1;
for (long long i = n; i > n- k; --i) {
  res = (res * i) % p;
}

Now, as p is prime the reciprocal of each integer that is coprime with p is well defined i.e. a-1 can be found. And this can be done using Fermat's theorem ap-1=1(mod p) => a*ap-2=1(mod p) and so a-1=ap-2. Now all you need to do is to implement fast exponentiation(for example using the binary method):

long long degree(long long a, long long k, long long p) {
  long long res = 1;
  long long cur = a;

  while (k) {
    if (k % 2) {
      res = (res * cur) % p;
    }
    k /= 2;
    cur = (cur * cur) % p;
  }
  return res;
}

And now you can add the denominator to our result:

long long res = 1;
for (long long i = 1; i <= k; ++i) {
  res = (res * degree(i, p- 2)) % p;
}

Please note I am using long long everywhere to avoid type overflow. Of course you don't need to do k exponentiations - you can compute k!(mod p) and then divide only once:

long long denom = 1;
for (long long i = 1; i <= k; ++i) {
  denom = (denom * i) % p;
}
res = (res * degree(denom, p- 2)) % p;

EDIT: as per @dbaupp's comment if k >= p the k! will be equal to 0 modulo p and (k!)^-1 will not be defined. To avoid that first compute the degree with which p is in n*(n-1)...(n-k+1) and in k! and compare them:

int get_degree(long long n, long long p) { // returns the degree with which p is in n!
  int degree_num = 0;
  long long u = p;
  long long temp = n;

  while (u <= temp) {
    degree_num += temp / u;
    u *= p;
  }
  return degree_num;
}

long long combinations(int n, int k, long long p) {
  int num_degree = get_degree(n, p) - get_degree(n - k, p);
  int den_degree = get_degree(k, p);

  if (num_degree > den_degree) {
    return 0;
  }
  long long res = 1;
  for (long long i = n; i > n - k; --i) {
    long long ti = i;
    while(ti % p == 0) {
      ti /= p;
    }
    res = (res * ti) % p;
  }
  for (long long i = 1; i <= k; ++i) {
    long long ti = i;
    while(ti % p == 0) {
      ti /= p;
    }
    res = (res * degree(ti, p-2, p)) % p;
  }
  return res;
}

EDIT: There is one more optimization that can be added to the solution above - instead of computing the inverse number of each multiple in k!, we can compute k!(mod p) and then compute the inverse of that number. Thus we have to pay the logarithm for the exponentiation only once. Of course again we have to discard the p divisors of each multiple. We only have to change the last loop with this:

long long denom = 1;
for (long long i = 1; i <= k; ++i) {
  long long ti = i;
  while(ti % p == 0) {
    ti /= p;
  }
  denom = (denom * ti) % p;
}
res = (res * degree(denom, p-2, p)) % p;
Unpolite answered 12/4, 2012 at 6:17 Comment(9)
Are you just computing n*(n-1)*...*(n-k+1) * (k!)^-1? This is only defined if k < p, otherwise k! == 0 and no inverse exists.Tactic
If k > p then special care should be taken to compute the degree of p in n*(n-1)*...*(n-k+1) and in k! and then to cancel those ocurancesUnpolite
I think the "computing the degree of p and cancelling" bit isn't trivial. At least, not to do efficiently.Tactic
This seems similar to the implementation I showed in the first link I posted about (how 144 choose 6 mod 5 didn't work etc)Psi
I have updated my post please read it again. Sorry for the mistake.Unpolite
I had a typo - you need to replace i with ti in both cycles. I think this is obvious. I have edited my final code and now it should be correct. I have test the code and it does return 2.Unpolite
This seems to work now (I just tested it for very large inputs and it matches my Python output) -- thank you very much!Psi
@IvayloStrandjev I just used this to solve a combinations problem. The last optimization actually was useful to resolve the TLE. :) Thanks so much!Banded
@IvayloStrandjev Thx :)Pr
C
15

For large k, we can reduce the work significantly by exploiting two fundamental facts:

  1. If p is a prime, the exponent of p in the prime factorisation of n! is given by (n - s_p(n)) / (p-1), where s_p(n) is the sum of the digits of n in the base p representation (so for p = 2, it's popcount). Thus the exponent of p in the prime factorisation of choose(n,k) is (s_p(k) + s_p(n-k) - s_p(n)) / (p-1), in particular, it is zero if and only if the addition k + (n-k) has no carry when performed in base p (the exponent is the number of carries).

  2. Wilson's theorem: p is a prime, if and only if (p-1)! ≡ (-1) (mod p).

The exponent of p in the factorisation of n! is usually calculated by

long long factorial_exponent(long long n, long long p)
{
    long long ex = 0;
    do
    {
        n /= p;
        ex += n;
    }while(n > 0);
    return ex;
}

The check for divisibility of choose(n,k) by p is not strictly necessary, but it's reasonable to have that first, since it will often be the case, and then it's less work:

long long choose_mod(long long n, long long k, long long p)
{
    // We deal with the trivial cases first
    if (k < 0 || n < k) return 0;
    if (k == 0 || k == n) return 1;
    // Now check whether choose(n,k) is divisible by p
    if (factorial_exponent(n) > factorial_exponent(k) + factorial_exponent(n-k)) return 0;
    // If it's not divisible, do the generic work
    return choose_mod_one(n,k,p);
}

Now let us take a closer look at n!. We separate the numbers ≤ n into the multiples of p and the numbers coprime to p. With

n = q*p + r, 0 ≤ r < p

The multiples of p contribute p^q * q!. The numbers coprime to p contribute the product of (j*p + k), 1 ≤ k < p for 0 ≤ j < q, and the product of (q*p + k), 1 ≤ k ≤ r.

For the numbers coprime to p we will only be interested in the contribution modulo p. Each of the full runs j*p + k, 1 ≤ k < p is congruent to (p-1)! modulo p, so altogether they produce a contribution of (-1)^q modulo p. The last (possibly) incomplete run produces r! modulo p.

So if we write

n   = a*p + A
k   = b*p + B
n-k = c*p + C

we get

choose(n,k) = p^a * a!/ (p^b * b! * p^c * c!) * cop(a,A) / (cop(b,B) * cop(c,C))

where cop(m,r) is the product of all numbers coprime to p which are ≤ m*p + r.

There are two possibilities, a = b + c and A = B + C, or a = b + c + 1 and A = B + C - p.

In our calculation, we have eliminated the second possibility beforehand, but that is not essential.

In the first case, the explicit powers of p cancel, and we are left with

choose(n,k) = a! / (b! * c!) * cop(a,A) / (cop(b,B) * cop(c,C))
            = choose(a,b) * cop(a,A) / (cop(b,B) * cop(c,C))

Any powers of p dividing choose(n,k) come from choose(a,b) - in our case, there will be none, since we've eliminated these cases before - and, although cop(a,A) / (cop(b,B) * cop(c,C)) need not be an integer (consider e.g. choose(19,9) (mod 5)), when considering the expression modulo p, cop(m,r) reduces to (-1)^m * r!, so, since a = b + c, the (-1) cancel and we are left with

choose(n,k) ≡ choose(a,b) * choose(A,B) (mod p)

In the second case, we find

choose(n,k) = choose(a,b) * p * cop(a,A)/ (cop(b,B) * cop(c,C))

since a = b + c + 1. The carry in the last digit means that A < B, so modulo p

p * cop(a,A) / (cop(b,B) * cop(c,C)) ≡ 0 = choose(A,B)

(where we can either replace the division with a multiplication by the modular inverse, or view it as a congruence of rational numbers, meaning the numerator is divisible by p). Anyway, we again find

choose(n,k) ≡ choose(a,b) * choose(A,B) (mod p)

Now we can recur for the choose(a,b) part.

Example:

choose(144,6) (mod 5)
144 = 28 * 5 + 4
  6 =  1 * 5 + 1
choose(144,6) ≡ choose(28,1) * choose(4,1) (mod 5)
              ≡ choose(3,1) * choose(4,1) (mod 5)
              ≡ 3 * 4 = 12 ≡ 2 (mod 5)

choose(12349,789) ≡ choose(2469,157) * choose(4,4)
                  ≡ choose(493,31) * choose(4,2) * choose(4,4
                  ≡ choose(98,6) * choose(3,1) * choose(4,2) * choose(4,4)
                  ≡ choose(19,1) * choose(3,1) * choose(3,1) * choose(4,2) * choose(4,4)
                  ≡ 4 * 3 * 3 * 1 * 1 = 36 ≡ 1 (mod 5)

Now the implementation:

// Preconditions: 0 <= k <= n; p > 1 prime
long long choose_mod_one(long long n, long long k, long long p)
{
    // For small k, no recursion is necessary
    if (k < p) return choose_mod_two(n,k,p);
    long long q_n, r_n, q_k, r_k, choose;
    q_n = n / p;
    r_n = n % p;
    q_k = k / p;
    r_k = k % p;
    choose = choose_mod_two(r_n, r_k, p);
    // If the exponent of p in choose(n,k) isn't determined to be 0
    // before the calculation gets serious, short-cut here:
    /* if (choose == 0) return 0; */
    choose *= choose_mod_one(q_n, q_k, p);
    return choose % p;
}

// Preconditions: 0 <= k <= min(n,p-1); p > 1 prime
long long choose_mod_two(long long n, long long k, long long p)
{
    // reduce n modulo p
    n %= p;
    // Trivial checks
    if (n < k) return 0;
    if (k == 0 || k == n) return 1;
    // Now 0 < k < n, save a bit of work if k > n/2
    if (k > n/2) k = n-k;
    // calculate numerator and denominator modulo p
    long long num = n, den = 1;
    for(n = n-1; k > 1; --n, --k)
    {
        num = (num * n) % p;
        den = (den * k) % p;
    }
    // Invert denominator modulo p
    den = invert_mod(den,p);
    return (num * den) % p;
}

To calculate the modular inverse, you can use Fermat's (so-called little) theorem

If p is prime and a not divisible by p, then a^(p-1) ≡ 1 (mod p).

and calculate the inverse as a^(p-2) (mod p), or use a method applicable to a wider range of arguments, the extended Euclidean algorithm or continued fraction expansion, which give you the modular inverse for any pair of coprime (positive) integers:

long long invert_mod(long long k, long long m)
{
    if (m == 0) return (k == 1 || k == -1) ? k : 0;
    if (m < 0) m = -m;
    k %= m;
    if (k < 0) k += m;
    int neg = 1;
    long long p1 = 1, p2 = 0, k1 = k, m1 = m, q, r, temp;
    while(k1 > 0) {
        q = m1 / k1;
        r = m1 % k1;
        temp = q*p1 + p2;
        p2 = p1;
        p1 = temp;
        m1 = k1;
        k1 = r;
        neg = !neg;
    }
    return neg ? m - p2 : p2;
}

Like calculating a^(p-2) (mod p), this is an O(log p) algorithm, for some inputs it's significantly faster (it's actually O(min(log k, log p)), so for small k and large p, it's considerably faster), for others it's slower.

Overall, this way we need to calculate at most O(log_p k) binomial coefficients modulo p, where each binomial coefficient needs at most O(p) operations, yielding a total complexity of O(p*log_p k) operations. When k is significantly larger than p, that is much better than the O(k) solution. For k <= p, it reduces to the O(k) solution with some overhead.

Cressi answered 2/6, 2012 at 13:59 Comment(4)
Can you post a summary of your algorithm? It is a bit hard for me to follow the steps.Beffrey
Can you give me a hint where you have difficulties? Would be easier to do if I didn't have to entirely guess at what parts might be problematic for people not being able to read my mind.Cressi
It seems that you are running a loop (under guise of recursive function) through result of Lucas theorem in first part, and use multiplicative inverse to calculate nCk mod p in second part? (This is something I'm looking for). Lucas theorem will take care the case p is small.Beffrey
Yes, that's it (didn't know somebody went to the trouble of making a theorem of the relation when I wrote it, hence no mention of master Lucas; now that I know that, I should add a reference to it).Cressi
G
0

If you're calculating it more than once, there's another way that's faster. I'm going to post code in python because it'll probably be the easiest to convert into another language, although I'll put the C++ code at the end.

Calculating Once

Brute force:

def choose(n, k, m):
    ans = 1
    for i in range(k): ans *= (n-i)
    for i in range(k): ans //= i
    return ans % m

But the calculation can get into very big numbers, so we can use modular airthmetic tricks instead:

(a * b) mod m = (a mod m) * (b mod m) mod m

(a / (b*c)) mod m = (a mod m) / ((b mod m) * (c mod m) mod m)

(a / b) mod m = (a mod m) * (b mod m)^-1

Note the ^-1 at the end of the last equation. This is the multiplicative inverse of b mod m. It basically means that ((b mod m) * (b mod m)^-1) mod m = 1, just like how a * a^-1 = a * 1/a = 1 with (non-zero) integers.

This can be calculated in a few ways, one of which is the extended euclidean algorithm:

def multinv(n, m):
    ''' Multiplicative inverse of n mod m '''
    if m == 1: return 0
    m0, y, x = m, 0, 1

    while n > 1:
        y, x = x - n//m*y, y
        m, n = n%m, m
    
    return x+m0 if x < 0 else x

Note that another method, exponentiation, works only if m is prime. If it is, you can do this:

def powmod(b, e, m):
    ''' b^e mod m '''
    # Note: If you use python, there's a built-in pow(b, e, m) that's probably faster
    # But that's not in C++, so you can convert this instead:
    P = 1
    while e:
        if  e&1: P = P * b % m
        e >>= 1; b = b * b % m
    return P

def multinv(n, m):
    ''' Multiplicative inverse of n mod m, only if m is prime '''
    return powmod(n, m-2, m)
    

But note that the Extended Euclidean Algorithm tends to still run faster, even though they technically have the same time complexity, O(log m), because it has a lower constant factor.

So now the full code:

def multinv(n, m):
    ''' Multiplicative inverse of n mod m in log(m) '''
    if m == 1: return 0
    m0, y, x = m, 0, 1

    while n > 1:
        y, x = x - n//m*y, y
        m, n = n%m, m
    
    return x+m0 if x < 0 else x


def choose(n, k, m):
    num = den = 1
    for i in range(k): num = num * (n-i) % m
    for i in range(k): den = den * i % m
    return num * multinv(den, m)

Querying Multiple Times

We can calculate the numerator and denominator separately, and then combine them. But notice that the product we're calculating for the numerator is n * (n-1) * (n-2) * (n-3) ... * (n-k+1). If you've ever learned about something called prefix sums, this is awfully similar. So let's apply it.

Precalculate fact[i] = i! mod m for i up to whatever the max value of n is, maybe 1e7 (ten million). Then, the numerator is (fact[n] * fact[n-k]^-1) mod m, and the denominator is fact[k]. So we can calculate choose(n, k, m) = fact[n] * multinv(fact[n-k], m) % m * multinv(fact[k], m) % m.

Python code:

MAXN = 1000 # Increase if necessary
MOD = 10**9+7 # A common mod that's used, change if necessary

fact = [1]
for i in range(1, MAXN+1):
    fact.append(fact[-1] * i % MOD)

def multinv(n, m):
    ''' Multiplicative inverse of n mod m in log(m) '''
    if m == 1: return 0
    m0, y, x = m, 0, 1

    while n > 1:
        y, x = x - n//m*y, y
        m, n = n%m, m
    
    return x+m0 if x < 0 else x


def choose(n, k, m):
    return fact[n] * multinv(fact[n-k] * fact[k] % m, m) % m

C++ code:

#include <iostream>
using namespace std;

const int MAXN = 1000; // Increase if necessary
const int MOD = 1e9+7; // A common mod that's used, change if necessary

int fact[MAXN+1];

int multinv(int n, int m) {
    /* Multiplicative inverse of n mod m in log(m) */
    if (m == 1) return 0;
    int m0 = m, y = 0, x = 1, t;

    while (n > 1) {
        t = y;
        y = x - n/m*y;
        x = t;
        
        t = m;
        m = n%m;
        n = t;
    }
    
    return x<0 ? x+m0 : x;
}

int choose(int n, int k, int m) {
    return (long long) fact[n]
         * multinv((long long) fact[n-k] * fact[k] % m, m) % m;
}

int main() {
    fact[0] = 1;
    for (int i = 1; i <= MAXN; i++) {
        fact[i] = (long long) fact[i-1] * i % MOD;
    }

    cout << choose(4, 2, MOD) << '\n';
    cout << choose(1e6, 1e3, MOD) << '\n';
}

Note that I'm casting to long long to avoid overflow.

Gerlachovka answered 23/5, 2021 at 20:33 Comment(3)
Thanks! I found this helpful. It is missing the last m parameter in the calls to multinv() in the last Python version though.Disobey
Adding c++ code was great for people who don't know python.Gallaway
Downvoted. Most of the Python code doesn't work at all. For example, for i in range(k): ans //= i immediately produces a ZeroDivisionError. I'm not sure the C++ code is much better.Savadove
W
0

I'm surprised nobody else mentioned this, but if you need to calculate these binomial coefficients often, and numbers n and and k are very large but similar (<10**3 difference) to n and k you queried prior, there is a way to speed this up massively, using these three formulas:

formula (n k) = (n/k)(n - 1 k-1)

enter image description here

enter image description here

You'll cache the previous value of binomial coefficient and from there it is only one multiplication and one division to get the next value for n+1 or k+1. You don't have to calculate any more factorials, which is nice. And your algorithm doesn't depend on the size of n and k in question much - you compute the combination number once, and then use these formulas, this algorithm depends only on the difference of n and difference of k to previous queries the result of which you cached (a 2D binary tree to find the closest computed (n k) ).

Of course, these formulas can be adapted for modular arithmetic, instead of division use modular multiplicative inverse, and the rest is mod p.

See https://en.wikipedia.org/wiki/Binomial_coefficient#Identities_involving_binomial_coefficients

(sorry for formatting, SO still doesn't support Latex)

Wed answered 27/4 at 14:24 Comment(0)
W
0

This code is much faster and compute much larger n than all existing answers. By Lucas's Theorem:

from sympy import binomial
from functools import cache
from math import log, floor

@cache
def mybinomial(n,k,p):
    return binomial(n,k)%p

@cache
def num_p(n, p):
    return sum(n // p**i for i in range(1, floor(log(n, p)) + 1)) if n else 0
     
def n_choose_m_mod_p(n,m,p):
    if num_p(n,p)>num_p(n-m,p)+num_p(m,p):
        return 0
    elif n<p**2:
        return (mybinomial(n//p,m//p,p)*mybinomial(n%p,m%p,p))%p
    return (n_choose_m_mod_p(n//p,m//p,p)*mybinomial(n%p,m%p,p))%p
Watch answered 13/7 at 1:54 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.