diff --git a/src/Verifier/SAW/Rewriter.hs b/src/Verifier/SAW/Rewriter.hs index 48b50a7c..1479e090 100644 --- a/src/Verifier/SAW/Rewriter.hs +++ b/src/Verifier/SAW/Rewriter.hs @@ -60,6 +60,7 @@ import Data.Foldable (Foldable) import Control.Lens import Control.Monad.Identity import Control.Monad.State +import Control.Monad.Trans.Maybe import Data.Bits import qualified Data.Foldable as Foldable import Data.IORef (IORef) @@ -119,6 +120,32 @@ first_order_match pat term = match pat term Map.empty -- occur as the 2nd argument of an @App@ constructor. This ensures -- that instantiations are well-typed. +-- | An enhanced matcher that can handle some patterns containing lambdas. +scMatch :: forall s. SharedContext s -> SharedTerm s -> SharedTerm s -> MaybeT IO (Map DeBruijnIndex (SharedTerm s)) +scMatch sc pat term = match 0 pat term Map.empty + where + match :: Int -> SharedTerm s -> SharedTerm s -> Map DeBruijnIndex (SharedTerm s) -> MaybeT IO (Map DeBruijnIndex (SharedTerm s)) + match depth x y m = + case (unwrapTermF x, unwrapTermF y) of + -- check that neither x nor y contains bound variables less than `depth` + (LocalVar i, _) | i >= depth && looseVars y .&. (bit depth - 1) == 0 -> + do -- decrement loose variables in y by `depth` + y1 <- lift $ S.instantiateVarList sc 0 (replicate depth (error "scMatch: impossible")) y + let (my2, m') = insertLookup (i - depth) y1 m + case my2 of + Nothing -> return m' + Just y2 -> if y == y2 then return m' else mzero + (App x1 x2, App y1 y2) -> + match depth x1 y1 m >>= match depth x2 y2 + (FTermF xf, FTermF yf) -> + case zipWithFlatTermF (match depth) xf yf of + Nothing -> mzero + Just zf -> Foldable.foldl (>=>) return zf m + (Lambda _ t1 x1, Lambda _ t2 x2) -> + match depth t1 t2 m >>= match (depth + 1) x1 x2 + (_, _) -> + if x == y then return m else mzero + ---------------------------------------------------------------------- -- Building rewrite rules @@ -356,8 +383,9 @@ rewriteSharedTerm sc ss t0 = [Either (RewriteRule (SharedTerm s)) (Conversion (SharedTerm s))] -> SharedTerm s -> IO (SharedTerm s) apply [] t = return t - apply (Left (RewriteRule {lhs, rhs}) : rules) t = - case first_order_match lhs t of + apply (Left (RewriteRule {lhs, rhs}) : rules) t = do + result <- runMaybeT (scMatch sc lhs t) + case result of Nothing -> apply rules t Just inst | lhs == rhs ->