diff --git a/src/Verifier/SAW/Rewriter.hs b/src/Verifier/SAW/Rewriter.hs index aa25da6c..275bf62f 100644 --- a/src/Verifier/SAW/Rewriter.hs +++ b/src/Verifier/SAW/Rewriter.hs @@ -58,6 +58,7 @@ import Control.Applicative ((<$>), pure, (<*>)) import Data.Foldable (Foldable) #endif import Control.Applicative (Alternative) +import Control.Monad (guard) import Control.Monad.Identity import Control.Monad.State import Control.Monad.Trans.Maybe @@ -76,6 +77,7 @@ import Verifier.SAW.Cache import Verifier.SAW.Conversion import qualified Verifier.SAW.Recognizer as R import Verifier.SAW.SharedTerm +import Verifier.SAW.Term.Functor import Verifier.SAW.TypedAST import qualified Verifier.SAW.TermNet as Net @@ -146,12 +148,13 @@ bottom_convs convs t = do _ -> return t fromMaybe (return t') $ msum [ runConversion c t' | c <- convs ] --- | An enhanced matcher that can handle some patterns containing lambdas. +-- | An enhanced matcher that can handle higher-order patterns. scMatch :: SharedContext -> Term -> Term -> MaybeT IO (Map DeBruijnIndex Term) -scMatch sc pat term = do - MatchState inst cs <- match 0 pat term emptyMatchState - mapM_ (check inst) cs - return inst +scMatch sc pat term = + do --lift $ putStrLn $ "********** scMatch **********" + MatchState inst cs <- match 0 [] pat term emptyMatchState + mapM_ (check inst) cs + return inst where check :: Map DeBruijnIndex Term -> (Term, Integer) -> MaybeT IO () check inst (t, n) = do @@ -167,34 +170,70 @@ scMatch sc pat term = do FTermF (NatLit i) | i == n -> return () _ -> mzero - match :: Int -> Term -> Term -> MatchState -> MaybeT IO MatchState - match depth x y s@(MatchState m cs) = do - --lift $ putStrLn $ "matching (lhs): " ++ show x - --lift $ putStrLn $ "matching (rhs): " ++ show y - case (unwrapTermF x, unwrapTermF y) of - -- check that neither x nor y contains bound variables less than `depth` - (LocalVar i, _) | i >= depth && - (looseVars y `intersectBitSets` (completeBitSet depth) - == emptyBitSet) -> - do -- decrement loose variables in y by `depth` - y1 <- lift $ instantiateVarList sc 0 (replicate depth (error "scMatch: impossible")) y - let (my2, m') = insertLookup (i - depth) y1 m - case my2 of + asVarPat :: Int -> Term -> Maybe (DeBruijnIndex, [DeBruijnIndex]) + asVarPat depth = go [] + where + go js x = + case unwrapTermF x of + LocalVar i + | i >= depth -> Just (i, js) + | otherwise -> Nothing + App t (unwrapTermF -> LocalVar j) + | j < depth -> go (j : js) t + _ -> Nothing + + match :: Int -> [(String, Term)] -> Term -> Term -> MatchState -> MaybeT IO MatchState + match depth env x y s@(MatchState m cs) = + --do + --lift $ putStrLn $ "matching (lhs): " ++ scPrettyTerm defaultPPOpts x + --lift $ putStrLn $ "matching (rhs): " ++ scPrettyTerm defaultPPOpts y + case asVarPat depth x of + Just (i, js) -> + do -- ensure parameter variables are distinct + guard (Set.size (Set.fromList js) == length js) + -- ensure y mentions only variables that are in js + let fvj = foldl unionBitSets emptyBitSet (map singletonBitSet js) + let fvy = looseVars y `intersectBitSets` (completeBitSet depth) + guard (fvy `unionBitSets` fvj == fvj) + let fixVar t (nm, ty) = + do v <- scFreshGlobal sc nm ty + ec <- R.asExtCns v + t' <- instantiateVar sc 0 v t + return (t', ec) + let fixVars t [] = return (t, []) + fixVars t (ty : tys) = + do (t', ec) <- fixVar t ty + (t'', ecs) <- fixVars t' tys + return (t'', ec : ecs) + -- replace local bound variables with global ones + -- this also decrements loose variables in y by `depth` + (y1, ecs) <- lift $ fixVars y env + -- replace global variables with reindexed bound vars + -- y2 should have no more of the newly-created ExtCns vars + y2 <- lift $ scAbstractExts sc [ ecs !! j | j <- js ] y1 + let (my3, m') = insertLookup (i - depth) y2 m + case my3 of Nothing -> return (MatchState m' cs) - Just y2 -> if y == y2 then return (MatchState m' cs) else mzero - (App x1 x2, App y1 y2) -> - match depth x1 y1 s >>= match depth x2 y2 - (FTermF xf, FTermF yf) -> - case zipWithFlatTermF (match depth) xf yf of - Nothing -> mzero - Just zf -> Foldable.foldl (>=>) return zf s - (Lambda _ t1 x1, Lambda _ t2 x2) -> - match depth t1 t2 s >>= match (depth + 1) x1 x2 - (App _ _, FTermF (NatLit n)) -> - -- add deferred constraint - return (MatchState m ((x, n) : cs)) - (_, _) -> - if x == y then return s else mzero + Just y3 -> if y2 == y3 then return (MatchState m' cs) else mzero + Nothing -> + case (unwrapTermF x, unwrapTermF y) of + -- check that neither x nor y contains bound variables less than `depth` + (FTermF xf, FTermF yf) -> + case zipWithFlatTermF (match depth env) xf yf of + Nothing -> mzero + Just zf -> Foldable.foldl (>=>) return zf s + (App x1 x2, App y1 y2) -> + match depth env x1 y1 s >>= match depth env x2 y2 + (Lambda _ t1 x1, Lambda nm t2 x2) -> + match depth env t1 t2 s >>= match (depth + 1) ((nm, t2) : env) x1 x2 + (Pi _ t1 x1, Pi nm t2 x2) -> + match depth env t1 t2 s >>= match (depth + 1) ((nm, t2) : env) x1 x2 + (App _ _, FTermF (NatLit n)) -> + -- add deferred constraint + return (MatchState m ((x, n) : cs)) + (_, _) -> + -- other possible matches are local vars and constants + if x == y then return s else mzero ---------------------------------------------------------------------- -- Building rewrite rules