Skip to content

Commit

Permalink
Improve scAbstractExts and scGeneralizeExts
Browse files Browse the repository at this point in the history
Previously, these functions were not working correctly when
the _types_ of some `ExtCns` values mention some of the
`ExtCns` values being abstracted or generalized.  Now we are
careful to both evaluate inside the types of `ExtCns` values
occuring inside the term, as well as making sure that the
types of `ExtCns` values in the list are processed properly
in the context of the outer values.
  • Loading branch information
robdockins committed Jun 14, 2021
1 parent dd5a89b commit de30403
Showing 1 changed file with 59 additions and 15 deletions.
74 changes: 59 additions & 15 deletions saw-core/src/Verifier/SAW/SharedTerm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ import Control.Lens
import Control.Monad.State.Strict as State
import Control.Monad.Reader
import Data.Bits
import Data.List (inits)
import Data.Maybe
import qualified Data.Foldable as Fold
import Data.Foldable (foldl', foldlM, foldrM, maximum)
Expand Down Expand Up @@ -1058,7 +1059,7 @@ instantiateLocalVars sc f initialLevel t0 =
go' _ tf@(Constant {}) = scTermF sc tf

instantiateVars :: SharedContext
-> (DeBruijnIndex -> Either (ExtCns Term) DeBruijnIndex -> IO Term)
-> ((Term -> IO Term) -> DeBruijnIndex -> Either (ExtCns Term) DeBruijnIndex -> IO Term)
-> DeBruijnIndex -> Term -> IO Term
instantiateVars sc f initialLevel t0 =
do cache <- newCache
Expand All @@ -1073,14 +1074,14 @@ instantiateVars sc f initialLevel t0 =

go' :: (?cache :: Cache IO (TermIndex, DeBruijnIndex) Term) =>
DeBruijnIndex -> TermF Term -> IO Term
go' l (FTermF (ExtCns ec)) = f l (Left ec)
go' l (FTermF (ExtCns ec)) = f (go l) l (Left ec)
go' l (FTermF tf) = scFlatTermF sc =<< (traverse (go l) tf)
go' l (App x y) = scTermF sc =<< (App <$> go l x <*> go l y)
go' l (Lambda i tp rhs) = scTermF sc =<< (Lambda i <$> go l tp <*> go (l+1) rhs)
go' l (Pi i lhs rhs) = scTermF sc =<< (Pi i <$> go l lhs <*> go (l+1) rhs)
go' l (LocalVar i)
| i < l = scTermF sc (LocalVar i)
| otherwise = f l (Right i)
| otherwise = f (go l) l (Right i)
go' _ tf@(Constant {}) = scTermF sc tf

-- | @incVars k j t@ increments free variables at least @k@ by @j@.
Expand Down Expand Up @@ -2185,17 +2186,20 @@ getConstantSet t = snd $ go (IntSet.empty, Map.empty) t
Constant (EC vidx n ty) body -> (idxs, Map.insert vidx (n, ty, body) names)
_ -> foldl' go acc tf

-- | Instantiate some of the external constants
-- | Instantiate some of the external constants.
-- Note: this replacement is _not_ applied recursively
-- to the terms in the replacement map; so external constants
-- in those terms will not be replaced.
scInstantiateExt :: SharedContext
-> Map VarIndex Term
-> Term
-> IO Term
scInstantiateExt sc vmap = instantiateVars sc fn 0
where fn l (Left ec) =
where fn _rec l (Left ec) =
case Map.lookup (ecVarIndex ec) vmap of
Just t -> incVars sc 0 l t
Nothing -> scFlatTermF sc $ ExtCns ec
fn _ (Right i) = scTermF sc $ LocalVar i
fn _ _ (Right i) = scTermF sc $ LocalVar i

{-
-- RWD: I'm pretty sure the following implementation gets incorrect results when
Expand Down Expand Up @@ -2233,12 +2237,12 @@ scExtsToLocals _ [] x = return x
scExtsToLocals sc exts x = instantiateVars sc fn 0 x
where
m = Map.fromList [ (ecVarIndex ec, k) | (ec, k) <- zip (reverse exts) [0 ..] ]
fn l e =
fn rec l e =
case e of
Left ec ->
case Map.lookup (ecVarIndex ec) m of
Just k -> scLocalVar sc (l + k)
Nothing -> scFlatTermF sc (ExtCns ec)
Just k -> scLocalVar sc (l + k)
Nothing -> scFlatTermF sc . ExtCns =<< traverse rec ec
Right i ->
scLocalVar sc (i + length exts)

Expand All @@ -2247,18 +2251,58 @@ scExtsToLocals sc exts x = instantiateVars sc fn 0 x
-- occurrences with the appropriate local variables.
scAbstractExts :: SharedContext -> [ExtCns Term] -> Term -> IO Term
scAbstractExts _ [] x = return x
scAbstractExts sc exts x =
do let lams = [ (toShortName (ecName ec), ecType ec) | ec <- exts ]
scLambdaList sc lams =<< scExtsToLocals sc exts x
scAbstractExts sc exts x = loop (zip (inits exts) exts)
where
-- each pair contains a single ExtCns and a list of all
-- the ExtCns values that appear before it in the original list.
loop :: [([ExtCns Term], ExtCns Term)] -> IO Term

-- specical case: outermost variable, no need to abstract
-- inside the type of ec
loop (([],ec):ecs) =
do tm' <- loop ecs
scLambda sc (toShortName (ecName ec)) (ecType ec) tm'

-- ordinary case. We need to abstract over all the ExtCns in @begin@
-- before apply scLambda. This ensures any dependenices between the
-- types are handled correctly.
loop ((begin,ec):ecs) =
do tm' <- loop ecs
tp' <- scExtsToLocals sc begin (ecType ec)
scLambda sc (toShortName (ecName ec)) tp' tm'

-- base case, convert all the exts in the body of x into deBruijn variables
loop [] = scExtsToLocals sc exts x


-- | Generalize over the given list of external constants by wrapping
-- the given term with foralls and replacing the external constant
-- occurrences with the appropriate local variables.
scGeneralizeExts :: SharedContext -> [ExtCns Term] -> Term -> IO Term
scGeneralizeExts _ [] x = return x
scGeneralizeExts sc exts x =
do let pis = [ (toShortName (ecName ec), ecType ec) | ec <- exts ]
scPiList sc pis =<< scExtsToLocals sc exts x
scGeneralizeExts sc exts x = loop (zip (inits exts) exts)
where
-- each pair contains a single ExtCns and a list of all
-- the ExtCns values that appear before it in the original list.
loop :: [([ExtCns Term], ExtCns Term)] -> IO Term

-- specical case: outermost variable, no need to abstract
-- inside the type of ec
loop (([],ec):ecs) =
do tm' <- loop ecs
scPi sc (toShortName (ecName ec)) (ecType ec) tm'

-- ordinary case. We need to abstract over all the ExtCns in @begin@
-- before apply scLambda. This ensures any dependenices between the
-- types are handled correctly.
loop ((begin,ec):ecs) =
do tm' <- loop ecs
tp' <- scExtsToLocals sc begin (ecType ec)
scPi sc (toShortName (ecName ec)) tp' tm'

-- base case, convert all the exts in the body of x into deBruijn variables
loop [] = scExtsToLocals sc exts x


scUnfoldConstants :: SharedContext -> [VarIndex] -> Term -> IO Term
scUnfoldConstants sc names t0 = scUnfoldConstantSet sc True (Set.fromList names) t0
Expand Down

0 comments on commit de30403

Please sign in to comment.