Using Scala continuations with while loops
Asked Answered
L

1

9

I realize this is counter to the usual sense of SO questions, but the following code works even though I think it should not work. Below is a small Scala program that uses continuations with a while loop. According to my understanding of continuation passing style, this code should produce a stack overflow error by adding a frame to the stack for each iteration of the while loop. However, it works just fine.

import util.continuations.{shift, reset}


class InfiniteCounter extends Iterator[Int] {
    var count = 0
    var callback: Unit=>Unit = null
    reset {
        while (true) {
            shift {f: (Unit=>Unit) =>
                callback = f
            }
            count += 1
        }

    }

    def hasNext: Boolean = true

    def next(): Int = {
        callback()
        count
    }
}

object Experiment3 {

    def main(args: Array[String]) {
        val counter = new InfiniteCounter()
        println(counter.next())
        println("Hello")
        println(counter.next())
        for (i <- 0 until 100000000) {
            counter.next()
        }
        println(counter.next())
    }

}

The output is:

1
Hello
2
100000003

My question is: why is there no stack overflow? Is the Scala compiler doing tail call optimization (which I thought it couldn't do with continuations) or is there some other thing going on?

(This experiment is on github along with the sbt configuration needed to run it: https://github.com/jcrudy/scala-continuation-experiments. See commit 7cec9befcf58820b925bb222bc25f2a48cbec4a6)

Legislator answered 21/12, 2013 at 1:38 Comment(0)
R
7

The reason that you don't get a stack overflow here because the way you're using shift and callback() is acting like a trampoline.

Each time the execution thread reaches the shift construct, it sets callback equal to the current continuation (a closure), and then immediately returns Unit to the calling context. When you call next() and invoke callback(), you execute the continuation closure, which just executes count += 1, then jumps back to the beginning of the loop and executes the shift again.

One of the key benefits of the CPS transformation is that it capture the flow of control in the continuation rather than using the stack. When you set callback = f on each "iteration" you're overwriting your only reference to the previous continuation/state of the function, and that allows it to be garbage collected.

The stack here only ever reaches a depth of a few frames (it's probably around 10 because of all the nested closures). Each time you execute the shift it captures the current state in a closure (in the heap), and then the stack unrolls back to your for expression.

I feel like a diagram would make this clearer—but stepping through the code with your debugger would probably be just as useful. I think the key point here is, since you've essentially built a trampoline, you'll never blow the stack.

Romero answered 21/12, 2013 at 19:6 Comment(3)
I think that's a great explanation for something that's pretty confusing. I may try to sketch this out visually, and I'll post a link here if I create something illustrative. To clarify, this means that not only will this construction not blow the stack, but the total memory requirement is also modest and does not depend on the number of "iterations"?Legislator
@Legislator - Yes, the memory requirement is independent of the number iterations in the for expression. As a simple test I ran your code like this: JAVA_OPTS="-Xmx2M -verbose:gc" scala Experiment3 (this works on my laptop). This proves that the 100M iterations work with just 2MB of heap space. Adding the option for verbose garbage collection also lets you see that the actual memory usage stays pretty much constant throughout the execution.Romero
There is a nice, illustrated blog post by Rich Dougherty on trampolines in Scala: blog.richdougherty.com/2009/04/…Legislator

© 2022 - 2024 — McMap. All rights reserved.