Monadic fold with State monad in constant space (heap and stack)?
Asked Answered
M

2

11

Is it possible to perform a fold in the State monad in constant stack and heap space? Or is a different functional technique a better fit to my problem?

The next sections describe the problem and a motivating use case. I'm using Scala, but solutions in Haskell are welcome too.


Fold in the State Monad Fills the Heap

Assume Scalaz 7. Consider a monadic fold in the State monad. To avoid stack overflows, we'll trampoline the fold.

import scalaz._
import Scalaz._
import scalaz.std.iterable._
import Free.Trampoline

type TrampolinedState[S, B] = StateT[Trampoline, S, B] // monad type constructor

type S = Int  // state is an integer
type M[B] = TrampolinedState[S, B] // our trampolined state monad

type R = Int  // or some other monoid

val col: Iterable[R] = largeIterableofRs() // defined elsewhere

val (count, sum): (S, R) = col.foldLeftM[M, R](Monoid[R].zero){ 
    (acc: R, x: R) => StateT[Trampoline, S, R] {
      s: S => Trampoline.done { 
        (s + 1, Monoid[R].append(acc, x))
      }
    }
} run 0 run

// In Scalaz 7, foldLeftM is implemented in terms of foldRight, which in turn
// is a reversed.foldLeft. This pulls the whole collection into memory and kills
// the heap.  Ignore this heap overflow. We could reimplement foldLeftM to avoid
// this overflow or use a foldRightM instead.
// Our real issue is the heap used by the unexecuted State mobits.

For a large collection col, this will fill the heap.

I believe that during the fold, a closure (a State mobit) is created for each value in the collection (the x: R parameter), filling the heap. None of those can be evaluated until run 0 is executed, providing the initial state.

Can this O(n) heap usage be avoided?

More specifically, can the initial state be provided before the fold so that the State monad can execute during each bind, rather than nesting closures for later evaluation?

Or can the fold be constructed such that it is executed lazily after the State monad is run? In this way, the next x: R closure would not be created until after the previous ones have been evaluated and made suitable for garbage collection.

Or is there a better functional paradigm for this sort of work?


Example Application

But perhaps I'm using the wrong tool for the job. The evolution of an example use case follows. Am I wandering down the wrong path here?

Consider reservoir sampling, i.e., picking in one pass a uniform random k items from a collection too large to fit in memory. In Scala, such a function might be

def sample[A](col: TraversableOnce[A])(k: Int): Vector[A]

and if pimped into the TraversableOnce type could be used like this

val tenRandomInts = (Int.Min to Int.Max) sample 10

The work done by sample is essentially a fold:

def sample[A](col: Traversable[A])(k: Int): Vector[A] = {
    col.foldLeft(Vector()){update(k)(_: Vector[A], _: A)}
}

However, update is stateful; it depends on n, the number of items already seen. (It also depends on an RNG, but for simplicity I assume that is global and stateful. The techniques used to handle n would extend trivially.). So how to handle this state?

The impure solution is simple and runs with constant stack and heap.

/* Impure version of update function */
def update[A](k: Int) = new Function2[Vector[A], A, Vector[A]] {
    var n = 0
    def apply(sample: Vector[A], x: A): Vector[A] = {
        n += 1
        algorithmR(k, n, acc, x)
    }
}

def algorithmR(k: Int, n: Int, acc: Vector[A], x: A): Vector[A] = {
    if (sample.size < k) {
        sample :+ x // must keep first k elements
    } else {
        val r = rand.nextInt(n) + 1 // for simplicity, rand is global/stateful
        if (r <= k)
            sample.updated(r - 1, x) // sample is 0-index
        else
            sample
    }
}

But what about a purely functional solution? update must take n as an additional parameter and return the new value along with the updated sample. We could include n in the implicit state, the fold accumulator, e.g.,

(col.foldLeft ((0, Vector())) (update(k)(_: (Int, Vector[A]), _: A)))._2

But that obscures the intent; we only really intend to accumulate the sample vector. This problem seems ready made for the State monad and a monadic left fold. Let's try again.

We'll use Scalaz 7, with these imports

import scalaz._
import Scalaz._
import scalaz.std.iterable_

and operate over an Iterable[A], since Scalaz doesn't support monadic folding of a Traversable.

sample is now defined

// sample using State monad
def sample[A](col: Iterable[A])(k: Int): Vector[A] = {       
    type M[B] = State[Int, B]

    // foldLeftM is implemented using foldRight, which must reverse `col`, blowing
    // the heap for large `col`.  Ignore this issue for now.
    // foldLeftM could be implemented differently or we could switch to
    // foldRightM, implemented using foldLeft.
    col.foldLeftM[M, Vector[A]](Vector())(update(k)(_: Vector[A], _: A)) eval 0
}

where update is

// update using State monad
def update(k: Int) = {
    (acc: Vector[A], x: A) => State[Int, Vector[A]] {
        n => (n + 1, algorithmR(k, n + 1, acc, x)) // algR same as impure solution
    }
}

Unfortunately, this blows the stack on a large collection.

So let's trampoline it. sample is now

// sample using trampolined State monad
def sample[A](col: Iterable[A])(k: Int): Vector[A] = {
    import Free.Trampoline

    type TrampolinedState[S, B] = StateT[Trampoline, S, B]
    type M[B] = TrampolinedState[Int, B]

    // Same caveat about foldLeftM using foldRight and blowing the heap
    // applies here.  Ignore for now. This solution blows the heap anyway;
    // let's fix that issue first.
    col.foldLeftM[M, Vector[A]](Vector())(update(k)(_: Vector[A], _: A)) eval 0 run
}

where update is

// update using trampolined State monad
def update(k: Int) = {
    (acc: Vector[A], x: A) => StateT[Trampoline, Int, Vector[A]] {
        n => Trampoline.done { (n + 1, algorithmR(k, n + 1, acc, x) }
    }
}

This fixes the stack overflow, but still blows the heap for very large collections (or very small heaps). One anonymous function per value in the collection is created during the fold (I believe to close over each x: A parameter), consuming the heap before the trampoline is even run. (FWIW, the State version has this issue too; the stack overflow just surfaces first with smaller collections.)

Marasmus answered 24/12, 2013 at 7:5 Comment(3)
I don't think your guess is accurate, that there is "one function per value" being created on the heap and that this is what's eating your memory. The composite function is lazily created. Think about it. When you say f = s => bigFun() then bigFun does not actually occur until you pass s. At which point f can be discarded unless you are holding on to it. More likely what is happening is that your collection is overly strict. Try with an EphemeralStream and compare the results.Chopstick
Lazy creation was my initial understanding, but I am seeing those closures created (using a profiler). It's after the initial state is provided and the trampoline is run, but before the trampoline actually executes each thing. See my comments on your answer.Marasmus
Incidentally, once my confusion is resolved, I'll edit my question to remove red herrings (e.g, whether the collection fitting in memory. That's not actually relevant; just the big-O heap usage of the monadic fold...)Marasmus
C
7

Our real issue is the heap used by the unexecuted State mobits.

No, it is not. The real issue is that the collection doesn't fit in memory and that foldLeftM and foldRightM force the entire collection. A side effect of the impure solution is that you are freeing memory as you go. In the "purely functional" solution, you're not doing that anywhere.

Your use of Iterable ignores a crucial detail: what kind of collection col actually is, how its elements are created and how they are expected to be discarded. And so, necessarily, does foldLeftM on Iterable. It is likely too strict, and you are forcing the entire collection into memory. For example, if it is a Stream, then as long as you are holding on to col all the elements forced so far will be in memory. If it's some other kind of lazy Iterable that doesn't memoize its elements, then the fold is still too strict.

I tried your first example with an EphemeralStream did not see any significant heap pressure, even though it will clearly have the same "unexecuted State mobits". The difference is that an EphemeralStream's elements are weakly referenced and its foldRight doesn't force the entire stream.

I suspect that if you used Foldable.foldr, then you would not see the problematic behaviour since it folds with a function that is lazy in its second argument. When you call the fold, you want it to return a suspension that looks something like this immediately:

Suspend(() => head |+| tail.foldRightM(...))

When the trampoline resumes the first suspension and runs up to the next suspension, all of the allocations between suspensions will become available to be freed by the garbage collector.

Try the following:

def foldM[M[_]:Monad,A,B](a: A, bs: Iterable[B])(f: (A, B) => M[A]): M[A] =
  if (bs.isEmpty) Monad[M].point(a)
  else Monad[M].bind(f(a, bs.head))(fax => foldM(fax, bs.tail)(f))

val MS = StateT.stateTMonadState[Int, Trampoline]
import MS._

foldM[M,R,Int](Monoid[R].zero, col) {
  (x, r) => modify(_ + 1) map (_ => Monoid[R].append(x, r))
} run 0 run

This will run in constant heap for a trampolined monad M, but will overflow the stack for a non-trampolined monad.

But the real problem is that Iterable is not a good abstraction for data that are too large to fit in memory. Sure, you can write an imperative side-effecty program where you explicitly discard elements after each iteration or use a lazy right fold. That works well until you want to compose that program with another one. And I'm assuming that the whole reason you're investigating doing this in a State monad to begin with is to gain compositionality.

So what can you do? Here are some options:

  1. Make use of Reducer, Monoid, and composition thereof, then run in an imperative explicitly-freeing loop (or a trampolined lazy right fold) as the last step, after which composition is not possible or expected.
  2. Use Iteratee composition and monadic Enumerators to feed them.
  3. Write compositional stream transducers with Scalaz-Stream.

The last of these options is the one that I would use and recommend in the general case.

Chopstick answered 25/12, 2013 at 0:15 Comment(11)
For my tests, I used an anonymous new Iterator{...} (that simply incremented a var: Int). This doesn't hold prior elements in memory (confirmed with the stateful solution in the example application). Its behavior should be the same on the other sample implementations.Marasmus
I'm not concerned about some Iterable collections requiring all elements to be in memory---that should be considered when the collection is chosen. I am concerned about foldLeftM[State[B]](...)(...) using an additional O(n) heap space. (I should have been more specific in the question; I just thought the "too large to fit in memory" explanation was more simple to explain.)Marasmus
To be clear, I'm not trying to dispute your analysis (heck, I learned about StateT[Trampoline, S, B] from your paper on free monads and the stack overflows), just understand the root cause of my issue. Your three suggestions may be better for my problem (thanks!), but I'd like to understand why my foldLeftM[StateT[Trampoline, S, B] isn't using constant additional heap.Marasmus
With foldLeftM on a N element collection, I'm seeing N scalaz.IterableInstances$$anonfunc$foldRight$1$1$$anonfun$apply$1 functions created each referencing a scalaz.Foldable$$anonfun$foldLeftM$2$$anonfun$apply$10 closure over the collection element. These are all created at once when Trampoline.run is called, before any are excuted and unwound.Marasmus
The EphemeralStream fold operations differ in two ways from the collections hierarchy (e.g., my iterator). First, foldRight is rescursive---not a reserved.foldLeft that puts the whole collection into a new list. I don't think this is the root cause. Second, the reduce function parameter for foldRight must use lazy evaluation (call-by-name) for the arguments. The strict evaluation used by the collections hierarchy might cause the closures to be created before evaluation. A modified Iterable with the EphemeralStream implementation of foldRight works properly: heap usage is O(1).Marasmus
It's irrelevant whether there is a closure holding on to the stream element. The problem is that each stream element is never freed from memory during the fold.Chopstick
Updated the answer to reflect your reason why this works with EphemeralStream.Chopstick
I don't understand why you say "It's irrelevant whether there is a closure holding on to the stream element". That is my issue! From my real use case: I have a 50 GB collection that fits in memory (on a machine with 60 GB of ram). All my processing must use an additional O(1) heap space, but those closures are using an additional O(n) and overflowing the heap.Marasmus
Foldr make sense. Do you know why the scala collection foldLeft and foldRight can't be lazy in their arguments?Marasmus
That's entirely a design decision of the standard library. In a strict language, making lazy versions of everything takes considerable effort and is difficult to get exactly right.Chopstick
Yeah, I guess it doesn't help matters that you are creating all of those closures from the inside out. You should be creating only the closure that you need in order to return control to the trampoline once.Chopstick
D
1

Using State, or any similar monad, isn't a good approach to the problem. Using State is condemned to blow the stack/heap on large collections. Consider a value of x: State[A,B] constructed from a large collection (for example by folding over it). Then x can be evaluated on different values of the initial state A, yielding different results. So x needs to retain all information contained in the collection. An in pure settings, x can't forget some information not to blow stack/heap, so anything that is computed remains in memory until the whole monadic value is freed, which happens only after the result is evaluated. So the memory consumption of x is proportional to the size of the collection.

I believe a fitting approach to this problem is to use functional iteratees/pipes/conduits. This concept (referred to under these three names) was invented to process large collections of data with constant memory consumption, and to describe such processes using simple combinator.

I tried to use Scalaz' Iteratees, but it seems this part isn't mature yet, it suffers from stack overflows just as State does (or perhaps I'm not using it right; the code is available here, if anybody is interested).

However, it was simple using my (still a bit experimental) scala-conduit library (disclaimer: I'm the author):

import conduit._
import conduit.Pipe._

object Run extends App {
  // Define a sampling function as a sink: It consumes
  // data of type `A` and produces a vector of samples.
  def sampleI[A](k: Int): Sink[A, Vector[A]] =
    sampleI[A](k, 0, Vector())

  // Create a sampling sink with a given state. It requests
  // a value from the upstream conduit. If there is one,
  // update the state and continue (the first argument to `requestF`).
  // If not, return the current sample (the second argument).
  // The `Finalizer` part isn't important for our problem.
  private def sampleI[A](k: Int, n: Int, sample: Vector[A]):
                  Sink[A, Vector[A]] =
    requestF((x: A) => sampleI(k, n + 1, algorithmR(k, n + 1, sample, x)),
             (_: Any) => sample)(Finalizer.empty)


  // The sampling algorithm copied from the question.
  val rand = new scala.util.Random()

  def algorithmR[A](k: Int, n: Int, sample: Vector[A], x: A): Vector[A] = {
    if (sample.size < k) {
      sample :+ x // must keep first k elements
    } else {
      val r = rand.nextInt(n) + 1 // for simplicity, rand is global/stateful
      if (r <= k)
        sample.updated(r - 1, x) // sample is 0-index
      else
        sample
    }
  }

  // Construct an iterable of all `short` values, pipe it into our sampling
  // funcition, and run the combined pipe.
  {
    print(runPipe(Util.fromIterable(Short.MinValue to Short.MaxValue) >->
          sampleI(10)))
  }
}

Update: It'd be possible to solve the problem using State, but we need to implement a custom fold specifically for State that knows how to do it constant space:

import scala.collection._
import scala.language.higherKinds
import scalaz._
import Scalaz._
import scalaz.std.iterable._

object Run extends App {
  // Folds in a state monad over a foldable
  def stateFold[F[_],E,S,A](xs: F[E],
                            f: (A, E) => State[S,A],
                            z: A)(implicit F: Foldable[F]): State[S,A] =
    State[S,A]((s: S) => F.foldLeft[E,(S,A)](xs, (s, z))((p, x) => f(p._2, x)(p._1)))


  // Sample a lazy collection view
  def sampleS[F[_],A](k: Int, xs: F[A])(implicit F: Foldable[F]):
                  State[Int,Vector[A]] =
    stateFold[F,A,Int,Vector[A]](xs, update(k), Vector())

  // update using State monad
  def update[A](k: Int) = {
    (acc: Vector[A], x: A) => State[Int, Vector[A]] {
        n => (n + 1, algorithmR(k, n + 1, acc, x)) // algR same as impure solution
    }
  }

  def algorithmR[A](k: Int, n: Int, sample: Vector[A], x: A): Vector[A] = ...

  {
    print(sampleS(10, (Short.MinValue to Short.MaxValue)).eval(0))
  }
}
Derosa answered 25/12, 2013 at 21:4 Comment(3)
Your first paragraph jibes with my understanding of what is happening---the monad references the entire collection with its own O(N)-sized set of closures and can't unwind/free those until provided the initial state. I believe that @Chopstick is saying that for an appropriate implementation of fold, the collection isn't iterated until the initial state is provided and the trampoline is run---an element can be freed when the next one is loaded.Marasmus
Regardless, the Iteratee/conduit approach suggested by both of you should avoid all this complication and fold-implementation-specific headache anyway.Marasmus
@DavidB. Another option (updated in the answer) is to create a specific function for folding in the State monad in constant space. I guess this is something half-way to the tramolined monad approach.Derosa

© 2022 - 2024 — McMap. All rights reserved.