300 lines
7.7 KiB
Idris
300 lines
7.7 KiB
Idris
import Data.Fin
|
||
import Data.Vect
|
||
import Control.Monad.State
|
||
import Control.Monad.Reader
|
||
import Control.Monad.Either
|
||
|
||
%default total
|
||
|
||
{-- Basic data definitions:
|
||
* types;
|
||
* length-indexed contexts and lookup;
|
||
* well-scoped terms;
|
||
* constraint sets; and
|
||
* solutions.
|
||
--}
|
||
|
||
data Ty : Type where
|
||
TyId : String -> Ty
|
||
TyArr : Ty -> Ty -> Ty
|
||
TyBool : Ty
|
||
TyNat : Ty
|
||
|
||
FromString Ty where
|
||
fromString = TyId
|
||
|
||
0 Ctxt : Nat -> Type
|
||
Ctxt n = Vect n Ty
|
||
|
||
data Tm : Nat -> Type where
|
||
Var : forall n. Fin n -> Tm n
|
||
Abs : forall n. Ty -> Tm (S n) -> Tm n
|
||
App : forall n. Tm n -> Tm n -> Tm n
|
||
True : forall n. Tm n
|
||
False : forall n. Tm n
|
||
If : forall n. Tm n -> Tm n -> Tm n -> Tm n
|
||
Zero : forall n. Tm n
|
||
Succ : forall n. Tm n -> Tm n
|
||
Pred : forall n. Tm n -> Tm n
|
||
IsZero : forall n. Tm n -> Tm n
|
||
|
||
0 Constrs : Type
|
||
Constrs = List (Ty, Ty)
|
||
|
||
0 Solution : Type
|
||
Solution = (Ty, Constrs)
|
||
|
||
{-- Raw, unscoped terms to scoped terms --}
|
||
|
||
private infix 1 `Then`
|
||
data TmRaw : Type where
|
||
ι : String -> TmRaw
|
||
λ : String -> Ty -> TmRaw -> TmRaw
|
||
(#) : TmRaw -> TmRaw -> TmRaw
|
||
T : TmRaw
|
||
F : TmRaw
|
||
Then : TmRaw -> (TmRaw, TmRaw) -> TmRaw
|
||
ZZ : TmRaw
|
||
SS : TmRaw -> TmRaw
|
||
PP : TmRaw -> TmRaw
|
||
IZ : TmRaw -> TmRaw
|
||
|
||
FromString TmRaw where
|
||
fromString = ι
|
||
|
||
Iff : TmRaw -> TmRaw
|
||
Iff = id
|
||
|
||
private infix 2 `Else`
|
||
Else : TmRaw -> TmRaw -> (TmRaw, TmRaw)
|
||
Else = (,)
|
||
|
||
scoped : forall n. Vect n String -> TmRaw -> Either String (Tm n)
|
||
scoped g (ι x) =
|
||
case findIndex (== x) g of
|
||
Just n => pure $ Var n
|
||
Nothing => throwError "The variable \{x} is not in scope"
|
||
scoped g (λ x t b) = Abs t <$> scoped (x :: g) b
|
||
scoped g (b # a) = App <$> scoped g b <*> scoped g a
|
||
scoped g T = pure True
|
||
scoped g F = pure False
|
||
scoped g (b `Then` (c, d)) = If <$> scoped g b <*> scoped g c <*> scoped g d
|
||
scoped g ZZ = pure Zero
|
||
scoped g (SS n) = Succ <$> scoped g n
|
||
scoped g (PP n) = Pred <$> scoped g n
|
||
scoped g (IZ n) = IsZero <$> scoped g n
|
||
|
||
{-- Some instances for the above, Interpolation acting as pretty printer --}
|
||
|
||
Eq Ty where
|
||
TyId x == TyId y = x == y
|
||
TyArr tal tbl == TyArr tar tbr = tal == tar && tbl == tbr
|
||
TyBool == TyBool = True
|
||
TyNat == TyNat = True
|
||
_ == _ = False
|
||
|
||
isTerminal : forall n. Tm n -> Bool
|
||
isTerminal (Var _) = True
|
||
isTerminal True = True
|
||
isTerminal False = True
|
||
isTerminal Zero = True
|
||
isTerminal _ = False
|
||
|
||
Interpolation Ty where
|
||
interpolate (TyId x) = x
|
||
interpolate (TyArr ta@(TyArr _ _) tb) = "(\{ta}) -> \{tb}"
|
||
interpolate (TyArr ta tb) = "\{ta} → \{tb}"
|
||
interpolate TyBool = "Bool"
|
||
interpolate TyNat = "Nat"
|
||
|
||
Interpolation (Tm n) where
|
||
interpolate (Var n) = "v" ++ show n
|
||
interpolate (Abs ta b) = "λ_: \{ta}. \{b}"
|
||
interpolate (App b a) =
|
||
case b of
|
||
App _ _ => "\{b}"
|
||
_ => if isTerminal b then "\{b}" else "(\{b})"
|
||
++ " " ++ if isTerminal a then "\{a}" else "(\{a})"
|
||
interpolate True = "true"
|
||
interpolate False = "false"
|
||
interpolate (If b c d) = "if \{b} then \{c} else \{d}"
|
||
interpolate Zero = "0"
|
||
interpolate (Succ n) =
|
||
case n of
|
||
Succ _ => "\{n} + 1"
|
||
Pred _ => "\{n} + 1"
|
||
_ => if isTerminal n then "\{n} + 1" else "(\{n}) + 1"
|
||
interpolate (Pred n) =
|
||
case n of
|
||
Succ _ => "\{n} - 1"
|
||
Pred _ => "\{n} - 1"
|
||
_ => if isTerminal n then "\{n} - 1" else "(\{n}) - 1"
|
||
interpolate (IsZero n) =
|
||
if isTerminal n then "zero? \{n}" else "zero? (\{n})"
|
||
|
||
Interpolation (Ctxt n) where
|
||
interpolate [] = "·"
|
||
interpolate [t] = "\{t}"
|
||
interpolate (t :: ts) = "\{ts}, \{t}"
|
||
|
||
Interpolation Constrs where
|
||
interpolate [] = ""
|
||
interpolate [(ta, tb)] = "\{ta} = \{tb}"
|
||
interpolate ((ta, tb) :: cs) = "\{ta} = \{tb}, \{cs}"
|
||
|
||
{-- Constraint-based type checking monad with:
|
||
* state monad for generating fresh type variables; and
|
||
* reader monad for extending type context.
|
||
--}
|
||
|
||
0 CTM : Nat -> Type -> Type
|
||
CTM n = ReaderT (Ctxt n) (State Nat)
|
||
|
||
-- Reader.local doesn't work since extending the context changes its type
|
||
withTy : forall a, n. Ty -> CTM (S n) a -> CTM n a
|
||
withTy t ctm = MkReaderT (\ctxt => runReaderT (t :: ctxt) ctm)
|
||
|
||
nextVar : forall n. CTM n Ty
|
||
nextVar = do
|
||
c <- get
|
||
put (S c)
|
||
pure $ TyId ("?X_" ++ show c)
|
||
|
||
{-- Constraint-based type checking algorithm --}
|
||
|
||
ctype : forall n. Tm n -> CTM n Solution
|
||
ctype (Var n) = do
|
||
ctxt <- ask
|
||
pure (index n ctxt, [])
|
||
ctype (Abs ta b) = do
|
||
(tb, cb) <- withTy ta (ctype b)
|
||
pure (TyArr ta tb, cb)
|
||
ctype (App b a) = do
|
||
(tb, cb) <- ctype b
|
||
(ta, ca) <- ctype a
|
||
x <- nextVar
|
||
pure (x, (tb, TyArr ta x) :: cb ++ ca)
|
||
ctype True = pure (TyBool, [])
|
||
ctype False = pure (TyBool, [])
|
||
ctype (If b c d) = do
|
||
(tb, cb) <- ctype b
|
||
(tc, cc) <- ctype c
|
||
(td, cd) <- ctype d
|
||
pure (tc, (tb, TyBool) :: (tc, td) :: cb ++ cc ++ cd)
|
||
ctype Zero = pure (TyNat, [])
|
||
ctype (Succ n) = do
|
||
(tn, cn) <- ctype n
|
||
pure (TyNat, (tn, TyNat) :: cn)
|
||
ctype (Pred n) = do
|
||
(tn, cn) <- ctype n
|
||
pure (TyNat, (tn, TyNat) :: cn)
|
||
ctype (IsZero n) = do
|
||
(tn, cn) <- ctype n
|
||
pure (TyBool, (tn, TyNat) :: cn)
|
||
|
||
solve : forall n. Ctxt n -> Tm n -> Solution
|
||
solve ctxt = runIdentity . evalStateT 0 . runReaderT ctxt . ctype
|
||
|
||
{-- Substitutions are mappings from strings to types
|
||
and act on terms and types --}
|
||
|
||
0 Subst : Type
|
||
Subst = String -> Ty
|
||
|
||
(:<) : Subst -> (String, Ty) -> Subst
|
||
s :< (x, t) = \y => if x == y then t else s y
|
||
|
||
subst : Subst -> Ty -> Ty
|
||
subst s (TyId x) = s x
|
||
subst s (TyArr ta tb) = TyArr (subst s ta) (subst s tb)
|
||
subst s t = t
|
||
|
||
substTm : forall n. Subst -> Tm n -> Tm n
|
||
substTm s (Abs ta b) = Abs (subst s ta) (substTm s b)
|
||
substTm s (App b a) = App (substTm s b) (substTm s a)
|
||
substTm s (If b c d) = If (substTm s b) (substTm s c) (substTm s d)
|
||
substTm s a = a
|
||
|
||
{-- Constraint unification algorithm --}
|
||
|
||
occurs : Ty -> String -> Bool
|
||
occurs (TyId y) x = x == y
|
||
occurs (TyArr ta tb) x = occurs ta x || occurs tb x
|
||
occurs _ _ = False
|
||
|
||
0 UCM : Type -> Type
|
||
UCM = StateT Subst $ EitherT String Identity
|
||
|
||
covering
|
||
unify : Constrs -> UCM ()
|
||
unify [] = pure ()
|
||
unify ((ta, tb) :: cs) = do
|
||
s <- get
|
||
let ta = subst s ta
|
||
let tb = subst s tb
|
||
if ta == tb then unify cs else
|
||
case (ta, tb) of
|
||
(TyId x, t) =>
|
||
if occurs t x
|
||
then throwError "The variable \{x} occurs in the right hand side \{t}"
|
||
else do
|
||
put $ s :< (x, t)
|
||
unify cs
|
||
(t, TyId x) =>
|
||
if occurs t x
|
||
then throwError "The variable \{x} occurs in the left hand side \{t}"
|
||
else do
|
||
put $ s :< (x, t)
|
||
unify cs
|
||
(TyArr tal tbl, TyArr tar tbr) =>
|
||
unify ((tal, tar) :: (tbl, tbr) :: cs)
|
||
_ => throwError "Cannot unify \{ta} with \{tb}"
|
||
|
||
covering
|
||
unifier : Constrs -> Either String Subst
|
||
unifier = runIdentity . runEitherT . execStateT TyId . unify
|
||
|
||
{-- Example expressions from 22.5.2, 22.5.5 --}
|
||
|
||
-- λx: X. x y
|
||
test0 : TmRaw
|
||
test0 = λ "x" "X" $ "x" # "y"
|
||
|
||
-- λx: X. x
|
||
test1 : TmRaw
|
||
test1 = λ "x" "X" "x"
|
||
|
||
-- λz: Z. λy: Y. z (y true)
|
||
test2 : TmRaw
|
||
test2 = λ "z" "Z" $ λ "y" "Y" $ "z" # ("y" # T)
|
||
|
||
-- λw: W. if true then false else w false
|
||
test3 : TmRaw
|
||
test3 = λ "w" "W" $ Iff T `Then` F `Else` ("w" # F)
|
||
|
||
-- λx: X. λy: Y. λz: Z. (x z) (y z)
|
||
test4 : TmRaw
|
||
test4 = λ "x" "X" $ λ "y" "Y" $ λ "z" "Z" $ ("x" # "z") # ("y" # "z")
|
||
|
||
-- λx: X. x x
|
||
test5 : TmRaw
|
||
test5 = λ "x" "X" $ "x" # "x"
|
||
|
||
covering
|
||
printSolveUnify : TmRaw -> IO ()
|
||
printSolveUnify a = do
|
||
let Right a = scoped [] a
|
||
| Left e => putStrLn "scope failed: \{e}"
|
||
let (ta, cs) = solve [] a
|
||
-- [context] ⊢ (term) : (type) | {constraints}
|
||
putStrLn "solve: ⊢ (\{a}) : (\{ta}) | {\{cs}}"
|
||
let Right s = unifier cs
|
||
| Left e => putStrLn "unify failed: \{e}"
|
||
-- [context{s}] ⊢ (term{s}) : (type{s})
|
||
putStrLn "unify: ⊢ (\{substTm s a}) : (\{subst s ta})"
|
||
|
||
covering
|
||
main : IO ()
|
||
main = do
|
||
traverse_ {t = List} printSolveUnify
|
||
[test0, test1, test2, test3, test4, test5] |