Equality for GADTs which erase type parameter
Asked Answered
M

1

17

I cannot implement an instance of Eq for the following typesafe DSL for expressions implemented with GADTs.

data Expr a where
  Num :: Int -> Expr Int
  Bool :: Bool -> Expr Bool
  Plus :: Expr Int -> Expr Int -> Expr Int
  If :: Expr Bool -> Expr a -> Expr a -> Expr a
  Equal :: Eq a => Expr a -> Expr a -> Expr Bool

Expressions can be either of type Bool or Int. There are constructors for literals Bool and Num which have the corresponding types. Only Int expressions can be added up (constructor Plus). The condition in the If expression should have type Bool while both branches should have the same type. There is also an equality expression Equal whose operands should have the same type, and the type of the equality expression is Bool.

I have no problems implementing the interpreter eval for this DSL. It compiles and works like a charm:

eval :: Expr a -> a
eval (Num x) = x
eval (Bool x) = x
eval (Plus x y) = eval x + eval y
eval (If c t e) = if eval c then eval t else eval e
eval (Equal x y) = eval x == eval y

However, I struggle to implement an instance of Eq for the DSL. I tried the simple syntactic equality:

instance Eq a => Eq (Expr a) where
  Num x == Num y = x == y
  Bool x == Bool y = x == y
  Plus x y == Plus x' y' = x == x' && y == y'
  If c t e == If c' t' e' = c == c' && t == t' && e == e'
  Equal x y == Equal x' y' = x == x' && y == y'
  _ == _ = False

It does not typecheck (with ghc 8.6.5), the error is the following:

[1 of 1] Compiling Main             ( Main.hs, Main.o )

Main.hs:17:35: error:
    • Could not deduce: a2 ~ a1
      from the context: (a ~ Bool, Eq a1)
        bound by a pattern with constructor:
                   Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
                 in an equation for ‘==’
        at Main.hs:17:3-11
      ‘a2’ is a rigid type variable bound by
        a pattern with constructor:
          Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
        in an equation for ‘==’
        at Main.hs:17:16-26
      ‘a1’ is a rigid type variable bound by
        a pattern with constructor:
          Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
        in an equation for ‘==’
        at Main.hs:17:3-11
      Expected type: Expr a1
        Actual type: Expr a2
    • In the second argument of ‘(==)’, namely ‘x'’
      In the first argument of ‘(&&)’, namely ‘x == x'’
      In the expression: x == x' && y == y'
    • Relevant bindings include
        y' :: Expr a2 (bound at Main.hs:17:25)
        x' :: Expr a2 (bound at Main.hs:17:22)
        y :: Expr a1 (bound at Main.hs:17:11)
        x :: Expr a1 (bound at Main.hs:17:9)
   |
17 |   Equal x y == Equal x' y' = x == x' && y == y'
   |  

I believe the reason is that the constructor Equal "forgets" the value of the type parameter a of its subexpressions and there is no way for the typechecker to ensure subexpressions x and y both have the same type Expr a.

I tried adding one more type parameter to Expr a to keep track of the type of subexpressions:

data Expr a b where
  Num :: Int -> Expr Int b
  Bool :: Bool -> Expr Bool b
  Plus :: Expr Int b -> Expr Int b -> Expr Int b
  If :: Expr Bool b -> Expr a b -> Expr a b -> Expr a b
  Equal :: Eq a => Expr a a -> Expr a a -> Expr Bool a

instance Eq a => Eq (Expr a b) where
  -- same implementation

eval :: Expr a b -> a
  -- same implementation

This approach does not seem scalable to me, once more constructors with subexpressions of different types are added.

All this makes me think that I do use GADTs incorrectly to implement this kind of DSL. Is there a way to implement Eq for this type? If not, what is the idiomatic way to express this kind of type constraint on the expressions?

Complete code:

{-# LANGUAGE GADTs #-}
 
module Main where

data Expr a where
  Num :: Int -> Expr Int
  Bool :: Bool -> Expr Bool
  Plus :: Expr Int -> Expr Int -> Expr Int
  If :: Expr Bool -> Expr a -> Expr a -> Expr a
  Equal :: Eq a => Expr a -> Expr a -> Expr Bool

instance Eq a => Eq (Expr a) where
  Num x == Num y = x == y
  Bool x == Bool y = x == y
  Plus x y == Plus x' y' = x == x' && y == y'
  If c t e == If c' t' e' = c == c' && t == t' && e == e'
  Equal x y == Equal x' y' = x == x' && y == y'
  _ == _ = False

eval :: Expr a -> a
eval (Num x) = x
eval (Bool x) = x
eval (Plus x y) = eval x + eval y
eval (If c t e) = if eval c then eval t else eval e
eval (Equal x y) = eval x == eval y

main :: IO ()
main = do
  let expr1 = If (Equal (Num 13) (Num 42)) (Bool True) (Bool False)
  let expr2 = If (Equal (Num 13) (Num 42)) (Num 42) (Num 777)
  print (eval expr1)
  print (eval expr2)
  print (expr1 == expr1)
Mcdougall answered 24/6, 2021 at 6:31 Comment(0)
B
21

Your issue is that in

Equal x y == Equal x' y' = ...

it is possible that x and x' have different types. For example, Equal (Bool True) (Bool True) == Equal (Int 42) (Int 42) type checks, but we can't then simply compare Bool True == Int 42 as we might try to do in the Eq instance.

Here are a few alternative solutions. The last one (generalizing == to eqExpr) seems the simplest to me, but the others are interesting as well.

Use a singleton and compute types

We start from your original type

{-# LANGUAGE GADTs #-}
module Main where

data Expr a where
  Num :: Int -> Expr Int
  Bool :: Bool -> Expr Bool
  Plus :: Expr Int -> Expr Int -> Expr Int
  If :: Expr Bool -> Expr a -> Expr a -> Expr a
  Equal :: Eq a => Expr a -> Expr a -> Expr Bool

and define a singleton GADT to represent the types you have

data Ty a where
  TyInt  :: Ty Int
  TyBool :: Ty Bool

We then prove that your types can only be Int or Bool, and how to compute them from the expression.

tyExpr :: Expr a -> Ty a
tyExpr (Num _)     = TyInt
tyExpr (Bool _)    = TyBool
tyExpr (Plus _ _)  = TyInt
tyExpr (If _ t _)  = tyExpr t
tyExpr (Equal _ _) = TyBool

We can now exploit that and define the Eq instance.

instance Eq (Expr a) where
  Num x     == Num y       = x == y
  Bool x    == Bool y      = x == y
  Plus x y  == Plus x' y'  = x == x' && y == y'
  If c t e  == If c' t' e' = c == c' && t == t' && e == e'
  Equal x y == Equal x' y' = case (tyExpr x, tyExpr x') of
     (TyInt,  TyInt ) -> x == x' && y == y'
     (TyBool, TyBool) -> x == x' && y == y'
     _                -> False
  _ == _ = False

Use Typeable

We slightly modify the original GADT:

import Data.Typeable
  
data Expr a where
  Num :: Int -> Expr Int
  Bool :: Bool -> Expr Bool
  Plus :: Expr Int -> Expr Int -> Expr Int
  If :: Expr Bool -> Expr a -> Expr a -> Expr a
  Equal :: (Typeable a, Eq a) => Expr a -> Expr a -> Expr Bool

We can then try to cast the values to the right types: if the cast fails, we had two Equals among distinct types, so we can return False.

instance Eq (Expr a) where
  Num x     == Num y       = x == y
  Bool x    == Bool y      = x == y
  Plus x y  == Plus x' y'  = x == x' && y == y'
  If c t e  == If c' t' e' = c == c' && t == t' && e == e'
  Equal x y == Equal x' y' = case cast (x,y) of
     Just (x2, y2) -> x2 == x' && y2 == y'
     Nothing       -> False
  _ == _ = False

Generalize to heterogeneous equality

We can use the original GADT:

data Expr a where
  Num :: Int -> Expr Int
  Bool :: Bool -> Expr Bool
  Plus :: Expr Int -> Expr Int -> Expr Int
  If :: Expr Bool -> Expr a -> Expr a -> Expr a
  Equal :: Eq a => Expr a -> Expr a -> Expr Bool

and write a heterogeneous equality test, that can work even if the two expressions haven't the same type:

eqExpr :: Expr a -> Expr b -> Bool
eqExpr (Num x)     (Num y)       = x == y
eqExpr (Bool x)    (Bool y)      = x == y
eqExpr (Plus x y)  (Plus x' y')  = eqExpr x x' && eqExpr y y'
eqExpr (If c t e)  (If c' t' e') = eqExpr c c' && eqExpr t t' && eqExpr e e'
eqExpr (Equal x y) (Equal x' y') = eqExpr x x' && eqExpr y y'
eqExpr _           _             = False

The Eq instance is then a special case.

instance Eq (Expr a) where
  (==) = eqExpr

A final note

As pointed out by Joseph Sible in the comments, in all these approaches we do not need the Eq a context in the instances. We can simply remove it:

instance {- Eq a => -} Eq (Expr a) where
   ...

Further, in principle we do not even really need the Eq a in the definition of Equal, so we could simplify our GADT:

data Expr a where
  Num :: Int -> Expr Int
  Bool :: Bool -> Expr Bool
  Plus :: Expr Int -> Expr Int -> Expr Int
  If :: Expr Bool -> Expr a -> Expr a -> Expr a
  Equal :: Expr a -> Expr a -> Expr Bool

However, if we do that the definition of eval :: Expr a -> a becomes more complex in the Equal case, where we probably need to use something like tyExpr to infer the type, so that we can use ==.

Brainstorm answered 24/6, 2021 at 7:30 Comment(4)
I was going to post an answer saying you could implement Equal x y == Equal x' y' by doing case analysis on x, y, x', and y', since that gets you knowledge of what the type parameters are. But implementing that as a general heterogenous equality function rather than burying the same code in one case of the Eq instance is much better than my idea.Cooperation
This answer should be mandatory reading for anyone who says type checkers just get in their way. An excellent pragmatic example of using the type system to prove properties about your data types.Haik
Do you need instance Eq a => Eq (Expr a) where? Can't it just be instance Eq (Expr a) where?Hydroxide
@JosephSible-ReinstateMonica Right! We indeed do not need that. We also do not need the Eq a in the GADT.Brainstorm

© 2022 - 2024 — McMap. All rights reserved.