Karatsuba algorithm too much recursion
Asked Answered
D

5

6

I am trying to implement the Karatsuba multiplication algorithm in c++ but right now I am just trying to get it to work in python.

Here is my code:

def mult(x, y, b, m):
    if max(x, y) < b:
        return x * y

    bm = pow(b, m)
    x0 = x / bm
    x1 = x % bm
    y0 = y / bm
    y1 = y % bm

    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0

What I don't get is: how should z2, z1, and z0 be created? Is using the mult function recursively correct? If so, I'm messing up somewhere because the recursion isn't stopping.

Can someone point out where the error is?

Dogfish answered 14/8, 2011 at 18:34 Comment(6)
Of course the recursion isn't stopping: where's the condition that makes the recursion to stop?Mendoza
I am not sure, maybe using x0, x1 = divmod(x, bm) would be faster.Teniafuge
@neurino, at first line, if statement has return.Teniafuge
Moreover read comments to this questionMendoza
@utdemir: if it experiences infinite recursion it means that max(x, y) < b never occurMendoza
1. Shouldn't your parameter m change in some recursive call ? 2. If you work in base b, and if I give you (xb^m) with (x < b), how do you go about returning (xb^(m+1)) ? (x*b^(2m)) ? How costly is that operation ? How costly is your last line ? Remember Karatsuba is a slight improvement on fast multiplication.Supernormal
O
5

NB: the response below addresses directly the OP's question about excessive recursion, but it does not attempt to provide a correct Karatsuba algorithm. The other responses are far more informative in this regard.

Try this version:

def mult(x, y, b, m):
    bm = pow(b, m)

    if min(x, y) <= bm:
        return x * y

    # NOTE the following 4 lines
    x0 = x % bm
    x1 = x / bm
    y0 = y % bm
    y1 = y / bm

    z0 = mult(x0, y0, b, m)
    z2 = mult(x1, y1, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
    assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
    return retval

The most serious problem with your version is that your calculations of x0 and x1, and of y0 and y1 are flipped. Also, the algorithm's derivation does not hold if x1 and y1 are 0, because in this case, a factorization step becomes invalid. Therefore, you must avoid this possibility by ensuring that both x and y are greater than b**m.

EDIT: fixed a typo in the code; added clarifications

EDIT2:

To be clearer, commenting directly on your original version:

def mult(x, y, b, m):
    # The termination condition will never be true when the recursive 
    # call is either
    #    mult(z2, bm ** 2, b, m)
    # or mult(z1, bm, b, m)
    #
    # Since every recursive call leads to one of the above, you have an
    # infinite recursion condition.
    if max(x, y) < b:
        return x * y

    bm = pow(b, m)

    # Even without the recursion problem, the next four lines are wrong
    x0 = x / bm  # RHS should be x % bm
    x1 = x % bm  # RHS should be x / bm
    y0 = y / bm  # RHS should be y % bm
    y1 = y % bm  # RHS should be y / bm

    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
Outroar answered 14/8, 2011 at 20:52 Comment(3)
Can you explain what's different about this version from the version posted above? How does this address the infinite recursion?Viridescent
This solves the termination problem, but the two final recursive calls in retval's assignation are wrong, and so is the computation of bm : this will not give you Karatsuba's O(n^(log_2(3)) running time.Supernormal
@huitseeker, the OP explained that he/she just wants to get the basic logic working, not the optimal running time. This is clear not only from the computation of bm in every recursion, but also from the use of % and /. As for the last two recursive calls, I don't know if they are part of the original algorithm (I was not able to readily find a complete version of it), so in this sense they may be "wrong", but algorithm as shown produced the correct results in all my test runs. If you have a counterexample, please post it.Outroar
I
4

Usually big numbers are stored as arrays of integers. Each integer represents one digit. This approach allows to multiply any number by the power of base with simple left shift of the array.

Here is my list-based implementation (may contain bugs):

def normalize(l,b):
    over = 0
    for i,x in enumerate(l):
        over,l[i] = divmod(x+over,b)
    if over: l.append(over)
    return l
def sum_lists(x,y,b):
    l = min(len(x),len(y))
    res = map(operator.add,x[:l],y[:l])
    if len(x) > l: res.extend(x[l:])
    else: res.extend(y[l:])
    return normalize(res,b)
def sub_lists(x,y,b):
    res = map(operator.sub,x[:len(y)],y)
    res.extend(x[len(y):])
    return normalize(res,b)
def lshift(x,n):
    if len(x) > 1 or len(x) == 1 and x[0] != 0:
        return [0 for i in range(n)] + x
    else: return x
def mult_lists(x,y,b):
    if min(len(x),len(y)) == 0: return [0]
    m = max(len(x),len(y))
    if (m == 1): return normalize([x[0]*y[0]],b)
    else: m >>= 1
    x0,x1 = x[:m],x[m:]
    y0,y1 = y[:m],y[m:]
    z0 = mult_lists(x0,y0,b)
    z1 = mult_lists(x1,y1,b)
    z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
    t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
    t2 = lshift(z1,m*2)
    return sum_lists(sum_lists(z0,t1,b),t2,b)

sum_lists and sub_lists returns unnormalized result - single digit can be greater than the base value. normalize function solved this problem.

All functions expect to get list of digits in the reverse order. For example 12 in base 10 should be written as [2,1]. Lets take a square of 9987654321.

» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]
Incontrovertible answered 14/8, 2011 at 21:53 Comment(0)
S
4

The goal of the Karatsuba multiplication is to improve on the divide-and conquer multiplication algorithm by making 3 recursive calls instead of four. Therefore, the only lines in your script that should contain a recursive call to the multiplication are those assigning z0,z1 and z2. Anything else will give you a worse complexity. You can't use pow to compute bm when you haven't defined multiplication yet (and a fortiori exponentiation), either.

For that, the algorithm crucially uses the fact that it is using a positional notation system. If you have a representation x of a number in base b, then x*bm is simply obtained by shifting the digits of that representation m times to the left. That shifting operation is essentially "free" with any positional notation system. That also means that if you want to implement that, you have to reproduce this positional notation, and the "free" shift. Either you chose to compute in base b=2 and use python's bit operators (or the bit operators of a given decimal, hex, ... base if your test platform has them), or you decide to implement for educational purposes something that works for an arbitrary b, and you reproduce this positional arithmetic with something like strings, arrays, or lists.

You have a solution with lists already. I like to work with strings in python, since int(s, base) will give you the integer corresponding to the string s seen as a number representation in base base: it makes tests easy. I have posted an heavily commented string-based implementation as a gist here, including string-to-number and number-to-string primitives for good measure.

You can test it by providing padded strings with the base and their (equal) length as arguments to mult:

In [169]: mult("987654321","987654321",10,9)

Out[169]: '966551847789971041'

If you don't want to figure out the padding or count string lengths, a padding function can do it for you:

In [170]: padding("987654321","2")

Out[170]: ('987654321', '000000002', 9)

And of course it works with b>10:

In [171]: mult('987654321', '000000002', 16, 9)

Out[171]: '130eca8642'

(Check with wolfram alpha)

Supernormal answered 14/8, 2011 at 22:44 Comment(0)
V
1

I believe that the idea behind the technique is that the zi terms are computed using the recursive algorithm, but the results are not unified together that way. Since the net result that you want is

z0 B^2m + z1 B^m + z2

Assuming that you choose a suitable value of B (say, 2) you can compute B^m without doing any multiplications. For example, when using B = 2, you can compute B^m using bit shifts rather than multiplications. This means that the last step can be done without doing any multiplications at all.

One more thing - I noticed that you've picked a fixed value of m for the whole algorithm. Typically, you would implement this algorithm by having m always be a value such that B^m is half the number of digits in x and y when they are written in base B. If you're using powers of two, this would be done by picking m = ceil((log x) / 2).

Hope this helps!

Viridescent answered 14/8, 2011 at 19:57 Comment(1)
thanks. ill eventually change it so that it will use powers of 2 for b. for now, im just trying to get this to workDogfish
T
0

In Python 2.7: Save this file as Karatsuba.py

   def karatsuba(x,y):
        """Karatsuba multiplication algorithm.
        Return the product of two numbers in an efficient manner
        @author Shashank
        date: 23-09-2018

        Parameters
        ----------
        x : int
            First Number 
        y : int
            Second Number   

        Returns
        -------
        prod : int
               The product of two numbers 

        Examples
        --------
        >>> import Karatsuba.karatsuba
        >>> a = 1234567899876543211234567899876543211234567899876543211234567890
        >>> b = 9876543211234567899876543211234567899876543211234567899876543210
        >>> Karatsuba.karatsuba(a,b)
        12193263210333790590595945731931108068998628253528425547401310676055479323014784354458161844612101832860844366209419311263526900
        """
        if len(str(x)) == 1 or len(str(y)) == 1:
            return x*y
        else:
            n = max(len(str(x)), len(str(y)))
            m = n/2

            a = x/10**m
            b = x%10**m
            c = y/10**m
            d = y%10**m

            ac = karatsuba(a,c)                             #step 1
            bd = karatsuba(b,d)                             #step 2
            ad_plus_bc = karatsuba(a+b, c+d) - ac - bd      #step 3
            prod = ac*10**(2*m) + bd + ad_plus_bc*10**m     #step 4
            return prod
Tinct answered 23/9, 2018 at 10:25 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.