From 3bfd00c5bdda3111e528f85292e6dc2f70d42bde Mon Sep 17 00:00:00 2001 From: Brian Huffman Date: Mon, 8 Feb 2016 13:31:02 -0800 Subject: [PATCH] Support tuple patterns in saw-script 'let' declarations. Fixes #99. --- src/SAWScript/AST.hs | 8 ++-- src/SAWScript/Interpreter.hs | 48 ++++++++++------------ src/SAWScript/MGU.hs | 80 ++++++++++++++++++++++-------------- src/SAWScript/Parser.y | 7 ++-- 4 files changed, 78 insertions(+), 65 deletions(-) diff --git a/src/SAWScript/AST.hs b/src/SAWScript/AST.hs index 257dc42f0f..a6eca47f14 100644 --- a/src/SAWScript/AST.hs +++ b/src/SAWScript/AST.hs @@ -128,7 +128,7 @@ data DeclGroup deriving (Eq, Show) data Decl - = Decl { dName :: LName, dType :: Maybe Schema, dDef :: Expr } + = Decl { dPat :: Pattern, dType :: Maybe Schema, dDef :: Expr } deriving (Eq, Show) -- }}} @@ -268,9 +268,9 @@ instance Pretty Stmt where --ppName n = ppIdent (P.nameIdent n) prettyDef :: Decl -> PP.Doc -prettyDef Decl{dName,dDef} = - PP.text (getVal dName) PP.<+> - let (args, body) = dissectLambda dDef +prettyDef (Decl pat _ def) = + PP.pretty pat PP.<+> + let (args, body) = dissectLambda def in (if not (null args) then PP.hsep (map PP.pretty args) PP.<> PP.space else PP.empty) PP.<> diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index db3276136f..b4e06d84a8 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -136,20 +136,25 @@ getMergedEnv :: LocalEnv -> TopLevel TopLevelRW getMergedEnv env = mergeLocalEnv env `fmap` getTopLevelRW bindPatternGeneric :: (SS.LName -> Maybe SS.Schema -> Maybe String -> Value -> e -> e) - -> SS.Pattern -> Value -> e -> e -bindPatternGeneric ext pat v env = + -> SS.Pattern -> Maybe SS.Schema -> Value -> e -> e +bindPatternGeneric ext pat ms v env = case pat of SS.PWild _ -> env - SS.PVar x mt -> ext x (fmap SS.tMono mt) Nothing v env + SS.PVar x _ -> ext x ms Nothing v env SS.PTuple ps -> case v of - VTuple vs -> foldr ($) env (zipWith (bindPatternGeneric ext) ps vs) - _ -> error "bindPattern: expected tuple value" - -bindPatternLocal :: SS.Pattern -> Value -> LocalEnv -> LocalEnv + VTuple vs -> foldr ($) env (zipWith3 (bindPatternGeneric ext) ps mss vs) + where mss = case ms of + Nothing -> repeat Nothing + Just (SS.Forall ks (SS.TyCon (SS.TupleCon _) ts)) + -> [ Just (SS.Forall ks t) | t <- ts ] + _ -> error "bindPattern: expected tuple value" + _ -> error "bindPattern: expected tuple value" + +bindPatternLocal :: SS.Pattern -> Maybe SS.Schema -> Value -> LocalEnv -> LocalEnv bindPatternLocal = bindPatternGeneric extendLocal -bindPatternEnv :: SS.Pattern -> Value -> TopLevelRW -> TopLevelRW +bindPatternEnv :: SS.Pattern -> Maybe SS.Schema -> Value -> TopLevelRW -> TopLevelRW bindPatternEnv = bindPatternGeneric extendEnv -- Interpretation of SAWScript ------------------------------------------------- @@ -184,7 +189,7 @@ interpret env expr = case Map.lookup x (rwValues rw) of Nothing -> fail $ "unknown variable: " ++ SS.getVal x Just v -> return v - SS.Function pat e -> do let f v = interpret (bindPatternLocal pat v env) e + SS.Function pat e -> do let f v = interpret (bindPatternLocal pat Nothing v env) e return $ VLambda f SS.Application e1 e2 -> do v1 <- interpret env e1 v2 <- interpret env e2 @@ -196,15 +201,15 @@ interpret env expr = SS.TSig e _ -> interpret env e interpretDecl :: LocalEnv -> SS.Decl -> TopLevel LocalEnv -interpretDecl env (SS.Decl n mt expr) = do +interpretDecl env (SS.Decl pat mt expr) = do v <- interpret env expr - return (extendLocal n mt Nothing v env) + return (bindPatternLocal pat mt v env) interpretFunction :: LocalEnv -> SS.Expr -> Value interpretFunction env expr = case expr of SS.Function pat e -> VLambda f - where f v = interpret (bindPatternLocal pat v env) e + where f v = interpret (bindPatternLocal pat Nothing v env) e SS.TSig e _ -> interpretFunction env e _ -> error "interpretFunction: not a function" @@ -213,7 +218,7 @@ interpretDeclGroup env (SS.NonRecursive d) = interpretDecl env d interpretDeclGroup env (SS.Recursive ds) = return env' where env' = foldr addDecl env ds - addDecl (SS.Decl n mty e) = extendLocal n mty Nothing (interpretFunction env' e) + addDecl (SS.Decl pat mty e) = bindPatternLocal pat mty (interpretFunction env' e) interpretStmts :: LocalEnv -> [SS.Stmt] -> TopLevel Value interpretStmts env stmts = @@ -222,7 +227,7 @@ interpretStmts env stmts = [SS.StmtBind (SS.PWild _) _ e] -> interpret env e SS.StmtBind pat _ e : ss -> do v1 <- interpret env e - let f v = interpretStmts (bindPatternLocal pat v env) ss + let f v = interpretStmts (bindPatternLocal pat Nothing v env) ss bindValue v1 (VLambda f) SS.StmtLet bs : ss -> interpret env (SS.Let bs (SS.Block ss)) SS.StmtCode s : ss -> @@ -251,7 +256,7 @@ processStmtBind printBinds pat _mc expr = do -- mx mt let expr' = case mt of Nothing -> expr Just t -> SS.TSig expr (SS.tBlock ctx t) - let decl = SS.Decl lname Nothing expr' + let decl = SS.Decl pat Nothing expr' rw <- getTopLevelRW let opts = rwPPOpts rw @@ -284,18 +289,7 @@ processStmtBind printBinds pat _mc expr = do -- mx mt _ -> return () rw' <- getTopLevelRW - let pat' = annotatePattern pat ty - putTopLevelRW $ bindPatternEnv pat' result rw' - -annotatePattern :: SS.Pattern -> SS.Type -> SS.Pattern -annotatePattern pat ty = - case pat of - SS.PWild _ -> SS.PWild (Just ty) - SS.PVar x _ -> SS.PVar x (Just ty) - SS.PTuple ps -> - case ty of - SS.TyCon (SS.TupleCon _) ts -> SS.PTuple (zipWith annotatePattern ps ts) - _ -> pat + putTopLevelRW $ bindPatternEnv pat (Just (SS.tMono ty)) result rw' -- | Interpret a block-level statement in the TopLevel monad. interpretStmt :: Bool -> SS.Stmt -> TopLevel () diff --git a/src/SAWScript/MGU.hs b/src/SAWScript/MGU.hs index 6371f043b5..9ead7e53fa 100644 --- a/src/SAWScript/MGU.hs +++ b/src/SAWScript/MGU.hs @@ -28,6 +28,7 @@ import Control.Monad import Control.Monad.Reader import Control.Monad.State import Control.Monad.Identity +import Data.List (genericLength) import Data.Map (Map) import qualified Data.Map as M import qualified Data.Set as S @@ -178,14 +179,6 @@ newTypePattern pat = PTuple ps -> do (ts, ps') <- unzip <$> mapM newTypePattern ps return (tTuple ts, PTuple ps') -skolemType :: Name -> TI Type -skolemType n = TySkolemVar n <$> newTypeIndex - -skolemInst :: Schema -> TI Type -skolemInst (Forall ns t) = do - nts <- mapM (\n -> (,) n <$> skolemType n) ns - return (instantiate nts t) - appSubstM :: AppSubst t => t -> TI t appSubstM t = do s <- TI $ gets subst @@ -216,7 +209,7 @@ bindSchemas bs m = foldr (uncurry bindSchema) m bs bindDecl :: Decl -> TI a -> TI a bindDecl (Decl _ Nothing _) m = m -bindDecl (Decl n (Just s) _) m = bindSchema n s m +bindDecl (Decl p (Just s) _) m = bindPatternSchema p s m bindDeclGroup :: DeclGroup -> TI a -> TI a bindDeclGroup (NonRecursive d) m = bindDecl d m @@ -233,6 +226,17 @@ patternBindings pat = PVar x mt -> [(x, mt)] PTuple ps -> concatMap patternBindings ps +bindPatternSchema :: Pattern -> Schema -> TI a -> TI a +bindPatternSchema pat s@(Forall vs t) m = + case pat of + PWild _ -> m + PVar n _ -> bindSchema n s m + PTuple ps -> + case t of + TyCon (TupleCon _) ts -> foldr ($) m + [ bindPatternSchema p (Forall vs t') | (p, t') <- zip ps ts ] + _ -> m + -- FIXME: This function may miss type variables that occur in the type -- of a binding that has been shadowed by another value with the same -- name. This could potentially cause a run-time type error if the @@ -319,7 +323,7 @@ instance AppSubst DeclGroup where appSubst s (NonRecursive d) = NonRecursive (appSubst s d) instance AppSubst Decl where - appSubst s (Decl n mt e) = Decl n (appSubst s mt) (appSubst s e) + appSubst s (Decl p mt e) = Decl (appSubst s p) (appSubst s mt) (appSubst s e) -- }}} @@ -349,7 +353,6 @@ instance Instantiate Type where type OutExpr = Expr type OutStmt = Stmt - inferE :: (LName, Expr) -> TI (OutExpr,Type) inferE (ln, expr) = case expr of Bit b -> return (Bit b, tBool) @@ -527,31 +530,48 @@ inferStmts m ctx (StmtImport imp : more) = do (more', t) <- inferStmts m ctx more return (StmtImport imp : more', t) +patternLNames :: Pattern -> [LName] +patternLNames pat = + case pat of + PWild _ -> [] + PVar n _ -> [n] + PTuple ps -> concatMap patternLNames ps + +constrainTypeWithPattern :: LName -> Type -> Pattern -> TI () +constrainTypeWithPattern ln t pat = + case pat of + PWild Nothing -> return () + PWild (Just t') -> unify ln t t' + PVar _ Nothing -> return () + PVar _ (Just t') -> unify ln t t' + PTuple ps -> + case t of + TyCon (TupleCon k) ts + | k == genericLength ps -> + sequence_ $ zipWith (constrainTypeWithPattern ln) ts ps + _ -> recordError $ unlines + [ "type mismatch: " ++ pShow (TupleCon (genericLength ps)) ++ " and " ++ pShow t + , " at " ++ show ln + ] + inferDecl :: Decl -> TI Decl -inferDecl (Decl n Nothing e) = do +inferDecl (Decl pat _ e) = do + let n = head (patternLNames pat) (e',t) <- inferE (n, e) + constrainTypeWithPattern n t pat [(e1,s)] <- generalize [e'] [t] - return (Decl n (Just s) e1) - -inferDecl (Decl n (Just s) e) = do - (e', t) <- inferE (n, e) - t' <- skolemInst s - unify n t t' - -- FIXME: make sure the skolem variables didn't "leak" into the surrounding context - return (Decl n (Just s) e') + return (Decl pat (Just s) e1) inferRecDecls :: [Decl] -> TI [Decl] inferRecDecls ds = - do let names = map dName ds - guessedSchemas <- mapM (maybe (tMono <$> newType) return . dType) ds - (es,ts) <- unzip `fmap` - bindSchemas (zip names guessedSchemas) - (mapM inferE [ (n, e) | Decl n _ e <- ds ]) - guessedTypes <- mapM skolemInst guessedSchemas - sequence_ $ zipWith3 unify names ts guessedTypes - (es1,ss) <- unzip `fmap` generalize es ts - return [ Decl n (Just s) e | (n, s, e) <- zip3 names ss es1 ] - + do let pats = map dPat ds + (_ts, pats') <- unzip <$> mapM newTypePattern pats + (es, ts) <- fmap unzip + $ flip (foldr bindPattern) pats' + $ sequence [ inferE (head (patternLNames p), e) | Decl p _ e <- ds ] + sequence_ $ zipWith (constrainTypeWithPattern (error "FIXME")) ts pats' + ess <- generalize es ts + return [ Decl p (Just s) e1 | (p, (e1, s)) <- zip pats ess ] generalize :: [OutExpr] -> [Type] -> TI [(OutExpr,Schema)] generalize es0 ts0 = diff --git a/src/SAWScript/Parser.y b/src/SAWScript/Parser.y index 4ff02811dc..c86cae3af9 100644 --- a/src/SAWScript/Parser.y +++ b/src/SAWScript/Parser.y @@ -116,10 +116,9 @@ Stmt :: { Stmt } | 'import' Import { StmtImport $2 } Declaration :: { Decl } - : name list(Arg) '=' Expression { Decl (toLName $1) Nothing (buildFunction $2 $4) } - | name list(Arg) ':' Type '=' Expression - { Decl (toLName $1) Nothing (buildFunction $2 (TSig $6 $4)) } - + : Arg list(Arg) '=' Expression { Decl $1 Nothing (buildFunction $2 $4) } + | Arg list(Arg) ':' Type '=' Expression + { Decl $1 Nothing (buildFunction $2 (TSig $6 $4)) } Pattern :: { Pattern } : Arg { $1 }