This isn't "an answer" so much as a spur to think harder ;-) For concreteness, I'll wrap the OP's code, slightly respelled, in a function that also weeds out duplicates:
def gen(pools, ixs):
from itertools import combinations_with_replacement as cwr
from itertools import chain, product
from collections import Counter
assert all(0 <= i < len(pools) for i in ixs)
seen = set()
cnt = Counter(ixs) # map index to count
blocks = [cwr(pools[i], count) for i, count in cnt.items()]
for t in product(*blocks):
t = tuple(sorted(chain(*t)))
if t not in seen:
seen.add(t)
yield t
I don't fear sorting here - it's memory-efficient, and for small tuples is likely faster than all the overheads involved in creating a Counter
object.
But regardless of that, the point here is to emphasize the real value the OP got by reformulating the problem to use combinations_with_replacement
(cwr
). Consider these inputs:
N = 64
pools = [[0, 1]]
ixs = [0] * N
There are only 65 unique results, and the function generates them instantaneously, with no internal duplicates at all. On the other hand, the essentially identical
pools = [[0, 1]] * N
ixs = range(N)
also has the same 65 unique results, but essentially runs forever (as would, e.g, the other answers so far given), slogging through 2**64 possibilities. cwr
doesn't help here because each pool index appears only once.
So there's astronomical room for improvement over any solution that "merely" weeds out duplicates from a full Cartesian product, and some of that can be won by doing what the OP already did.
It seems to me the most promising approach would be to write a custom generator (not one relying primarily on itertools
functions) that generated all possibilities in lexicographic order to begin with (so, by construction, no duplicates would be created to begin with). But that requires some "global" analysis of the input pools first, and the code I started on quickly got more complex than I can make time to wrestle with now.
One based on @user2357112's answer
Combining cwr
with @user2357112's incremental de-duplicating gives a brief algorithm that runs fast on all the test cases I have. For example, it's essentially instantaneous for both spellings of the [0, 1] ** 64
examples above, and runs the example at the end of @Joseph Wood's answer approximately as fast as he said his C++ code ran (0.35 seconds on my box under Python 3.7.0, and, yes, found 162295 results):
def gen(pools, ixs):
from itertools import combinations_with_replacement as cwr
from collections import Counter
assert all(0 <= i < len(pools) for i in ixs)
result = {()}
for i, count in Counter(ixs).items():
result = {tuple(sorted(old + new))
for new in cwr(pools[i], count)
for old in result}
return result
To make it easier for other Pythonistas to try the last example, here's the input as executable Python:
pools = [[1, 10, 14, 6],
[7, 2, 4, 8, 3, 11, 12],
[11, 3, 13, 4, 15, 8, 6, 5],
[10, 1, 3, 2, 9, 5, 7],
[1, 5, 10, 3, 8, 14],
[15, 3, 7, 10, 4, 5, 8, 6],
[14, 9, 11, 15],
[7, 6, 13, 14, 10, 11, 9, 4]]
ixs = range(len(pools))
However, the OP later added that they typically have about 20 pools, each with some thousands of elements. 1000**20 = 1e60 is waaaaay out of practical reach for any approach that builds the full Cartesian product, no matter how cleverly it weeds out duplicates. It remains clear as mud how many they expect to be duplicates, though, so also clear as mud whether this kind of "incremental de-duplicating" is good enough to be practical.
Ideally we'd have a generator that yielded one result at a time, in lexicographic order.
Lazy lexicographic one-at-a-time generation
Building on the incremental de-duplication, suppose we have a strictly increasing (lexicographic) sequence of sorted tuples, append the same tuple T
to each, and sort each again. Then the derived sequence is still in strictly increasing order. For example, in the left column we have the 10 unique pairs from range(4)
, and in the right column what happens after we append (and sort again) 2 to each:
00 002
01 012
02 022
03 023
11 112
12 122
13 123
22 222
23 223
33 233
They started in sorted order, and the derived triples are also in sorted order. I'll skip the easy proof (sketch: if t1
and t2
are adjacent tuples, t1 < t2
, and let i
be the smallest index such that t1[i] != t2[i]
. Then t1[i] < t2[i]
(what "lexicographic <" means). Then if you throw x
into both tuples, proceed by cases: is x <= t1[i]
? between t1[i]
and t2[i]
? is x >= t2[i]
? In each case it's easy to see that the first derived tuple remains strictly less then the second derived tuple.)
So supposing we have a sorted sequence result
of all unique sorted tuples from some number of pools, what happens when we add elements of a new pool P
into the tuples? Well, as above,
[tuple(sorted(old + (P[0],))) for old in result]
is also sorted, and so is
[tuple(sorted(old + (P[i],))) for old in result]
for all i
in range(len(P))
. These guaranteed already-sorted sequences can be merged via heapq.merge()
, and another generator (killdups()
below) run on the merge result to weed out duplicates on the fly. There's no need to, e.g., keep a set of all tuples seen so far. Because the output of the merge is non-decreasing, it's sufficient just to check whether the next result is the same as the last result output.
Getting this to work lazily is delicate. The entire result-so-far sequence has to be accessed by each element of the new pool being added, but we don't want to materialize the whole thing in one gulp. Instead itertools.tee()
allows each element of the next pool to traverse the result-so-far sequence at its own pace, and automatically frees memory for each result item after all new pool elements have finished with it.
The function build1()
(or some workalike) is needed to ensure that the right values are accessed at the right times. For example, if the body of build1()
is pasted in inline where it's called, the code will fail spectacularly (the body would access the final values bound to rcopy
and new
instead of what they were bound to at the time the generator expression was created).
In all, of course this is somewhat slower, due to layers of delayed generator calls and heap merges. In return, it returns the results in lexicographic order, can start delivering results very quickly, and has lower peak memory burden if for no other reason than that the final result sequence isn't materialized at all (little is done until the caller iterates over the returned generator).
Tech note: don't fear sorted()
here. The appending is done via old + new
for a reason: old
is already sorted, and new
is typically a 1-tuple. Python's sort is linear-time in this case, not O(N log N)
.
def gen(pools, ixs):
from itertools import combinations_with_replacement as cwr
from itertools import tee
from collections import Counter
from heapq import merge
def killdups(xs):
last = None
for x in xs:
if x != last:
yield x
last = x
def build1(rcopy, new):
return (tuple(sorted(old + new)) for old in rcopy)
assert all(0 <= i < len(pools) for i in ixs)
result = [()]
for i, count in Counter(ixs).items():
poolelts = list(cwr(pools[i], count))
xs = [build1(rcopy, new)
for rcopy, new in zip(tee(result, len(poolelts)),
poolelts)]
result = killdups(merge(*xs))
return result
2 inputs
Turns out that for the 2-input case there's an easy approach derived from set algebra. If x
and y
are the same, cwr(x, 2)
is the answer. If x
and y
are disjoint, product(x, y)
. Else the intersection c
of x
and y
is non-empty, and the answer is the catenation of 4 cross-products obtained from the 3 pairwise-disjoint sets c
, x-c
, and y-c
: cwr(c, 2)
, product(x-c, c)
, product(y-c, c)
, and product(x-c, y-c)
. Proof is straightforward but tedious so I'll skip it. For example, there are no duplicates between cwr(c, 2)
and product(x-c, c)
because every tuple in the latter contains an element from x-c
, but every tuple in the former contains elements only from c
, and x-c
and c
are disjoint by construction. There are no duplicates within product(x-c, y-c)
because the two inputs are disjoint (if they contained an element in common, that would have been in the intersection of x
and y
, contradicting that x-c
has no element in the intersection). Etc.
Alas, I haven't found a way to generalize this beyond 2 inputs, which surprised me. It can be used on its own, or as a building block in other approaches. For example, if there are many inputs, they can be searched for pairs with large intersections, and this 2-input scheme used to do those parts of the overall products directly.
Even at just 3 inputs, it's not clear to me how to get the right result for
[1, 2], [2, 3], [1, 3]
The full Cartesian product has 2**3 = 8 elements, only one of which repeats: (1, 2, 3)
appears twice (as (1, 2, 3)
and again as (2, 3, 1)
). Each pair of inputs has a 1-element intersection, but the intersection of all 3 is empty.
Here's an implementation:
def pair(x, y):
from itertools import product, chain
from itertools import combinations_with_replacement
x = set(x)
y = set(y)
c = x & y
chunks = []
if c:
x -= c
y -= c
chunks.append(combinations_with_replacement(c, 2))
if x:
chunks.append(product(x, c))
if y:
chunks.append(product(y, c))
if x and y:
chunks.append(product(x, y))
return chain.from_iterable(chunks)
A Proof-of-Concept Based on Maximal Matching
This blends ideas from @Leon's sketch and an approach @JosephWoods sketched in comments. It's not polished, and can obviously be sped up, but it's reasonably quick on all the cases I tried. Because it's rather complex, it's probably more useful to post it in an already-hard-enough-to-follow un-optimized form anyway!
This doesn't make any attempt to determine the set of "free" pools (as in @Leon's sketch). Primarily because I didn't have code sitting around for that, and partly because it wasn't immediately clear how to accomplish that efficiently. I did have code sitting around to find a match in a bipartite graph, which required only a few changes to use in this context.
So this tries plausible result prefixes in lexicographic order, as in @JosephWood's sketch, and for each sees whether it's actually possible to construct via checking whether a bipartite-graph match exists.
So while the details of @Leon's sketch are largely unimplemented here, the visible behaviors are much the same: it produces results in lexicographic order, it doesn't need to check for duplicates, it's a lazy generator, peak memory use is proportional to the sum of the lengths of the pools, it can obviously be parallelized in many ways (set different processes to work on different regions of the result space), and the key to making it majorly faster lies in reducing the massive amounts of redundant work done by the graph-matching function (a great deal of what it does on each call merely reproduces what it did on the previous call).
def matchgen(pools, ixs):
from collections import Counter
from collections import defaultdict
from itertools import chain, repeat, islice
elt2pools = defaultdict(set)
npools = 0
for i, count in Counter(ixs).items():
set_indices = set(range(npools, npools + count))
for elt in pools[i]:
elt2pools[elt] |= set_indices
npools += count
elt2count = {elt : len(ps) for elt, ps in elt2pools.items()}
cands = sorted(elt2pools.keys())
ncands = len(cands)
result = [None] * npools
# Is it possible to match result[:n] + [elt]*count?
# We already know it's possible to match result[:n], but
# this code doesn't exploit that.
def match(n, elt, count):
def extend(x, seen):
for y in elt2pools[x]:
if y not in seen:
seen.add(y)
if y in y2x:
if extend(y2x[y], seen):
y2x[y] = x
return True
else:
y2x[y] = x
return True
return False
y2x = {}
freexs = []
# A greedy pass first to grab easy matches.
for x in chain(islice(result, n), repeat(elt, count)):
for y in elt2pools[x]:
if y not in y2x:
y2x[y] = x
break
else:
freexs.append(x)
# Now do real work.
seen = set()
for x in freexs:
seen.clear()
if not extend(x, seen):
return False
return True
def inner(i, j): # fill result[j:] with elts from cands[i:]
if j >= npools:
yield tuple(result)
return
for i in range(i, ncands):
elt = cands[i]
# Find the most times `elt` can be added.
count = min(elt2count[elt], npools - j)
while count:
if match(j, elt, count):
break
count -= 1
# Since it can be added `count` times, it can also
# be added any number of times less than `count`.
for k in range(count):
result[j + k] = elt
while count:
yield from inner(i + 1, j + count)
count -= 1
return inner(0, 0)
EDIT: note that there's a potential trap here, illustrated by the pair of pools range(10_000)
and range(100_000)
. After producing (9999, 99999)
, the first position increments to 10000, and then it continues for a very long time deducing that there's no match for any of the possibilities in 10001 .. 99999 in the second position; and then for 10001 in the first position no match for any of the possibilities in 10002 .. 99999 in the second position; and so on. @Leon's scheme instead would have noted that range(10_000)
was the only free pool remaining having picked 10000 in the first position, and noted at once then that range(10_000)
doesn't contain any values greater than 10000. It would apparently need to do that again for 10001, 10002, ..., 99999 in the first position. That's a linear-time rather than quadratic-time waste of cycles, but a waste all the same. Moral of the story: don't trust anything until you have actual code to try ;-)
And One Based on @Leon's Scheme
Following is a more-than-less faithful implementation of @Leon's ideas. I like the code better than my "proof of concept" (POC) code just above, but was surprised to find that the new code runs significantly slower (a factor of 3 to 4 times slower on a variety of cases akin to @JospephWood's randomized example) relative to a comparably "optimized" variant of the POC code.
The primary reason appears to be more calls to the matching function. The POC code called that once per "plausible" prefix. The new code doesn't generate any impossible prefixes, but for each prefix it generates may need to make multiple match()
calls to determine the possibly smaller set of free pools remaining. Perhaps there's a cleverer way to do that.
Note that I added one twist: if a free pool's elements are all smaller than the last element of the prefix so far, it remains "a free pool" with respect to the prefix, but it's useless because none of its elements can appear in the candidates. This doesn't matter to the outcome, but it means the pool remains in the set of free pools for all remaining recursive calls, which in turn means they can waste time determining that it's still a "free pool". So when a free pool can no longer be used for anything, this version removes it from the set of free pools. This gave a significant speedup.
Note: there are many ways to try matching, some of which have better theoretical O()
worst-case behavior. In my experience, simple depth-first (as here) search runs faster in real life on typical cases. But it depends very much on characteristics of what "typical" graphs look like in the application at hand. I haven't tried other ways here.
Bottom lines, ignoring the "2 inputs" special-case code:
Nothing here beats incremental de-duplication for speed, if you have the RAM. But nothing is worse than that for peak memory burden.
Nothing beats the matching-based approaches for frugal memory burden. They're in an entirely different universe on that measure. They're also the slowest, although at least in the same universe ;-)
The code:
def matchgen(pools, ixs):
from collections import Counter
from collections import defaultdict
from itertools import islice
elt2pools = defaultdict(list)
allpools = []
npools = 0
for i, count in Counter(ixs).items():
indices = list(range(npools, npools + count))
plist = sorted(pools[i])
for elt in plist:
elt2pools[elt].extend(indices)
for i in range(count):
allpools.append(plist)
npools += count
pools = allpools
assert npools == len(pools)
result = [None] * npools
# Is it possible to match result[:n] not using pool
# bady? If not, return None. Else return a matching,
# a dict whose keys are pool indices and whose values
# are a permutation of result[:n].
def match(n, bady):
def extend(x, seen):
for y in elt2pools[x]:
if y not in seen:
seen.add(y)
if y not in y2x or extend(y2x[y], seen):
y2x[y] = x
return True
return False
y2x = {}
freexs = []
# A greedy pass first to grab easy matches.
for x in islice(result, n):
for y in elt2pools[x]:
if y not in y2x and y != bady:
y2x[y] = x
break
else:
freexs.append(x)
# Now do real work.
for x in freexs:
if not extend(x, {bady}):
return None
return y2x
def inner(j, freepools): # fill result[j:]
from bisect import bisect_left
if j >= npools:
yield tuple(result)
return
if j:
new_freepools = set()
allcands = set()
exhausted = set() # free pools with elts too small
atleast = result[j-1]
for pi in freepools:
if pi not in new_freepools:
m = match(j, pi)
if not m: # match must use pi
continue
# Since `m` is a match to result[:j],
# any pool in freepools it does _not_
# use must still be free.
new_freepools |= freepools - m.keys()
assert pi in new_freepools
# pi is free with respect to result[:j].
pool = pools[pi]
if pool[-1] < atleast:
exhausted.add(pi)
else:
i = bisect_left(pool, atleast)
allcands.update(pool[i:])
if exhausted:
freepools -= exhausted
new_freepools -= exhausted
else: # j == 0
new_freepools = freepools
allcands = elt2pools.keys()
for result[j] in sorted(allcands):
yield from inner(j + 1, new_freepools)
return inner(0, set(range(npools)))
Note: this has its own classes of "bad cases". For example, passing 128 copies of [0, 1]
consumes about 2 minutes(!) of time on my box to find the 129 results. The POC code takes under a second, while some of the non-matching approaches appear instantaneous.
I won't go into detail about why. Suffice it to say that because all the pools are the same, they all remain "free pools" no matter how deep the recursion goes. match()
never has a hard time, always finding a complete match for the prefix in its first (greedy) pass, but even that takes time proportional to the square of the current prefix length (== the current recursion depth).
Pragmatic hybrid
One more here. As noted before, the matching-based approaches suffer some from the expense of graph matching as a fundamental operation repeated so often, and have some unfortunate bad cases pretty easy to stumble into.
Highly similar pools cause the set of free pools to shrink slowly (or not at all). But in that case the pools are so similar that it rarely matters which pool an element is taken from. So the approach below doesn't try to keep exact track of free pools, picks arbitrary pools so long as such are obviously available, and resorts to graph-matching only when it gets stuck. That seems to work well. As an extreme example, the 129 results from 128 [0, 1]
pools are delivered in less than a tenth of second instead of in two minutes. It turns out it never needs to do graph-matching in that case.
Another problem with the POC code (and less so for the other match-based approach) was the possibility of spinning wheels for a long time after the last result was delivered. A pragmatic hack solves that one completely ;-) The last tuple of the sequence is easily computed in advance, and code raises an internal exception to end everything immediately after the last tuple is delivered.
That's it for me! A generalization of the "two inputs" case would remain very interesting to me, but all the itches I got from the other approaches have been scratched now.
def combogen(pools, ixs):
from collections import Counter
from collections import defaultdict
from itertools import islice
elt2pools = defaultdict(set)
npools = 0
cands = []
MAXTUPLE = []
for i, count in Counter(ixs).items():
indices = set(range(npools, npools + count))
huge = None
for elt in pools[i]:
elt2pools[elt] |= indices
for i in range(count):
cands.append(elt)
if huge is None or elt > huge:
huge = elt
MAXTUPLE.extend([huge] * count)
npools += count
MAXTUPLE = tuple(sorted(MAXTUPLE))
cands.sort()
ncands = len(cands)
ALLPOOLS = set(range(npools))
availpools = ALLPOOLS.copy()
result = [None] * npools
class Finished(Exception):
pass
# Is it possible to match result[:n]? If not, return None. Else
# return a matching, a dict whose keys are pool indices and
# whose values are a permutation of result[:n].
def match(n):
def extend(x, seen):
for y in elt2pools[x]:
if y not in seen:
seen.add(y)
if y not in y2x or extend(y2x[y], seen):
y2x[y] = x
return True
return False
y2x = {}
freexs = []
# A greedy pass first to grab easy matches.
for x in islice(result, n):
for y in elt2pools[x]:
if y not in y2x:
y2x[y] = x
break
else:
freexs.append(x)
# Now do real work.
seen = set()
for x in freexs:
seen.clear()
if not extend(x, seen):
return None
return y2x
def inner(i, j): # fill result[j:] with cands[i:]
nonlocal availpools
if j >= npools:
r = tuple(result)
yield r
if r == MAXTUPLE:
raise Finished
return
restore_availpools = None
last = None
jp1 = j + 1
for i in range(i, ncands):
elt = cands[i]
if elt == last:
continue
last = result[j] = elt
pools = elt2pools[elt] & availpools
if pools:
pool = pools.pop() # pick one - arbitrary
availpools.remove(pool)
else:
# Find _a_ matching, and if that's possible fiddle
# availpools to pretend that's the one we used all
# along.
m = match(jp1)
if not m: # the prefix can't be extended with elt
continue
if restore_availpools is None:
restore_availpools = availpools.copy()
availpools = ALLPOOLS - m.keys()
# Find a pool from which elt was taken.
for pool, v in m.items():
if v == elt:
break
else:
assert False
yield from inner(i+1, jp1)
availpools.add(pool)
if restore_availpools is not None:
availpools = restore_availpools
try:
yield from inner(0, 0)
except Finished:
pass
list
; could you in practice also collect them in aset
after sorting them? Or do you have to generate them? – Mother