Tail-recursive bounded stream of pairs of integers (Scala)?
Asked Answered
H

2

14

I'm very new to Scala, so forgive my ignorance! I'm trying to iterate of pairs of integers that are bounded by a maximum. For example, if the maximum is 5, then the iteration should return:

(0, 0), (0, 1), ..., (0, 5), (1, 0), ..., (5, 5)

I've chosen to try and tail-recursively return this as a Stream:

    @tailrec
    def _pairs(i: Int, j: Int, maximum: Int): Stream[(Int, Int)] = {
        if (i == maximum && j == maximum) Stream.empty
        else if (j == maximum) (i, j) #:: _pairs(i + 1, 0, maximum)
        else (i, j) #:: _pairs(i, j + 1, maximum)
    }

Without the tailrec annotation the code works:

scala> _pairs(0, 0, 5).take(11)
res16: scala.collection.immutable.Stream[(Int, Int)] = Stream((0,0), ?)

scala> _pairs(0, 0, 5).take(11).toList
res17: List[(Int, Int)] = List((0,0), (0,1), (0,2), (0,3), (0,4), (0,5), (1,0), (1,1), (1,2), (1,3), (1,4))

But this isn't good enough for me. The compiler is correctly pointing out that the last line of _pairs is not returning _pairs:

could not optimize @tailrec annotated method _pairs: it contains a recursive call not in tail position
    else (i, j) #:: _pairs(i, j + 1, maximum)
                ^

So, I have several questions:

  • directly addressing the implementation above, how does one tail-recursively return Stream[(Int, Int)]?
  • taking a step back, what is the most memory-efficient way to iterate over bounded sequences of integers? I don't want to iterate over Range because Range extends IndexedSeq, and I don't want the sequence to exist entirely in memory. Or am I wrong? If I iterate over Range.view do I avoid it coming into memory?

In Python (!), all I want is:

In [6]: def _pairs(maximum):
   ...:     for i in xrange(maximum+1):
   ...:         for j in xrange(maximum+1):
   ...:             yield (i, j)
   ...:             

In [7]: p = _pairs(5)

In [8]: [p.next() for i in xrange(11)]
Out[8]: 
[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (1, 0),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 4)]

Thanks for your help! If you think I need to read references / API docs / anything else please tell me, because I'm keen to learn.

Headman answered 9/5, 2012 at 23:15 Comment(0)
S
27

This is not tail-recursion

Let's suppose you were making a list instead of a stream: (let me use a simpler function to make my point)

def foo(n: Int): List[Int] =
  if (n == 0)
    0 :: Nil
  else
    n :: foo(n - 1)

In the general case in this recursion, after foo(n - 1) returns the function has to do something with the list that it returns -- it has to concatenate another item onto the beginning of the list. So the function can't be tail recursive, becuase something has to be done to the list after the recursion.

Without tail recursion, for some large value of n, you run out of stack space.

The usual list solution

The usual solution would be to pass a ListBuffer as a second parameter, and fill that.

def foo(n: Int) = {
  def fooInternal(n: Int, list: ListBuffer[Int]) = {
    if (n == 0) 
      list.toList
    else {
      list += n
      fooInternal(n - 1, list)
    }
  }
  fooInternal(n, new ListBuffer[Int]())
}

What you're doing is known as "tail recursion modulo cons", and this is an optimization performed automatically by LISP Prolog compilers when they see the tail recursion modulo cons pattern, since it's so common. Scala's compiler does not optimize this automatically.

Streams don't need tail recursion

Streams don't need tail recursion to avoid running out of stack space -- this is becuase they use a clever trick to keep from executing the recursive call to foo at the point where it appears in the code. The function call gets wrapped in a thunk, and only called at the point that you actually try to get the value from the stream. Only one call to foo is active at a time -- it's never recursive.

I've written a previous answer explaining how the #:: operator works here on Stackoverflow. Here's what happens when you call the following recursive stream function. (It is recursive in the mathematical sense, but it doesn't make a function call from within a function call the way you usually expect.)

def foo(n: Int): Stream[Int] =
  if (n == 0)
    0 #:: Nil
  else
    n #:: foo(n - 1)

You call foo(10), it returns a stream with one element computed already, and the tail is a thunk that will call foo(9) the next time you need an element from the stream. foo(9) is not called right now -- rather the call is bound to a lazy val inside the stream, and foo(10) returns immediately. When you finally do need the second value in the stream, foo(9) is called, and it computes one element and sets the tail of hte stream to be a thunk that will call foo(8). foo(9) returns immediately without calling foo(8). And so on...

This allows you to create infinite streams without running out of memory, for example:

def countUp(start: Int): Stream[Int] = start #::countUp(start + 1)

(Be careful what operations you call on this stream. If you try to do a forEach or a map, you'll fill up your whole heap, but using take is a good way to work with an arbitrary prefix of the stream.)

A simpler solution altogether

Instead of dealing with recursion and streams, why not just use Scala's for loop?

def pairs(maximum:Int) =
  for (i <- 0 to maximum;
       j <- 0 to maximum)
    yield (i, j)

This materializes the entire collection in memory, and returns an IndexedSeq[(Int, Int)].

If you need a Stream specifically, you can convert the first range into a Stream.

def pairs(maximum:Int) =
  for (i <- 0 to maximum toStream;
       j <- 0 to maximum)
    yield (i, j)

This will return a Stream[(Int, Int)]. When you access a certain point in the sequence, it will be materialized into memory, and it will stick around as long as you still have a reference to any point in the stream before that element.

You can get even better memory usage by converting both ranges into views.

def pairs(maximum:Int) =
  for (i <- 0 to maximum view;
       j <- 0 to maximum view)
    yield (i, j)

That returns a SeqView[(Int, Int),Seq[_]] that computes each element each time you need it, and doesn't store precomputed results.

You can also get an iterator (which you can only traverse once) the same way

def pairs(maximum:Int) =
  for (i <- 0 to maximum iterator;
       j <- 0 to maximum iterator)
    yield (i, j)

That returns Iterator[(Int, Int)].

Supper answered 9/5, 2012 at 23:26 Comment(8)
Thank you for your answer! I understand why what I did isn't tail recursive, and I'd definitely prefer to use for. The problem I have is that pairs, as you've suggested, returns IndexedSeq. Hence the whole result will exist in memory when pairs is called. Could you please elaborate on how to use views to avoid this?Headman
And do you have more details and references about Streams and thunks? I'm very curious about how I'm not going to blow the stack by recursively calling a non-tail-call optimised function where I don't use coroutines. So much to learn!Headman
+1 for the nice answer. Just one more remark: You can actually safely call map on the countUp stream, because the result will be a Stream again. Only the foreach call will have eager evaluation.Monitory
Wow, I really had no idea how Range works. Checking out the source code, github.com/scala/scala/blob/master/src/library/scala/collection/…, it's clear that they're lazy. Hence both (0 to 10) and (0 to 10000000) have the same memory occupancy (three Ints). Hence Range view or Range iterator are delightful answers, where Iterator tells callers "you can traverse the result once", and View tells callers "you can treat this like a real collection".Headman
@AsimIhsan: that's correct. Range.map, however materializes the whole collection, and that's what's going on in the for loop wihtout calling view or iterator first. (Scala 2.7 used to perform Range.map lazily, but that behavior was found to be surprising and too confusing, so it was changed.)Supper
@AsimIhsan: I just looked in the code. A Range has more like 7 Ints in its internal representation, but yeah it's still constant space.Supper
@Ken Bloom Whoops! I can't read :). I misread the following sentence on this page docs.scala-lang.org/overviews/collections/…: "Ranges are represented in constant space, because they can be defined by just three numbers: their start, their end, and the stepping value." I mistook constant space for three Ints. Thanks for clearing this up, and your amazing answer.Headman
You just saved Scala for me, thanks, I had a hard time finding out about .iterator()...Howlan
S
2

Maybe an Iterator is better suited for you?

class PairIterator (max: Int) extends Iterator [(Int, Int)] {
  var count = -1
  def hasNext = count <= max * max 
  def next () = { count += 1; (count / max, count % max) }
}

val pi = new PairIterator (5)
pi.take (7).toList 
Stencil answered 10/5, 2012 at 1:6 Comment(1)
By the way thanks for sharing this with me. I've been using Iterators for a lot of other problems and this is the only full example I can find!Headman

© 2022 - 2024 — McMap. All rights reserved.