{-# LANGUAGE TupleSections #-} import Control.Monad.State (State, evalState, get, gets, modify, put) import Data.Function (on) import Data.List (unionBy) import Data.Map (Map) import Data.Maybe (fromMaybe) import qualified Data.Map as Map type Name = String data VLiteral = LInteger Integer | LBoolean Bool deriving (Show, Eq) data Expr = ELam Name Expr | EApp Expr Expr | ELit VLiteral | EVar Name deriving (Show, Eq) intLit = ELit . LInteger data TLiteral = TInt | TBool deriving (Eq) instance Show TLiteral where show TInt = "Int" show TBool = "Bool" data Type = TLam Type Type | TVar Name | TLit TLiteral deriving (Eq) instance Show Type where show (TLam a@TLam{} b) = "(" ++ show a ++ ")" ++ " -> " ++ show b show (TLam a b) = show a ++ " -> " ++ show b show (TLit lit) = show lit show (TVar name) = name intTy = TLit TInt boolTy = TLit TBool data Constraint = Constraint Type Type cmap f (Constraint a b) = Constraint (f a) (f b) instance Show Constraint where show (Constraint a b) = show a ++ " <=> " ++ show b data ConState = ConState { conFreshId :: Int , conEnv :: Map Name Type } deriving (Show) type Con a = State ConState a freshTVar :: Con Type freshTVar = do i <- gets conFreshId modify (\s -> s { conFreshId = 1 + i }) pure $ TVar ("a" ++ show i) lookupType :: Name -> Con Type lookupType name = gets (fromMaybe notFound . Map.lookup name . conEnv) where notFound = error ("name not found: " ++ name) insertType :: Name -> Type -> Con () insertType name ty = do env <- gets conEnv modify (\s -> s { conEnv = Map.insert name ty env }) literalType :: VLiteral -> Type literalType (LInteger _) = intTy literalType (LBoolean _) = boolTy scoped :: State s a -> State s a scoped action = do s <- get action <* put s constrain :: Expr -> Con (Type, [Constraint]) constrain (ELit lit) = pure (literalType lit, []) constrain (EVar name) = (, []) <$> lookupType name constrain (ELam var body) = do tvar <- freshTVar (tbody, ctrs) <- scoped (insertType var tvar *> constrain body) pure (TLam tvar tbody, ctrs) constrain (EApp left right) = do (lty, lctrs) <- constrain left (rty, rctrs) <- constrain right tvar <- freshTVar pure (tvar, lctrs ++ rctrs ++ [Constraint lty (TLam rty tvar)]) data Substitution = Substitution Name Type subMap f (Substitution name ty) = Substitution name (f ty) subName (Substitution name _) = name instance Show Substitution where show (Substitution name ty) = name ++ " => " ++ show ty applySub :: Substitution -> Type -> Type applySub (Substitution target ty) = go where go (TLam head body) = TLam (go head) (go body) go (TVar name) = if name == target then ty else TVar name go stuff = stuff applySubs :: [Substitution] -> Type -> Type applySubs subs subject = foldr applySub subject subs combine :: [Substitution] -> [Substitution] -> [Substitution] combine left right = right' `merge` left where right' = map (subMap (applySubs left)) right merge = unionBy ((==) `on` subName) unify :: Type -> Type -> Either String [Substitution] unify tyA tyB | tyA == tyB = Right [] | otherwise = case (tyA, tyB) of (TLam argA bodyA, TLam argB bodyB) -> do argSub <- unify argA argB bodySub <- unify (applySubs argSub bodyA) (applySubs argSub bodyB) Right $ combine bodySub argSub (TVar name, ty) -> Right $ [Substitution name ty] (ty, TVar name) -> Right $ [Substitution name ty] otherwise -> Left ("Cannot unify `" ++ show tyA ++ "` with `" ++ show tyB ++ "`") solve :: [Constraint] -> Either String [Substitution] solve = go [] where go final [] = Right final go final (Constraint a b:cs) = do sub <- unify a b go (combine sub final) (map (cmap $ applySubs sub) cs) stdLib = Map.fromList [ ("add", TLam intTy (TLam intTy intTy)) , ("gt", TLam intTy (TLam intTy boolTy)) , ("if", TLam boolTy (TLam (TVar "a") (TLam (TVar "a") (TVar "a")))) , ("fix", TLam (TLam (TVar "b") (TVar "b")) (TVar "b")) ] emptyConState = ConState 0 stdLib infer :: Expr -> Either String Type infer expr = let (ty, ctrs) = evalState (constrain expr) emptyConState in (\cs -> applySubs cs ty) <$> solve ctrs