Here's a short python implementation, basically a recursive variation of Kruskal's. Uses weight of the the first MST found to limit the size of the search space thereafter. Definitely still exponential complexity but better than generating every spanning tree. Some test code is also included.
[Note: this was just my own experimentation for fun and possible inspiration of further thoughts on the problem from others, it's not an attempt to specifically implement any of the solutions suggested in other supplied answers here]
# Disjoint set find (and collapse)
def find(nd, djset):
uv = nd
while djset[uv] >= 0: uv = djset[uv]
if djset[nd] >= 0: djset[nd] = uv
return uv
# Disjoint set union (does not modify djset)
def union(nd1, nd2, djset):
unionset = djset.copy()
if unionset[nd2] < unionset[nd1]:
nd1, nd2 = nd2, nd1
unionset[nd1] += unionset[nd2]
unionset[nd2] = nd1
return unionset
# Bitmask convenience methods; uses bitmasks
# internally to represent MST edge combinations
def setbit(j, mask): return mask | (1 << j)
def isbitset(j, mask): return (mask >> j) & 1
def masktoedges(mask, sedges):
return [sedges[i] for i in range(len(sedges))
if isbitset(i, mask)]
# Upper-bound count of viable MST edge combination, i.e.
# count of edge subsequences of length: NEDGES, w/sum: WEIGHT
def count_subsequences(sedges, weight, nedges):
#{
def count(i, target, length, cache):
tkey = (i, target, length)
if tkey in cache: return cache[tkey]
if i == len(sedges) or target < sedges[i][2]: return 0
cache[tkey] = (count(i+1, target, length, cache) +
count(i+1, target - sedges[i][2], length - 1, cache) +
(1 if sedges[i][2] == target and length == 1 else 0))
return cache[tkey]
return count(0, weight, nedges, {})
#}
# Arg: n is number of nodes in graph [0, n-1]
# Arg: sedges is list of graph edges sorted by weight
# Return: list of MSTs, where each MST is a list of edges
def find_all_msts(n, sedges):
#{
# Recursive variant of kruskal to find all MSTs
def buildmsts(i, weight, mask, nedges, djset):
#{
nonlocal maxweight, msts
if nedges == (n-1):
msts.append(mask)
if maxweight == float('inf'):
print(f"MST weight: {weight}, MST edges: {n-1}, Total graph edges: {len(sedges)}")
print(f"Upper bound numb viable MST edge combinations: {count_subsequences(sedges, weight, n-1)}\n")
maxweight = weight
return
if i < len(sedges):
#{
u,v,wt = sedges[i]
if weight + wt*((n-1) - nedges) <= maxweight:
#{
# Left recursive branch - include edge if valid
nd1, nd2 = find(u, djset), find(v, djset)
if nd1 != nd2: buildmsts(i+1, weight + wt,
setbit(i, mask), nedges+1, union(nd1, nd2, djset))
# Right recursive branch - always skips edge
buildmsts(i+1, weight, mask, nedges, djset)
#}
#}
#}
maxweight, msts = float('inf'), []
djset = {i: -1 for i in range(n)}
buildmsts(0, 0, 0, 0, djset)
return [masktoedges(mask, sedges) for mask in msts]
#}
import time, numpy
def run_test_case(low=10, high=21):
rng = numpy.random.default_rng()
n = rng.integers(low, high)
nedges = rng.integers(n-1, n*(n-1)//2)
edges = set()
while len(edges) < nedges:
u,v = sorted(rng.choice(range(n), size=2, replace=False))
edges.add((u,v))
weights = sorted(rng.integers(1, 2*n, size=nedges))
sedges = [[u,v,wt] for (u,v), wt in zip(edges, weights)]
print(f"Numb nodes: {n}\nSorted edges: {sedges}\n")
for i, mst in enumerate(find_all_msts(n, sedges)):
if i == 0: print("MSTs:")
print((i+1), ":", mst)
if __name__ == "__main__":
initial = time.time()
run_test_case(20, 35)
print(f"\nRun time: {time.time() - initial}s")
n - 1
edges too slow? – Bra