modular multiplication of large numbers in c++
Asked Answered
S

5

9

I have three integers A, B (less than 10^12) and C (less than 10^15). I want to calculate (A * B) % C. I know that

(A * B) % C = ((A % C) * (B % C)) % C

but say if A = B = 10^11 then above expression will cause an integer overflow. Is there any simple solution for above case or I have to use fast multiplication algorithms.

If I have to use fast multiplication algorithm then which algorithm I should use.

EDIT: I have tried above problem in C++ (which does not cause overflow, not sure why), but isn't the answer should be zero?

Thanks in advance.

Sordino answered 7/1, 2014 at 12:40 Comment(3)
The RHS will only overflow is C is sufficiently large (that's what is wonderful about remainders).Levana
Arithmetic overflows in C++ are usually silent - there's no error, they just happen. You find out about it when you see your output is 712049423024128 when you were expecting 0.Alixaliza
If you want something fast, I fear it will have to be platform-specific. What platform(s) are you interested in?Phillips
C
6

Given your formula and a the following variation:

(A + B) mod C = ((A mod C) + (B mod C)) mod C 

You can use the divide and conquer approach to develope an algorithm that is both easy and fast:

#include <iostream>

long bigMod(long  a, long  b, long c) {
    if (a == 0 || b == 0) {
        return 0;
    }
    if (a == 1) {
        return b;
    }
    if (b == 1) {
        return a;
    } 

    // Returns: (a * b/2) mod c
    long a2 = bigMod(a, b / 2, c);

    // Even factor
    if ((b & 1) == 0) {
        // [((a * b/2) mod c) + ((a * b/2) mod c)] mod c
        return (a2 + a2) % c;
    } else {
        // Odd exponent
        // [(a mod c) + ((a * b/2) mod c) + ((a * b/2) mod c)] mod c
        return ((a % c) + (a2 + a2)) % c;
    }
}

int main() { 
    // Use the min(a, b) as the second parameter
    // This prints: 27
    std::cout << bigMod(64545, 58971, 144) << std::endl;
    return 0;
}

Which is O(log N)

Crumley answered 7/1, 2014 at 15:38 Comment(6)
This is doing exponentiation, but the question was to do multiplication. You can probably change multiplication for addition in your code though and it should work.Itacolumite
You would also need to use if (b==0) return 0;Itacolumite
Yes! You're right, thanks for noticing (and for not downvoting the answer although it deserved it). I updated properlyCrumley
+1 I think this version of the algorithm is more readable (though longer) than the accepted answer which as far as I can tell does the same thing.Itacolumite
It works correctly but the algorithm terribly slow :(Fermin
a2 + a2 can absolutely overflow, and (a % c) + (a2 + a2) is even more prone to overflowing.Knopp
G
18

You can solve this using Schrage's method. This allows you to multiply two signed numbers a and z both with a certain modulus m without generating an intermediate number greater than that.

It's based on an approximate factorisation of the modulus m,

m = aq + r 

i.e.

q = [m / a]

and

r = m mod a

where [] denotes the integer part. If r < q and 0 < z < m − 1, then both a(z mod q) and r[z / q] lie in the range 0,...,m − 1 and

az mod m = a(z mod q) − r[z / q]

If this is negative then add m.

[This technique is frequently used in linear congruential random number generators].

Grayling answered 7/1, 2014 at 13:3 Comment(1)
Additionally, you can use this as a recursive algorithm to include all r >= q. Repeat the above algorithm for the product r*[z/q], so the new values become: a2 = r z2 = [r/q] This guarantees that the a values decrease: a > r = a2 -> a2 < a Eventually, when a <= sqrt(m), then: a*a <= m -> q >= a > r -> r<q and the algorithm terminates.Breadthways
C
6

Given your formula and a the following variation:

(A + B) mod C = ((A mod C) + (B mod C)) mod C 

You can use the divide and conquer approach to develope an algorithm that is both easy and fast:

#include <iostream>

long bigMod(long  a, long  b, long c) {
    if (a == 0 || b == 0) {
        return 0;
    }
    if (a == 1) {
        return b;
    }
    if (b == 1) {
        return a;
    } 

    // Returns: (a * b/2) mod c
    long a2 = bigMod(a, b / 2, c);

    // Even factor
    if ((b & 1) == 0) {
        // [((a * b/2) mod c) + ((a * b/2) mod c)] mod c
        return (a2 + a2) % c;
    } else {
        // Odd exponent
        // [(a mod c) + ((a * b/2) mod c) + ((a * b/2) mod c)] mod c
        return ((a % c) + (a2 + a2)) % c;
    }
}

int main() { 
    // Use the min(a, b) as the second parameter
    // This prints: 27
    std::cout << bigMod(64545, 58971, 144) << std::endl;
    return 0;
}

Which is O(log N)

Crumley answered 7/1, 2014 at 15:38 Comment(6)
This is doing exponentiation, but the question was to do multiplication. You can probably change multiplication for addition in your code though and it should work.Itacolumite
You would also need to use if (b==0) return 0;Itacolumite
Yes! You're right, thanks for noticing (and for not downvoting the answer although it deserved it). I updated properlyCrumley
+1 I think this version of the algorithm is more readable (though longer) than the accepted answer which as far as I can tell does the same thing.Itacolumite
It works correctly but the algorithm terribly slow :(Fermin
a2 + a2 can absolutely overflow, and (a % c) + (a2 + a2) is even more prone to overflowing.Knopp
G
4

UPDATED: Fixed error when high bit of a % c is set. (hat tip: Kevin Hopps)

If you're looking for simple over fast, then you can use the following:

typedef unsigned long long u64;

u64 multiplyModulo(u64 a, u64 b, u64 c)
{
    u64 result = 0;
    a %= c;
    b %= c;
    while(b) {
        if(b & 0x1) {
            result += a;
            result %= c;
        }
        b >>= 1;
        if(a < c - a) {
            a <<= 1;
        } else {
            a -= (c - a);
        }
    }
    return result;
}
Gan answered 7/1, 2014 at 13:1 Comment(2)
When "a" has the high bit set, this produces an incorrect result. See my post below.Serle
Correction for the first addition which can also overflow, but the rest of the code is correct (and commenter about shift left below is incorrect in current version of code): if (result < c - a) result = result + a; else result -= (c - a);Mesopause
S
1

Sorry, but godel9's algorithm will produce an incorrect result when the variable "a" holds a value that has the high bit set. This is because "a <<= 1" loses information. Here is a corrected algorithm that works for any integer type, signed or unsigned.

template <typename IntType>
IntType add(IntType a, IntType b, IntType c)
    {
    assert(c > 0 && 0 <= a && a < c && 0 <= b && b < c);
    IntType room = (c - 1) - a;
    if (b <= room)
        a += b;
    else
        a = b - room - 1;
    return a;
    }

template <typename IntType>
IntType mod(IntType a, IntType c)
    {
    assert(c > 0);
    IntType q = a / c; // q may be negative
    a -= q * c; // now -c < a && a < c
    if (a < 0)
        a += c;
    return a;
    }

template <typename IntType>
IntType multiplyModulo(IntType a, IntType b, IntType c)
    {
    IntType result = 0;
    a = mod(a, c);
    b = mod(b, c);
    if (b > a)
        std::swap(a, b);
    while (b)
        {
        if (b & 0x1)
            result = add(result, a, c);
        a = add(a, a, c);
        b >>= 1;
        }
    return result;
    }
Serle answered 10/2, 2015 at 22:25 Comment(2)
What's the reason to have addition function 'mod'? Why not just operator % ?Synergetic
@Hsilgos, the reason for using mod instead of operator% is to ensure a non-negative result.Serle
J
0

In this case, A and B are 40 bit numbers, and C is a 50 bit number, which isn't an issue in 64 bit mode, if you have an intrinsic or can write assembly code to use a 64 bit by 64 bit multiply that produces a 128 bit result (product is actually 80 bits), after which you divide a 128 bit dividend by a 50 bit divisor to produce a 50 bit remainder (the modulo).

Depending on the processor, it may be faster to implement the divide by 50 bit constant by multiplying by 81 bit (or less) constant. Again assuming 64 bit processor, it will take 4 multiplies and some adds followed by a shift of the upper bits of the 4 multiply product to produce a quotient . Then a multiply of quotient times 50 bit modulo number and subtract (from 80 bit product) is used to produce a 50 bit remainder.

Jungian answered 15/1, 2017 at 17:27 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.