Suppose that I'm trying to implement a very simple domain specific language with only one operation:
printLine(line)
Then I want to write a program that takes an integer n
as input, prints something if n
is divisible by 10k, and then calls itself with n + 1
, until n
reaches some maximum value N
.
Omitting all syntactic noise caused by for-comprehensions, what I want is:
@annotation.tailrec def p(n: Int): Unit = {
if (n % 10000 == 0) printLine("line")
if (n > N) () else p(n + 1)
}
Essentially, it would be a kind of "fizzbuzz".
Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:
import scalaz._
object Demo1 {
// define operations of a little domain specific language
sealed trait Lang[X]
case class PrintLine(line: String) extends Lang[Unit]
// define the domain specific language as the free monad of operations
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
// lift operations into the free monad
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
// write a program that is just a loop that prints current index
// after every few iteration steps
val mod = 100000
val N = 1000000
// straightforward syntax: deadly slow, exits with OutOfMemoryError
def p0(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- (if (i > N) ret else p0(i + 1))
} yield ()
// Same as above, but written out without `for`
def p1(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
}
// Same as above, with `map` attached to recursive call
def p2(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p2(i + 1).map{ ignore2 => () })
}
// Same as above, but without the `map`; performs ok.
def p3(i: Int): Prog[Unit] = {
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
if (i > N) ret else p3(i + 1)
}
}
// Variation of the above; Ok.
def p4(i: Int): Prog[Unit] = (for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
} yield ()).flatMap{ ignored2 =>
if (i > N) ret else p4(i + 1)
}
// try to use the variable returned by the last generator after yield,
// hope that the final `map` is optimized away (it's not optimized away...)
def p5(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
stopHere <- (if (i > N) ret else p5(i + 1))
} yield stopHere
// define an interpreter that translates the programs into Trampoline
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case PrintLine(l) => Trampoline.delay(println(l))
}
}
// try it out
def main(args: Array[String]): Unit = {
println("\n p0")
p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p1")
p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p2")
p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p3")
p3(0).foldMap(interpreter).run // ok
println("\n p4")
p4(0).foldMap(interpreter).run // ok
println("\n p5")
p5(0).foldMap(interpreter).run // OutOfMemory
}
}
Unfortunately, the straightforward translation (p0
) seems to run with some kind of O(N^2) overhead, and crashes with an OutOfMemoryError. The problem seems to be that the for
-comprehension appends a map{x => ()}
after the recursive call to p0
, which forces the Free
monad to fill the entire memory with reminders to "finish 'p0' and then do nothing".
If I manually "unroll" the for
comprehension, and write out the last flatMap
explicitly (as in p3
and p4
), then the problem goes away, and everything runs smoothly. This, however, is an extremely brittle workaround: the behavior of the program changes dramatically if we simply append a map(id)
to it, and this map(id)
isn't even visible in the code, because it is generated automatically by the for
-comprehension.
In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/
it has been repeatedly advised to wrap recursive calls into a suspend
. Here is an attempt with Applicative
instance and suspend
:
import scalaz._
// Essentially same as in `Demo1`, but this time with
// an `Applicative` and an explicit `Suspend` in the
// `for`-comprehension
object Demo2 {
sealed trait Lang[H]
case class Const[H](h: H) extends Lang[H]
case class PrintLine[H](line: String) extends Lang[H]
implicit object Lang extends Applicative[Lang] {
def point[A](a: => A): Lang[A] = Const(a)
def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
case Const(x) => {
f match {
case Const(ab) => Const(ab(x))
case _ => throw new Error
}
}
case PrintLine(l) => PrintLine(l)
}
}
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
val mod = 100000
val N = 2000000
// try to suspend the entire second generator
def p7(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- Free.suspend(if (i > N) ret else p7(i + 1))
} yield ()
// try to suspend the recursive call
def p8(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- if (i > N) ret else Free.suspend(p8(i + 1))
} yield ()
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case Const(x) => Trampoline.done(x)
case PrintLine(l) =>
(Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
}
}
def main(args: Array[String]): Unit = {
p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
p8(0).foldMap(interpreter).run // same...
}
}
Inserting suspend
did not really help: it's still slow, and crashes with OutOfMemoryError
s.
Should I use the suspend
somehow differently?
Maybe there is some purely syntactic remedy that makes it possible to use for-comprehensions without generating the map
in the end?
I'd really appreciate if someone could point out what I'm doing wrong here, and how to repair it.
OutOfMemory
. When I increased the N by ten times, it was getting slower (which is expectable, cause you should get O(N*N)) comparing with naive tailrec solution (where you have O(N)), but still no OOM error. – Phonemic