First of all how this algorithm works? It is based on the Extended Euclidean algorithm for computation of the GCD. In short the idea is following: if we can find some integer coefficients m
and n
such that
a*m + b*n = 1
then m
will be the answer for the modular inverse problem. It is easy to see because
a*m + b*n = a*m (mod b)
Luckily the Extended Euclidean algorithm does exactly that: if a
and b
are co-prime, it finds such m
and n
. It works in the following way: for each iteration track two triplets (ai, xai, yai)
and (bi, xbi, ybi)
such that at every step
ai = a0*xai + b0*yai
bi = a0*xbi + b0*ybi
so when finally the algorithm stops at the state of ai = 0
and bi = GCD(a0,b0)
, then
1 = GCD(a0,b0) = a0*xbi + b0*ybi
It is done using more explicit way to calculate modulo: if
q = a / b
r = a % b
then
r = a - q * b
Another important thing is that it can be proven that for positive a
and b
at every step |xai|,|xbi| <= b
and |yai|,|ybi| <= a
. This means there can be no overflow during calculation of those coefficients. Unfortunately negative values are possible, moreover, on every step after the first one in each equation one is positive and the other is negative.
What the code in your question does is a reduced version of the same algorithm: since all we are interested in is the x[a/b]
coefficients, it tracks only them and ignores the y[a/b]
ones. The simplest way to make that code work for uint64_t
is to track the sign explicitly in a separate field like this:
typedef struct tag_uint64AndSign {
uint64_t value;
bool isNegative;
} uint64AndSign;
uint64_t mul_inv(uint64_t a, uint64_t b)
{
if (b <= 1)
return 0;
uint64_t b0 = b;
uint64AndSign x0 = { 0, false }; // b = 1*b + 0*a
uint64AndSign x1 = { 1, false }; // a = 0*b + 1*a
while (a > 1)
{
if (b == 0) // means original A and B were not co-prime so there is no answer
return 0;
uint64_t q = a / b;
// (b, a) := (a % b, b)
// which is the same as
// (b, a) := (a - q * b, b)
uint64_t t = b; b = a % b; a = t;
// (x0, x1) := (x1 - q * x0, x0)
uint64AndSign t2 = x0;
uint64_t qx0 = q * x0.value;
if (x0.isNegative != x1.isNegative)
{
x0.value = x1.value + qx0;
x0.isNegative = x1.isNegative;
}
else
{
x0.value = (x1.value > qx0) ? x1.value - qx0 : qx0 - x1.value;
x0.isNegative = (x1.value > qx0) ? x1.isNegative : !x0.isNegative;
}
x1 = t2;
}
return x1.isNegative ? (b0 - x1.value) : x1.value;
}
Note that if a
and b
are not co-prime or when b
is 0 or 1, this problem has no solution. In all those cases my code returns 0
which is an impossible value for any real solution.
Note also that although the calculated value is really the modular inverse, simple multiplication will not always produce 1 because of the overflow at multiplication over uint64_t
. For example for a = 688231346938900684
and b = 2499104367272547425
the result is inv = 1080632715106266389
a * inv = 688231346938900684 * 1080632715106266389 =
= 743725309063827045302080239318310076 =
= 2499104367272547425 * 297596738576991899 + 1 =
= b * 297596738576991899 + 1
But if you do a naive multiplication of those a
and inv
of type uint64_t
, you'll get 4042520075082636476
so (a*inv)%b
will be 1543415707810089051
rather than expected 1
.
pX
can be less than zero, whereas in modular setting it is not and the usual way is to doif (pX < 0) pX += b
. But that is signed and unsigned addition which can (or cannot?) overflow. Ideally, for modular calculations there would be nostatic_cast<S>
. But thanks again for writing the article, if you are the author. – Ebonee