What type to use to store an in-memory mutable data table in Scala?
Asked Answered
S

5

20

Each time a function is called, if it's result for a given set of argument values is not yet memoized I'd like to put the result into an in-memory table. One column is meant to store a result, others to store arguments values.

How do I best implement this? Arguments are of diverse types, including some enums.

In C# I'd generally use DataTable. Is there an equivalent in Scala?

Strain answered 4/9, 2010 at 3:50 Comment(1)
If you search the Web for "Scala Function Memoization" you'll find several treatments of this topic.Malpighi
L
27

You could use a mutable.Map[TupleN[A1, A2, ..., AN], R] , or if memory is a concern, a WeakHashMap[1]. The definitions below (built on the memoization code from michid's blog) allow you to easily memoize functions with multiple arguments. For example:

import Memoize._

def reallySlowFn(i: Int, s: String): Int = {
   Thread.sleep(3000)
   i + s.length
}

val memoizedSlowFn = memoize(reallySlowFn _)
memoizedSlowFn(1, "abc") // returns 4 after about 3 seconds
memoizedSlowFn(1, "abc") // returns 4 almost instantly

Definitions:

/**
 * A memoized unary function.
 *
 * @param f A unary function to memoize
 * @param [T] the argument type
 * @param [R] the return type
 */
class Memoize1[-T, +R](f: T => R) extends (T => R) {
   import scala.collection.mutable
   // map that stores (argument, result) pairs
   private[this] val vals = mutable.Map.empty[T, R]

   // Given an argument x, 
   //   If vals contains x return vals(x).
   //   Otherwise, update vals so that vals(x) == f(x) and return f(x).
   def apply(x: T): R = vals getOrElseUpdate (x, f(x))
}

object Memoize {
   /**
    * Memoize a unary (single-argument) function.
    *
    * @param f the unary function to memoize
    */
   def memoize[T, R](f: T => R): (T => R) = new Memoize1(f)

   /**
    * Memoize a binary (two-argument) function.
    * 
    * @param f the binary function to memoize
    * 
    * This works by turning a function that takes two arguments of type
    * T1 and T2 into a function that takes a single argument of type 
    * (T1, T2), memoizing that "tupled" function, then "untupling" the
    * memoized function.
    */
   def memoize[T1, T2, R](f: (T1, T2) => R): ((T1, T2) => R) = 
      Function.untupled(memoize(f.tupled))

   /**
    * Memoize a ternary (three-argument) function.
    *
    * @param f the ternary function to memoize
    */
   def memoize[T1, T2, T3, R](f: (T1, T2, T3) => R): ((T1, T2, T3) => R) =
      Function.untupled(memoize(f.tupled))

   // ... more memoize methods for higher-arity functions ...

   /**
    * Fixed-point combinator (for memoizing recursive functions).
    */
   def Y[T, R](f: (T => R) => T => R): (T => R) = {
      lazy val yf: (T => R) = memoize(f(yf)(_))
      yf
   }
}

The fixed-point combinator (Memoize.Y) makes it possible to memoize recursive functions:

val fib: BigInt => BigInt = {                         
   def fibRec(f: BigInt => BigInt)(n: BigInt): BigInt = {
      if (n == 0) 1 
      else if (n == 1) 1 
      else (f(n-1) + f(n-2))                           
   }                                                     
   Memoize.Y(fibRec)
}

[1] WeakHashMap does not work well as a cache. See http://www.codeinstructions.com/2008/09/weakhashmap-is-not-cache-understanding.html and this related question.

Libbie answered 4/9, 2010 at 3:50 Comment(2)
Note that above implementation is not thread-safe, so if you need to cache some computation from multiple threads, this will potentially break. In order to change it to be thread-safe, just do: private[this] val vals = new HashMap[T, R] with SynchronizedMap[T, R]Luscious
There is another way doing memoization for recursive functions: https://mcmap.net/q/324396/-scala-memoization-how-does-this-scala-memo-work, and it doesn't require the usage of Y combinator or thus formulating a non-recursive form, which might be daunting for recursive functions with more than one parameters. Actually both methods relies on Scala's own support for function recursion, i.e. when using Y combinator yf is calling yf, while in the linked wrick's variant, a memoized function would call itself.Spirited
C
10

The version suggested by anovstrup using a mutable Map is basically the same as in C#, and therefore easy to use.

But if you want you can also use a more functional style as well. It uses immutable maps, which act as a kind of accumalator. Having Tuples (instead of Int in the example) as keys works exactly as in the mutable case.

def fib(n:Int) = fibM(n, Map(0->1, 1->1))._1

def fibM(n:Int, m:Map[Int,Int]):(Int,Map[Int,Int]) = m.get(n) match {
   case Some(f) => (f, m)
   case None => val (f_1,m1) = fibM(n-1,m)
                val (f_2,m2) = fibM(n-2,m1)
                val f = f_1+f_2
                (f, m2 + (n -> f))   
}

Of course this is a little bit more complicated, but a useful technique to know (note that the code above aims for clarity, not for speed).

Crumley answered 4/9, 2010 at 7:41 Comment(0)
S
3

Being a newbie in this subject, I could fully understand none of the examples given (but would like to thank anyway). Respectfully, I'd present my own solution for the case some one comes here having a same level and same problem. I think my code can be crystal clear for anybody having just the very-very basic Scala knowledge.



def MyFunction(dt : DateTime, param : Int) : Double
{
  val argsTuple = (dt, param)
  if(Memo.contains(argsTuple)) Memo(argsTuple) else Memoize(dt, param, MyRawFunction(dt, param))
}

def MyRawFunction(dt : DateTime, param : Int) : Double
{
  1.0 // A heavy calculation/querying here
}

def Memoize(dt : DateTime, param : Int, result : Double) : Double
{
  Memo += (dt, param) -> result
  result
}

val Memo = new  scala.collection.mutable.HashMap[(DateTime, Int), Double]


Works perfectly. I'd appreciate critique If I've missed something.

Strain answered 5/9, 2010 at 1:15 Comment(2)
I added some comments to my solution that will hopefully clarify it for you. The advantage of the approach I've outlined is that it allows you to memoize any function (ok, there are some caveats, but many functions). Sort of like the memoize keyword you posted about in a related question.Libbie
The one aspect that probably remains mystifying is the fixed-point combinator -- for that I encourage you to read michid's blog, drink lots of coffee, and maybe get friendly with some functional programming texts. The good news is that you only need it if you're memoizing a recursive function.Libbie
S
1

When using mutable map for memoization, one shall keep in mind that this would cause typical concurrency problems, e.g. doing a get when a write has not completed yet. However, thread-safe attemp of memoization suggests to do so it's of little value if not none.

The following thread-safe code creates a memoized fibonacci function, initiates a couple of threads (named from 'a' through to 'd') that make calls to it. Try the code a couple of times (in REPL), one can easily see f(2) set gets printed more than once. This means a thread A has initiated the calculation of f(2) but Thread B has totally no idea of it and starts its own copy of calculation. Such ignorance is so pervasive at the constructing phase of the cache, because all threads see no sub solution established and would enter the else clause.

object ScalaMemoizationMultithread {

  // do not use case class as there is a mutable member here
  class Memo[-T, +R](f: T => R) extends (T => R) {
    // don't even know what would happen if immutable.Map used in a multithreading context
    private[this] val cache = new java.util.concurrent.ConcurrentHashMap[T, R]
    def apply(x: T): R =
      // no synchronized needed as there is no removal during memoization
      if (cache containsKey x) {
        Console.println(Thread.currentThread().getName() + ": f(" + x + ") get")
        cache.get(x)
      } else {
        val res = f(x)
        Console.println(Thread.currentThread().getName() + ": f(" + x + ") set")
        cache.putIfAbsent(x, res) // atomic
        res
      }
  }

  object Memo {
    def apply[T, R](f: T => R): T => R = new Memo(f)

    def Y[T, R](F: (T => R) => T => R): T => R = {
      lazy val yf: T => R = Memo(F(yf)(_))
      yf
    }
  }

  val fibonacci: Int => BigInt = {
    def fiboF(f: Int => BigInt)(n: Int): BigInt = {
      if (n <= 0) 1
      else if (n == 1) 1
      else f(n - 1) + f(n - 2)
    }

    Memo.Y(fiboF)
  }

  def main(args: Array[String]) = {
    ('a' to 'd').foreach(ch =>
      new Thread(new Runnable() {
        def run() {
          import scala.util.Random
          val rand = new Random
          (1 to 2).foreach(_ => {
            Thread.currentThread().setName("Thread " + ch)
            fibonacci(5)
          })
        }
      }).start)
  }
}
Spirited answered 9/12, 2013 at 3:58 Comment(0)
S
0

In addition to Landei's answer, I also want to suggest the bottom-up (non-memoization) way of doing DP in Scala is possible, and the core idea is to use foldLeft(s).

Example for computing Fibonacci numbers

  def fibo(n: Int) = (1 to n).foldLeft((0, 1)) {
    (acc, i) => (acc._2, acc._1 + acc._2)
  }._1

Example for longest increasing subsequence

def longestIncrSubseq[T](xs: List[T])(implicit ord: Ordering[T]) = {
  xs.foldLeft(List[(Int, List[T])]()) {
    (memo, x) =>
      if (memo.isEmpty) List((1, List(x)))
      else {
        val resultIfEndsAtCurr = (memo, xs).zipped map {
          (tp, y) =>
            val len = tp._1
            val seq = tp._2
            if (ord.lteq(y, x)) { // current is greater than the previous end
              (len + 1, x :: seq) // reversely recorded to avoid O(n)
            } else {
              (1, List(x)) // start over
            }
        }
        memo :+ resultIfEndsAtCurr.maxBy(_._1)
      }
  }.maxBy(_._1)._2.reverse
}
Spirited answered 16/12, 2014 at 0:26 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.