Explain combination function of python module itertools
Asked Answered
M

3

7

I have often used itertools module in Python but it feels like cheating if I don't know the logic behind it.

Here is the code to find combinations of string when order is not important.

def combinations(iterable, r):
    # combinations('ABCD', 2) --> AB AC AD BC BD CD
    # combinations(range(4), 3) --> 012 013 023 123
    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)

Could someone please explain the basic idea? Especially on line 14

Minica answered 23/7, 2014 at 10:23 Comment(3)
line 14 is an else statement. Which line do you mean?Galvez
You say "here is my code", but it's a direct copy of the example here: docs.python.org/2/library/itertools.html#itertools.combinationsCoreycorf
@TomDalton Obviously it's not his code if he doesn't know how it works. What he meant was "here's the code on which this question is based".Constraint
B
6
def combinations(iterable, r):
    # combinations('ABCD', 2) --> AB AC AD BC BD CD
    # combinations(range(4), 3) --> 012 013 023 123
    pool = tuple(iterable)
    # first you create a tuple of the original input which you can refer later with 
    # the corresponding indices
    n = len(pool)
    # get the length of the tuple
    if r > n:
        return
    # if the length of the desired permutation is higher than the length of the tuple 
    # it is not possible to create permutations so return without doing something

    indices = list(range(r))
    # create the first list of indices in normal order ( indices = [0,1,2,3,...,r])
    # up to the desired range r

    yield tuple(pool[i] for i in indices)
    # return the first permutation which is a tuple of the input with the original 
    # indices up to r tuple(tuple[0], tuple[1],....,tuple[r])

    while True:
        for i in reversed(range(r)):
            # i will go from r-1, r-2, r-3, ....,0

            if indices[i] != i + n - r:
                # if condition is true except for the case 
                # that at the position i in the tuple the last possible 
                # character appears then it is equal and proceed with the character 
                # before which means that this character is replaced by the next 
                # possible one

                # example: tuple='ABCDE' so n = 5, r=3 indices is [0,1,2] at start i=2
                # yield (A,B,C)
                # indices[i] is 2 and checks if 2 != 4 (2 +5-3) is true and break
                # increase indices[i]+1 and yield (A,B,D)
                # indices[i] is 3 and checks if 3 != 4 (2 +5-3) is true and break
                # increase indices[i]+1 and yield (A,B,E) 
                # indices[i] is 4 and checks if 4 != 4 (2 +5-3) is false so next loop 
                # iteration:  i = 1 indices[i] is 1 and checks if 4 != 3 (1 +5-3) 
                # is true and break .... and so on

                break
        else:
            # when the forloop completely finished then all possible character 
            # combinations are processed and the function ends
            return

        indices[i] += 1 # as written proceed with the next character which means the 
                        # index at i is increased
        for j in range(i+1, r): 
            indices[j] = indices[j-1] + 1 # all the following indexes are increased as 
                                          # well since we only want to at following 
                                          # characters and not at previous one or the
                                          # same which is index at indice[i]
        yield tuple(pool[i] for i in indices)
        # return the new tuple
Buddhism answered 24/7, 2014 at 9:43 Comment(0)
C
4
def combinations(iterable, r):
    # first, we need to understand, this function is to record every possibility of indices
    # then return the elements with the indices

    pool = tuple(iterable)

    n = len(pool)

    if r > n:
        return
    indices = list(range(r))

    # yield the first permutation, 
    # cause in the "while" circle, we will start to change the indices by plus 1 consistently
    # for example: iterable is [1, 2, 3, 4, 5], and r = 3
    # this yield will return [1, 2, 3], but in the "while" loop, 
    # we will start to update last elements' index to 4, which will return [1, 2, 4]
    yield tuple(pool[i] for i in indices)

    while True:

        # in this for loop, we want to confirm whether indices[i] can be increased or not
        for i in reversed(range(r)):

            # after reversed, i will be r-1, r-2, r-3, ....,0
            # something we should know before we start the 'for' loop
            # the value of indices[r-1] should not greater than n-1
            # the value of indices[r-2] should not greater than n-2
            # and the maximum of indices[i] should be indices[r-1]
            # so the value of indices[r-1] should between r-1 and n-r + r-1, like this:
            #       r-1 <= indics[r-1] <= n-r + r-1
            # so, to r-2:
            #       r-2 <= indics[r-1] <= n-r + r-2
            # let's set i = r-1:
            #       i <= indices[i] <= n-r+i (n-1 is the maximum value)
            # since we will keep plusing the value of indices[i], let's ignore i <= indices[i]
            # and we just want to know if indices[i] can plus or not,
            # so indices[i] can be equal with n-r+i
            # then we got:
            #       indices[i] < i + n - r
            # the offical document give: indices[i] != i + n - r,
            # cause when indices[i] == i + n - r, it arrived the boundry, 
            # the "for" loop will get into indices[i-1], there will be no judgement for ">i+n-r"
            # so to use indices[i] != i + n - r is still a good way, 
            # but i prefer indices[i] < i + n - r, which is easier to understand for me.
            # so next question is "break" in here, 
            # it means the value of indices[i] doesn't reach the boundry still can plus one,
            # let break out to continue the iteration
            # when it hit the boundry, i will be r-2
            # So we can see the result:
            # 1, 2, 3
            # 1, 2, 4
            # 1, 2, 5
            # 1, 3, 4
            # always loop the last index, hit the boundry, check the last but one.
            if indices[i] < i + n - r:
                break
        else:
            # loop finished, return
            return

        # first of all, plus one for current indices[i], 
        # that's why we yield the first permutation, before the while loop
        # and increase every indices[i] by 1
        indices[i] = indices[i] + 1
        # this for loop is increase every indices which is after indices[i].
        # cause, current index increased, and we need to confirm every one behind is orderd
        # for example: current we got i = 2, indices[i]+1 will be 3, 
        # so the next loop will start with [1, 3, 4], not [1, 3, 3]
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1

        yield tuple(pool[i] for i in indices)
Crowns answered 10/3, 2017 at 7:0 Comment(0)
N
2

NOTICE! The code will be broken down and not properly indented in relation to each of it's parts, so I recommend taking a look at the code in the question itself / itertools documentation (same code) as well.

It's been over 7 years since this was asked. Wow. I've been interested in this myself, and the explanations above while being very helpful didn't really hit the spot for me, so here's the summarization I made for myself.
Since I finally managed to understand it (or atleast I think I do), I thought it might be beneficial to post this "version" of an explanation, in case there are more like me. so Let's start.

def combinations(iterable, r):
    pool = tuple(iterable)
    n = len(pool) 

In this first section, Simply make a tuple of the iterable and get the length of the iterable. These will be useful later.

if r > n:
        return
    indices = list(range(r))
    yield tuple(pool[i] for i in indices)

This is also quite straight forward - If the length of the needed combination is bigger than our pool of elements, we cannot construct a valid combination (you cannot make a combination of 5 elements from 4), therefore we simply stop the execution with a return statement. We also generate the first combination (the first r elements from our iterable).

This next part is slightly more complex, so read carefully.

while True:
    for i in reversed(range(r)):
        if indices[i] != n - (r - i):
            break
"""
The job of the while loop is to increment the indices one after
the other, and print out all the possible element combinations based
off all the possible valid indice combinations.

This for loop's job is to make sure we never over-increment any values.

In order for us to not run into any errors, the incremention of
the last element of the indice list must stop when it reaches one-less 
than the length of our element list, otherwise we'll run into an index error 
(trying to access an indice out of the list range).
How do we do that?
            
The range function will give us values cascading down from r-1 to 0
(r-1, r-2, r-3, ... , 0)
So first and foremost, the (r-1)st indice must not be greater than (n-1)
(remember, n is the length of our element pool), as that is the largest indice. 
We can then write

Indices[r - 1] < n - 1

Moreover, because we'll stop incrementing the r-1st indice when we reach it's
maximum value, we must also stop incrementing the (r-2)nd indice when we reach
it's maximum value. What's the (r-2)nd indice maximum value?

Since we'll also be incrementing the (r-1)st indice based on the 
(r-2)nd indice, and because the maximum value of the (r-1)st 
indice is (n-1), plus we want no duplicates, the maximum value the
(r-2)nd indice can reach would be (n-2).
This trend will continue. more generally:
            
Indices[r - k] < n - k

Now, since r - k is an arbitrary index generated by the reversed range function, 
namely (i), we can substitute:

r - k = i -----> k = r - i
Indices[r - k] < n - k -----> Indices[i] < n - (r - i)
            
That's our limit - for any indice i we generate, we must break the 
increment if this inequality { Indices[i] < n - (r - i) } is no longer 
true.
(In the documentation it's written as (Indice[i] != i + n - r), and it 
means the exact same thing. I simply find this version easier to visualize 
and understand).
"""
else:
    return
"""
When our for loop runs out - which means we've gone through and 
maximized each element in our indice list - we've gone through every 
single combination, and we can exit the function.

It's important to distinct that the else statement is not linked to 
the if statement in this case, but rather to the for loop. It's a 
for-else statement, meaning "If you've finished iterating through the 
entire loop, execute the else statement".
"""

If we did manage to break out of the for loop, this means that we can safely increment our indice to get the next combination (the first line below). The for loop below makes sure that every time we start on a new index, we reset the other indexes back to their smallest possible value, so as to not miss any combinations.

for example, if we were to not do that, then once we reached a point where we had to move on, say we had (0, 1, 2, 3, 4) and the combination indexes were (0, 1, 4), when we'd move on and increment 1 to 2, the last index will remain the same - 4, and we'll miss out on (0, 2, 3), only registering (0, 2, 4) as a valid combination. Instead, after we increment (1 -> 2), we update the latter indices based on that : (4 -> 3), and when we run the while loop again we'd increment 3 back to 4 (refer to the previous section).

Notice that we never increment the previous indices, as to not create duplicates.

And finally, for each iteration, the yield statement generates the element combination corresponding to the current indice combination.

indices[i] += 1
for j in range(i+1, r):
    indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)

And just as the documentation states, because we're dealing with positions, a unique combination is unique based on the elements' locations in the iterable, not their value.

Norty answered 5/9, 2021 at 20:40 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.