diff --git a/saw-core/src/Verifier/SAW/SharedTerm.hs b/saw-core/src/Verifier/SAW/SharedTerm.hs index dc9063e696..41e8f9d39c 100644 --- a/saw-core/src/Verifier/SAW/SharedTerm.hs +++ b/saw-core/src/Verifier/SAW/SharedTerm.hs @@ -255,6 +255,7 @@ import Control.Lens import Control.Monad.State.Strict as State import Control.Monad.Reader import Data.Bits +import Data.List (inits) import Data.Maybe import qualified Data.Foldable as Fold import Data.Foldable (foldl', foldlM, foldrM, maximum) @@ -286,6 +287,7 @@ import Verifier.SAW.Term.CtxTerm import Verifier.SAW.Term.Pretty import Verifier.SAW.TypedAST import Verifier.SAW.Unique +import Verifier.SAW.Utils (panic) #if !MIN_VERSION_base(4,8,0) countTrailingZeros :: (FiniteBits b) => b -> Int @@ -1058,7 +1060,7 @@ instantiateLocalVars sc f initialLevel t0 = go' _ tf@(Constant {}) = scTermF sc tf instantiateVars :: SharedContext - -> (DeBruijnIndex -> Either (ExtCns Term) DeBruijnIndex -> IO Term) + -> ((Term -> IO Term) -> DeBruijnIndex -> Either (ExtCns Term) DeBruijnIndex -> IO Term) -> DeBruijnIndex -> Term -> IO Term instantiateVars sc f initialLevel t0 = do cache <- newCache @@ -1073,14 +1075,14 @@ instantiateVars sc f initialLevel t0 = go' :: (?cache :: Cache IO (TermIndex, DeBruijnIndex) Term) => DeBruijnIndex -> TermF Term -> IO Term - go' l (FTermF (ExtCns ec)) = f l (Left ec) + go' l (FTermF (ExtCns ec)) = f (go l) l (Left ec) go' l (FTermF tf) = scFlatTermF sc =<< (traverse (go l) tf) go' l (App x y) = scTermF sc =<< (App <$> go l x <*> go l y) go' l (Lambda i tp rhs) = scTermF sc =<< (Lambda i <$> go l tp <*> go (l+1) rhs) go' l (Pi i lhs rhs) = scTermF sc =<< (Pi i <$> go l lhs <*> go (l+1) rhs) go' l (LocalVar i) | i < l = scTermF sc (LocalVar i) - | otherwise = f l (Right i) + | otherwise = f (go l) l (Right i) go' _ tf@(Constant {}) = scTermF sc tf -- | @incVars k j t@ increments free variables at least @k@ by @j@. @@ -2185,17 +2187,20 @@ getConstantSet t = snd $ go (IntSet.empty, Map.empty) t Constant (EC vidx n ty) body -> (idxs, Map.insert vidx (n, ty, body) names) _ -> foldl' go acc tf --- | Instantiate some of the external constants +-- | Instantiate some of the external constants. +-- Note: this replacement is _not_ applied recursively +-- to the terms in the replacement map; so external constants +-- in those terms will not be replaced. scInstantiateExt :: SharedContext -> Map VarIndex Term -> Term -> IO Term scInstantiateExt sc vmap = instantiateVars sc fn 0 - where fn l (Left ec) = + where fn _rec l (Left ec) = case Map.lookup (ecVarIndex ec) vmap of Just t -> incVars sc 0 l t Nothing -> scFlatTermF sc $ ExtCns ec - fn _ (Right i) = scTermF sc $ LocalVar i + fn _ _ (Right i) = scTermF sc $ LocalVar i {- -- RWD: I'm pretty sure the following implementation gets incorrect results when @@ -2233,12 +2238,12 @@ scExtsToLocals _ [] x = return x scExtsToLocals sc exts x = instantiateVars sc fn 0 x where m = Map.fromList [ (ecVarIndex ec, k) | (ec, k) <- zip (reverse exts) [0 ..] ] - fn l e = + fn rec l e = case e of Left ec -> case Map.lookup (ecVarIndex ec) m of - Just k -> scLocalVar sc (l + k) - Nothing -> scFlatTermF sc (ExtCns ec) + Just k -> scLocalVar sc (l + k) + Nothing -> scFlatTermF sc . ExtCns =<< traverse rec ec Right i -> scLocalVar sc (i + length exts) @@ -2247,18 +2252,43 @@ scExtsToLocals sc exts x = instantiateVars sc fn 0 x -- occurrences with the appropriate local variables. scAbstractExts :: SharedContext -> [ExtCns Term] -> Term -> IO Term scAbstractExts _ [] x = return x -scAbstractExts sc exts x = - do let lams = [ (toShortName (ecName ec), ecType ec) | ec <- exts ] - scLambdaList sc lams =<< scExtsToLocals sc exts x +scAbstractExts sc exts x = loop (inits exts) exts + where + loop :: [[ExtCns Term]] -> [ExtCns Term] -> IO Term + loop ([]:bgs) (ec:ecs) = + do tm' <- loop bgs ecs + scLambda sc (toShortName (ecName ec)) (ecType ec) tm' + + loop (begin:bgs) (ec:ecs) = + do tm' <- loop bgs ecs + tp' <- scExtsToLocals sc begin (ecType ec) + scLambda sc (toShortName (ecName ec)) tp' tm' + + loop _ [] = scExtsToLocals sc exts x + + loop _ _ = panic "scAbstractExts" ["list size mismatch"] -- | Generalize over the given list of external constants by wrapping -- the given term with foralls and replacing the external constant -- occurrences with the appropriate local variables. scGeneralizeExts :: SharedContext -> [ExtCns Term] -> Term -> IO Term scGeneralizeExts _ [] x = return x -scGeneralizeExts sc exts x = - do let pis = [ (toShortName (ecName ec), ecType ec) | ec <- exts ] - scPiList sc pis =<< scExtsToLocals sc exts x +scGeneralizeExts sc exts x = loop (inits exts) exts + where + loop :: [[ExtCns Term]] -> [ExtCns Term] -> IO Term + loop ([]:bgs) (ec:ecs) = + do tm' <- loop bgs ecs + scPi sc (toShortName (ecName ec)) (ecType ec) tm' + + loop (begin:bgs) (ec:ecs) = + do tm' <- loop bgs ecs + tp' <- scExtsToLocals sc begin (ecType ec) + scPi sc (toShortName (ecName ec)) tp' tm' + + loop _ [] = scExtsToLocals sc exts x + + loop _ _ = panic "scGeneralizeExts" ["list size mismatch"] + scUnfoldConstants :: SharedContext -> [VarIndex] -> Term -> IO Term scUnfoldConstants sc names t0 = scUnfoldConstantSet sc True (Set.fromList names) t0