How to do arithmetic modulo another number, without overflow?
Asked Answered
G

4

10

I'm trying to implement a fast primality test for Rust's u32 and u64 datatypes. As part of it, I need to compute (n*n)%d where n and d are u32 (or u64, respectively).

While the result can easily fit in the datatype, I'm at a loss for how to compute this. As far as I know there is no processor primitive for this.

For u32 we can fake it -- cast up to u64, so that the product won't overflow, then take the modulus, then cast back down to u32, knowing this won't overflow. However since I don't have a u128 datatype (as far as I know) this trick won't work for u64.

So for u64, the most obvious way I can think of to accomplish this is to somehow compute x*y to get a pair (carry, product) of u64, so we capture the amount of overflow instead of just losing it (or panicking, or whatever).

Is there a way to do this? Or another standard way to solve the problem?

Grand answered 28/8, 2017 at 11:39 Comment(7)
doc.rust-lang.org/std/u128Pinard
Just use this: huonw.github.io/primal/primal/fn.is_prime.htmlTaurus
Wiki modular arithmetic example implementationAnisole
@Anisole if you (or someone) can turn that into an answer I would love to accept it.Grand
@Anisole Why do not you post it as answer?Pinard
@Pinard that's a good trick if you're on nightly, but then I have to wonder how to do it for u128 (and so on until I run out of higher-bit datatypes)Grand
@RichardRast Sure, I just wanted to answer your interrogation about u128Pinard
A
10

Richard Rast pointed out that Wikipedia version works only with 63-bit integers. I extended the code provided by Boiethios to work with full range of 64-bit unsigned integers.

fn mul_mod64(mut x: u64, mut y: u64, m: u64) -> u64 {
    let msb = 0x8000_0000_0000_0000;
    let mut d = 0;
    let mp2 = m >> 1;
    x %= m;
    y %= m;

    if m & msb == 0 {
        for _ in 0..64 {
            d = if d > mp2 {
                (d << 1) - m
            } else {
                d << 1
            };
            if x & msb != 0 {
                d += y;
            }
            if d >= m {
                d -= m;
            }
            x <<= 1;
        }
        d
    } else {
        for _ in 0..64 {
            d = if d > mp2 {
                d.wrapping_shl(1).wrapping_sub(m)
            } else {
                // the case d == m && x == 0 is taken care of 
                // after the end of the loop
                d << 1
            };
            if x & msb != 0 {
                let (mut d1, overflow) = d.overflowing_add(y);
                if overflow {
                    d1 = d1.wrapping_sub(m);
                }
                d = if d1 >= m { d1 - m } else { d1 };
            }
            x <<= 1;
        }
        if d >= m { d - m } else { d }
    }
}

#[test]
fn test_mul_mod64() {
    let half = 1 << 16;
    let max = std::u64::MAX;

    assert_eq!(mul_mod64(0, 0, 2), 0);
    assert_eq!(mul_mod64(1, 0, 2), 0);
    assert_eq!(mul_mod64(0, 1, 2), 0);
    assert_eq!(mul_mod64(1, 1, 2), 1);
    assert_eq!(mul_mod64(42, 1, 2), 0);
    assert_eq!(mul_mod64(1, 42, 2), 0);
    assert_eq!(mul_mod64(42, 42, 2), 0);
    assert_eq!(mul_mod64(42, 42, 42), 0);
    assert_eq!(mul_mod64(42, 42, 41), 1);
    assert_eq!(mul_mod64(1239876, 2948635, 234897), 163320);

    assert_eq!(mul_mod64(1239876, 2948635, half), 18476);
    assert_eq!(mul_mod64(half, half, half), 0);
    assert_eq!(mul_mod64(half+1, half+1, half), 1);

    assert_eq!(mul_mod64(max, max, max), 0);
    assert_eq!(mul_mod64(1239876, 2948635, max), 3655941769260);
    assert_eq!(mul_mod64(1239876, max, max), 0);
    assert_eq!(mul_mod64(1239876, max-1, max), max-1239876);
    assert_eq!(mul_mod64(max, 2948635, max), 0);
    assert_eq!(mul_mod64(max-1, 2948635, max), max-2948635);
    assert_eq!(mul_mod64(max-1, max-1, max), 1);
    assert_eq!(mul_mod64(2, max/2, max-1), 0);
}
Anisole answered 28/8, 2017 at 18:5 Comment(3)
@mcarton what is half half of?Switch
@Switch now that I think about it, nothing :)Fricassee
I regret to say that I have not gone through the work of understanding this code :( but I have tested it thoroughly and it works very well. Thank you!Grand
B
3

Here's an alternative approach (there's now a u128 datatype):

fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
    let (a, b, m) = (a as u128, b as u128, m as u128);
    ((a * b) % m) as u64
}

This approach just leans on LLVM's 128 bit integer arithmetic.

The thing I like about this version is that it's really easy to convince yourself that the solution is correct for the entire domain. Since a and b are u64s the product is guaranteed to fit in a u128, and since m is a u64 the downcast at the end is guaranteed to be safe.

I don't know how performance compares to other approaches, but I would be pretty surprised if it were dramatically slower. If you really care about performance you're going to want to run some benchmarks and try a few alternatives in any case.

Buchan answered 20/3, 2021 at 14:0 Comment(0)
A
2

Use simple mathematics:

(n*n)%d = (n%d)*(n%d)%d

To see that this is indeed true, set n = k*d+r:

n*n%d = k**2*d**2+2*k*d*r+r**2 %d = r**2%d = (n%d)*(n%d)%d
Annulment answered 28/8, 2017 at 11:42 Comment(3)
strictly speaking n * n % d = (n % d) * (n % d) % dThoroughgoing
This can still overflow if d is large.Heartwhole
@Heartwhole Specifically, it can overflow if d > 2^16Kimon
P
1

red75prime added a useful comment. Here is the Rust code to calculate a modulo of two multiplied numbers, taken from Wikipedia:

fn mul_mod(mut x: u64, mut y: u64, m: u64) -> u64 {
    let mut d = 0_u64;
    let mp2 = m >> 1;
    x %= m;
    y %= m;

    for _ in 0..64 {
        d = if d > mp2 {
            (d << 1) - m
        } else {
            d << 1
        };
        if x & 0x8000_0000_0000_0000_u64 != 0 {
            d += y;
        }
        if d > m {
            d -= m;
        }
        x <<= 1;
    }
    d
}
Pinard answered 28/8, 2017 at 16:22 Comment(4)
@Anisole If you want to post your own answer, I delete mine.Pinard
No, it's fine. The problem is this algorithm isn't correct. I found one bug if d > m ... should be if d >= m .... Another one causes subtract with overflow in ` (d << 1) - m`. I didn't find why yet.Anisole
Also it seems to give incorrect answers. 11552001120680995*15777587326414455 (mod 18442563521290148565) should be 844062957336182220, algorithm gives 13054753449364403936Anisole
Note that the wikipedia section indicates that all arguments (x, y and m) must be at most 63 bits (that is, $<2^63$). @Anisole your given arguments (here, the modulus) uses all 64 bits.Grand

© 2022 - 2024 — McMap. All rights reserved.