Skip to content

Commit

Permalink
add heterogenous equality on return types in MRSolver
Browse files Browse the repository at this point in the history
  • Loading branch information
m-yac committed May 19, 2022
1 parent 5d9bc4b commit 45de464
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 76 deletions.
4 changes: 1 addition & 3 deletions heapster-saw/examples/sha512_mr_solver.saw
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,4 @@ monadify_term {{ Maj }};
monadify_term {{ round_00_15_spec }};

run_test "round_00_15 |= round_00_15_spec" (mr_solver round_00_15 {{ round_00_15_spec }}) true;

// FIXME: Need to add heterogenous equality on output types for this to work
// run_test "round_16_80 |= round_16_80_spec" (mr_solver_debug 0 round_16_80 {{ round_16_80_spec }}) true;
run_test "round_16_80 |= round_16_80_spec" (mr_solver round_16_80 {{ round_16_80_spec }}) true;
10 changes: 8 additions & 2 deletions src/SAWScript/Prover/MRSolver/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
Expand Down Expand Up @@ -64,7 +65,7 @@ data FailCtx

-- | That's MR. Failure to you
data MRFailure
= TermsNotEq Term Term
= TermsNotRel Bool Term Term
| TypesNotEq Type Type
| CompsDoNotRefine NormComp NormComp
| ReturnNotError Term
Expand All @@ -86,6 +87,9 @@ data MRFailure
| MRFailureDisj MRFailure MRFailure
deriving Show

pattern TermsNotEq :: Term -> Term -> MRFailure
pattern TermsNotEq t1 t2 = TermsNotRel False t1 t2

-- | Pretty-print an object prefixed with a 'String' that describes it
ppWithPrefix :: PrettyInCtx a => String -> a -> PPInCtxM SawDoc
ppWithPrefix str a = (pretty str <>) <$> nest 2 <$> (line <>) <$> prettyInCtx a
Expand All @@ -111,8 +115,10 @@ instance PrettyInCtx FailCtx where
prettyInCtx t]

instance PrettyInCtx MRFailure where
prettyInCtx (TermsNotEq t1 t2) =
prettyInCtx (TermsNotRel False t1 t2) =
ppWithPrefixSep "Could not prove terms equal:" t1 "and" t2
prettyInCtx (TermsNotRel True t1 t2) =
ppWithPrefixSep "Could not prove terms heterogeneously related:" t1 "and" t2
prettyInCtx (TypesNotEq tp1 tp2) =
ppWithPrefixSep "Types not equal:" tp1 "and" tp2
prettyInCtx (CompsDoNotRefine m1 m2) =
Expand Down
229 changes: 159 additions & 70 deletions src/SAWScript/Prover/MRSolver/SMT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -196,32 +196,48 @@ readBackValueNoConfig err_str sc tv v =
-- | Implementations of primitives for normalizing Mr Solver terms
smtNormPrims :: SharedContext -> Map Ident TmPrim
smtNormPrims sc = Map.fromList
[
[ -- Don't unfold @genBVVec@ when normalizing
("Prelude.genBVVec",
Prim (do tp <- scTypeOfGlobal sc "Prelude.genBVVec"
VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$>
scGlobalDef sc "Prelude.genBVVec")
),
-- Normalize applications of @genBVVecFromVec@ to a @genFromBVVec@ term or
-- a vector literal into the body of the @genFromBVVec@ term or @genBVVec@
-- of an sequence of @ite@s defined by the literal, respectively
("Prelude.genBVVecFromVec",
natFun $ \_m -> tvalFun $ \a -> primFromBVVecOrLit sc a $ \eith ->
PrimFun $ \_def -> natFun $ \n -> primBVTermFun sc $ \len ->
Prim (do n' <- scNat sc n
a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)"
sc a
a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)" sc a
tp <- scGlobalApply sc "Prelude.BVVec" [n', len, a']
VExtra <$> VExtraTerm (VTyTerm (mkSort 0) tp) <$>
bvVecFromBVVecOrLit sc n n' len a' eith)
),
-- Don't normalize applications of @genFromBVVec@
("Prelude.genFromBVVec",
Prim (do tp <- scTypeOfGlobal sc "Prelude.genFromBVVec"
VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$>
scGlobalDef sc "Prelude.genFromBVVec")
natFun $ \n -> PrimStrict $ \len -> tvalFun $ \a -> PrimStrict $ \v ->
PrimStrict $ \def -> natFun $ \m ->
Prim (do n' <- scNat sc n
let len_tp = VVecType n VBoolType
len' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc len_tp len
a' <- readBackTValueNoConfig "smtNormPrims (genFromBVVec)" sc a
bvToNat_len <- scGlobalApply sc "Prelude.bvToNat" [n', len']
v_tp <- VTyTerm (mkSort 0) <$> scVecType sc bvToNat_len a'
v' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc v_tp v
def' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc a def
m' <- scNat sc m
tm <- scGlobalApply sc "Prelude.genFromBVVec" [n', len', a', v', def', m']
return $ VExtra $ VExtraTerm (VVecType m a) tm)
),
-- Normalize applications of @atBVVec@ to a @genBVVec@ term into an
-- application of the body of the @genBVVec@ term to the index
("Prelude.atBVVec",
PrimFun $ \_n -> PrimFun $ \_len -> tvalFun $ \a ->
primGenBVVec sc $ \f -> primBVTermFun sc $ \ix -> PrimFun $ \_pf ->
Prim (VExtra <$> VExtraTerm a <$> scApply sc f ix)
Prim (VExtra <$> VExtraTerm a <$> scApplyBeta sc f ix)
),
-- Don't normalize applications of @CompM@
("Prelude.CompM",
PrimFilterFun "CompM" (\case
TValue tv -> return tv
Expand Down Expand Up @@ -384,100 +400,173 @@ mrProveEqSimple eqf t1 t2 =
TermInCtx [] <$> eqf t1' t2'

-- | Prove that two terms are equal, instantiating evars if necessary,
-- returning true on success
-- returning true on success - the same as @mrProveRel False@
mrProveEq :: Term -> Term -> MRM Bool
mrProveEq t1 t2 =
do mrDebugPPPrefixSep 1 "mrProveEq" t1 "==" t2
tp <- mrTypeOf t1 >>= mrSubstEVars
varmap <- mrVars
cond_in_ctx <- mrProveEqH varmap tp t1 t2
res <- withTermInCtx cond_in_ctx mrProvable
debugPrint 1 $ "mrProveEq: " ++ if res then "Success" else "Failure"
return res
mrProveEq = mrProveRel False

-- | Prove that two terms are equal, instantiating evars if necessary, or
-- throwing an error if this is not possible
-- throwing an error if this is not possible - the same as
-- @mrAssertProveRel False@
mrAssertProveEq :: Term -> Term -> MRM ()
mrAssertProveEq t1 t2 =
do success <- mrProveEq t1 t2
if success then return () else
throwMRFailure (TermsNotEq t1 t2)

-- | The main workhorse for 'mrProveEq'. Build a Boolean term expressing that
-- the third and fourth arguments, whose type is given by the second.
mrProveEqH :: Map MRVar MRVarInfo -> Term -> Term -> Term -> MRM TermInCtx
mrAssertProveEq = mrAssertProveRel False

-- | Prove that two terms are related, heterogeneously iff the first argument
-- is true, instantiating evars if necessary, returning true on success
mrProveRel :: Bool -> Term -> Term -> MRM Bool
mrProveRel het t1 t2 =
do let nm = if het then "mrProveRel" else "mrProveEq"
mrDebugPPPrefixSep 1 nm t1 (if het then "~=" else "==") t2
tp1 <- mrTypeOf t1 >>= mrSubstEVars
tp2 <- mrTypeOf t2 >>= mrSubstEVars
cond_in_ctx <- mrProveRelH het tp1 tp2 t1 t2
res <- withTermInCtx cond_in_ctx mrProvable
debugPrint 1 $ nm ++ ": " ++ if res then "Success" else "Failure"
return res

{-
mrProveEqH _ _ t1 t2
| trace ("mrProveEqH:\n" ++ showTerm t1 ++ "\n==\n" ++ showTerm t2) False = undefined
-}
-- | Prove that two terms are related, heterogeneously iff the first argument,
-- is true, instantiating evars if necessary, or throwing an error if this is
-- not possible
mrAssertProveRel :: Bool -> Term -> Term -> MRM ()
mrAssertProveRel het t1 t2 =
do success <- mrProveRel het t1 t2
if success then return () else
throwMRFailure (TermsNotRel het t1 t2)

-- | The main workhorse for 'mrProveEq' and 'mrProveRel'. Build a Boolean term
-- expressing that the fourth and fifth arguments are related, heterogeneously
-- iff the first argument is true, whose types are given by the second and
-- third arguments, respectively
mrProveRelH :: Bool -> Term -> Term -> Term -> Term -> MRM TermInCtx
mrProveRelH het tp1 tp2 t1 t2 =
do varmap <- mrVars
tp1' <- liftSC1 scWhnf tp1
tp2' <- liftSC1 scWhnf tp2
mrProveRelH' varmap het tp1' tp2' t1 t2

-- | The body of 'mrProveRelH'
-- NOTE: Don't call this function recursively, call 'mrProveRelH'
mrProveRelH' :: Map MRVar MRVarInfo -> Bool ->
Term -> Term -> Term -> Term -> MRM TermInCtx

-- If t1 is an instantiated evar, substitute and recurse
mrProveEqH var_map tp (asEVarApp var_map -> Just (_, args, Just f)) t2 =
mrApplyAll f args >>= \t1' -> mrProveEqH var_map tp t1' t2

-- If t1 is an uninstantiated evar, instantiate it with t2
mrProveEqH var_map _tp (asEVarApp var_map -> Just (evar, args, Nothing)) t2 =
do t2' <- mrSubstEVars t2
mrProveRelH' var_map het tp1 tp2 (asEVarApp var_map -> Just (_, args, Just f)) t2 =
mrApplyAll f args >>= \t1' -> mrProveRelH het tp1 tp2 t1' t2

-- If t1 is an uninstantiated evar, ensure the types are equal and instantiate
-- it with t2
mrProveRelH' var_map _ tp1 tp2 (asEVarApp var_map -> Just (evar, args, Nothing)) t2 =
do tps_are_eq <- mrConvertible tp1 tp2
if tps_are_eq then return () else
throwMRFailure (TypesNotEq (Type tp1) (Type tp2))
t2' <- mrSubstEVars t2
success <- mrTrySetAppliedEVar evar args t2'
when success $
mrDebugPPPrefixSep 2 "mrProveRelH setting evar" evar "to" t2
TermInCtx [] <$> liftSC1 scBool success

-- If t2 is an instantiated evar, substitute and recurse
mrProveEqH var_map tp t1 (asEVarApp var_map -> Just (_, args, Just f)) =
mrApplyAll f args >>= \t2' -> mrProveEqH var_map tp t1 t2'

-- If t2 is an uninstantiated evar, instantiate it with t1
mrProveEqH var_map _tp t1 (asEVarApp var_map -> Just (evar, args, Nothing)) =
do t1' <- mrSubstEVars t1
mrProveRelH' var_map het tp1 tp2 t1 (asEVarApp var_map -> Just (_, args, Just f)) =
mrApplyAll f args >>= \t2' -> mrProveRelH het tp1 tp2 t1 t2'

-- If t2 is an uninstantiated evar, ensure the types are equal and instantiate
-- it with t1
mrProveRelH' var_map _ tp1 tp2 t1 (asEVarApp var_map -> Just (evar, args, Nothing)) =
do tps_are_eq <- mrConvertible tp1 tp2
if tps_are_eq then return () else
throwMRFailure (TypesNotEq (Type tp1) (Type tp2))
t1' <- mrSubstEVars t1
success <- mrTrySetAppliedEVar evar args t1'
when success $
mrDebugPPPrefixSep 2 "mrProveRelH setting evar" evar "to" t1
TermInCtx [] <$> liftSC1 scBool success

-- For unit types, always return true
mrProveEqH _ (asTupleType -> Just []) _ _ =
mrProveRelH' _ _ (asTupleType -> Just []) (asTupleType -> Just []) _ _ =
TermInCtx [] <$> liftSC1 scBool True

-- For the nat, bitvector, Boolean, and integer types, call mrProveEqSimple
mrProveEqH _ (asDataType -> Just (pn, [])) t1 t2
| primName pn == "Prelude.Nat" =
mrProveEqSimple (liftSC2 scEqualNat) t1 t2
mrProveEqH _ (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 =
-- FIXME: make a better solver for bitvector equalities
mrProveEqSimple (liftSC3 scBvEq n) t1 t2
mrProveEqH _ (asBoolType -> Just _) t1 t2 =
-- For nat, bitvector, Boolean, and integer types, call mrProveEqSimple
mrProveRelH' _ _ (asNatType -> Just _) (asNatType -> Just _) t1 t2 =
mrProveEqSimple (liftSC2 scEqualNat) t1 t2
mrProveRelH' _ _ tp1@(asVectorType -> Just (n1, asBoolType -> Just ()))
tp2@(asVectorType -> Just (n2, asBoolType -> Just ())) t1 t2 =
do ns_are_eq <- mrConvertible n1 n2
if ns_are_eq then return () else
throwMRFailure (TypesNotEq (Type tp1) (Type tp2))
mrProveEqSimple (liftSC3 scBvEq n1) t1 t2
mrProveRelH' _ _ (asBoolType -> Just _) (asBoolType -> Just _) t1 t2 =
mrProveEqSimple (liftSC2 scBoolEq) t1 t2
mrProveEqH _ (asIntegerType -> Just _) t1 t2 =
mrProveRelH' _ _ (asIntegerType -> Just _) (asIntegerType -> Just _) t1 t2 =
mrProveEqSimple (liftSC2 scIntEq) t1 t2

-- For pair types, prove both the left and right projections are equal
mrProveEqH var_map (asPairType -> Just (tpL, tpR)) t1 t2 =
-- For pair types, prove both the left and right projections are related
mrProveRelH' _ het (asPairType -> Just (tpL1, tpR1))
(asPairType -> Just (tpL2, tpR2)) t1 t2 =
do t1L <- liftSC1 scPairLeft t1
t2L <- liftSC1 scPairLeft t2
t1R <- liftSC1 scPairRight t1
t2R <- liftSC1 scPairRight t2
condL <- mrProveEqH var_map tpL t1L t2L
condR <- mrProveEqH var_map tpR t1R t2R
condL <- mrProveRelH het tpL1 tpL2 t1L t2L
condR <- mrProveRelH het tpR1 tpR2 t1R t2R
andTermInCtx condL condR

-- For non-bitvector vector types, prove all projections are equal by
-- quantifying over a universal index variable and proving equality at that
-- index
mrProveEqH _ (asBVVecType -> Just (n, len, tp)) t1 t2 =
-- For BVVec types, prove all projections are related by quantifying over an
-- index variable and proving the projections at that index are related
mrProveRelH' _ het tp1@(asBVVecType -> Just (n1, len1, tpA1))
tp2@(asBVVecType -> Just (n2, len2, tpA2)) t1 t2 =
mrConvertible n1 n2 >>= \ns_are_eq ->
mrConvertible len1 len2 >>= \lens_are_eq ->
(if ns_are_eq && lens_are_eq then return () else
throwMRFailure (TypesNotEq (Type tp1) (Type tp2))) >>
liftSC0 scBoolType >>= \bool_tp ->
liftSC2 scVecType n bool_tp >>= \ix_tp ->
withUVarLift "eq_ix" (Type ix_tp) (n,(len,(tp,(t1,t2)))) $
\ix' (n',(len',(tp',(t1',t2')))) ->
liftSC2 scGlobalApply "Prelude.is_bvult" [n', ix', len'] >>= \pf_tp ->
withUVarLift "eq_pf" (Type pf_tp) (n',(len',(tp',(ix',(t1',t2'))))) $
\pf'' (n'',(len'',(tp'',(ix'',(t1'',t2''))))) ->
do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'',
liftSC2 scVecType n1 bool_tp >>= \ix_tp ->
withUVarLift "eq_ix" (Type ix_tp) (n1,(len1,(tpA1,(tpA2,(t1,t2))))) $
\ix' (n1',(len1',(tpA1',(tpA2',(t1',t2'))))) ->
liftSC2 scGlobalApply "Prelude.is_bvult" [n1', ix', len1'] >>= \pf_tp ->
withUVarLift "eq_pf" (Type pf_tp) (n1',(len1',(tpA1',(tpA2',(ix',(t1',t2')))))) $
\pf'' (n1'',(len1'',(tpA1'',(tpA2'',(ix'',(t1'',t2'')))))) ->
do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1'', len1'', tpA1'',
t1'', ix'', pf'']
t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'',
t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1'', len1'', tpA2'',
t2'', ix'', pf'']
var_map <- mrVars
extTermInCtx [("eq_ix",ix_tp),("eq_pf",pf_tp)] <$>
mrProveEqH var_map tp'' t1_prj t2_prj
mrProveRelH het tpA1'' tpA2'' t1_prj t2_prj

-- If our relation is heterogeneous and we have a BVVec on one side and a
-- non-BVVec vector on the other, wrap the non-BVVec vector term in
-- genBVVecFromVec and recurse
mrProveRelH' _ True tp1@(asBVVecType -> Just (n, len, _))
tp2@(asNonBVVecVectorType -> Just (m, tpA2)) t1 t2 =
do m' <- mrBvToNat n len
ms_are_eq <- mrConvertible m' m
if ms_are_eq then return () else
throwMRFailure (TypesNotEq (Type tp1) (Type tp2))
len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len]
tp2' <- liftSC2 scVecType len' tpA2
err_str_tm <- liftSC1 scString "FIXME: mrProveRelH error"
err_tm <- liftSC2 scGlobalApply "Prelude.error" [tpA2, err_str_tm]
t2' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec"
[m, tpA2, t2, err_tm, n, len]
-- mrDebugPPPrefixSep 2 "mrProveRelH on BVVec/Vec: " t1 "and" t2'
mrProveRelH True tp1 tp2' t1 t2'
mrProveRelH' _ True tp1@(asNonBVVecVectorType -> Just (m, tpA1))
tp2@(asBVVecType -> Just (n, len, _)) t1 t2 =
do m' <- mrBvToNat n len
ms_are_eq <- mrConvertible m' m
if ms_are_eq then return () else
throwMRFailure (TypesNotEq (Type tp1) (Type tp2))
len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len]
tp1' <- liftSC2 scVecType len' tpA1
err_str_tm <- liftSC1 scString "FIXME: mrProveRelH error"
err_tm <- liftSC2 scGlobalApply "Prelude.error" [tpA1, err_str_tm]
t1' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec"
[m, tpA1, t1, err_tm, n, len]
-- mrDebugPPPrefixSep 2 "mrProveRelH on Vec/BVVec: " t1' "and" t2
mrProveRelH True tp1' tp2 t1' t2

-- As a fallback, for types we can't handle, just check convertibility
mrProveEqH _ _ t1 t2 =
mrProveRelH' _ _ tp1 tp2 t1 t2 =
do success <- mrConvertible t1 t2
if success then return () else
mrDebugPPPrefixSep 2 "mrProveRelH could not match types: " tp1 "and" tp2 >>
mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2
TermInCtx [] <$> liftSC1 scBool success
2 changes: 1 addition & 1 deletion src/SAWScript/Prover/MRSolver/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ mrRefines t1 t2 =
-- | The main implementation of 'mrRefines'
mrRefines' :: NormComp -> NormComp -> MRM ()

mrRefines' (ReturnM e1) (ReturnM e2) = mrAssertProveEq e1 e2
mrRefines' (ReturnM e1) (ReturnM e2) = mrAssertProveRel True e1 e2
mrRefines' (ErrorM _) (ErrorM _) = return ()
mrRefines' (ReturnM e) (ErrorM _) = throwMRFailure (ReturnNotError e)
mrRefines' (ErrorM _) (ReturnM e) = throwMRFailure (ReturnNotError e)
Expand Down
8 changes: 8 additions & 0 deletions src/SAWScript/Prover/MRSolver/Term.hs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ instance PrettyInCtx MRVar where
instance PrettyInCtx [Term] where
prettyInCtx xs = list <$> mapM prettyInCtx xs

instance PrettyInCtx a => PrettyInCtx (Maybe a) where
prettyInCtx (Just x) = (<+>) "Just" <$> prettyInCtx x
prettyInCtx Nothing = return "Nothing"

instance (PrettyInCtx a, PrettyInCtx b) => PrettyInCtx (a,b) where
prettyInCtx (x, y) = (\x' y' -> parens (x' <> "," <> y')) <$> prettyInCtx x
<*> prettyInCtx y

instance PrettyInCtx TermProj where
prettyInCtx TermProjLeft = return (pretty '.' <> "1")
prettyInCtx TermProjRight = return (pretty '.' <> "2")
Expand Down

0 comments on commit 45de464

Please sign in to comment.