Free ~> Trampoline : recursive program crashes with OutOfMemoryError
Asked Answered
A

1

6

Suppose that I'm trying to implement a very simple domain specific language with only one operation:

printLine(line)

Then I want to write a program that takes an integer n as input, prints something if n is divisible by 10k, and then calls itself with n + 1, until n reaches some maximum value N.

Omitting all syntactic noise caused by for-comprehensions, what I want is:

@annotation.tailrec def p(n: Int): Unit = {
  if (n % 10000 == 0) printLine("line")
  if (n > N) () else p(n + 1)
}

Essentially, it would be a kind of "fizzbuzz".

Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:

import scalaz._

object Demo1 {

  // define operations of a little domain specific language
  sealed trait Lang[X]
  case class PrintLine(line: String) extends Lang[Unit]

  // define the domain specific language as the free monad of operations
  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}

  // lift operations into the free monad
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  // write a program that is just a loop that prints current index 
  // after every few iteration steps
  val mod =  100000
  val N =   1000000

  // straightforward syntax: deadly slow, exits with OutOfMemoryError
  def p0(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- (if (i > N) ret else p0(i + 1))
  } yield ()

  // Same as above, but written out without `for`
  def p1(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
    }

  // Same as above, with `map` attached to recursive call
  def p2(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p2(i + 1).map{ ignore2 => () })
    }

  // Same as above, but without the `map`; performs ok.
  def p3(i: Int): Prog[Unit] = {
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ 
      ignore1 =>
      if (i > N) ret else p3(i + 1)
    }
  }

  // Variation of the above; Ok.
  def p4(i: Int): Prog[Unit] = (for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
  } yield ()).flatMap{ ignored2 => 
    if (i > N) ret else p4(i + 1) 
  }

  // try to use the variable returned by the last generator after yield,
  // hope that the final `map` is optimized away (it's not optimized away...)
  def p5(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    stopHere <- (if (i > N) ret else p5(i + 1))
  } yield stopHere

  // define an interpreter that translates the programs into Trampoline
  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]  
  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case PrintLine(l) => Trampoline.delay(println(l))
    }
  }

  // try it out
  def main(args: Array[String]): Unit = {
    println("\n p0")
    p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p1")
    p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p2")
    p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p3")
    p3(0).foldMap(interpreter).run // ok 
    println("\n p4")
    p4(0).foldMap(interpreter).run // ok
    println("\n p5")
    p5(0).foldMap(interpreter).run // OutOfMemory
  }
}

Unfortunately, the straightforward translation (p0) seems to run with some kind of O(N^2) overhead, and crashes with an OutOfMemoryError. The problem seems to be that the for-comprehension appends a map{x => ()} after the recursive call to p0, which forces the Free monad to fill the entire memory with reminders to "finish 'p0' and then do nothing". If I manually "unroll" the for comprehension, and write out the last flatMap explicitly (as in p3 and p4), then the problem goes away, and everything runs smoothly. This, however, is an extremely brittle workaround: the behavior of the program changes dramatically if we simply append a map(id) to it, and this map(id) isn't even visible in the code, because it is generated automatically by the for-comprehension.

In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ it has been repeatedly advised to wrap recursive calls into a suspend. Here is an attempt with Applicative instance and suspend:

import scalaz._

// Essentially same as in `Demo1`, but this time with 
// an `Applicative` and an explicit `Suspend` in the 
// `for`-comprehension
object Demo2 {

  sealed trait Lang[H]

  case class Const[H](h: H) extends Lang[H]
  case class PrintLine[H](line: String) extends Lang[H]

  implicit object Lang extends Applicative[Lang] {
    def point[A](a: => A): Lang[A] = Const(a)
    def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
      case Const(x) => {
        f match {
          case Const(ab) => Const(ab(x))
          case _ => throw new Error
        }
      }
      case PrintLine(l) => PrintLine(l)
    }
  }

  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  val mod = 100000
  val N = 2000000

  // try to suspend the entire second generator
  def p7(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- Free.suspend(if (i > N) ret else p7(i + 1))
  } yield ()

  // try to suspend the recursive call
  def p8(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- if (i > N) ret else Free.suspend(p8(i + 1))
  } yield ()

  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]

  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case Const(x) => Trampoline.done(x)
      case PrintLine(l) => 
        (Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
    }
  }

  def main(args: Array[String]): Unit = {
    p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    p8(0).foldMap(interpreter).run // same...
  }
}

Inserting suspend did not really help: it's still slow, and crashes with OutOfMemoryErrors.

Should I use the suspend somehow differently?

Maybe there is some purely syntactic remedy that makes it possible to use for-comprehensions without generating the map in the end?

I'd really appreciate if someone could point out what I'm doing wrong here, and how to repair it.

Amphipod answered 13/12, 2016 at 14:1 Comment(2)
Hi, I copied and run your code and it was neither slow, not I got OutOfMemory. When I increased the N by ten times, it was getting slower (which is expectable, cause you should get O(N*N)) comparing with naive tailrec solution (where you have O(N)), but still no OOM error.Phonemic
It will probably depend on JVM settings and hardware. If you don't see the effect right away, try it with something like n = 1000000, N = 10.000.000 instead. On my laptop, some of the programs run noticeably slower, and fail with OutOfMemory for N = 5000000. But you should see the slowdown for smaller values of N.Amphipod
T
3

That superfluous map added by the Scala compiler moves the recursion from tail position to non-tail position. Free monad still makes this stack safe, but space complexity becomes O(N) instead of O(1). (Specifically, it is still not O(N2).)

Whether it is possible to make scalac optimize that map away makes for a separate question (which I don't know the answer to).

I will try to illustrate what is going on when interpreting p1 versus p3. (I will ignore the translation to Trampoline, which is redundant (see below).)

p3 (i.e. without extra map)

Let me use the following shorthand:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => if (i > N) ret else p3(i + 1)

Now p3(0) is interpreted as follows

p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)

and so on... You see that the amount of memory needed at any point doesn't exceed some constant upper bound.

p1 (i.e. with extra map)

I will use the following shorthands:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }

def cpu: Unit => Prg[Unit] = // constant pure unit
  ignore => Free.pure(())

Now p1(0) is interpreted as follows:

p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu

and so on... You see that the memory consumption depends linearly on N. We just moved the evaluation from stack to heap.

Take away: To keep Free memory friendly, keep the recursion in "tail position", that is, on the right hand-side of flatMap (or map).

Aside: The translation to Trampoline is not necessary, since Free is already trampolined. You could interpret directly to Id and use foldMapRec for stack-safe interpretation:

val idInterpreter = new (Lang ~> Id) {
  def apply[A](cmd: Lang[A]): Id[A] = cmd match {
    case PrintLine(l) => println(l)
  }
}

p0(0).foldMapRec(idInterpreter)

This will regain you some fraction of memory (but doesn't make the problem go away).

Thermoluminescent answered 13/12, 2016 at 18:53 Comment(4)
Thank you very much for the detailed illustration, it confirms my intuition that p0 leaves an O(N)-trail in the memory while running. I wasn't sure about the time overhead: with some older implementations of Free, that used F:Functor to append next operation to the end of a linked-list-like structure, I could imagine that it could actually be O(N^2), but I'll have to look at the current implementation and think about it again.Amphipod
On "Aside": I used Trampoline just for illustration, I'll probably use something else as 'interpretation target'. The reasoning "translation to Trampoline is not necessary, since Free is already trampolined" seems to be different from the answer to this question: #29660567Amphipod
On "optimizing map away": maybe something like def noMap[X](x: X) = new { def map(f: Unit => Unit): X = x }, and then wrap the last generator in noMap? It works, and eliminates the last map produced by for, but it would be nicer to use something more conventional, if it already exists somewhere (in Scalaz or somewhere else).Amphipod
@AndreyTyukin yes, there was an implementation of Free that depended on Functor and suffered from quadratic time complexity. That has been improved since. The answer you linked to is pre-foldMapRec. I failed to emphasize the importance of foldMapRec in stack-safe interpretation to Id. That's a clever trick with noMap! I'm not aware of a "standard" solution, but would be interested in hearing if you find one.Thermoluminescent

© 2022 - 2024 — McMap. All rights reserved.