Is it possible to use continuations to make foldRight tail recursive?
Asked Answered
S

4

10

The following blog article shows how in F# foldBack can be made tail recursive using continuation passing style.

In Scala this would mean that:

def foldBack[T,U](l: List[T], acc: U)(f: (T, U) => U): U = {
  l match {
    case x :: xs => f(x, foldBack(xs, acc)(f))
    case Nil => acc
  }
} 

can be made tail recursive by doing this:

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    l match {
      case x :: xs => loop(xs, (racc => k(f(x, racc))))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

Unfortunately, I still get a stack overflow for long lists. loop is tail recursive and optimized but I guess the stack accumulation is just moved into the continuation calls.

Why is this not a problem with F#? And is there any way to work around this with Scala?

Edit: here some code that shows depth of stack:

def showDepth(s: Any) {
  println(s.toString + ": " + (new Exception).getStackTrace.size)
}

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    showDepth("loop")
    l match {
      case x :: xs => loop(xs, (racc => { showDepth("k"); k(f(x, racc)) }))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

foldCont(List.fill(10)(1), 0)(_ + _)

This prints:

loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
k: 51
k: 52
k: 53
k: 54
k: 55
k: 56
k: 57
k: 58
k: 59
k: 60
res2: Int = 10
Sophistic answered 18/12, 2011 at 2:32 Comment(2)
It doesn't make sense. Do you have a simple test case?Greig
@DanielC.Sobral, see the code I added and that it prints.Sophistic
F
4

The problem is the continuation function (racc => k(f(x, racc))) itself. It should be tailcall optimized for this whole business to work, but isn't.

Scala cannot make tailcall optimizations for arbitrary tail calls, only for those it can transform into loops (i.e. when the function calls itself, not some other function).

Fletcherfletcherism answered 18/12, 2011 at 4:34 Comment(3)
That's what I guessed. Is there anything that can be done? Like using something like trampolines?Sophistic
Trampolines will probably help, but I think in this particular case leftFold would solve the problem with far less pain. If you for some reason absolutely must have foldRight semantics, you can reverse the list and call foldLeft on the result.Fletcherfletcherism
Turns out it's really not that painful in this case, see my own answer.Sophistic
S
6

Jon, n.m., thank you for your answers. Based on your comments I thought I'd give a try and use trampoline. A bit of research shows Scala has library support for trampolines in TailCalls. Here is what I came up with after a bit of fiddling around:

def foldContTC[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  import scala.util.control.TailCalls._
  @annotation.tailrec
  def loop(l: List[T], k: (U) => TailRec[U]): TailRec[U] = {
    l match {
      case x :: xs => loop(xs, (racc => tailcall(k(f(x, racc)))))
      case Nil => k(acc)
    }
  }
  loop(list, u => done(u)).result
} 

I was interested to see how this compares to the solution without the trampoline as well as the default foldLeft and foldRight implementations. Here is the benchmark code and some results:

val size = 1000
val list = List.fill(size)(1)
val warm = 10
val n = 1000
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldCont", warm, lots(n, foldCont(list, 0)(_ + _)))
bench("foldRight", warm, lots(n, list.foldRight(0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))

The timings are:

foldContTC: warming...
Elapsed: 0.094
foldCont: warming...
Elapsed: 0.060
foldRight: warming...
Elapsed: 0.160
foldLeft: warming...
Elapsed: 0.076
foldLeft.reverse: warming...
Elapsed: 0.155

Based on this, it would seem that trampolining is actually yielding pretty good performance. I suspect that the penalty on top of the boxing/unboxing is relatively not that bad.

Edit: as suggested by Jon's comments, here are the timings on 1M items which confirm that performance degrades with larger lists. Also I found out that library List.foldLeft implementation is not overriden, so I timed with the following foldLeft2:

def foldLeft2[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  list match {
    case x :: xs => foldLeft2(xs, f(x, acc))(f)
    case Nil => acc
  }
} 

val size = 1000000
val list = List.fill(size)(1)
val warm = 10
val n = 2
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft2", warm, lots(n, foldLeft2(list, 0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))
bench("foldLeft2.reverse", warm, lots(n, foldLeft2(list.reverse, 0)(_ + _)))

yields:

foldContTC: warming...
Elapsed: 0.801
foldLeft: warming...
Elapsed: 0.156
foldLeft2: warming...
Elapsed: 0.054
foldLeft.reverse: warming...
Elapsed: 0.808
foldLeft2.reverse: warming...
Elapsed: 0.221

So foldLeft2.reverse is the winner...

Sophistic answered 18/12, 2011 at 20:36 Comment(2)
"pretty good performance". Indeed. I'd call that suspiciously good performance! Perhaps the trampoline implementation is clever enough to realize that it doesn't have to kick in because your lists are so short? What performance measurements do you get with 1M-element lists?Zaller
With timings that close, cache and GC issues will also come into play, e.g. reversing the same 1k-element list over and over is cheap with a generational GC and cache efficient but 1M-element lists are likely to survive the nursery or thread-local region which will incur its overhead and degrade cache efficiency.Zaller
F
4

The problem is the continuation function (racc => k(f(x, racc))) itself. It should be tailcall optimized for this whole business to work, but isn't.

Scala cannot make tailcall optimizations for arbitrary tail calls, only for those it can transform into loops (i.e. when the function calls itself, not some other function).

Fletcherfletcherism answered 18/12, 2011 at 4:34 Comment(3)
That's what I guessed. Is there anything that can be done? Like using something like trampolines?Sophistic
Trampolines will probably help, but I think in this particular case leftFold would solve the problem with far less pain. If you for some reason absolutely must have foldRight semantics, you can reverse the list and call foldLeft on the result.Fletcherfletcherism
Turns out it's really not that painful in this case, see my own answer.Sophistic
Z
4

Why is this not a problem with F#?

F# has all tail calls optimized.

And is there any way to work around this with Scala?

You can do TCO using other techniques like trampolines but you lose interop because it changes the calling convention and it is ~10× slower. This is one of the three reasons I don't use Scala.

EDIT

Your benchmark results indicate that Scala's trampolines are a lot faster than they were the last time I tested them. Also, it is interesting to add equivalent benchmarks using F# and for larger lists (because there's no point in doing CPS on small lists!).

For 1,000x on a 1,000-element list on my netbook with a 1.67GHz N570 Intel Atom, I get:

List.fold     0.022s
List.rev+fold 0.116s
List.foldBack 0.047s
foldContTC    0.334s

For 1x 1,000,000-element list, I get:

List.fold     0.024s
List.rev+fold 0.188s
List.foldBack 0.054s
foldContTC    0.570s

You may also be interested in the old discussions about this on the caml-list in the context of replacing OCaml's non-tail-recursive list functions with optimized tail recursive ones.

Zaller answered 18/12, 2011 at 15:14 Comment(2)
What are the other two reasons you don't use Scala?Rations
@StephenSwensen: Lack of value types and lack of type inference. Note that the lack of tail calls and value types is a problem with the JVM and not Scala. These are also the reasons why I chose to develop HLVM on LLVM rather than the JVM. Geoff Reedy's project to port Scala to LLVM has the potential to fix both of these problems, which would be absolutely awesome.Zaller
I
3

I'm late to this question, but I wanted to show how you can write a tail-recursive FoldRight without using a full trampoline; by accumulating a list of continuations (instead of having them call each other when done, which leads to a stack overflow) and folding over them at the end, kind of like keeping a stack, but on the heap:

object FoldRight {

  def apply[A, B](list: Seq[A])(init: B)(f: (A, B) => B): B = {
    @scala.annotation.tailrec
    def step(current: Seq[A], conts: List[B => B]): B = current match {
      case Seq(last) => conts.foldLeft(f(last, init)) { (acc, next) => next(acc) }
      case Seq(x, xs @ _*) => step(xs, { acc: B => f(x, acc) } +: conts)
      case Nil => init
    }
    step(list, Nil)
  }

}

The fold that happens at the end is itself tail-recursive. Try it out in ScalaFiddle

In terms of performance, it performs slightly worse than the tail call version.

[info] Benchmark            (length)  Mode  Cnt   Score    Error  Units
[info] FoldRight.conts           100  avgt   30   0.003 ±  0.001  ms/op
[info] FoldRight.conts         10000  avgt   30   0.197 ±  0.004  ms/op
[info] FoldRight.conts       1000000  avgt   30  77.292 ±  9.327  ms/op
[info] FoldRight.standard        100  avgt   30   0.002 ±  0.001  ms/op
[info] FoldRight.standard      10000  avgt   30   0.154 ±  0.036  ms/op
[info] FoldRight.standard    1000000  avgt   30  18.796 ±  0.551  ms/op
[info] FoldRight.tailCalls       100  avgt   30   0.002 ±  0.001  ms/op
[info] FoldRight.tailCalls     10000  avgt   30   0.176 ±  0.004  ms/op
[info] FoldRight.tailCalls   1000000  avgt   30  33.525 ±  1.041  ms/op
Inception answered 22/9, 2016 at 7:34 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.