numba-safe version of itertools.combinations?
Asked Answered
S

1

10

I have some code which loops through a large set of itertools.combinations, which is now a performance bottleneck. I'm trying to turn to numba's @jit(nopython=True) to speed it up, but I'm running into some issues.

First, it seems numba can't handle itertools.combinations itself, per this small example:

import itertools
import numpy as np
from numba import jit

arr = [1, 2, 3]
c = 2

@jit(nopython=True)
def using_it(arr, c):
    return itertools.combinations(arr, c)

for i in using_it(arr, c):
    print(i)

throw error: numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend) Unknown attribute 'combinations' of type Module(<module 'itertools' (built-in)>)

After some googling, I found this github issue where the questioner proposed this numba-safe function for calculating permutations:

@jit(nopython=True)
def permutations(A, k):
    r = [[i for i in range(0)]]
    for i in range(k):
        r = [[a] + b for a in A for b in r if (a in b)==False]
    return r

Leveraging that, I can then easily filter down to combinations:

@jit(nopython=True)
def combinations(A, k):
    return [item for item in permutations(A, k) if sorted(item) == item]

Now I can run that combinations function without errors and get the correct result. However, this is now dramatically slower with the @jit(nopython=True) than without it. Running this timing test:

A = list(range(20))  # numba throws 'cannot determine numba type of range' w/o list
k = 2
start = pd.Timestamp.utcnow()
print(combinations(A, k))
print(f"took {pd.Timestamp.utcnow() - start}")

clocks in at 2.6 seconds with the numba @jit(nopython=True) decorators, and under 1/000 of a second with them commented out. So that's not really a workable solution for me either.

Supplement answered 17/4, 2020 at 0:17 Comment(8)
I highly suspect your performance bottleneck is fundamentally due to combinatorial explosion, because it is fundamentally O( nCk), and numba will only shave constant factors off your computation, and not really an effective way to improve your runtime. Using the python version isn't going to be very efficient, requiring you to materialize the whole thing, also killing your performanceMicrophotograph
What is the problem you are actually trying to solve? There may be better alternatives than what you have suggested here. Also, itertools.combinations/permutations are already written in C and are very efficient. I would like to see how the numba homemade permutation function compares to itertools.Pedo
I actually just did a benchmark myself. list(itertools.permutations(list(range(10)), 8)) ran in 0.3338 seconds. The numba homemade version took about 3.5 seconds. As @Microphotograph says, efficiency in generating combinations is not your problem here.Pedo
The problem I am actually trying to solve is a local search algorithm for swapping points in a Traveling Salesman Tour. So for a problem with e.g. 100 nodes, I calculate swap_options = itertools.combinations(100,2), then for each swap_option (consisting of two node indexes) I see if swapping them improves the tour length. So I'm not actually exhausting my combinations (or I only do once, when I hit a local optimum). But I am constantly looping through them and recalculating themSupplement
so I have a two_opt function which calls itertools.combinations, and I would like to numba @jit() the whole function but can't without a numba-safe itertools.combinations alternativeSupplement
@MaxPower you mention “constantly looping through and recalculating them “ this sounds like a good place for a hash table maybe?Romance
well I want to iterate through each combination of nodes in my tour. Every time I change the tour, I (think I want to) restart proceeding through every possible combination again. Because I know I've reached a local optimum (for 2-opt local search moves) only when I've tried swapping each 2-combination of nodes for a given tour. This may be due to my lack of understanding, but I don't see how a hash map is better suited to that than a generator of combinations such as itertools.combinationsSupplement
This post provides a fast Numba implementation for a similar question. The case where c=2 can be computed very efficiently even though the number of combination is huge.Vinnievinnitsa
A
4

There is not much to gain with Numba in this case as itertools.combinations is written in C.

If you want to benchmark it, here is a Numba / Python implementation of what itertools.combinatiions does:

@jit(nopython=True)
def using_numba(pool, r):
    n = len(pool)
    indices = list(range(r))
    empty = not(n and (0 < r <= n))

    if not empty:
        result = [pool[i] for i in indices]
        yield result

    while not empty:
        i = r - 1
        while i >= 0 and indices[i] == i + n - r:
            i -= 1
        if i < 0:
            empty = True
        else:
            indices[i] += 1
            for j in range(i+1, r):
                indices[j] = indices[j-1] + 1

            result = [pool[i] for i in indices]
            yield result

On my machine, this is about 15 times slower than itertools.combinations. Getting the permutations and filtering the combinations would certainly be even slower.

Ascent answered 23/4, 2020 at 17:43 Comment(1)
I think that itertools is written in C explains that sequentially generating the combinations won't run faster with numba; I suspect that generating an array of all combinations can be substantially faster with numba with an algorithm that allows for parallel generation of combinations via a bijection from their enumeration in e.g. lexical order.Hallow

© 2022 - 2024 — McMap. All rights reserved.