I want to combine the elements of multiple Stream instances into a single Stream. What's the best way to do this?

This article compares a few different solutions.

Stream.concat(a, b)

The JDK provides Stream.concat(a, b) for concatenating two streams.

void exampleConcatTwo() {
  Stream<String> a = Stream.of("one", "two");
  Stream<String> b = Stream.of("three", "four");
  Stream<String> out = Stream.concat(a, b);
  out.forEach(System.out::println);
  // Output:
  // one
  // two
  // three
  // four
}

What if we have more than two streams?

We could use Stream.concat(a, b) multiple times. With three streams we could write Stream.concat(Stream.concat(a, b), c).

To me that approach is depressing at three streams, and it rapidly gets worse as we add more streams.

Reduce

Alternatively, we can use reduce to perform the multiple incantations of Stream.concat(a, b) for us. The code adapts elegantly to handle any number of input streams.

void exampleReduce() {
  Stream<String> a = Stream.of("one", "two");
  Stream<String> b = Stream.of("three", "four");
  Stream<String> c = Stream.of("five", "six");
  Stream<String> out = Stream.of(a, b, c)
      .reduce(Stream::concat)
      .orElseGet(Stream::empty);
  out.forEach(System.out::println);
  // Output:
  // one
  // two
  // three
  // four
  // five
  // six
}

Be careful using this pattern! Note the warning in the documentation of Stream.concat(a, b):

Use caution when constructing streams from repeated concatenation. Accessing an element of a deeply concatenated stream can result in deep call chains, or even StackOverflowError.

It takes quite a few input streams to trigger this problem, but it is trivial to demonstrate:

void exampleStackOverflow() {
  List<Stream<String>> inputs = new AbstractList<Stream<String>>() {
    @Override
    public Stream<String> get(int index) {
      return Stream.of("one", "two");
    }

    @Override
    public int size() {
      return 1_000_000; // try changing this number
    }
  };
  Stream<String> out = inputs.stream()
      .reduce(Stream::concat)
      .orElseGet(Stream::empty);
  long count = out.count(); // probably throws
  System.out.println("count: " + count); // probably never reached
}

On my workstation, this method throws StackOverflowError after several seconds of churning.

What's going on here?

We can think of the calls to Stream.concat(a, b) as forming a binary tree. At the root is the concatenation of all the input streams. At the leaves are the individual input streams. Let's look at the trees for up to five input streams as formed by our reduce operation.

Two streams:
concat(a,b)ab
Three streams:
concat(concat(a,b),c)concat(a,b)cab
Four streams:
concat(concat(concat(a,b),c),d)concat(concat(a,b),c)dconcat(a,b)cab
Five streams:
concat(concat(concat(concat(a,b),c),d),e)concat(concat(econcat(a,b),c),d)concat(concat(a,b),c)dconcat(a,b)cab

The trees are perfectly unbalanced! Each additional input stream adds one layer of depth to the tree and one layer of indirection to reach all the other streams. This can have a noticeable negative impact on performance. With enough layers of indirection we'll see a StackOverflowError.

Balance

If we're worried that we'll concatenate a large number of streams and run into the aforementioned problems, we can balance the tree. This is as if we're optimizing a O(n) algorithm into a O(logn) one. We won't totally eliminate the possibility of StackOverflowError, and there may be other approaches that perform even better, but this should be quite an improvement over the previous solution.

void exampleBalance() {
  Stream<String> a = Stream.of("one", "two");
  Stream<String> b = Stream.of("three", "four");
  Stream<String> c = Stream.of("five", "six");
  Stream<String> out = concat(a, b, c);
  out.forEach(System.out::println);
  // Output:
  // one
  // two
  // three
  // four
  // five
  // six
}

@SafeVarargs
static <T> Stream<T> concat(Stream<T>... in) {
  return concat(in, 0, in.length);
}

static <T> Stream<T> concat(Stream<T>[] in, int low, int high) {
  switch (high - low) {
    case 0: return Stream.empty();
    case 1: return in[low];
    default:
      int mid = (low + high) >>> 1;
      Stream<T> left = concat(in, low, mid);
      Stream<T> right = concat(in, mid, high);
      return Stream.concat(left, right);
  }
}

Flatmap

There is another way to concatenate streams that is built into the JDK, and it does not involve Stream.concat(a, b) at all. It is flatMap.

void exampleFlatMap() {
  Stream<String> a = Stream.of("one", "two");
  Stream<String> b = Stream.of("three", "four");
  Stream<String> c = Stream.of("five", "six");
  Stream<String> out = Stream.of(a, b, c).flatMap(s -> s);
  out.forEach(System.out::println);
  // Output:
  // one
  // two
  // three
  // four
  // five
  // six
}

This generally outperforms the solutions based on Stream.concat(a, b) when each input stream contains fewer than 32 elements. As we increase the element count past 32, flatMap performs comparatively worse and worse as the element count rises.

flatMap avoids the StackOverflowError issue but it comes with its own set of quirks. For example, it interacts poorly with infinite streams. Calling findAny on the concatenated stream may cause the program to enter an infinite loop, whereas the other solutions would terminate almost immediately.

void exampleInfiniteLoop() {
  Stream<String> a = Stream.generate(() -> "one");
  Stream<String> b = Stream.generate(() -> "two");
  Stream<String> c = Stream.generate(() -> "three");
  Stream<String> out = Stream.of(a, b, c).flatMap(s -> s);
  Optional<String> any = combined.findAny(); // infinite loop
  System.out.println(any); // never reached
}

(The infinite loop is an implementation detail. This could be fixed in the JDK without changing the contract of flatMap.)

Also, flatMap forces its input streams into sequential mode even if they were originally parallel. The outermost concatenated stream can still be made parallel, and we will be able to process elements from distinct input streams in parallel, but the elements of each individual input stream must all be processed sequentially.

Analysis

Let me share a few trends that I've noticed when dealing with streams and stream concatenation in general, having written a fair amount of code in Java 8 by now.

  • There have been maybe one dozen cases where I've needed to concatenate streams. That's not all that many, so no matter how good the solution is, it's not going to have much of an impact for me.
  • In all but one of those one dozen cases, I needed to concatenate exactly two streams, so Stream.concat(a, b) was sufficient.
  • In the remaining case, I needed to concatenate exactly three streams. I was not even close to the point where StackOverflowError would become an issue. Stream.concat(Stream.concat(a, b), c) would have worked just fine, although I went with flatMap because I felt that it was easier to read.
  • I have never needed to concatenate streams in performance-critical sections of code.
  • I use infinite streams very rarely. When I do use them, it is obvious in context that they are infinite. And so concatenating infinite streams together and then asking a question like findAny on the result is just not something that I would be tempted to do. That particular issue with flatMap seems like one that I'll never come across.
  • I use parallel streams very rarely. I think I've only used them twice in production code. It is almost never the case that going parallel improves performance, and even when it might improve performance, it is unlikely that processing them in the singleton ForkJoinPool.commonPool() is how I will want to manage that work. The issue with flatMap forcing the input streams to be sequential seems very unlikely to be a real problem for me.
  • Let's suppose that I do want to concatenate parallel streams and have them processed in parallel. If I have eight input streams on an eight core machine, and each stream has roughly the same number of elements, the fact that flatMap forces the individual streams to be sequential will not degrade performance for me at all. All eight cores will be fully utilized, each core processing one of the eight input streams. If I have seven input streams on that same machine, I will see only slightly degraded performance. With six, slightly more degraded, and so on.

What's the takeaway from all this? Here is my advice:

For two input streams, use:
Stream.concat(a, b)

For more than two input streams, use:
Stream.of(a, b, c, ...).flatMap(s -> s)

That solution is good enough...

Overboard

...but what if we're not satisfied with "good enough"? What if we want a solution that's really fast no matter the size and shape of the input and doesn't have any of the quirks of the other solutions?

It is a bit much to inline in a blog article, so take a look at StreamConcatenation.java for the source code.

This implementation is similar to Stream.concat(a, b) in that it uses a custom Spliterator, except this implementation handles any number of input streams.

It performs quite well. It does not outperform every other solution in every scenario (flatMap is generally better for very small input streams), but it never performs much worse and it scales nicely with the number and size of the input streams.

Benchmark

I wrote a JMH benchmark to compare the four solutions discussed in this article. The benchmark uses each solution to concatenate a variable number of input streams with a variable number of elements per stream, then iterates over the elements of the concatenated stream. Here is the raw JMH output from my workstation and a prettier visualization of the benchmark results.