Extensible state machines in Haskell
Asked Answered
U

1

7

I can define a toy state machine (with trivial input) as follows:

--------------------------------------------
-- module State where

data State = A | B Int

--------------------------------------------
-- module A where
-- import State

transitionA :: State
transitionA = B 10

--------------------------------------------
-- module B where
-- import State

transitionB :: Int -> State
transitionB i
  | i < 0     = A
  | otherwise = B (i-1)

--------------------------------------------
-- module StateMachine where
-- import State
-- import A
-- import B

transition :: State -> State
transition A     = transitionA
transition (B i) = transitionB i

If I now decide to add a new state, I have to:

  1. modify the State module to add the new state, say

 

data State = A | B Int | C Double Double
  1. add a new transition function transitionC in a module C

  2. import C in the last module, and add the C case to the pattern match

I'd like to set things up so that I only have to perform step 2 (write a new transition function) and everything else gets automatically taken care of.
For instance, one could try to use existential types to do something like the following:

--------------------------------------------
{-# LANGUAGE ExistentialQuantification #-}
-- module State where

class State s where
    transition :: s -> AState

data AState = forall s. State s => AState s

instance State AState where
    transition (AState s) = transition s

-------------------------------------
-- module A where
-- import State
-- import B

data A = A

instance State A where
  transition _ = AState (B 10)

-------------------------------------
-- module B where
-- import State
-- import A

data B = B Int

instance State B where
    transition (B i)
      | i < 0     = AState ( A )
      | otherwise = AState ( B (i-1) )

This is very convenient: to add a new state, we only need to do one thing, which is to write a data type and its associated transition function in a new module, and nothing else needs to be changed. Unfortunately, this approach doesn't work because it creates cyclic dependencies, e.g. in this case A needs to refer to B and B needs to refer to A.

I also tried to look into using extensible sum types (polymorphic variants), but the same problem arises unless we declare all the possible states ahead of time in a separate module so that subsequent modules can refer to them. In other words, it can eliminate step 3, but not step 1.

Is this the kind of problem that can be tackled using (Conor McBride's version of) indexed monads? It seems we could use some kind of indexed state monad where we don't know the return state in advance, which I gather, from his answer to What is indexed monad?, is something that MonadIx achieves.

Umont answered 29/5, 2018 at 8:10 Comment(1)
How do you intend to get around the cyclic dependencies? If A doesn't know about B (or B doesn't know about A), how can you refer to it?Chitarrone
C
8

Using extensible sums we can remove step 1 and reduce step 3 to "import C".

Removing both step 3 and step 1 entirely poses the problem of making the final module aware of the new transition, and I'm not sure that's possible purely using Haskell. Some kind of metaprogramming would be needed (e.g., via TH or CPP).

As an alternative (and simpler) approach, I infer the set of states as those reachable from a predetermined initial state, implying that step 2 might also include some changes to existing transition functions to make the new state reachable. I hope that is a fair assumption to make.


If we take as a constraint that the states do not need to be pre-declared, we still need some kind of alphabet to refer to these states. A convenient alphabet is given by GHC's Symbol type (type-level strings). We wrap symbols in a fresh type constructor to make things a bit more hygienic: an application can create a new namespace of states by declaring its own version of Named.

data Named (s :: Symbol)

Every type Named s is a "name" or "key" (k) that identifies a type of state, e.g., Named "A" or Named "B". We can use a type class to associate them to

  • the type of their contents (e.g., B contains an Int);
  • the set of possible output states, each given as a pair of its name and its contents.

This type class also contains the transition function to be defined for every state.

class State k where
  type Contents k :: *
  type Outputs k :: [(*, *)]
  transition :: Contents k -> S (Outputs k)

S is an extensible sum type. For example, S '[ '(Named "A", ()), '(Named "B", Int) ] is a sum of a unit tagged by "A" and an Int tagged by "B".

data S (u :: [(*, *)]) where
  Here :: forall k a u. a -> S ('(k, a) ': u)
  There :: forall u x. S u -> S (x ': u)

We can automate the injection of a type in a sum, using a smart constructor inj1 @k indexed by the key k.

-- v is a list containing the pair (k, a)
-- instances omitted
class Inj1 k a v where
  inj1 :: a -> S v

Skipping the whole setup, let's see what using this framework looks like.

To create a new transition is to declare an instance of State. The only dependencies are the general ones. As mentioned earlier, the file doesn't need to be aware of a predetermined set of states, it declares what it needs.

Module A

-- Transitions out of A
instance State (Named "A") where

  -- There is no meaningful value contained in the A state
  type Contents (Named "A") = ()

  -- The only transition is to "B"
  type Outputs (Named "A") = '[ '(Named "B", Int)]

  transition () = inj1 @(Named "B") 10

Module B

-- transitions out of B
instance State (Named "B") where
  type Contents (Named "B") = Int
  type Outputs (Named "B") = '[ '(Named "A", ()), '(Named "B", Int)]
  transition i
    | i < 0 = inj1 @(Named "A") ()
    | otherwise = inj1 @(Named "B") (i-1)

In the main module, we still need to import all the transitions, and to pick an initial state from which the reachable states can be computed.

import A
import B

type Initial = Named "A"

-- Initial state A
initial :: Inj1 Initial () u => S u
initial = inj1 @Initial ()

Given the name of the initial state, there is a general function to produce the complete transition function, generating the complete list of reachable states.

sm :: forall initial u ...
   .  (... {- all reachable states from 'initial' are in 'u' -})
   => S u -> S u

Thus we can define and use the transition as follows:

transition' = sm @Initial  -- everything inferred (S _ -> S _)

-- Run 14 steps from the initial state.
main = do
  let steps = 14
  mapM_ print . take (steps+1) . iterate transition' $ initial

Output:

Here ()
There Here 10
There Here 9
There Here 8
There Here 7
There Here 6
There Here 5
There Here 4
There Here 3
There Here 2
There Here 1
There Here 0
There Here -1
Here ()
There Here 10

Hopefully it is apparent that the State type class provides enough information at the type level to reconstruct the complete state machine. From there it's "just" a matter of type-level programming to make that intuition a reality. I can talk a bit more about that if prompted, but for now here is a complete example:

https://gist.github.com/Lysxia/769ee0d4eaa30004aa457eb809bd2786

This example uses INCOHERENT instances for simplicity, to generate the final set of states by unification, but a more robust solution with an explicit fixpoint iteration/graph search is certainly possible.

Clubfoot answered 29/5, 2018 at 15:28 Comment(1)
Thank you for taking the time to write this fantastically thorough answer! I really like how your solution infers the most general type for the collection of states reachable from a given initial state, that's very neat. (Well, of course it's only the most general given the static type information provided, but that's all one can ask for.) I'll try this approach with a flat representation for the extensible sum (instead of the linked representation you provided).Umont

© 2022 - 2024 — McMap. All rights reserved.