How can I make Cartesian product with Java 8 streams?
Asked Answered
F

11

52

I have the following collection type:

Map<String, Collection<String>> map;

I would like to create unique combinations of each of map.size() from a single value in the collection for each Key.

For example suppose the map looks like the following:

A, {a1, a2, a3, ..., an}
B, {b1, b2, b3, ..., bn}
C, {c1, c2, c3, ..., cn}

The result I would like to get would a List<Set<String>> result, looking similar to (ordering is not important, it just needs to be a 'complete' result consisting of all possible combinations):

{a1, b1, c1},
{a1, b1, c2},
{a1, b1, c3},
{a1, b2, c1},
{a1, b2, c2},
{a1, b2, c3},
...
{a2, b1, c1},
{a2, b1, c2},
...
{a3, b1, c1},
{a3, b1, c2},
...
{an, bn, cn}

This is basically a counting problem, but I would like to see if a solution is possible using Java 8 streams.

Floyfloyd answered 21/8, 2015 at 3:38 Comment(0)
H
23

You can solve this using the recursive flatMap chain.

First as we need to move back and forth by the map values, it's better to copy them to the ArrayList (this is not the deep copy, in your case it's ArrayList of 3 elements only, so the additional memory usage is low).

Second, to maintain a prefix of previously visited elements, let's create a helper immutable Prefix class:

private static class Prefix<T> {
    final T value;
    final Prefix<T> parent;

    Prefix(Prefix<T> parent, T value) {
        this.parent = parent;
        this.value = value;
    }

    // put the whole prefix into given collection
    <C extends Collection<T>> C addTo(C collection) {
        if (parent != null)
            parent.addTo(collection);
        collection.add(value);
        return collection;
    }
}

This is very simple immutable linked list which can be used like this:

List<String> list = new Prefix<>(new Prefix<>(new Prefix<>(null, "a"), "b"), "c")
                          .addTo(new ArrayList<>()); // [a, b, c];

Next, let's create the internal method which chains flatMaps:

private static <T, C extends Collection<T>> Stream<C> comb(
        List<? extends Collection<T>> values, int offset, Prefix<T> prefix,
        Supplier<C> supplier) {
    if (offset == values.size() - 1)
        return values.get(offset).stream()
                     .map(e -> new Prefix<>(prefix, e).addTo(supplier.get()));
    return values.get(offset).stream()
            .flatMap(e -> comb(values, offset + 1, new Prefix<>(prefix, e), supplier));
}

Looks like recursion, but it's more complex: it doesn't call itself directly, but passed lambda which calls the outer method. Parameters:

  • values: the List of original values (new ArrayList<>(map.values) in your case).
  • offset: the current offset within this list
  • prefix: the current prefix of length offset (or null if offset == 0). It contains currently selected elements from the collections list.get(0), list.get(1) up to list.get(offset-1).
  • supplier: the factory method to create the resulting collection.

When we reached the end of the values list (offset == values.size() - 1), we map the elements of the last collection from the values to the final combination using the supplier. Otherwise we use the flatMap which for each intermediate element enlarges the prefix and calls the comb method again for the next offset.

Finally here's public method to use this feature:

public static <T, C extends Collection<T>> Stream<C> ofCombinations(
        Collection<? extends Collection<T>> values, Supplier<C> supplier) {
    if (values.isEmpty())
        return Stream.empty();
    return comb(new ArrayList<>(values), 0, null, supplier);
}

A usage example:

Map<String, Collection<String>> map = new LinkedHashMap<>(); // to preserve the order
map.put("A", Arrays.asList("a1", "a2", "a3", "a4"));
map.put("B", Arrays.asList("b1", "b2", "b3"));
map.put("C", Arrays.asList("c1", "c2"));

ofCombinations(map.values(), LinkedHashSet::new).forEach(System.out::println);

We collect individual combinations to the LinkedHashSet again to preserve the order. You can use any other collection instead (e.g. ArrayList::new).

Heulandite answered 21/8, 2015 at 4:30 Comment(0)
D
16

A simpler answer, for a simpler situation where you just want to have the cartesian product of the elements of two collections.

Here's some code which uses flatMap to generate the cartesian product of two short lists:

public static void main(String[] args) {
    List<Integer> aList = Arrays.asList(1, 2, 3);
    List<Integer> bList = Arrays.asList(4, 5, 6);

    Stream<List<Integer>> product = aList.stream().flatMap(a ->
            bList.stream().flatMap(b ->
                    Stream.of(Arrays.asList(a, b))));

    product.forEach(p -> { System.out.println(p); });

    // prints:
    // [1, 4]
    // [1, 5]
    // [1, 6]
    // [2, 4]
    // [2, 5]
    // [2, 6]
    // [3, 4]
    // [3, 5]
    // [3, 6]
}

If you want to add more collections, just nest the streams a litter further:

aList.stream().flatMap(a ->
    bList.stream().flatMap(b ->
        cList.stream().flatMap(c ->
            Stream.of(Arrays.asList(a, b, c)))));
Dibranchiate answered 21/6, 2019 at 11:26 Comment(2)
This is good, but the second flatMap and the last stream are unnecessary – you can just do Stream<List<Integer>> product = aList.stream().flatMap(a -> bList.stream().map(b -> Arrays.asList(a, b)));.Ellingson
True, but I liked the symmetry of this code. Not optimal in efficiency as you said. Thanks for the feedback!Dibranchiate
C
15

Cartesian product in Java 8 with forEach:

List<String> listA = Arrays.asList("0", "1");
List<String> listB = Arrays.asList("a", "b");

List<String> cartesianProduct = new ArrayList<>();
listA.forEach(a -> listB.forEach(b -> cartesianProduct.add(a + b)));

System.out.println(cartesianProduct);
// Output: [0a, 0b, 1a, 1b]
Commensal answered 15/3, 2017 at 16:56 Comment(0)
S
11

A solution that mainly operates on lists, making things a lot simpler. It does a recursive call in flatMap, keeping track of the elements that have already been combined, and the collections of elements that are still missing, and offers the results of this nested recursive construction as a stream of lists:

import java.util.*;
import java.util.stream.Stream;

public class CartesianProduct {
    public static void main(String[] args) {
        Map<String, Collection<String>> map =
                new LinkedHashMap<String, Collection<String>>();
        map.put("A", Arrays.asList("a1", "a2", "a3", "a4"));
        map.put("B", Arrays.asList("b1", "b2", "b3"));
        map.put("C", Arrays.asList("c1", "c2"));
        ofCombinations(map.values()).forEach(System.out::println);
    }

    public static <T> Stream<List<T>> ofCombinations(
            Collection<? extends Collection<T>> collections) {
        return ofCombinations(
                new ArrayList<Collection<T>>(collections),
                Collections.emptyList());
    }

    private static <T> Stream<List<T>> ofCombinations(
            List<? extends Collection<T>> collections, List<T> current) {
        return collections.isEmpty() ? Stream.of(current) :
                collections.get(0).stream().flatMap(e -> {
                    List<T> list = new ArrayList<T>(current);
                    list.add(e);
                    return ofCombinations(
                            collections.subList(1, collections.size()), list);
                });
    }
}
Seringapatam answered 21/8, 2015 at 17:26 Comment(0)
R
10

While it's not a Stream solution, Guava's com.google.common.collect.Sets does that for you.

Set<List<String>> result = Sets.cartesianProduct(
        Set.of("a1", "a2"), Set.of("b1", "b2"), Set.of("c1", "c2"));
Raney answered 9/4, 2019 at 12:1 Comment(0)
F
5

Here is another solution, which does not use as many features from Streams as Tagir's example; however I believe it to be more straight-forward:

public class Permutations {
    transient List<Collection<String>> perms;
    public List<Collection<String>> list(Map<String, Collection<String>> map) {
        SortedMap<String, Collection<String>> sortedMap = new TreeMap<>();
        sortedMap.putAll(map);
        sortedMap.values().forEach((v) -> perms = expand(perms, v));
        return perms;
    }

    private List<Collection<String>> expand(
            List<Collection<String>> list, Collection<String> elements) {
        List<Collection<String>> newList = new LinkedList<>();
        if (list == null) {
            elements.forEach((e) -> {
                SortedSet<String> set = new TreeSet<>();
                set.add(e);
                newList.add(set);
            });
        } else {
            list.forEach((set) ->
                    elements.forEach((e) -> {
                        SortedSet<String> newSet = new TreeSet<>();
                        newSet.addAll(set);
                        newSet.add(e);
                        newList.add(newSet);
                    }));
        }
        return newList;
    }
}

You can remove the Sorted prefix if you are not interested in ordering of elements; though, I think it's easier to debug if everything is sorted.

Usage:

Permutations p = new Permutations();
List<Collection<String>> plist = p.list(map);
plist.forEach((s) -> System.out.println(s));

Enjoy!

Femi answered 21/8, 2015 at 5:58 Comment(1)
Note that your solution actually uses zero Stream API features (Collection.forEach is not the part of Stream API). You can replace .forEach with good old for-in loop and your code will be Java 5-compatible. Also note that you store all the combinations in the memory. While this seems ok for the OP, it can become problematic with larger input. Finally there's no easy way to parallelize it.Heulandite
M
2

The map-and-reduce approach with nested loops within one stream

One outer stream can be easily converted to parallel - this can reduce the computation time in some cases. Inner iterations are implemented with loops.

Try it online!

/**
 * @param map a map of lists
 * @param <T> the type of the elements
 * @return the Cartesian product of map values
 */
public static <T> List<List<T>> cartesianProduct(Map<T, List<T>> map) {
    // check if incoming data is not null
    if (map == null) return Collections.emptyList();
    return map.values().stream().parallel()
            // non-null and non-empty lists
            .filter(list -> list != null && list.size() > 0)
            // represent each list element as a singleton list
            .map(list -> {
                List<List<T>> nList = new ArrayList<>(list.size());
                for (T e : list) nList.add(Collections.singletonList(e));
                return nList;
            })
            // summation of pairs of inner lists
            .reduce((list1, list2) -> {
                // number of combinations
                int size = list1.size() * list2.size();
                // list of combinations
                List<List<T>> list = new ArrayList<>(size);
                for (List<T> inner1 : list1)
                    for (List<T> inner2 : list2) {
                        List<T> inner = new ArrayList<>();
                        inner.addAll(inner1);
                        inner.addAll(inner2);
                        list.add(inner);
                    }
                return list;
            }).orElse(Collections.emptyList());
}
public static void main(String[] args) {
    Map<String, List<String>> map = new LinkedHashMap<>();
    map.put("A", Arrays.asList("A1", "A2", "A3", "A4"));
    map.put("B", Arrays.asList("B1", "B2", "B3"));
    map.put("C", Arrays.asList("C1", "C2"));

    List<List<String>> cp = cartesianProduct(map);
    // column-wise output
    int rows = 6;
    for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cp.size(); j++)
            System.out.print(j % rows == i ? cp.get(j) + " " : "");
        System.out.println();
    }
}

Output:

[A1, B1, C1] [A2, B1, C1] [A3, B1, C1] [A4, B1, C1] 
[A1, B1, C2] [A2, B1, C2] [A3, B1, C2] [A4, B1, C2] 
[A1, B2, C1] [A2, B2, C1] [A3, B2, C1] [A4, B2, C1] 
[A1, B2, C2] [A2, B2, C2] [A3, B2, C2] [A4, B2, C2] 
[A1, B3, C1] [A2, B3, C1] [A3, B3, C1] [A4, B3, C1] 
[A1, B3, C2] [A2, B3, C2] [A3, B3, C2] [A4, B3, C2] 

See also: How to get Cartesian product from multiple lists?

Morrison answered 4/8, 2021 at 1:10 Comment(0)
D
0

Use a Consumer Function Class, a List<T> and a foreach

public void tester() {
    String[] strs1 = {"2", "4", "9"};
    String[] strs2 = {"9", "0", "5"};

    //Final output is {"29", "49, 99", "20", "40", "90", "25", "45", "95"}
    List<String> result = new ArrayList<>();
    Consumer<String> consumer = (String str) -> result.addAll(
            Arrays.stream(strs1).map(s -> s + str).collect(Collectors.toList()));
    Arrays.stream(strs2).forEach(consumer);

    System.out.println(result);
}
Dearborn answered 26/7, 2016 at 17:23 Comment(0)
W
0

In loop create combined list

List<String> cartesianProduct(List<List<String>> wordLists) {
    List<String> cp = wordLists.get(0);
    for (int i = 1; i < wordLists.size(); i++) {
        List<String> secondList = wordLists.get(i);
        List<String> combinedList = cp.stream()
                .flatMap(s1 -> secondList.stream()
                        .map(s2 -> s1 + s2))
                .collect(Collectors.toList());
        cp = combinedList;
    }
    return cp;
}
Waxen answered 4/4, 2018 at 10:38 Comment(0)
M
0

I wrote a class implementing Iterable, and holding only the current item in memory. The Iterable as well as the Iterator can be converted to a Stream if desired.

class CartesianProduct<T> implements Iterable<List<T>> {
  private final Iterable<? extends Iterable<T>> factors;

  public CartesianProduct(final Iterable<? extends Iterable<T>> factors) {
    this.factors = factors;
  }

  @Override
  public Iterator<List<T>> iterator() {
    return new CartesianProductIterator<>(factors);
  }
}

class CartesianProductIterator<T> implements Iterator<List<T>> {
  private final List<Iterable<T>> factors;
  private final Stack<Iterator<T>> iterators;
  private final Stack<T> current;
  private List<T> next;
  private int index = 0;

  private void computeNext() {
    while (true) {
      if (iterators.get(index).hasNext()) {
        current.add(iterators.get(index).next());
        if (index == factors.size() - 1) {
          next = new ArrayList<>(current);
          current.pop();
          return;
        }
        index++;
        iterators.add(factors.get(index).iterator());
      } else {
        index--;
        if (index < 0) {
          return;
        }
        iterators.pop();
        current.pop();
      }
    }
  }

  public CartesianProductIterator(final Iterable<? extends Iterable<T>> factors) {
    this.factors = StreamSupport.stream(factors.spliterator(), false)
          .collect(Collectors.toList());
    iterators = new Stack<>();
    current = new Stack<>();
    if (this.factors.size() == 0) {
      index = -1;
    } else {
      iterators.add(this.factors.get(0).iterator());
      computeNext();
    }
  }

  @Override
  public boolean hasNext() {
    if (next == null && index >= 0) {
      computeNext();
    }
    return next != null;
  }

  @Override
  public List<T> next() {
    if (!hasNext()) {
      throw new IllegalStateException();
    }
    var result = next;
    next = null;
    return result;
  }
}
Mystery answered 6/1, 2019 at 20:37 Comment(0)
S
0

You can use Stream.reduce method as follows.

Try it online!

Map<String, List<String>> map = new LinkedHashMap<>();
map.put("A", List.of("a1", "a2", "a3"));
map.put("B", List.of("b1", "b2", "b3"));
map.put("C", List.of("c1", "c2", "c3"));
List<List<String>> cartesianProduct = map.values().stream()
        // represent each list element as a singleton list
        .map(list -> list.stream().map(Collections::singletonList)
                .collect(Collectors.toList()))
        // reduce the stream of lists to a single list by
        // sequentially summing pairs of elements of two lists
        .reduce((list1, list2) -> list1.stream()
                // combinations of inner lists
                .flatMap(first -> list2.stream()
                        // merge two inner lists into one
                        .map(second -> Stream.of(first, second)
                                .flatMap(List::stream)
                                .collect(Collectors.toList())))
                // list of combinations
                .collect(Collectors.toList()))
        // List<List<String>>
        .orElse(Collections.emptyList());
// column-wise output
int rows = 9;
IntStream.range(0, rows)
        .mapToObj(i -> IntStream.range(0, cartesianProduct.size())
                .filter(j -> j % rows == i)
                .mapToObj(j -> cartesianProduct.get(j).toString())
                .collect(Collectors.joining("  ")))
        .forEach(System.out::println);

Output:

[a1, b1, c1]  [a2, b1, c1]  [a3, b1, c1]
[a1, b1, c2]  [a2, b1, c2]  [a3, b1, c2]
[a1, b1, c3]  [a2, b1, c3]  [a3, b1, c3]
[a1, b2, c1]  [a2, b2, c1]  [a3, b2, c1]
[a1, b2, c2]  [a2, b2, c2]  [a3, b2, c2]
[a1, b2, c3]  [a2, b2, c3]  [a3, b2, c3]
[a1, b3, c1]  [a2, b3, c1]  [a3, b3, c1]
[a1, b3, c2]  [a2, b3, c2]  [a3, b3, c2]
[a1, b3, c3]  [a2, b3, c3]  [a3, b3, c3]

See also: String permutations using recursion in Java

Sympetalous answered 17/4, 2021 at 10:23 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.