From 45de4641368684a608e913ce79ec47ca7d4617a7 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Wed, 18 May 2022 17:17:28 -0700 Subject: [PATCH] add heterogenous equality on return types in MRSolver --- heapster-saw/examples/sha512_mr_solver.saw | 4 +- src/SAWScript/Prover/MRSolver/Monad.hs | 10 +- src/SAWScript/Prover/MRSolver/SMT.hs | 229 ++++++++++++++------- src/SAWScript/Prover/MRSolver/Solver.hs | 2 +- src/SAWScript/Prover/MRSolver/Term.hs | 8 + 5 files changed, 177 insertions(+), 76 deletions(-) diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index 289956d90d..372a3f0731 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -106,6 +106,4 @@ monadify_term {{ Maj }}; monadify_term {{ round_00_15_spec }}; run_test "round_00_15 |= round_00_15_spec" (mr_solver round_00_15 {{ round_00_15_spec }}) true; - -// FIXME: Need to add heterogenous equality on output types for this to work -// run_test "round_16_80 |= round_16_80_spec" (mr_solver_debug 0 round_16_80 {{ round_16_80_spec }}) true; +run_test "round_16_80 |= round_16_80_spec" (mr_solver round_16_80 {{ round_16_80_spec }}) true; diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index d4b5a5bde7..bb5a5b9148 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -4,6 +4,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DerivingStrategies #-} @@ -64,7 +65,7 @@ data FailCtx -- | That's MR. Failure to you data MRFailure - = TermsNotEq Term Term + = TermsNotRel Bool Term Term | TypesNotEq Type Type | CompsDoNotRefine NormComp NormComp | ReturnNotError Term @@ -86,6 +87,9 @@ data MRFailure | MRFailureDisj MRFailure MRFailure deriving Show +pattern TermsNotEq :: Term -> Term -> MRFailure +pattern TermsNotEq t1 t2 = TermsNotRel False t1 t2 + -- | Pretty-print an object prefixed with a 'String' that describes it ppWithPrefix :: PrettyInCtx a => String -> a -> PPInCtxM SawDoc ppWithPrefix str a = (pretty str <>) <$> nest 2 <$> (line <>) <$> prettyInCtx a @@ -111,8 +115,10 @@ instance PrettyInCtx FailCtx where prettyInCtx t] instance PrettyInCtx MRFailure where - prettyInCtx (TermsNotEq t1 t2) = + prettyInCtx (TermsNotRel False t1 t2) = ppWithPrefixSep "Could not prove terms equal:" t1 "and" t2 + prettyInCtx (TermsNotRel True t1 t2) = + ppWithPrefixSep "Could not prove terms heterogeneously related:" t1 "and" t2 prettyInCtx (TypesNotEq tp1 tp2) = ppWithPrefixSep "Types not equal:" tp1 "and" tp2 prettyInCtx (CompsDoNotRefine m1 m2) = diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 02dd4697a7..70bf769020 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -196,32 +196,48 @@ readBackValueNoConfig err_str sc tv v = -- | Implementations of primitives for normalizing Mr Solver terms smtNormPrims :: SharedContext -> Map Ident TmPrim smtNormPrims sc = Map.fromList - [ + [ -- Don't unfold @genBVVec@ when normalizing ("Prelude.genBVVec", Prim (do tp <- scTypeOfGlobal sc "Prelude.genBVVec" VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> scGlobalDef sc "Prelude.genBVVec") ), + -- Normalize applications of @genBVVecFromVec@ to a @genFromBVVec@ term or + -- a vector literal into the body of the @genFromBVVec@ term or @genBVVec@ + -- of an sequence of @ite@s defined by the literal, respectively ("Prelude.genBVVecFromVec", natFun $ \_m -> tvalFun $ \a -> primFromBVVecOrLit sc a $ \eith -> PrimFun $ \_def -> natFun $ \n -> primBVTermFun sc $ \len -> Prim (do n' <- scNat sc n - a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)" - sc a + a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)" sc a tp <- scGlobalApply sc "Prelude.BVVec" [n', len, a'] VExtra <$> VExtraTerm (VTyTerm (mkSort 0) tp) <$> bvVecFromBVVecOrLit sc n n' len a' eith) ), + -- Don't normalize applications of @genFromBVVec@ ("Prelude.genFromBVVec", - Prim (do tp <- scTypeOfGlobal sc "Prelude.genFromBVVec" - VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> - scGlobalDef sc "Prelude.genFromBVVec") + natFun $ \n -> PrimStrict $ \len -> tvalFun $ \a -> PrimStrict $ \v -> + PrimStrict $ \def -> natFun $ \m -> + Prim (do n' <- scNat sc n + let len_tp = VVecType n VBoolType + len' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc len_tp len + a' <- readBackTValueNoConfig "smtNormPrims (genFromBVVec)" sc a + bvToNat_len <- scGlobalApply sc "Prelude.bvToNat" [n', len'] + v_tp <- VTyTerm (mkSort 0) <$> scVecType sc bvToNat_len a' + v' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc v_tp v + def' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc a def + m' <- scNat sc m + tm <- scGlobalApply sc "Prelude.genFromBVVec" [n', len', a', v', def', m'] + return $ VExtra $ VExtraTerm (VVecType m a) tm) ), + -- Normalize applications of @atBVVec@ to a @genBVVec@ term into an + -- application of the body of the @genBVVec@ term to the index ("Prelude.atBVVec", PrimFun $ \_n -> PrimFun $ \_len -> tvalFun $ \a -> primGenBVVec sc $ \f -> primBVTermFun sc $ \ix -> PrimFun $ \_pf -> - Prim (VExtra <$> VExtraTerm a <$> scApply sc f ix) + Prim (VExtra <$> VExtraTerm a <$> scApplyBeta sc f ix) ), + -- Don't normalize applications of @CompM@ ("Prelude.CompM", PrimFilterFun "CompM" (\case TValue tv -> return tv @@ -384,100 +400,173 @@ mrProveEqSimple eqf t1 t2 = TermInCtx [] <$> eqf t1' t2' -- | Prove that two terms are equal, instantiating evars if necessary, --- returning true on success +-- returning true on success - the same as @mrProveRel False@ mrProveEq :: Term -> Term -> MRM Bool -mrProveEq t1 t2 = - do mrDebugPPPrefixSep 1 "mrProveEq" t1 "==" t2 - tp <- mrTypeOf t1 >>= mrSubstEVars - varmap <- mrVars - cond_in_ctx <- mrProveEqH varmap tp t1 t2 - res <- withTermInCtx cond_in_ctx mrProvable - debugPrint 1 $ "mrProveEq: " ++ if res then "Success" else "Failure" - return res +mrProveEq = mrProveRel False -- | Prove that two terms are equal, instantiating evars if necessary, or --- throwing an error if this is not possible +-- throwing an error if this is not possible - the same as +-- @mrAssertProveRel False@ mrAssertProveEq :: Term -> Term -> MRM () -mrAssertProveEq t1 t2 = - do success <- mrProveEq t1 t2 - if success then return () else - throwMRFailure (TermsNotEq t1 t2) - --- | The main workhorse for 'mrProveEq'. Build a Boolean term expressing that --- the third and fourth arguments, whose type is given by the second. -mrProveEqH :: Map MRVar MRVarInfo -> Term -> Term -> Term -> MRM TermInCtx +mrAssertProveEq = mrAssertProveRel False + +-- | Prove that two terms are related, heterogeneously iff the first argument +-- is true, instantiating evars if necessary, returning true on success +mrProveRel :: Bool -> Term -> Term -> MRM Bool +mrProveRel het t1 t2 = + do let nm = if het then "mrProveRel" else "mrProveEq" + mrDebugPPPrefixSep 1 nm t1 (if het then "~=" else "==") t2 + tp1 <- mrTypeOf t1 >>= mrSubstEVars + tp2 <- mrTypeOf t2 >>= mrSubstEVars + cond_in_ctx <- mrProveRelH het tp1 tp2 t1 t2 + res <- withTermInCtx cond_in_ctx mrProvable + debugPrint 1 $ nm ++ ": " ++ if res then "Success" else "Failure" + return res -{- -mrProveEqH _ _ t1 t2 - | trace ("mrProveEqH:\n" ++ showTerm t1 ++ "\n==\n" ++ showTerm t2) False = undefined --} +-- | Prove that two terms are related, heterogeneously iff the first argument, +-- is true, instantiating evars if necessary, or throwing an error if this is +-- not possible +mrAssertProveRel :: Bool -> Term -> Term -> MRM () +mrAssertProveRel het t1 t2 = + do success <- mrProveRel het t1 t2 + if success then return () else + throwMRFailure (TermsNotRel het t1 t2) + +-- | The main workhorse for 'mrProveEq' and 'mrProveRel'. Build a Boolean term +-- expressing that the fourth and fifth arguments are related, heterogeneously +-- iff the first argument is true, whose types are given by the second and +-- third arguments, respectively +mrProveRelH :: Bool -> Term -> Term -> Term -> Term -> MRM TermInCtx +mrProveRelH het tp1 tp2 t1 t2 = + do varmap <- mrVars + tp1' <- liftSC1 scWhnf tp1 + tp2' <- liftSC1 scWhnf tp2 + mrProveRelH' varmap het tp1' tp2' t1 t2 + +-- | The body of 'mrProveRelH' +-- NOTE: Don't call this function recursively, call 'mrProveRelH' +mrProveRelH' :: Map MRVar MRVarInfo -> Bool -> + Term -> Term -> Term -> Term -> MRM TermInCtx -- If t1 is an instantiated evar, substitute and recurse -mrProveEqH var_map tp (asEVarApp var_map -> Just (_, args, Just f)) t2 = - mrApplyAll f args >>= \t1' -> mrProveEqH var_map tp t1' t2 - --- If t1 is an uninstantiated evar, instantiate it with t2 -mrProveEqH var_map _tp (asEVarApp var_map -> Just (evar, args, Nothing)) t2 = - do t2' <- mrSubstEVars t2 +mrProveRelH' var_map het tp1 tp2 (asEVarApp var_map -> Just (_, args, Just f)) t2 = + mrApplyAll f args >>= \t1' -> mrProveRelH het tp1 tp2 t1' t2 + +-- If t1 is an uninstantiated evar, ensure the types are equal and instantiate +-- it with t2 +mrProveRelH' var_map _ tp1 tp2 (asEVarApp var_map -> Just (evar, args, Nothing)) t2 = + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + t2' <- mrSubstEVars t2 success <- mrTrySetAppliedEVar evar args t2' + when success $ + mrDebugPPPrefixSep 2 "mrProveRelH setting evar" evar "to" t2 TermInCtx [] <$> liftSC1 scBool success -- If t2 is an instantiated evar, substitute and recurse -mrProveEqH var_map tp t1 (asEVarApp var_map -> Just (_, args, Just f)) = - mrApplyAll f args >>= \t2' -> mrProveEqH var_map tp t1 t2' - --- If t2 is an uninstantiated evar, instantiate it with t1 -mrProveEqH var_map _tp t1 (asEVarApp var_map -> Just (evar, args, Nothing)) = - do t1' <- mrSubstEVars t1 +mrProveRelH' var_map het tp1 tp2 t1 (asEVarApp var_map -> Just (_, args, Just f)) = + mrApplyAll f args >>= \t2' -> mrProveRelH het tp1 tp2 t1 t2' + +-- If t2 is an uninstantiated evar, ensure the types are equal and instantiate +-- it with t1 +mrProveRelH' var_map _ tp1 tp2 t1 (asEVarApp var_map -> Just (evar, args, Nothing)) = + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + t1' <- mrSubstEVars t1 success <- mrTrySetAppliedEVar evar args t1' + when success $ + mrDebugPPPrefixSep 2 "mrProveRelH setting evar" evar "to" t1 TermInCtx [] <$> liftSC1 scBool success -- For unit types, always return true -mrProveEqH _ (asTupleType -> Just []) _ _ = +mrProveRelH' _ _ (asTupleType -> Just []) (asTupleType -> Just []) _ _ = TermInCtx [] <$> liftSC1 scBool True --- For the nat, bitvector, Boolean, and integer types, call mrProveEqSimple -mrProveEqH _ (asDataType -> Just (pn, [])) t1 t2 - | primName pn == "Prelude.Nat" = - mrProveEqSimple (liftSC2 scEqualNat) t1 t2 -mrProveEqH _ (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = - -- FIXME: make a better solver for bitvector equalities - mrProveEqSimple (liftSC3 scBvEq n) t1 t2 -mrProveEqH _ (asBoolType -> Just _) t1 t2 = +-- For nat, bitvector, Boolean, and integer types, call mrProveEqSimple +mrProveRelH' _ _ (asNatType -> Just _) (asNatType -> Just _) t1 t2 = + mrProveEqSimple (liftSC2 scEqualNat) t1 t2 +mrProveRelH' _ _ tp1@(asVectorType -> Just (n1, asBoolType -> Just ())) + tp2@(asVectorType -> Just (n2, asBoolType -> Just ())) t1 t2 = + do ns_are_eq <- mrConvertible n1 n2 + if ns_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + mrProveEqSimple (liftSC3 scBvEq n1) t1 t2 +mrProveRelH' _ _ (asBoolType -> Just _) (asBoolType -> Just _) t1 t2 = mrProveEqSimple (liftSC2 scBoolEq) t1 t2 -mrProveEqH _ (asIntegerType -> Just _) t1 t2 = +mrProveRelH' _ _ (asIntegerType -> Just _) (asIntegerType -> Just _) t1 t2 = mrProveEqSimple (liftSC2 scIntEq) t1 t2 --- For pair types, prove both the left and right projections are equal -mrProveEqH var_map (asPairType -> Just (tpL, tpR)) t1 t2 = +-- For pair types, prove both the left and right projections are related +mrProveRelH' _ het (asPairType -> Just (tpL1, tpR1)) + (asPairType -> Just (tpL2, tpR2)) t1 t2 = do t1L <- liftSC1 scPairLeft t1 t2L <- liftSC1 scPairLeft t2 t1R <- liftSC1 scPairRight t1 t2R <- liftSC1 scPairRight t2 - condL <- mrProveEqH var_map tpL t1L t2L - condR <- mrProveEqH var_map tpR t1R t2R + condL <- mrProveRelH het tpL1 tpL2 t1L t2L + condR <- mrProveRelH het tpR1 tpR2 t1R t2R andTermInCtx condL condR --- For non-bitvector vector types, prove all projections are equal by --- quantifying over a universal index variable and proving equality at that --- index -mrProveEqH _ (asBVVecType -> Just (n, len, tp)) t1 t2 = +-- For BVVec types, prove all projections are related by quantifying over an +-- index variable and proving the projections at that index are related +mrProveRelH' _ het tp1@(asBVVecType -> Just (n1, len1, tpA1)) + tp2@(asBVVecType -> Just (n2, len2, tpA2)) t1 t2 = + mrConvertible n1 n2 >>= \ns_are_eq -> + mrConvertible len1 len2 >>= \lens_are_eq -> + (if ns_are_eq && lens_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2))) >> liftSC0 scBoolType >>= \bool_tp -> - liftSC2 scVecType n bool_tp >>= \ix_tp -> - withUVarLift "eq_ix" (Type ix_tp) (n,(len,(tp,(t1,t2)))) $ - \ix' (n',(len',(tp',(t1',t2')))) -> - liftSC2 scGlobalApply "Prelude.is_bvult" [n', ix', len'] >>= \pf_tp -> - withUVarLift "eq_pf" (Type pf_tp) (n',(len',(tp',(ix',(t1',t2'))))) $ - \pf'' (n'',(len'',(tp'',(ix'',(t1'',t2''))))) -> - do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', + liftSC2 scVecType n1 bool_tp >>= \ix_tp -> + withUVarLift "eq_ix" (Type ix_tp) (n1,(len1,(tpA1,(tpA2,(t1,t2))))) $ + \ix' (n1',(len1',(tpA1',(tpA2',(t1',t2'))))) -> + liftSC2 scGlobalApply "Prelude.is_bvult" [n1', ix', len1'] >>= \pf_tp -> + withUVarLift "eq_pf" (Type pf_tp) (n1',(len1',(tpA1',(tpA2',(ix',(t1',t2')))))) $ + \pf'' (n1'',(len1'',(tpA1'',(tpA2'',(ix'',(t1'',t2'')))))) -> + do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1'', len1'', tpA1'', t1'', ix'', pf''] - t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', + t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1'', len1'', tpA2'', t2'', ix'', pf''] - var_map <- mrVars extTermInCtx [("eq_ix",ix_tp),("eq_pf",pf_tp)] <$> - mrProveEqH var_map tp'' t1_prj t2_prj + mrProveRelH het tpA1'' tpA2'' t1_prj t2_prj + +-- If our relation is heterogeneous and we have a BVVec on one side and a +-- non-BVVec vector on the other, wrap the non-BVVec vector term in +-- genBVVecFromVec and recurse +mrProveRelH' _ True tp1@(asBVVecType -> Just (n, len, _)) + tp2@(asNonBVVecVectorType -> Just (m, tpA2)) t1 t2 = + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m' m + if ms_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp2' <- liftSC2 scVecType len' tpA2 + err_str_tm <- liftSC1 scString "FIXME: mrProveRelH error" + err_tm <- liftSC2 scGlobalApply "Prelude.error" [tpA2, err_str_tm] + t2' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" + [m, tpA2, t2, err_tm, n, len] + -- mrDebugPPPrefixSep 2 "mrProveRelH on BVVec/Vec: " t1 "and" t2' + mrProveRelH True tp1 tp2' t1 t2' +mrProveRelH' _ True tp1@(asNonBVVecVectorType -> Just (m, tpA1)) + tp2@(asBVVecType -> Just (n, len, _)) t1 t2 = + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m' m + if ms_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp1' <- liftSC2 scVecType len' tpA1 + err_str_tm <- liftSC1 scString "FIXME: mrProveRelH error" + err_tm <- liftSC2 scGlobalApply "Prelude.error" [tpA1, err_str_tm] + t1' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" + [m, tpA1, t1, err_tm, n, len] + -- mrDebugPPPrefixSep 2 "mrProveRelH on Vec/BVVec: " t1' "and" t2 + mrProveRelH True tp1' tp2 t1' t2 -- As a fallback, for types we can't handle, just check convertibility -mrProveEqH _ _ t1 t2 = +mrProveRelH' _ _ tp1 tp2 t1 t2 = do success <- mrConvertible t1 t2 + if success then return () else + mrDebugPPPrefixSep 2 "mrProveRelH could not match types: " tp1 "and" tp2 >> + mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2 TermInCtx [] <$> liftSC1 scBool success diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index ecbb732c40..74b65cae1b 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -623,7 +623,7 @@ mrRefines t1 t2 = -- | The main implementation of 'mrRefines' mrRefines' :: NormComp -> NormComp -> MRM () -mrRefines' (ReturnM e1) (ReturnM e2) = mrAssertProveEq e1 e2 +mrRefines' (ReturnM e1) (ReturnM e2) = mrAssertProveRel True e1 e2 mrRefines' (ErrorM _) (ErrorM _) = return () mrRefines' (ReturnM e) (ErrorM _) = throwMRFailure (ReturnNotError e) mrRefines' (ErrorM _) (ReturnM e) = throwMRFailure (ReturnNotError e) diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index 03f2ee5a97..10f958f67b 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -433,6 +433,14 @@ instance PrettyInCtx MRVar where instance PrettyInCtx [Term] where prettyInCtx xs = list <$> mapM prettyInCtx xs +instance PrettyInCtx a => PrettyInCtx (Maybe a) where + prettyInCtx (Just x) = (<+>) "Just" <$> prettyInCtx x + prettyInCtx Nothing = return "Nothing" + +instance (PrettyInCtx a, PrettyInCtx b) => PrettyInCtx (a,b) where + prettyInCtx (x, y) = (\x' y' -> parens (x' <> "," <> y')) <$> prettyInCtx x + <*> prettyInCtx y + instance PrettyInCtx TermProj where prettyInCtx TermProjLeft = return (pretty '.' <> "1") prettyInCtx TermProjRight = return (pretty '.' <> "2")