Skip to content

Commit

Permalink
Support tuple patterns in saw-script 'let' declarations.
Browse files Browse the repository at this point in the history
Fixes #99.
  • Loading branch information
Brian Huffman committed Feb 8, 2016
1 parent 46c4fc4 commit 3bfd00c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 65 deletions.
8 changes: 4 additions & 4 deletions src/SAWScript/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ data DeclGroup
deriving (Eq, Show)

data Decl
= Decl { dName :: LName, dType :: Maybe Schema, dDef :: Expr }
= Decl { dPat :: Pattern, dType :: Maybe Schema, dDef :: Expr }
deriving (Eq, Show)

-- }}}
Expand Down Expand Up @@ -268,9 +268,9 @@ instance Pretty Stmt where
--ppName n = ppIdent (P.nameIdent n)

prettyDef :: Decl -> PP.Doc
prettyDef Decl{dName,dDef} =
PP.text (getVal dName) PP.<+>
let (args, body) = dissectLambda dDef
prettyDef (Decl pat _ def) =
PP.pretty pat PP.<+>
let (args, body) = dissectLambda def
in (if not (null args)
then PP.hsep (map PP.pretty args) PP.<> PP.space
else PP.empty) PP.<>
Expand Down
48 changes: 21 additions & 27 deletions src/SAWScript/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,25 @@ getMergedEnv :: LocalEnv -> TopLevel TopLevelRW
getMergedEnv env = mergeLocalEnv env `fmap` getTopLevelRW

bindPatternGeneric :: (SS.LName -> Maybe SS.Schema -> Maybe String -> Value -> e -> e)
-> SS.Pattern -> Value -> e -> e
bindPatternGeneric ext pat v env =
-> SS.Pattern -> Maybe SS.Schema -> Value -> e -> e
bindPatternGeneric ext pat ms v env =
case pat of
SS.PWild _ -> env
SS.PVar x mt -> ext x (fmap SS.tMono mt) Nothing v env
SS.PVar x _ -> ext x ms Nothing v env
SS.PTuple ps ->
case v of
VTuple vs -> foldr ($) env (zipWith (bindPatternGeneric ext) ps vs)
_ -> error "bindPattern: expected tuple value"

bindPatternLocal :: SS.Pattern -> Value -> LocalEnv -> LocalEnv
VTuple vs -> foldr ($) env (zipWith3 (bindPatternGeneric ext) ps mss vs)
where mss = case ms of
Nothing -> repeat Nothing
Just (SS.Forall ks (SS.TyCon (SS.TupleCon _) ts))
-> [ Just (SS.Forall ks t) | t <- ts ]
_ -> error "bindPattern: expected tuple value"
_ -> error "bindPattern: expected tuple value"

bindPatternLocal :: SS.Pattern -> Maybe SS.Schema -> Value -> LocalEnv -> LocalEnv
bindPatternLocal = bindPatternGeneric extendLocal

bindPatternEnv :: SS.Pattern -> Value -> TopLevelRW -> TopLevelRW
bindPatternEnv :: SS.Pattern -> Maybe SS.Schema -> Value -> TopLevelRW -> TopLevelRW
bindPatternEnv = bindPatternGeneric extendEnv

-- Interpretation of SAWScript -------------------------------------------------
Expand Down Expand Up @@ -184,7 +189,7 @@ interpret env expr =
case Map.lookup x (rwValues rw) of
Nothing -> fail $ "unknown variable: " ++ SS.getVal x
Just v -> return v
SS.Function pat e -> do let f v = interpret (bindPatternLocal pat v env) e
SS.Function pat e -> do let f v = interpret (bindPatternLocal pat Nothing v env) e
return $ VLambda f
SS.Application e1 e2 -> do v1 <- interpret env e1
v2 <- interpret env e2
Expand All @@ -196,15 +201,15 @@ interpret env expr =
SS.TSig e _ -> interpret env e

interpretDecl :: LocalEnv -> SS.Decl -> TopLevel LocalEnv
interpretDecl env (SS.Decl n mt expr) = do
interpretDecl env (SS.Decl pat mt expr) = do
v <- interpret env expr
return (extendLocal n mt Nothing v env)
return (bindPatternLocal pat mt v env)

interpretFunction :: LocalEnv -> SS.Expr -> Value
interpretFunction env expr =
case expr of
SS.Function pat e -> VLambda f
where f v = interpret (bindPatternLocal pat v env) e
where f v = interpret (bindPatternLocal pat Nothing v env) e
SS.TSig e _ -> interpretFunction env e
_ -> error "interpretFunction: not a function"

Expand All @@ -213,7 +218,7 @@ interpretDeclGroup env (SS.NonRecursive d) = interpretDecl env d
interpretDeclGroup env (SS.Recursive ds) = return env'
where
env' = foldr addDecl env ds
addDecl (SS.Decl n mty e) = extendLocal n mty Nothing (interpretFunction env' e)
addDecl (SS.Decl pat mty e) = bindPatternLocal pat mty (interpretFunction env' e)

interpretStmts :: LocalEnv -> [SS.Stmt] -> TopLevel Value
interpretStmts env stmts =
Expand All @@ -222,7 +227,7 @@ interpretStmts env stmts =
[SS.StmtBind (SS.PWild _) _ e] -> interpret env e
SS.StmtBind pat _ e : ss ->
do v1 <- interpret env e
let f v = interpretStmts (bindPatternLocal pat v env) ss
let f v = interpretStmts (bindPatternLocal pat Nothing v env) ss
bindValue v1 (VLambda f)
SS.StmtLet bs : ss -> interpret env (SS.Let bs (SS.Block ss))
SS.StmtCode s : ss ->
Expand Down Expand Up @@ -251,7 +256,7 @@ processStmtBind printBinds pat _mc expr = do -- mx mt
let expr' = case mt of
Nothing -> expr
Just t -> SS.TSig expr (SS.tBlock ctx t)
let decl = SS.Decl lname Nothing expr'
let decl = SS.Decl pat Nothing expr'
rw <- getTopLevelRW
let opts = rwPPOpts rw

Expand Down Expand Up @@ -284,18 +289,7 @@ processStmtBind printBinds pat _mc expr = do -- mx mt
_ -> return ()

rw' <- getTopLevelRW
let pat' = annotatePattern pat ty
putTopLevelRW $ bindPatternEnv pat' result rw'

annotatePattern :: SS.Pattern -> SS.Type -> SS.Pattern
annotatePattern pat ty =
case pat of
SS.PWild _ -> SS.PWild (Just ty)
SS.PVar x _ -> SS.PVar x (Just ty)
SS.PTuple ps ->
case ty of
SS.TyCon (SS.TupleCon _) ts -> SS.PTuple (zipWith annotatePattern ps ts)
_ -> pat
putTopLevelRW $ bindPatternEnv pat (Just (SS.tMono ty)) result rw'

-- | Interpret a block-level statement in the TopLevel monad.
interpretStmt :: Bool -> SS.Stmt -> TopLevel ()
Expand Down
80 changes: 50 additions & 30 deletions src/SAWScript/MGU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Identity
import Data.List (genericLength)
import Data.Map (Map)
import qualified Data.Map as M
import qualified Data.Set as S
Expand Down Expand Up @@ -178,14 +179,6 @@ newTypePattern pat =
PTuple ps -> do (ts, ps') <- unzip <$> mapM newTypePattern ps
return (tTuple ts, PTuple ps')

skolemType :: Name -> TI Type
skolemType n = TySkolemVar n <$> newTypeIndex

skolemInst :: Schema -> TI Type
skolemInst (Forall ns t) = do
nts <- mapM (\n -> (,) n <$> skolemType n) ns
return (instantiate nts t)

appSubstM :: AppSubst t => t -> TI t
appSubstM t = do
s <- TI $ gets subst
Expand Down Expand Up @@ -216,7 +209,7 @@ bindSchemas bs m = foldr (uncurry bindSchema) m bs

bindDecl :: Decl -> TI a -> TI a
bindDecl (Decl _ Nothing _) m = m
bindDecl (Decl n (Just s) _) m = bindSchema n s m
bindDecl (Decl p (Just s) _) m = bindPatternSchema p s m

bindDeclGroup :: DeclGroup -> TI a -> TI a
bindDeclGroup (NonRecursive d) m = bindDecl d m
Expand All @@ -233,6 +226,17 @@ patternBindings pat =
PVar x mt -> [(x, mt)]
PTuple ps -> concatMap patternBindings ps

bindPatternSchema :: Pattern -> Schema -> TI a -> TI a
bindPatternSchema pat s@(Forall vs t) m =
case pat of
PWild _ -> m
PVar n _ -> bindSchema n s m
PTuple ps ->
case t of
TyCon (TupleCon _) ts -> foldr ($) m
[ bindPatternSchema p (Forall vs t') | (p, t') <- zip ps ts ]
_ -> m

-- FIXME: This function may miss type variables that occur in the type
-- of a binding that has been shadowed by another value with the same
-- name. This could potentially cause a run-time type error if the
Expand Down Expand Up @@ -319,7 +323,7 @@ instance AppSubst DeclGroup where
appSubst s (NonRecursive d) = NonRecursive (appSubst s d)

instance AppSubst Decl where
appSubst s (Decl n mt e) = Decl n (appSubst s mt) (appSubst s e)
appSubst s (Decl p mt e) = Decl (appSubst s p) (appSubst s mt) (appSubst s e)

-- }}}

Expand Down Expand Up @@ -349,7 +353,6 @@ instance Instantiate Type where
type OutExpr = Expr
type OutStmt = Stmt


inferE :: (LName, Expr) -> TI (OutExpr,Type)
inferE (ln, expr) = case expr of
Bit b -> return (Bit b, tBool)
Expand Down Expand Up @@ -527,31 +530,48 @@ inferStmts m ctx (StmtImport imp : more) = do
(more', t) <- inferStmts m ctx more
return (StmtImport imp : more', t)

patternLNames :: Pattern -> [LName]
patternLNames pat =
case pat of
PWild _ -> []
PVar n _ -> [n]
PTuple ps -> concatMap patternLNames ps

constrainTypeWithPattern :: LName -> Type -> Pattern -> TI ()
constrainTypeWithPattern ln t pat =
case pat of
PWild Nothing -> return ()
PWild (Just t') -> unify ln t t'
PVar _ Nothing -> return ()
PVar _ (Just t') -> unify ln t t'
PTuple ps ->
case t of
TyCon (TupleCon k) ts
| k == genericLength ps ->
sequence_ $ zipWith (constrainTypeWithPattern ln) ts ps
_ -> recordError $ unlines
[ "type mismatch: " ++ pShow (TupleCon (genericLength ps)) ++ " and " ++ pShow t
, " at " ++ show ln
]

inferDecl :: Decl -> TI Decl
inferDecl (Decl n Nothing e) = do
inferDecl (Decl pat _ e) = do
let n = head (patternLNames pat)
(e',t) <- inferE (n, e)
constrainTypeWithPattern n t pat
[(e1,s)] <- generalize [e'] [t]
return (Decl n (Just s) e1)

inferDecl (Decl n (Just s) e) = do
(e', t) <- inferE (n, e)
t' <- skolemInst s
unify n t t'
-- FIXME: make sure the skolem variables didn't "leak" into the surrounding context
return (Decl n (Just s) e')
return (Decl pat (Just s) e1)

inferRecDecls :: [Decl] -> TI [Decl]
inferRecDecls ds =
do let names = map dName ds
guessedSchemas <- mapM (maybe (tMono <$> newType) return . dType) ds
(es,ts) <- unzip `fmap`
bindSchemas (zip names guessedSchemas)
(mapM inferE [ (n, e) | Decl n _ e <- ds ])
guessedTypes <- mapM skolemInst guessedSchemas
sequence_ $ zipWith3 unify names ts guessedTypes
(es1,ss) <- unzip `fmap` generalize es ts
return [ Decl n (Just s) e | (n, s, e) <- zip3 names ss es1 ]

do let pats = map dPat ds
(_ts, pats') <- unzip <$> mapM newTypePattern pats
(es, ts) <- fmap unzip
$ flip (foldr bindPattern) pats'
$ sequence [ inferE (head (patternLNames p), e) | Decl p _ e <- ds ]
sequence_ $ zipWith (constrainTypeWithPattern (error "FIXME")) ts pats'
ess <- generalize es ts
return [ Decl p (Just s) e1 | (p, (e1, s)) <- zip pats ess ]

generalize :: [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize es0 ts0 =
Expand Down
7 changes: 3 additions & 4 deletions src/SAWScript/Parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,9 @@ Stmt :: { Stmt }
| 'import' Import { StmtImport $2 }

Declaration :: { Decl }
: name list(Arg) '=' Expression { Decl (toLName $1) Nothing (buildFunction $2 $4) }
| name list(Arg) ':' Type '=' Expression
{ Decl (toLName $1) Nothing (buildFunction $2 (TSig $6 $4)) }

: Arg list(Arg) '=' Expression { Decl $1 Nothing (buildFunction $2 $4) }
| Arg list(Arg) ':' Type '=' Expression
{ Decl $1 Nothing (buildFunction $2 (TSig $6 $4)) }

Pattern :: { Pattern }
: Arg { $1 }
Expand Down

0 comments on commit 3bfd00c

Please sign in to comment.