State transformations with a shapeless State monad
Asked Answered
B

1

7

Scalaz State monad's modify has the following signature:

def modify[S](f: S => S): State[S, Unit]

This allows the state to be replaced by state of the same type, which does not work well when the state includes a shapeless value such as a Record whose type changes as new fields are added. In that case what we need is:

def modify[S, T](f: S => T): State[T, Unit]

What is a good way to adapt Scalaz's State monad to use shapeless state so that one can use Records as opposed to, say, the dreaded Map[String, Any]?

Example:

case class S[L <: HList](total: Int, scratch: L)

def contrivedAdd[L <: HList](n: Int): State[S[L], Int] =
  for {
    a <- init
    _ <- modify(s => S(s.total + n, ('latestAddend ->> n) :: s.scratch))
    r <- get
  } yield r.total

Update:

The full code for Travis's answer is here.

Boehmenism answered 19/1, 2016 at 7:30 Comment(3)
How about def modify[S, T](f: S => T): State[T, Unit] = State((s: S) => (f(s), ())) ?Choragus
@Choragus are you suggesting extending State and overloading modify?Boehmenism
No, just write this function anywhere in your code. There's no need to extend or overload anything.Choragus
C
8

State is a type alias for a more generic type IndexedStateT that's specifically designed to represent functions that change the state type as state computations:

type StateT[F[_], S, A] = IndexedStateT[F, S, S, A]
type State[S, A] = StateT[Id, S, A]

While it's not possible to write your modify[S, T] using State, it is possible with IndexedState (which is another type alias for IndexedStateT that fixes the effect type to Id):

import scalaz._, Scalaz._

def transform[S, T](f: S => T): IndexedState[S, T, Unit] =
  IndexedState(s => (f(s), ()))

You can even use this in for-comprehensions (which has always seemed a little odd to me, since the monadic type changes between operations, but it works):

val s = for {
  a <- init[Int];
  _ <- transform[Int, Double](_.toDouble)
  _ <- transform[Double, String](_.toString)
  r <- get
} yield r * a

And then:

scala> s(5)
res5: scalaz.Id.Id[(String, String)] = (5.0,5.05.05.05.05.0)

In your case you might write something like this:

import shapeless._, shapeless.labelled.{ FieldType, field }

case class S[L <: HList](total: Int, scratch: L)

def addField[K <: Symbol, A, L <: HList](k: Witness.Aux[K], a: A)(
  f: Int => Int
): IndexedState[S[L], S[FieldType[K, A] :: L], Unit] =
  IndexedState(s => (S(f(s.total), field[K](a) :: s.scratch), ()))

And then:

def contrivedAdd[L <: HList](n: Int) = for {
  a <- init[S[L]]
  _ <- addField('latestAdded, n)(_ + n)
  r <- get
} yield r.total

(This may not be the best way of factoring out the pieces of the update operation, but it shows how the basic idea works.)

It's also worth noting that if you don't care about representing the state transformation as a state computation, you can just use imap on any old State:

init[S[HNil]].imap(s =>
  S(1, field[Witness.`'latestAdded`.T](1) :: s.scratch)
)

This doesn't allow you to use these operations compositionally in the same way, but it may be all you need in some situations.

Callida answered 19/1, 2016 at 15:43 Comment(1)
Your answer, in addition to being elegant, holds a neat little gem inside it, which is the piece of addField that extends a record using a key and a value. The combination of the witness and field[K](a) is not obvious w/o reading the shapeless source.Boehmenism

© 2022 - 2024 — McMap. All rights reserved.