Is it possible to perform a fold in the State monad in constant stack and heap space? Or is a different functional technique a better fit to my problem?
The next sections describe the problem and a motivating use case. I'm using Scala, but solutions in Haskell are welcome too.
Fold in the State
Monad Fills the Heap
Assume Scalaz 7. Consider a monadic fold in the State monad. To avoid stack overflows, we'll trampoline the fold.
import scalaz._
import Scalaz._
import scalaz.std.iterable._
import Free.Trampoline
type TrampolinedState[S, B] = StateT[Trampoline, S, B] // monad type constructor
type S = Int // state is an integer
type M[B] = TrampolinedState[S, B] // our trampolined state monad
type R = Int // or some other monoid
val col: Iterable[R] = largeIterableofRs() // defined elsewhere
val (count, sum): (S, R) = col.foldLeftM[M, R](Monoid[R].zero){
(acc: R, x: R) => StateT[Trampoline, S, R] {
s: S => Trampoline.done {
(s + 1, Monoid[R].append(acc, x))
}
}
} run 0 run
// In Scalaz 7, foldLeftM is implemented in terms of foldRight, which in turn
// is a reversed.foldLeft. This pulls the whole collection into memory and kills
// the heap. Ignore this heap overflow. We could reimplement foldLeftM to avoid
// this overflow or use a foldRightM instead.
// Our real issue is the heap used by the unexecuted State mobits.
For a large collection col
, this will fill the heap.
I believe that during the fold, a closure (a State mobit) is created for each value in the collection (the x: R
parameter), filling the heap. None of those can be evaluated until run 0
is executed, providing the initial state.
Can this O(n) heap usage be avoided?
More specifically, can the initial state be provided before the fold so that the State monad can execute during each bind, rather than nesting closures for later evaluation?
Or can the fold be constructed such that it is executed lazily after the State monad is run
? In this way, the next x: R
closure would not be created until after the previous ones have been evaluated and made suitable for garbage collection.
Or is there a better functional paradigm for this sort of work?
Example Application
But perhaps I'm using the wrong tool for the job. The evolution of an example use case follows. Am I wandering down the wrong path here?
Consider reservoir sampling, i.e., picking in one pass a uniform random k
items from a collection too large to fit in memory. In Scala, such a function might be
def sample[A](col: TraversableOnce[A])(k: Int): Vector[A]
and if pimped into the TraversableOnce
type could be used like this
val tenRandomInts = (Int.Min to Int.Max) sample 10
The work done by sample
is essentially a fold
:
def sample[A](col: Traversable[A])(k: Int): Vector[A] = {
col.foldLeft(Vector()){update(k)(_: Vector[A], _: A)}
}
However, update
is stateful; it depends on n
, the number of items already seen. (It also depends on an RNG, but for simplicity I assume that is global and stateful. The techniques used to handle n
would extend trivially.). So how to handle this state?
The impure solution is simple and runs with constant stack and heap.
/* Impure version of update function */
def update[A](k: Int) = new Function2[Vector[A], A, Vector[A]] {
var n = 0
def apply(sample: Vector[A], x: A): Vector[A] = {
n += 1
algorithmR(k, n, acc, x)
}
}
def algorithmR(k: Int, n: Int, acc: Vector[A], x: A): Vector[A] = {
if (sample.size < k) {
sample :+ x // must keep first k elements
} else {
val r = rand.nextInt(n) + 1 // for simplicity, rand is global/stateful
if (r <= k)
sample.updated(r - 1, x) // sample is 0-index
else
sample
}
}
But what about a purely functional solution? update
must take n
as an additional parameter and return the new value along with the updated sample. We could include n
in the implicit state, the fold accumulator, e.g.,
(col.foldLeft ((0, Vector())) (update(k)(_: (Int, Vector[A]), _: A)))._2
But that obscures the intent; we only really intend to accumulate the sample vector. This problem seems ready made for the State monad and a monadic left fold. Let's try again.
We'll use Scalaz 7, with these imports
import scalaz._
import Scalaz._
import scalaz.std.iterable_
and operate over an Iterable[A]
, since Scalaz doesn't support monadic folding of a Traversable
.
sample
is now defined
// sample using State monad
def sample[A](col: Iterable[A])(k: Int): Vector[A] = {
type M[B] = State[Int, B]
// foldLeftM is implemented using foldRight, which must reverse `col`, blowing
// the heap for large `col`. Ignore this issue for now.
// foldLeftM could be implemented differently or we could switch to
// foldRightM, implemented using foldLeft.
col.foldLeftM[M, Vector[A]](Vector())(update(k)(_: Vector[A], _: A)) eval 0
}
where update is
// update using State monad
def update(k: Int) = {
(acc: Vector[A], x: A) => State[Int, Vector[A]] {
n => (n + 1, algorithmR(k, n + 1, acc, x)) // algR same as impure solution
}
}
Unfortunately, this blows the stack on a large collection.
So let's trampoline it. sample
is now
// sample using trampolined State monad
def sample[A](col: Iterable[A])(k: Int): Vector[A] = {
import Free.Trampoline
type TrampolinedState[S, B] = StateT[Trampoline, S, B]
type M[B] = TrampolinedState[Int, B]
// Same caveat about foldLeftM using foldRight and blowing the heap
// applies here. Ignore for now. This solution blows the heap anyway;
// let's fix that issue first.
col.foldLeftM[M, Vector[A]](Vector())(update(k)(_: Vector[A], _: A)) eval 0 run
}
where update is
// update using trampolined State monad
def update(k: Int) = {
(acc: Vector[A], x: A) => StateT[Trampoline, Int, Vector[A]] {
n => Trampoline.done { (n + 1, algorithmR(k, n + 1, acc, x) }
}
}
This fixes the stack overflow, but still blows the heap for very large collections (or very small heaps). One anonymous function per
value in the collection is created during the fold (I believe to close over each x: A
parameter), consuming the heap before the trampoline is even run. (FWIW, the State version has this issue too; the stack overflow just surfaces first with smaller collections.)
f = s => bigFun()
thenbigFun
does not actually occur until you passs
. At which pointf
can be discarded unless you are holding on to it. More likely what is happening is that your collection is overly strict. Try with anEphemeralStream
and compare the results. – Chopstick