diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5f08806864..08e80fc9a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -197,6 +197,40 @@ jobs: name: "saw-${{ runner.os }}-${{ matrix.ghc }}" path: "dist/bin/saw" + mr-solver-tests: + needs: [build] + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-10.15] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + with: + submodules: true + + - shell: bash + run: .github/ci.sh install_system_deps + env: + BUILD_TARGET_OS: ${{ matrix.os }} + + - uses: actions/download-artifact@v2 + with: + name: "${{ runner.os }}-bins" + path: dist/bin + + - name: Update PATH to include SAW + shell: bash + run: | + chmod +x dist/bin/* + echo $GITHUB_WORKSPACE/dist/bin >> $GITHUB_PATH + + - working-directory: examples/mr_solver + shell: bash + run: | + saw monadify.saw + saw mr_solver_unit_tests.saw + heapster-tests: needs: [build] strategy: @@ -464,7 +498,7 @@ jobs: s2n-tests: name: "Test s2n proofs" - timeout-minutes: 60 + timeout-minutes: 120 needs: build runs-on: ubuntu-18.04 strategy: diff --git a/CHANGES.md b/CHANGES.md index 61b4130864..3fe9085243 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -33,6 +33,20 @@ dealing with C `union` types, as the type information provided by LLVM is imprecise in these cases. +* A new `llvm_union` function has been added that uses debug + information to allow users to select fields from `union` types by + name. This automates the process of manually applying + `llvm_cast_pointer` with the type of the selected union field. Just + as with `llvm_field`, debug symbols are required for `llvm_union` to + work correctly. + +* A new highly experimental `llvm_verify_fixpoint_x86` function that + allows partial correctness verification of loops using loop + invariants instead of full symbolic unrolling. Only certain very simple + styles of loops can currently be accommodated, and the user is + required to provide a term that describes how the live variables in + the loop evolve over an iteration. + # Version 0.9 ## New Features diff --git a/README.md b/README.md index 64192769b0..9aed46475a 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ unix-time Much of the work on SAW has been funded by, and lots of design input was provided by the team at the [NSA's Laboratory for Advanced Cybersecurity -Research](https://www.nsa.gov/what-we-do/research/cybersecurity-research/), +Research](https://www.nsa.gov/Research/NSA-Mission-Oriented-Research/LAC/), including Brad Martin, Frank Taylor, and Sean Weaver. Portions of SAW are also based upon work supported by the Office diff --git a/crux-mir-comp/src/Mir/Compositional/Builder.hs b/crux-mir-comp/src/Mir/Compositional/Builder.hs index cf37cc5d88..faf09ae025 100644 --- a/crux-mir-comp/src/Mir/Compositional/Builder.hs +++ b/crux-mir-comp/src/Mir/Compositional/Builder.hs @@ -613,6 +613,7 @@ substMethodSpec sc sm ms = do MS.SetupElem b sv idx -> MS.SetupElem b <$> goSetupValue sv <*> pure idx MS.SetupField b sv name -> MS.SetupField b <$> goSetupValue sv <*> pure name MS.SetupCast v _ _ -> case v of {} + MS.SetupUnion v _ _ -> case v of {} MS.SetupGlobal _ _ -> return sv MS.SetupGlobalInitializer _ _ -> return sv diff --git a/crux-mir-comp/src/Mir/Compositional/MethodSpec.hs b/crux-mir-comp/src/Mir/Compositional/MethodSpec.hs index 28404f136e..9c5c2c275a 100644 --- a/crux-mir-comp/src/Mir/Compositional/MethodSpec.hs +++ b/crux-mir-comp/src/Mir/Compositional/MethodSpec.hs @@ -30,6 +30,7 @@ type instance MS.HasSetupArray MIR = 'True type instance MS.HasSetupElem MIR = 'True type instance MS.HasSetupField MIR = 'True type instance MS.HasSetupCast MIR = 'False +type instance MS.HasSetupUnion MIR = 'False type instance MS.HasSetupGlobalInitializer MIR = 'False type instance MS.HasGhostState MIR = 'False diff --git a/cryptol-saw-core/saw/Cryptol.sawcore b/cryptol-saw-core/saw/Cryptol.sawcore index fe13215f39..3a387e57c4 100644 --- a/cryptol-saw-core/saw/Cryptol.sawcore +++ b/cryptol-saw-core/saw/Cryptol.sawcore @@ -232,6 +232,9 @@ ecRatio x y = (); eqRational : Rational -> Rational -> Bool; eqRational x y = error Bool "Unimplemented: (==) Rational"; +leRational : Rational -> Rational -> Bool; +leRational x y = error Bool "Unimplemented: (<=) Rational"; + ltRational : Rational -> Rational -> Bool; ltRational x y = error Bool "Unimplemented: (<) Rational"; @@ -455,6 +458,9 @@ errorBinary s a _ _ = error a s; boolCmp : Bool -> Bool -> Bool -> Bool; boolCmp x y k = ite Bool x (and y k) (or y k); +boolLt : Bool -> Bool -> Bool; +boolLt x y = and (not x) y; + integerCmp : Integer -> Integer -> Bool -> Bool; integerCmp x y k = or (intLt x y) (and (intEq x y) k); @@ -473,14 +479,35 @@ vecCmp n a f xs ys k = foldr (Bool -> Bool) Bool n (\ (f : Bool -> Bool) -> f) k (zipWith a a (Bool -> Bool) f n xs ys); +vecLt : + (n : Nat) -> (a : isort 0) -> + (a -> a -> Bool -> Bool) -> + (a -> a -> Bool) -> + (Vec n a -> Vec n a -> Bool); +vecLt n a f g xs ys = + foldr (Bool -> Bool) Bool n (\ (f : Bool -> Bool) -> f) False + (zipWith a a (Bool -> Bool) f n xs ys); + unitCmp : #() -> #() -> Bool -> Bool; -unitCmp _ _ _ = False; +unitCmp _ _ k = k; + +unitLe : #() -> #() -> Bool; +unitLe _ _ = True; + +unitLt : #() -> #() -> Bool; +unitLt _ _ = False; pairCmp : (a b : sort 0) -> (a -> a -> Bool -> Bool) -> (b -> b -> Bool -> Bool) -> a * b -> a * b -> Bool -> Bool; pairCmp a b f g x12 y12 k = f (fst a b x12) (fst a b y12) (g (snd a b x12) (snd a b y12) k); +pairLt : + (a b : sort 0) -> (a -> a -> Bool -> Bool) -> (b -> b -> Bool) -> + a * b -> a * b -> Bool; +pairLt a b f g x y = + f (fst a b x) (fst a b y) (g (snd a b x) (snd a b y)); + -------------------------------------------------------------------------------- -- Dictionaries and overloading @@ -534,23 +561,31 @@ PEqPair a b pa pb = { eq = pairEq a b pa.eq pb.eq }; -- Cmp class +-- `cmp x y k` computes `if k then x <= y else x < y` PCmp : sort 0 -> sort 1; PCmp a = #{ cmpEq : PEq a , cmp : a -> a -> Bool -> Bool + , le : a -> a -> Bool + , lt : a -> a -> Bool }; PCmpBit : PCmp Bool; -PCmpBit = { cmpEq = PEqBit, cmp = boolCmp }; +PCmpBit = { cmpEq = PEqBit, cmp = boolCmp, le = implies, lt = boolLt }; PCmpInteger : PCmp Integer; -PCmpInteger = { cmpEq = PEqInteger, cmp = integerCmp }; +PCmpInteger = { cmpEq = PEqInteger, cmp = integerCmp, le = intLe, lt = intLt }; PCmpRational : PCmp Rational; -PCmpRational = { cmpEq = PEqRational, cmp = rationalCmp }; +PCmpRational = { cmpEq = PEqRational, cmp = rationalCmp, le = leRational, lt = ltRational }; PCmpVec : (n : Nat) -> (a : isort 0) -> PCmp a -> PCmp (Vec n a); -PCmpVec n a pa = { cmpEq = PEqVec n a pa.cmpEq, cmp = vecCmp n a pa.cmp }; +PCmpVec n a pa = + { cmpEq = PEqVec n a pa.cmpEq + , cmp = vecCmp n a pa.cmp + , le = \ (x : Vec n a) -> \ (y : Vec n a) -> vecCmp n a pa.cmp x y True + , lt = \ (x : Vec n a) -> \ (y : Vec n a) -> vecCmp n a pa.cmp x y False + }; PCmpSeq : (n : Num) -> (a : isort 0) -> PCmp a -> PCmp (seq n a); PCmpSeq n = @@ -560,7 +595,7 @@ PCmpSeq n = n; PCmpWord : (n : Nat) -> PCmp (Vec n Bool); -PCmpWord n = { cmpEq = PEqWord n, cmp = bvCmp n }; +PCmpWord n = { cmpEq = PEqWord n, cmp = bvCmp n, le = bvule n, lt = bvult n }; PCmpSeqBool : (n : Num) -> PCmp (seq n Bool); PCmpSeqBool n = @@ -570,26 +605,33 @@ PCmpSeqBool n = n; PCmpUnit : PCmp #(); -PCmpUnit = { cmpEq = PEqUnit, cmp = unitCmp }; +PCmpUnit = { cmpEq = PEqUnit, cmp = unitCmp, le = unitLe, lt = unitLt }; PCmpPair : (a b : sort 0) -> PCmp a -> PCmp b -> PCmp (a * b); PCmpPair a b pa pb = { cmpEq = PEqPair a b pa.cmpEq pb.cmpEq , cmp = pairCmp a b pa.cmp pb.cmp + , le = pairLt a b pa.cmp pb.le + , lt = pairLt a b pa.cmp pb.lt }; -- SignedCmp class +-- `scmp x y k` computes `if k then sle x y else slt x y` PSignedCmp : sort 0 -> sort 1; PSignedCmp a = #{ signedCmpEq : PEq a , scmp : a -> a -> Bool -> Bool + , sle : a -> a -> Bool + , slt : a -> a -> Bool }; PSignedCmpVec : (n : Nat) -> (a : isort 0) -> PSignedCmp a -> PSignedCmp (Vec n a); PSignedCmpVec n a pa = { signedCmpEq = PEqVec n a pa.signedCmpEq , scmp = vecCmp n a pa.scmp + , sle = \ (x : Vec n a) -> \ (y : Vec n a) -> vecCmp n a pa.scmp x y True + , slt = \ (x : Vec n a) -> \ (y : Vec n a) -> vecCmp n a pa.scmp x y False }; PSignedCmpSeq : (n : Num) -> (a : isort 0) -> PSignedCmp a -> PSignedCmp (seq n a); @@ -600,7 +642,7 @@ PSignedCmpSeq n = n; PSignedCmpWord : (n : Nat) -> PSignedCmp (Vec n Bool); -PSignedCmpWord n = { signedCmpEq = PEqWord n, scmp = bvSCmp n }; +PSignedCmpWord n = { signedCmpEq = PEqWord n, scmp = bvSCmp n, sle = bvsle n, slt = bvslt n }; PSignedCmpSeqBool : (n : Num) -> PSignedCmp (seq n Bool); PSignedCmpSeqBool n = @@ -610,12 +652,14 @@ PSignedCmpSeqBool n = n; PSignedCmpUnit : PSignedCmp #(); -PSignedCmpUnit = { signedCmpEq = PEqUnit, scmp = unitCmp }; +PSignedCmpUnit = { signedCmpEq = PEqUnit, scmp = unitCmp, sle = unitLe, slt = unitLt }; PSignedCmpPair : (a b : sort 0) -> PSignedCmp a -> PSignedCmp b -> PSignedCmp (a * b); PSignedCmpPair a b pa pb = { signedCmpEq = PEqPair a b pa.signedCmpEq pb.signedCmpEq , scmp = pairCmp a b pa.scmp pb.scmp + , sle = pairLt a b pa.scmp pb.sle + , slt = pairLt a b pa.scmp pb.slt }; @@ -1110,20 +1154,20 @@ ecNotEq a pa x y = not (ecEq a pa x y); -- Cmp ecLt : (a : sort 0) -> PCmp a -> a -> a -> Bool; -ecLt a pa x y = pa.cmp x y False; +ecLt a pa = pa.lt; ecGt : (a : sort 0) -> PCmp a -> a -> a -> Bool; ecGt a pa x y = ecLt a pa y x; ecLtEq : (a : sort 0) -> PCmp a -> a -> a -> Bool; -ecLtEq a pa x y = not (ecLt a pa y x); +ecLtEq a pa = pa.le; ecGtEq : (a : sort 0) -> PCmp a -> a -> a -> Bool; -ecGtEq a pa x y = not (ecLt a pa x y); +ecGtEq a pa x y = ecLtEq a pa y x; -- SignedCmp ecSLt : (a : sort 0) -> PSignedCmp a -> a -> a -> Bool; -ecSLt a pa x y = pa.scmp x y False; +ecSLt a pa = pa.slt; -- Logic ecAnd : (a : sort 0) -> PLogic a -> a -> a -> a; @@ -1592,6 +1636,8 @@ PCmpFloat : (e p : Num) -> PCmp (TCFloat e p); PCmpFloat e p = { cmpEq = PEqFloat e p , cmp = \(x y : TCFloat e p) (k : Bool) -> error Bool "Unimplemented: Cmp Float" + , le = \(x y : TCFloat e p) -> error Bool "Unimplemented: Cmp Float" + , lt = \(x y : TCFloat e p) -> error Bool "Unimplemented: Cmp Float" }; PZeroFloat : (e p : Num) -> PZero (TCFloat e p); diff --git a/cryptol-saw-core/saw/CryptolM.sawcore b/cryptol-saw-core/saw/CryptolM.sawcore index 452fb45cbc..61df5156f2 100644 --- a/cryptol-saw-core/saw/CryptolM.sawcore +++ b/cryptol-saw-core/saw/CryptolM.sawcore @@ -24,12 +24,20 @@ numAssertEqM n m = isFinite : Num -> Prop; isFinite = Num_rec (\ (_:Num) -> Prop) (\ (_:Nat) -> TrueProp) FalseProp; +-- Check whether a Num is finite +checkFinite : (n:Num) -> Maybe (isFinite n); +checkFinite = + Num_rec (\ (n:Num) -> Maybe (isFinite n)) + (\ (n:Nat) -> Just (isFinite (TCNum n)) (Refl Bool True)) + (Nothing (isFinite TCInf)); + -- Assert that a Num is finite, or fail assertFiniteM : (n:Num) -> CompM (isFinite n); -assertFiniteM = - Num_rec (\ (n:Num) -> CompM (isFinite n)) - (\ (_:Nat) -> returnM TrueProp TrueI) - (errorM FalseProp "assertFiniteM: Num not finite"); +assertFiniteM n = + maybe (isFinite n) (CompM (isFinite n)) + (errorM (isFinite n) "assertFiniteM: Num not finite") + (returnM (isFinite n)) + (checkFinite n); -- Recurse over a Num known to be finite Num_rec_fin : (p: Num -> sort 1) -> ((n:Nat) -> p (TCNum n)) -> diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs index 259f09366a..35484309fd 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs @@ -83,10 +83,13 @@ data Env = Env , envC :: Map C.Name C.Schema -- ^ Cryptol type environment , envS :: [Term] -- ^ SAW-Core bound variable environment (for type checking) , envRefPrims :: Map C.PrimIdent C.Expr + , envPrims :: Map C.PrimIdent Term -- ^ Translations for other primitives + , envPrimTypes :: Map C.PrimIdent Term -- ^ Translations for primitive types } emptyEnv :: Env -emptyEnv = Env Map.empty Map.empty Map.empty Map.empty [] Map.empty +emptyEnv = + Env Map.empty Map.empty Map.empty Map.empty [] Map.empty Map.empty Map.empty liftTerm :: (Term, Int) -> (Term, Int) liftTerm (t, j) = (t, j + 1) @@ -103,6 +106,8 @@ liftEnv env = , envC = envC env , envS = envS env , envRefPrims = envRefPrims env + , envPrims = envPrims env + , envPrimTypes = envPrimTypes env } bindTParam :: SharedContext -> C.TParam -> Env -> IO Env @@ -263,7 +268,11 @@ importType sc env ty = b <- go (tyargs !! 1) scFun sc a b C.TCTuple _n -> scTupleType sc =<< traverse go tyargs - C.TCAbstract{} -> panic "importType TODO: abstract type" [] + C.TCAbstract (C.UserTC n _) + | Just prim <- C.asPrim n + , Just t <- Map.lookup prim (envPrimTypes env) -> + scApplyAllBeta sc t =<< traverse go tyargs + | True -> panic ("importType: unknown primitive type: " ++ show n) [] C.PC pc -> case pc of C.PLiteral -> -- we omit first argument to class Literal @@ -669,6 +678,9 @@ importPrimitive sc primOpts env n sch nmi <- importName n scConstant' sc nmi e t + -- lookup primitive in the extra primitive lookup table + | Just nm <- C.asPrim n, Just t <- Map.lookup nm (envPrims env) = return t + -- Optionally, create an opaque constant representing the primitive -- if it doesn't match one of the ones we know about. | Just _ <- C.asPrim n, allowUnknownPrimitives primOpts = diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs index 526990782e..b296dabb4d 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs @@ -398,6 +398,7 @@ monadifyType ctx tp@(asPi -> Just (_, _, tp_out)) monadifyType ctx tp@(asPi -> Just (x, tp_in, tp_out)) = MTyArrow (monadifyType ctx tp_in) (monadifyType ((x,tp,Nothing):ctx) tp_out) +monadifyType _ (asTupleType -> Just []) = mkMonType0 unitTypeOpenTerm monadifyType ctx (asPairType -> Just (tp1, tp2)) = MTyPair (monadifyType ctx tp1) (monadifyType ctx tp2) monadifyType ctx (asRecordType -> Just tps) = @@ -529,6 +530,36 @@ fromCompTerm :: MonType -> OpenTerm -> MonTerm fromCompTerm mtp t | isBaseType mtp = CompMonTerm mtp t fromCompTerm mtp t = ArgMonTerm $ fromArgTerm mtp t +-- | Test if a monadification type @tp@ is pure, meaning @MT(tp)=tp@ +monTypeIsPure :: MonType -> Bool +monTypeIsPure (MTyForall _ _ _) = False -- NOTE: this could potentially be true +monTypeIsPure (MTyArrow _ _) = False +monTypeIsPure (MTySeq _ _) = False +monTypeIsPure (MTyPair mtp1 mtp2) = monTypeIsPure mtp1 && monTypeIsPure mtp2 +monTypeIsPure (MTyRecord fld_mtps) = all (monTypeIsPure . snd) fld_mtps +monTypeIsPure (MTyBase _ _) = True +monTypeIsPure (MTyNum _) = True + +-- | Test if a monadification type @tp@ is semi-pure, meaning @SemiP(tp) = tp@, +-- where @SemiP@ is defined in the documentation for 'fromSemiPureTermFun' below +monTypeIsSemiPure :: MonType -> Bool +monTypeIsSemiPure (MTyForall _ k tp_f) = + monTypeIsSemiPure $ tp_f $ MTyBase k $ + -- This dummy OpenTerm should never be inspected by the recursive call + error "monTypeIsSemiPure" +monTypeIsSemiPure (MTyArrow tp_in tp_out) = + monTypeIsPure tp_in && monTypeIsSemiPure tp_out +monTypeIsSemiPure (MTySeq _ _) = False +monTypeIsSemiPure (MTyPair mtp1 mtp2) = + -- NOTE: functions in pairs are not semi-pure; only pure types in pairs are + -- semi-pure + monTypeIsPure mtp1 && monTypeIsPure mtp2 +monTypeIsSemiPure (MTyRecord fld_mtps) = + -- Same as pairs, record types are only semi-pure if they are pure + all (monTypeIsPure . snd) fld_mtps +monTypeIsSemiPure (MTyBase _ _) = True +monTypeIsSemiPure (MTyNum _) = True + -- | Build a monadification term from a function on terms which, when viewed as -- a lambda, is a "semi-pure" function of the given monadification type, meaning -- it maps terms of argument type @MT(tp)@ to an output value of argument type; @@ -857,8 +888,13 @@ monadifyTerm' _ (asApplyAll -> (asTypedGlobalDef -> Just glob, args)) = do let (macro_args, reg_args) = splitAt (macroNumArgs macro) args mtrm_f <- macroApply macro glob macro_args monadifyApply mtrm_f reg_args - Nothing -> error ("Monadification failed: unhandled constant: " - ++ globalDefString glob) + Nothing -> + monadifyTypeM (globalDefType glob) >>= \glob_mtp -> + if monTypeIsSemiPure glob_mtp then + monadifyApply (ArgMonTerm $ fromSemiPureTerm glob_mtp $ + globalDefOpenTerm glob) args + else error ("Monadification failed: unhandled constant: " + ++ globalDefString glob) monadifyTerm' _ (asApp -> Just (f, arg)) = do mtrm_f <- monadifyTerm Nothing f monadifyApply mtrm_f [arg] @@ -959,6 +995,25 @@ iteMacro = MonMacro 4 $ \_ args -> [toCompType mtp, toArgTerm atrm_cond, toCompTerm mtrm1, toCompTerm mtrm2] +-- | The macro for the either elimination function, which converts the +-- application @either a b c@ to @either a b (CompM c)@ +eitherMacro :: MonMacro +eitherMacro = MonMacro 3 $ \_ args -> + do let (tp_a, tp_b, tp_c) = + case args of + [t1, t2, t3] -> (t1, t2, t3) + _ -> error "eitherMacro: wrong number of arguments!" + mtp_a <- monadifyTypeM tp_a + mtp_b <- monadifyTypeM tp_b + mtp_c <- monadifyTypeM tp_c + let eith_app = applyGlobalOpenTerm "Prelude.either" [toArgType mtp_a, + toArgType mtp_b, + toCompType mtp_c] + let tp_eith = dataTypeOpenTerm "Prelude.Either" [toArgType mtp_a, + toArgType mtp_b] + return $ fromCompTerm (MTyArrow (MTyArrow mtp_a mtp_c) + (MTyArrow (MTyArrow mtp_b mtp_c) + (MTyArrow (mkMonType0 tp_eith) mtp_c))) eith_app -- | Make a 'MonMacro' that maps a named global whose first argument is @n:Num@ -- to a global of semi-pure type that takes an additional argument of type @@ -1048,6 +1103,7 @@ defaultMonEnv = mmCustom "Prelude.unsafeAssert" unsafeAssertMacro , mmCustom "Prelude.ite" iteMacro , mmCustom "Prelude.fix" fixMacro + , mmCustom "Prelude.either" eitherMacro -- Top-level sequence functions , mmArg "Cryptol.seqMap" "CryptolM.seqMapM" diff --git a/cryptol-saw-core/src/Verifier/SAW/CryptolEnv.hs b/cryptol-saw-core/src/Verifier/SAW/CryptolEnv.hs index 7456b70fd0..ab10fb9921 100644 --- a/cryptol-saw-core/src/Verifier/SAW/CryptolEnv.hs +++ b/cryptol-saw-core/src/Verifier/SAW/CryptolEnv.hs @@ -126,6 +126,8 @@ data CryptolEnv = CryptolEnv , eExtraTypes :: Map T.Name T.Schema -- ^ Cryptol types for extra names in scope , eExtraTSyns :: Map T.Name T.TySyn -- ^ Extra Cryptol type synonyms in scope , eTermEnv :: Map T.Name Term -- ^ SAWCore terms for *all* names in scope + , ePrims :: Map C.PrimIdent Term -- ^ SAWCore terms for primitives + , ePrimTypes :: Map C.PrimIdent Term -- ^ SAWCore terms for primitive type names } @@ -217,6 +219,8 @@ initCryptolEnv sc = do , eExtraTypes = Map.empty , eExtraTSyns = Map.empty , eTermEnv = termEnv + , ePrims = Map.empty + , ePrimTypes = Map.empty } -- Parse ----------------------------------------------------------------------- @@ -297,6 +301,8 @@ mkCryEnv env = let cryEnv = C.emptyEnv { C.envE = fmap (\t -> (t, 0)) terms , C.envC = types' + , C.envPrims = ePrims env + , C.envPrimTypes = ePrimTypes env } return cryEnv diff --git a/deps/argo b/deps/argo index fd8529883c..904fb34872 160000 --- a/deps/argo +++ b/deps/argo @@ -1 +1 @@ -Subproject commit fd8529883cd462b5f666506ecce5802bbf6867df +Subproject commit 904fb34872fcef462030fe38978842aa5a9db903 diff --git a/deps/crucible b/deps/crucible index e1308319ee..d18505d1ad 160000 --- a/deps/crucible +++ b/deps/crucible @@ -1 +1 @@ -Subproject commit e1308319eef8e0fcb55ed04df7eb2e9d5e87aac5 +Subproject commit d18505d1ad1fe03e142371359852b5f52ea0b1f0 diff --git a/deps/cryptol b/deps/cryptol index 413788c578..fe0bd96ca7 160000 --- a/deps/cryptol +++ b/deps/cryptol @@ -1 +1 @@ -Subproject commit 413788c57877aa58b6656eb7757e711e91d499fc +Subproject commit fe0bd96ca72c493608ffed5bf7547f2ab2aad2bc diff --git a/deps/cryptol-specs b/deps/cryptol-specs index 031b6c4558..0365dca32d 160000 --- a/deps/cryptol-specs +++ b/deps/cryptol-specs @@ -1 +1 @@ -Subproject commit 031b6c45584150a33aa7c0b817703372b0189492 +Subproject commit 0365dca32d13d6fc12a93ded11b686e18e6490eb diff --git a/deps/llvm-pretty b/deps/llvm-pretty index ed904c679d..34c95e77fb 160000 --- a/deps/llvm-pretty +++ b/deps/llvm-pretty @@ -1 +1 @@ -Subproject commit ed904c679d1a10ff98d1968da3407ff56cfa06a2 +Subproject commit 34c95e77fb9fdc584c23208f81f6072cb0e05c3f diff --git a/deps/llvm-pretty-bc-parser b/deps/llvm-pretty-bc-parser index af0c6951b3..1bad3e43c7 160000 --- a/deps/llvm-pretty-bc-parser +++ b/deps/llvm-pretty-bc-parser @@ -1 +1 @@ -Subproject commit af0c6951b3eebffa3404ff116685a92ad8b0697e +Subproject commit 1bad3e43c7444e363ef4c3d9f954bc04b01b1795 diff --git a/deps/macaw b/deps/macaw index d1d71fd973..45f8af1e5a 160000 --- a/deps/macaw +++ b/deps/macaw @@ -1 +1 @@ -Subproject commit d1d71fd973f802483e93dffc968dfbdde12fab59 +Subproject commit 45f8af1e5a0023f00c8c1985834bdf3b1e8bcfbc diff --git a/deps/parameterized-utils b/deps/parameterized-utils index b0a84444c5..fea8c1ab6c 160000 --- a/deps/parameterized-utils +++ b/deps/parameterized-utils @@ -1 +1 @@ -Subproject commit b0a84444c5ce096255a54e07179f242ad3d5e9dd +Subproject commit fea8c1ab6c354485d065eb4764714b06b015ce93 diff --git a/deps/what4 b/deps/what4 index 629f9f1d6f..ea717ac94a 160000 --- a/deps/what4 +++ b/deps/what4 @@ -1 +1 @@ -Subproject commit 629f9f1d6fa586cef756e3cc65c130c14de34e17 +Subproject commit ea717ac94a186b5ee18f138b71d8b4b4b2f00955 diff --git a/doc/manual/manual.md b/doc/manual/manual.md index 088742583b..2661c3cf9a 100644 --- a/doc/manual/manual.md +++ b/doc/manual/manual.md @@ -2256,6 +2256,18 @@ flows into. This is especially useful for dealing with C `union` types, as the type information provided by LLVM is imprecise in these cases. +We can automate the process of applying pointer casts if we have debug +information avaliable: + +* `llvm_union : SetupValue -> String -> SetupValue` + +Given a pointer setup value, this attempts to select the named union +branch and cast the type of the pointer. For this to work, debug +symbols must be included; moreover, the process of correlating LLVM +type information with information contained in debug symbols is a bit +heuristic. If `llvm_union` cannot figure out how to cast a pointer, +one can fall back on the more manual `llvm_cast_pointer` instead. + In the experimental Java verification implementation, the following functions can be used to state the equivalent of a combination of diff --git a/doc/manual/manual.pdf b/doc/manual/manual.pdf index 6d4671cbde..5d121e0738 100644 Binary files a/doc/manual/manual.pdf and b/doc/manual/manual.pdf differ diff --git a/examples/mr_solver/monadify.saw b/examples/mr_solver/monadify.saw index 5b5ba8974e..37e0e54d28 100644 --- a/examples/mr_solver/monadify.saw +++ b/examples/mr_solver/monadify.saw @@ -2,50 +2,113 @@ enable_experimental; import "SpecPrims.cry" as SpecPrims; import "monadify.cry"; +load_sawcore_from_file "../../cryptol-saw-core/saw/CryptolM.sawcore"; set_monadification "SpecPrims::exists" "Prelude.existsM"; set_monadification "SpecPrims::forall" "Prelude.forallM"; +let run_test name cry_term mon_term_expected = + do { print (str_concat "Test: " name); + print "Original term:"; + print_term cry_term; + mon_term <- monadify_term cry_term; + print "Monadified term:"; + print_term mon_term; + success <- is_convertible mon_term mon_term_expected; + if success then print "Success - monadified term matched expected\n" else + do { print "Test failed - did not match expected monadified term:"; + print_term mon_term_expected; + exit 1; }; }; + my_abs <- unfold_term ["my_abs"] {{ my_abs }}; -print "[my_abs] original term:"; -print_term my_abs; -my_absM <- monadify_term my_abs; -print "[my_abs] monadified term:"; -print_term my_absM; +my_abs_M <- parse_core_mod "CryptolM" "\ +\ \\(x : (mseq (TCNum 64) Bool)) -> \ +\ bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x' : (isFinite (TCNum 64))) -> \ +\ bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x'' : (isFinite (TCNum 64))) -> \ +\ ite (CompM (mseq (TCNum 64) Bool)) \ +\ (ecLt (mseq (TCNum 64) Bool) (PCmpMSeqBool (TCNum 64) x') x \ +\ (ecNumber (TCNum 0) (mseq (TCNum 64) Bool) (PLiteralSeqBoolM (TCNum 64) x''))) \ +\ (bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x''' : (isFinite (TCNum 64))) -> \ +\ returnM (mseq (TCNum 64) Bool) (ecNeg (mseq (TCNum 64) Bool) (PRingMSeqBool (TCNum 64) x''') x))) \ +\ (returnM (mseq (TCNum 64) Bool) x)))"; +run_test "my_abs" my_abs my_abs_M; -/* err_if_lt0 <- unfold_term ["err_if_lt0"] {{ err_if_lt0 }}; -print "[err_if_lt0] original term:"; -err_if_lt0M <- monadify_term err_if_lt0; -print "[err_if_lt0] monadified term:"; -print_term err_if_lt0M; -*/ +err_if_lt0_M <- parse_core_mod "CryptolM" "\ +\ \\(x : (mseq (TCNum 64) Bool)) -> \ +\ bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x' : (isFinite (TCNum 64))) -> \ +\ bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x'' : (isFinite (TCNum 64))) -> \ +\ ite (CompM (mseq (TCNum 64) Bool)) \ +\ (ecLt (mseq (TCNum 64) Bool) (PCmpMSeqBool (TCNum 64) x') x \ +\ (ecNumber (TCNum 0) (mseq (TCNum 64) Bool) (PLiteralSeqBoolM (TCNum 64) x''))) \ +\ (bindM (isFinite (TCNum 8)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 8)) \ +\ (\\(x''' : (isFinite (TCNum 8))) -> \ +\ ecErrorM (mseq (TCNum 64) Bool) (TCNum 5) \ +\ (seqToMseq (TCNum 5) (mseq (TCNum 8) Bool) \ +\ [ ecNumber (TCNum 120) (mseq (TCNum 8) Bool) (PLiteralSeqBoolM (TCNum 8) x''') \ +\ , (ecNumber (TCNum 32) (mseq (TCNum 8) Bool) (PLiteralSeqBoolM (TCNum 8) x''')) \ +\ , ecNumber (TCNum 60) (mseq (TCNum 8) Bool) (PLiteralSeqBoolM (TCNum 8) x''') \ +\ , (ecNumber (TCNum 32) (mseq (TCNum 8) Bool) (PLiteralSeqBoolM (TCNum 8) x''')) \ +\ , ecNumber (TCNum 48) (mseq (TCNum 8) Bool) (PLiteralSeqBoolM (TCNum 8) x''') ]))) \ +\ (returnM (mseq (TCNum 64) Bool) x)))"; +run_test "err_if_lt0" err_if_lt0 err_if_lt0_M; /* sha1 <- {{ sha1 }}; -print "[SHA1] original term:"; +print "Test: sha1"; +print "Original term:"; print_term sha1; -mtrm <- monadify_term sha1; -print "[SHA1] monadified term:"; -print_term mtrm; +sha1M <- monadify_term sha1; +print "Monadified term:"; +print_term sha1M; */ fib <- unfold_term ["fib"] {{ fib }}; -print "[fib] original term:"; -print_term fib; -fibM <- monadify_term fib; -print "[fib] monadified term:"; -print_term fibM; +fibM <- parse_core_mod "CryptolM" "\ +\ \\(_x : (mseq (TCNum 64) Bool)) -> \ +\ multiArgFixM (LRT_Fun (mseq (TCNum 64) Bool) (\\(_ : (mseq (TCNum 64) Bool)) -> LRT_Ret (mseq (TCNum 64) Bool))) \ +\ (\\(fib : ((mseq (TCNum 64) Bool) -> (CompM (mseq (TCNum 64) Bool)))) -> \ +\ \\(x : (mseq (TCNum 64) Bool)) -> \ +\ bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x' : (isFinite (TCNum 64))) -> \ +\ bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x'' : (isFinite (TCNum 64))) -> \ +\ ite (CompM (mseq (TCNum 64) Bool)) \ +\ (ecEq (mseq (TCNum 64) Bool) (PEqMSeqBool (TCNum 64) x') x \ +\ (ecNumber (TCNum 0) (mseq (TCNum 64) Bool) (PLiteralSeqBoolM (TCNum 64) x''))) \ +\ (bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x''' : (isFinite (TCNum 64))) -> \ +\ returnM (mseq (TCNum 64) Bool) \ +\ (ecNumber (TCNum 1) (mseq (TCNum 64) Bool) \ +\ (PLiteralSeqBoolM (TCNum 64) x''')))) \ +\ (bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x''' : (isFinite (TCNum 64))) -> \ +\ bindM (isFinite (TCNum 64)) (mseq (TCNum 64) Bool) (assertFiniteM (TCNum 64)) \ +\ (\\(x'''' : (isFinite (TCNum 64))) -> \ +\ bindM (mseq (TCNum 64) Bool) (mseq (TCNum 64) Bool) \ +\ (fib \ +\ (ecMinus (mseq (TCNum 64) Bool) (PRingMSeqBool (TCNum 64) x''') x \ +\ (ecNumber (TCNum 1) (mseq (TCNum 64) Bool) \ +\ (PLiteralSeqBoolM (TCNum 64) x'''')))) \ +\ (\\(x''''' : (mseq (TCNum 64) Bool)) -> \ +\ returnM (mseq (TCNum 64) Bool) \ +\ (ecMul (mseq (TCNum 64) Bool) (PRingMSeqBool (TCNum 64) x''') x \ +\ x''''')))))))) \ +\ _x"; +run_test "fib" fib fibM; noErrors <- unfold_term ["noErrors"] {{ SpecPrims::noErrors }}; -print "[noErrors] original term:"; -print_term noErrors; -noErrorsM <- monadify_term noErrors; -print "[noErrors] monadified term:"; -print_term noErrorsM; +noErrorsM <- parse_core_mod "CryptolM" "\\(a : sort 0) -> existsM a a (\\(x : a) -> returnM a x)"; +run_test "noErrors" noErrors noErrorsM; fibSpecNoErrors <- unfold_term ["fibSpecNoErrors"] {{ fibSpecNoErrors }}; -print "[fibSpecNoErrors] original term:"; -print_term fibSpecNoErrors; -fibSpecNoErrorsM <- monadify_term fibSpecNoErrors; -print "[fibSpecNoErrors] monadified term:"; -print_term fibSpecNoErrorsM; +fibSpecNoErrorsM <- parse_core_mod "CryptolM" "\ +\ \\(__p1 : (mseq (TCNum 64) Bool)) -> \ +\ existsM (mseq (TCNum 64) Bool) (mseq (TCNum 64) Bool) \ +\ (\\(x : (mseq (TCNum 64) Bool)) -> \ +\ returnM (mseq (TCNum 64) Bool) x)"; +run_test "fibSpecNoErrors" fibSpecNoErrors fibSpecNoErrorsM; diff --git a/heapster-saw/examples/Either.cry b/heapster-saw/examples/Either.cry new file mode 100644 index 0000000000..6adf0f39e0 --- /dev/null +++ b/heapster-saw/examples/Either.cry @@ -0,0 +1,10 @@ + +/* The definition of the Either type as an abstract type in Cryptol */ + +module Either where + +primitive type Either : * -> * -> * + +primitive Left : {a, b} a -> Either a b +primitive Right : {a, b} b -> Either a b +primitive either : {a, b, c} (a -> c) -> (b -> c) -> Either a b -> c diff --git a/heapster-saw/examples/Makefile b/heapster-saw/examples/Makefile index dee6e52df7..d10f19fc5d 100644 --- a/heapster-saw/examples/Makefile +++ b/heapster-saw/examples/Makefile @@ -1,4 +1,4 @@ -all: Makefile.coq +all: Makefile.coq mr-solver-tests Makefile.coq: _CoqProject coq_makefile -f _CoqProject -o Makefile.coq @@ -32,3 +32,12 @@ rust_data.bc: rust_data.rs rust_lifetimes.bc: rust_lifetimes.rs rustc --crate-type=lib --emit=llvm-bc rust_lifetimes.rs + +# Lists all the Mr Solver tests, without their ".saw" suffix +MR_SOLVER_TESTS = arrays_mr_solver linked_list_mr_solver + +.PHONY: mr-solver-tests $(MR_SOLVER_TESTS) +mr-solver-tests: $(MR_SOLVER_TESTS) + +$(MR_SOLVER_TESTS): + $(SAW) $@.saw diff --git a/heapster-saw/examples/arrays.sawcore b/heapster-saw/examples/arrays.sawcore index 7c20a89268..6b1f16867b 100644 --- a/heapster-saw/examples/arrays.sawcore +++ b/heapster-saw/examples/arrays.sawcore @@ -2,3 +2,40 @@ module arrays where import Prelude; + +-- The helper function for noErrorsContains0 +-- +-- noErrorsContains0H len i v = +-- orM (exists x. returnM x) (noErrorsContains0H len (i+1) v) +noErrorsContains0H : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> + CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool); +noErrorsContains0H len_top i_top v_top = + letRecM + (LRT_Cons + (LRT_Fun (Vec 64 Bool) (\ (len:Vec 64 Bool) -> + LRT_Fun (Vec 64 Bool) (\ (_:Vec 64 Bool) -> + LRT_Fun (BVVec 64 len (Vec 64 Bool)) (\ (_:BVVec 64 len (Vec 64 Bool)) -> + LRT_Ret (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool))))) + LRT_Nil) + (BVVec 64 len_top (Vec 64 Bool) * Vec 64 Bool) + (\ (f : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> + CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) -> + ((\ (len:Vec 64 Bool) (i:Vec 64 Bool) (v:BVVec 64 len (Vec 64 Bool)) -> + precondHint + (CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) + (and (bvsle 64 0x0000000000000000 i) + (bvsle 64 i 0x0fffffffffffffff)) + (orM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) + (existsM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) + (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) + (returnM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool))) + (f len (bvAdd 64 i 0x0000000000000001) v))), ())) + (\ (f : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> + CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) -> + f len_top i_top v_top); + +-- The specification that contains0 has no errors +noErrorsContains0 : (len:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> + CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool); +noErrorsContains0 len v = + noErrorsContains0H len 0x0000000000000000 v; diff --git a/heapster-saw/examples/arrays_mr_solver.saw b/heapster-saw/examples/arrays_mr_solver.saw new file mode 100644 index 0000000000..eaa38a79f7 --- /dev/null +++ b/heapster-saw/examples/arrays_mr_solver.saw @@ -0,0 +1,23 @@ +include "arrays.saw"; + +let eq_bool b1 b2 = + if b1 then + if b2 then true else false + else + if b2 then false else true; + +let fail = do { print "Test failed"; exit 1; }; +let run_test name test expected = + do { if expected then print (str_concat "Test: " name) else + print (str_concat (str_concat "Test: " name) " (expecting failure)"); + actual <- test; + if eq_bool actual expected then print "Success\n" else + do { print "Test failed\n"; exit 1; }; }; + +// Test that contains0 |= contains0 +contains0 <- parse_core_mod "arrays" "contains0"; +// run_test "contains0 |= contains0" (mr_solver contains0 contains0) true; + +noErrorsContains0 <- parse_core_mod "arrays" "noErrorsContains0"; +run_test "contains0 |= noErrorsContains0" + (mr_solver_debug 0 contains0 noErrorsContains0) true; diff --git a/heapster-saw/examples/either.saw b/heapster-saw/examples/either.saw new file mode 100644 index 0000000000..b711c9a8fb --- /dev/null +++ b/heapster-saw/examples/either.saw @@ -0,0 +1,15 @@ +/* Helper SAW script for defining the Either type in Cryptol */ + +eith_tp <- parse_core "\\ (a b:sort 0) -> Either a b"; +cryptol_add_prim_type "Either" "Either" eith_tp; + +left_fun <- parse_core "\\ (a b:sort 0) (x:a) -> Left a b x"; +cryptol_add_prim "Either" "Left" left_fun; + +right_fun <- parse_core "\\ (a b:sort 0) (x:b) -> Right a b x"; +cryptol_add_prim "Either" "Right" right_fun; + +either_fun <- parse_core "either"; +cryptol_add_prim "Either" "either" either_fun; + +import "Either.cry"; diff --git a/heapster-saw/examples/linked_list.bc b/heapster-saw/examples/linked_list.bc index bfa56aa299..d8d8693240 100644 Binary files a/heapster-saw/examples/linked_list.bc and b/heapster-saw/examples/linked_list.bc differ diff --git a/heapster-saw/examples/linked_list.c b/heapster-saw/examples/linked_list.c index 3ade859db9..0643443c92 100644 --- a/heapster-saw/examples/linked_list.c +++ b/heapster-saw/examples/linked_list.c @@ -6,6 +6,17 @@ typedef struct list64_t { struct list64_t *next; } list64_t; +/* Test if a value is the head of a list, returning 1 if so and 0 otherwiese */ +int64_t is_head (int64_t x, list64_t *l) { + if (l == NULL) { + return 0; + } else if (l->data == x) { + return 1; + } else { + return 0; + } +} + /* Test if a specific value is in a list, returning 1 if so and 0 otherwise */ int64_t is_elem (int64_t x, list64_t *l) { if (l == NULL) { diff --git a/heapster-saw/examples/linked_list.cry b/heapster-saw/examples/linked_list.cry new file mode 100644 index 0000000000..85e63ec8c4 --- /dev/null +++ b/heapster-saw/examples/linked_list.cry @@ -0,0 +1,14 @@ + +module LinkedList where + +import Either + +primitive type List : * -> * + +primitive foldList : {a} Either () (a, List a) -> List a +primitive unfoldList : {a} List a -> Either () (a, List a) + +is_elem_spec : [64] -> List [64] -> [64] +is_elem_spec x l = + either (\ _ -> 0) (\ (y,l') -> if x == y then 1 else is_elem_spec x l') + (unfoldList l) diff --git a/heapster-saw/examples/linked_list_mr_solver.saw b/heapster-saw/examples/linked_list_mr_solver.saw new file mode 100644 index 0000000000..c741d7e890 --- /dev/null +++ b/heapster-saw/examples/linked_list_mr_solver.saw @@ -0,0 +1,55 @@ +include "linked_list.saw"; + +/*** + *** Testing infrastructure + ***/ + +let eq_bool b1 b2 = + if b1 then + if b2 then true else false + else + if b2 then false else true; + +let fail = do { print "Test failed"; exit 1; }; +let run_test name test expected = + do { if expected then print (str_concat "Test: " name) else + print (str_concat (str_concat "Test: " name) " (expecting failure)"); + actual <- test; + if eq_bool actual expected then print "Success\n" else + do { print "Test failed\n"; exit 1; }; }; + + +/*** + *** Setup Cryptol environment + ***/ + +include "either.saw"; + +list_tp <- parse_core "\\ (a:sort 0) -> List a"; +cryptol_add_prim_type "LinkedList" "List" list_tp; + +fold_fun <- parse_core "foldList"; +cryptol_add_prim "LinkedList" "foldList" fold_fun; + +unfold_fun <- parse_core "unfoldList"; +cryptol_add_prim "LinkedList" "unfoldList" unfold_fun; + +import "linked_list.cry"; + + +/*** + *** The actual tests + ***/ + +heapster_typecheck_fun env "is_head" + "(). arg0:int64<>, arg1:List,always,R> -o \ + \ arg0:true, arg1:true, ret:int64<>"; + +/* +is_head <- parse_core_mod "linked_list" "is_head"; +run_test "is_head |= is_head" (mr_solver is_head is_head) true; +*/ + +is_elem <- parse_core_mod "linked_list" "is_elem"; +run_test "is_elem |= is_elem_spec" (mr_solver_debug 2 is_elem {{ is_elem_spec }}) true; +//run_test "is_elem |= is_elem" (mr_solver_debug 1 is_elem is_elem) true; diff --git a/intTests/test_llvm_union2/Makefile b/intTests/test_llvm_union2/Makefile new file mode 100644 index 0000000000..35b13c7b81 --- /dev/null +++ b/intTests/test_llvm_union2/Makefile @@ -0,0 +1,2 @@ +test.bc : test.c + clang -O0 -c -g -emit-llvm -o test.bc test.c diff --git a/intTests/test_llvm_union2/README b/intTests/test_llvm_union2/README new file mode 100644 index 0000000000..08b8b08efd --- /dev/null +++ b/intTests/test_llvm_union2/README @@ -0,0 +1,4 @@ +This example is derived from an older union example +from `examples/llvm/union`. It is intended to demonstrate +the use of the `llvm_union` operation for selecting +the branches of unions. diff --git a/intTests/test_llvm_union2/test.bc b/intTests/test_llvm_union2/test.bc new file mode 100644 index 0000000000..faa00acb8f Binary files /dev/null and b/intTests/test_llvm_union2/test.bc differ diff --git a/intTests/test_llvm_union2/test.c b/intTests/test_llvm_union2/test.c new file mode 100644 index 0000000000..3a0d213e9d --- /dev/null +++ b/intTests/test_llvm_union2/test.c @@ -0,0 +1,35 @@ +#include + +typedef enum { INC_1 , INC_2 } alg; + +typedef struct { + uint32_t x; +} inc_1_st; + +typedef struct { + uint32_t x; + uint32_t y; +} inc_2_st; + +typedef struct { + alg alg; + union { + inc_1_st inc_1_st; + inc_2_st inc_2_st; + } inc_st; +} st; + +uint32_t inc(st *st) { + switch (st->alg) { + case INC_1: + st->inc_st.inc_1_st.x += 1; + break; + case INC_2: + st->inc_st.inc_2_st.x += 1; + st->inc_st.inc_2_st.y += 1; + break; + default: + return 1/0; + } + return 0; +} diff --git a/intTests/test_llvm_union2/test.saw b/intTests/test_llvm_union2/test.saw new file mode 100644 index 0000000000..612835067d --- /dev/null +++ b/intTests/test_llvm_union2/test.saw @@ -0,0 +1,52 @@ +m <- llvm_load_module "test.bc"; + +let {{ +INC_1 = 0 : [32] +INC_2 = 1 : [32] +}}; + + +// The argument 'INC' specifies which 'alg' enum to test. +let inc_spec INC = do { + + stp <- llvm_alloc (llvm_alias "struct.st"); + llvm_points_to (llvm_field stp "alg") (llvm_term {{ INC }}); + + if eval_bool {{ INC == INC_1 }} then + do { + let p = llvm_union (llvm_field stp "inc_st") "inc_1_st"; + + x0 <- llvm_fresh_var "x0" (llvm_int 32); + llvm_points_to (llvm_field p "x") (llvm_term x0); + + llvm_execute_func [stp]; + + llvm_points_to (llvm_field p "x") (llvm_term {{ x0 + 1 }}); + } + else if eval_bool {{ INC == INC_2 }} then + do { + let p = llvm_union (llvm_field stp "inc_st") "inc_2_st"; + + x0 <- llvm_fresh_var "x0" (llvm_int 32); + y0 <- llvm_fresh_var "y0" (llvm_int 32); + + llvm_points_to (llvm_field p "x") (llvm_term x0); + llvm_points_to (llvm_field p "y") (llvm_term y0); + + llvm_execute_func [stp]; + + llvm_points_to (llvm_field p "x") (llvm_term {{ x0 + 1 }}); + llvm_points_to (llvm_field p "y") (llvm_term {{ y0 + 1 }}); + } + else return (); // Unknown INC value + + llvm_return (llvm_term {{ 0 : [32] }}); +}; + +print "Verifying 'inc_1' using 'llvm_verify':"; +llvm_verify m "inc" [] true (inc_spec {{ INC_1 }}) abc; +print ""; + +print "Verifying 'inc_2' using 'llvm_verify':"; +llvm_verify m "inc" [] true (inc_spec {{ INC_2 }}) abc; +print ""; diff --git a/intTests/test_llvm_union2/test.sh b/intTests/test_llvm_union2/test.sh new file mode 100644 index 0000000000..0b864017cd --- /dev/null +++ b/intTests/test_llvm_union2/test.sh @@ -0,0 +1 @@ +$SAW test.saw diff --git a/s2nTests/docker/awslc.dockerfile b/s2nTests/docker/awslc.dockerfile index 7f5267fb45..d05b6f99ed 100644 --- a/s2nTests/docker/awslc.dockerfile +++ b/s2nTests/docker/awslc.dockerfile @@ -7,7 +7,7 @@ WORKDIR /saw-script RUN mkdir -p /saw-script && \ git clone https://github.com/GaloisInc/aws-lc-verification.git && \ cd aws-lc-verification && \ - git checkout 7acbcfadd2e040b63cc33e8143e3f8e972408288 && \ + git checkout 1dcf4258305ce17592fb5b90a1c7b638e6bdff9e && \ git config --file=.gitmodules submodule.src.url https://github.com/awslabs/aws-lc && \ git submodule sync && \ git submodule update --init diff --git a/s2nTests/scripts/awslc-entrypoint.sh b/s2nTests/scripts/awslc-entrypoint.sh index 62ca1eb26f..afe2953913 100755 --- a/s2nTests/scripts/awslc-entrypoint.sh +++ b/s2nTests/scripts/awslc-entrypoint.sh @@ -3,6 +3,7 @@ set -xe cd /saw-script/aws-lc-verification/SAW ./scripts/install.sh +rm bin/saw cp /saw-bin/saw bin/saw cp /saw-bin/abc bin/abc diff --git a/saw-core-coq/coq/generated/CryptolToCoq/CryptolPrimitivesForSAWCore.v b/saw-core-coq/coq/generated/CryptolToCoq/CryptolPrimitivesForSAWCore.v index a57aee7096..175ebcd520 100644 --- a/saw-core-coq/coq/generated/CryptolToCoq/CryptolPrimitivesForSAWCore.v +++ b/saw-core-coq/coq/generated/CryptolToCoq/CryptolPrimitivesForSAWCore.v @@ -131,6 +131,9 @@ Definition ecRatio : SAWCoreScaffolding.Integer -> SAWCoreScaffolding.Integer -> Definition eqRational : Rational -> Rational -> SAWCoreScaffolding.Bool := fun (x : unit : Type) (y : unit : Type) => SAWCoreScaffolding.error SAWCoreScaffolding.Bool "Unimplemented: (==) Rational"%string. +Definition leRational : Rational -> Rational -> SAWCoreScaffolding.Bool := + fun (x : unit : Type) (y : unit : Type) => SAWCoreScaffolding.error SAWCoreScaffolding.Bool "Unimplemented: (<=) Rational"%string. + Definition ltRational : Rational -> Rational -> SAWCoreScaffolding.Bool := fun (x : unit : Type) (y : unit : Type) => SAWCoreScaffolding.error SAWCoreScaffolding.Bool "Unimplemented: (<) Rational"%string. @@ -219,6 +222,9 @@ Definition errorBinary : forall (s : SAWCoreScaffolding.String), forall (a : Typ Definition boolCmp : SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool := fun (x : SAWCoreScaffolding.Bool) (y : SAWCoreScaffolding.Bool) (k : SAWCoreScaffolding.Bool) => if x then SAWCoreScaffolding.and y k else SAWCoreScaffolding.or y k. +Definition boolLt : SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool := + fun (x : SAWCoreScaffolding.Bool) (y : SAWCoreScaffolding.Bool) => SAWCoreScaffolding.and (SAWCoreScaffolding.not x) y. + Definition integerCmp : SAWCoreScaffolding.Integer -> SAWCoreScaffolding.Integer -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool := fun (x : SAWCoreScaffolding.Integer) (y : SAWCoreScaffolding.Integer) (k : SAWCoreScaffolding.Bool) => SAWCoreScaffolding.or (SAWCoreScaffolding.intLt x y) (SAWCoreScaffolding.and (SAWCoreScaffolding.intEq x y) k). @@ -235,12 +241,25 @@ Definition vecCmp : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), fora fun (n : SAWCoreScaffolding.Nat) (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (f : a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (xs : SAWCoreVectorsAsCoqVectors.Vec n a) (ys : SAWCoreVectorsAsCoqVectors.Vec n a) (k : SAWCoreScaffolding.Bool) => let var__0 := SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool in SAWCoreVectorsAsCoqVectors.foldr var__0 SAWCoreScaffolding.Bool n (fun (f1 : var__0) => f1) k (SAWCorePrelude.zipWith a a var__0 f n xs ys). +Definition vecLt : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) -> (a -> a -> SAWCoreScaffolding.Bool) -> SAWCoreVectorsAsCoqVectors.Vec n a -> SAWCoreVectorsAsCoqVectors.Vec n a -> SAWCoreScaffolding.Bool := + fun (n : SAWCoreScaffolding.Nat) (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (f : a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (g : a -> a -> SAWCoreScaffolding.Bool) (xs : SAWCoreVectorsAsCoqVectors.Vec n a) (ys : SAWCoreVectorsAsCoqVectors.Vec n a) => let var__0 := SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool in + SAWCoreVectorsAsCoqVectors.foldr var__0 SAWCoreScaffolding.Bool n (fun (f1 : var__0) => f1) SAWCoreScaffolding.false (SAWCorePrelude.zipWith a a var__0 f n xs ys). + Definition unitCmp : (unit : Type) -> (unit : Type) -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool := - fun (_1 : unit : Type) (_2 : unit : Type) (_3 : SAWCoreScaffolding.Bool) => SAWCoreScaffolding.false. + fun (_1 : unit : Type) (_2 : unit : Type) (k : SAWCoreScaffolding.Bool) => k. + +Definition unitLe : (unit : Type) -> (unit : Type) -> SAWCoreScaffolding.Bool := + fun (_1 : unit : Type) (_2 : unit : Type) => SAWCoreScaffolding.true. + +Definition unitLt : (unit : Type) -> (unit : Type) -> SAWCoreScaffolding.Bool := + fun (_1 : unit : Type) (_2 : unit : Type) => SAWCoreScaffolding.false. Definition pairCmp : forall (a : Type), forall (b : Type), (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) -> (b -> b -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) -> prod a b -> prod a b -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool := fun (a : Type) (b : Type) (f : a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (g : b -> b -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (x12 : prod a b) (y12 : prod a b) (k : SAWCoreScaffolding.Bool) => f (fst x12) (fst y12) (g (snd x12) (snd y12) k). +Definition pairLt : forall (a : Type), forall (b : Type), (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) -> (b -> b -> SAWCoreScaffolding.Bool) -> prod a b -> prod a b -> SAWCoreScaffolding.Bool := + fun (a : Type) (b : Type) (f : a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (g : b -> b -> SAWCoreScaffolding.Bool) (x : prod a b) (y : prod a b) => f (fst x) (fst y) (g (snd x) (snd y)). + Definition PEq : Type -> Type := fun (a : Type) => RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil. @@ -278,55 +297,61 @@ Definition PEqPair : forall (a : Type), forall (b : Type), PEq a -> PEq b -> PEq fun (a : Type) (b : Type) (pa : RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (pb : RecordTypeCons "eq" (b -> b -> SAWCoreScaffolding.Bool) RecordTypeNil) => RecordCons "eq" (SAWCorePrelude.pairEq a b (RecordProj pa "eq") (RecordProj pb "eq")) RecordNil. Definition PCmp : Type -> Type := - fun (a : Type) => RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (PEq a) RecordTypeNil). + fun (a : Type) => let var__0 := a -> a -> SAWCoreScaffolding.Bool in + RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (PEq a) (RecordTypeCons "le" var__0 (RecordTypeCons "lt" var__0 RecordTypeNil))). Definition PCmpBit : PCmp SAWCoreScaffolding.Bool := - RecordCons "cmp" boolCmp (RecordCons "cmpEq" PEqBit RecordNil). + RecordCons "cmp" boolCmp (RecordCons "cmpEq" PEqBit (RecordCons "le" implies (RecordCons "lt" boolLt RecordNil))). Definition PCmpInteger : PCmp SAWCoreScaffolding.Integer := - RecordCons "cmp" integerCmp (RecordCons "cmpEq" PEqInteger RecordNil). + RecordCons "cmp" integerCmp (RecordCons "cmpEq" PEqInteger (RecordCons "le" SAWCoreScaffolding.intLe (RecordCons "lt" SAWCoreScaffolding.intLt RecordNil))). Definition PCmpRational : PCmp Rational := - RecordCons "cmp" rationalCmp (RecordCons "cmpEq" PEqRational RecordNil). + RecordCons "cmp" rationalCmp (RecordCons "cmpEq" PEqRational (RecordCons "le" leRational (RecordCons "lt" ltRational RecordNil))). Definition PCmpVec : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, PCmp a -> PCmp (SAWCoreVectorsAsCoqVectors.Vec n a) := - fun (n : SAWCoreScaffolding.Nat) (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) => RecordCons "cmp" (vecCmp n a (RecordProj pa "cmp")) (RecordCons "cmpEq" (PEqVec n a (RecordProj pa "cmpEq")) RecordNil). + fun (n : SAWCoreScaffolding.Nat) (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => let var__0 := SAWCoreVectorsAsCoqVectors.Vec n a in + RecordCons "cmp" (vecCmp n a (RecordProj pa "cmp")) (RecordCons "cmpEq" (PEqVec n a (RecordProj pa "cmpEq")) (RecordCons "le" (fun (x : var__0) (y : SAWCoreVectorsAsCoqVectors.Vec n a) => vecCmp n a (RecordProj pa "cmp") x y SAWCoreScaffolding.true) (RecordCons "lt" (fun (x : var__0) (y : SAWCoreVectorsAsCoqVectors.Vec n a) => vecCmp n a (RecordProj pa "cmp") x y SAWCoreScaffolding.false) RecordNil))). Definition PCmpSeq : forall (n : Num), forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, PCmp a -> PCmp (seq n a) := - fun (n : Num) => CryptolPrimitivesForSAWCore.Num_rect (fun (n1 : Num) => forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, PCmp a -> PCmp (seq n1 a)) (fun (n1 : SAWCoreScaffolding.Nat) => PCmpVec n1) (fun (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) => SAWCoreScaffolding.error (PCmp (SAWCorePrelude.Stream a)) "invalid Cmp instance"%string) n. + fun (n : Num) => CryptolPrimitivesForSAWCore.Num_rect (fun (n1 : Num) => forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, PCmp a -> PCmp (seq n1 a)) (fun (n1 : SAWCoreScaffolding.Nat) => PCmpVec n1) (fun (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => SAWCoreScaffolding.error (PCmp (SAWCorePrelude.Stream a)) "invalid Cmp instance"%string) n. Definition PCmpWord : forall (n : SAWCoreScaffolding.Nat), PCmp (SAWCoreVectorsAsCoqVectors.Vec n SAWCoreScaffolding.Bool) := - fun (n : SAWCoreScaffolding.Nat) => RecordCons "cmp" (bvCmp n) (RecordCons "cmpEq" (PEqWord n) RecordNil). + fun (n : SAWCoreScaffolding.Nat) => RecordCons "cmp" (bvCmp n) (RecordCons "cmpEq" (PEqWord n) (RecordCons "le" (SAWCoreVectorsAsCoqVectors.bvule n) (RecordCons "lt" (SAWCoreVectorsAsCoqVectors.bvult n) RecordNil))). Definition PCmpSeqBool : forall (n : Num), PCmp (seq n SAWCoreScaffolding.Bool) := fun (n : Num) => CryptolPrimitivesForSAWCore.Num_rect (fun (n1 : Num) => PCmp (seq n1 SAWCoreScaffolding.Bool)) (fun (n1 : SAWCoreScaffolding.Nat) => PCmpWord n1) (SAWCoreScaffolding.error (PCmp (SAWCorePrelude.Stream SAWCoreScaffolding.Bool)) "invalid Cmp instance"%string) n. Definition PCmpUnit : PCmp (unit : Type) := - RecordCons "cmp" unitCmp (RecordCons "cmpEq" PEqUnit RecordNil). + RecordCons "cmp" unitCmp (RecordCons "cmpEq" PEqUnit (RecordCons "le" unitLe (RecordCons "lt" unitLt RecordNil))). Definition PCmpPair : forall (a : Type), forall (b : Type), PCmp a -> PCmp b -> PCmp (prod a b) := - fun (a : Type) (b : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (pb : RecordTypeCons "cmp" (b -> b -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (b -> b -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) => RecordCons "cmp" (pairCmp a b (RecordProj pa "cmp") (RecordProj pb "cmp")) (RecordCons "cmpEq" (PEqPair a b (RecordProj pa "cmpEq") (RecordProj pb "cmpEq")) RecordNil). + fun (a : Type) (b : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (pb : RecordTypeCons "cmp" (b -> b -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (b -> b -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (b -> b -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (b -> b -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => let var__0 := RecordProj pa "cmp" in + RecordCons "cmp" (pairCmp a b var__0 (RecordProj pb "cmp")) (RecordCons "cmpEq" (PEqPair a b (RecordProj pa "cmpEq") (RecordProj pb "cmpEq")) (RecordCons "le" (pairLt a b var__0 (RecordProj pb "le")) (RecordCons "lt" (pairLt a b var__0 (RecordProj pb "lt")) RecordNil))). Definition PSignedCmp : Type -> Type := - fun (a : Type) => RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (PEq a) RecordTypeNil). + fun (a : Type) => let var__0 := a -> a -> SAWCoreScaffolding.Bool in + RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (PEq a) (RecordTypeCons "sle" var__0 (RecordTypeCons "slt" var__0 RecordTypeNil))). Definition PSignedCmpVec : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, PSignedCmp a -> PSignedCmp (SAWCoreVectorsAsCoqVectors.Vec n a) := - fun (n : SAWCoreScaffolding.Nat) (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (pa : RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) => RecordCons "scmp" (vecCmp n a (RecordProj pa "scmp")) (RecordCons "signedCmpEq" (PEqVec n a (RecordProj pa "signedCmpEq")) RecordNil). + fun (n : SAWCoreScaffolding.Nat) (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (pa : RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "sle" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "slt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => let var__0 := SAWCoreVectorsAsCoqVectors.Vec n a in + RecordCons "scmp" (vecCmp n a (RecordProj pa "scmp")) (RecordCons "signedCmpEq" (PEqVec n a (RecordProj pa "signedCmpEq")) (RecordCons "sle" (fun (x : var__0) (y : SAWCoreVectorsAsCoqVectors.Vec n a) => vecCmp n a (RecordProj pa "scmp") x y SAWCoreScaffolding.true) (RecordCons "slt" (fun (x : var__0) (y : SAWCoreVectorsAsCoqVectors.Vec n a) => vecCmp n a (RecordProj pa "scmp") x y SAWCoreScaffolding.false) RecordNil))). Definition PSignedCmpSeq : forall (n : Num), forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, PSignedCmp a -> PSignedCmp (seq n a) := - fun (n : Num) => CryptolPrimitivesForSAWCore.Num_rect (fun (n1 : Num) => forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, PSignedCmp a -> PSignedCmp (seq n1 a)) (fun (n1 : SAWCoreScaffolding.Nat) => PSignedCmpVec n1) (fun (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (pa : RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) => SAWCoreScaffolding.error (PSignedCmp (SAWCorePrelude.Stream a)) "invalid SignedCmp instance"%string) n. + fun (n : Num) => CryptolPrimitivesForSAWCore.Num_rect (fun (n1 : Num) => forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, PSignedCmp a -> PSignedCmp (seq n1 a)) (fun (n1 : SAWCoreScaffolding.Nat) => PSignedCmpVec n1) (fun (a : Type) {Inh_a : SAWCoreScaffolding.Inhabited a} (pa : RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "sle" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "slt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => SAWCoreScaffolding.error (PSignedCmp (SAWCorePrelude.Stream a)) "invalid SignedCmp instance"%string) n. Definition PSignedCmpWord : forall (n : SAWCoreScaffolding.Nat), PSignedCmp (SAWCoreVectorsAsCoqVectors.Vec n SAWCoreScaffolding.Bool) := - fun (n : SAWCoreScaffolding.Nat) => RecordCons "scmp" (bvSCmp n) (RecordCons "signedCmpEq" (PEqWord n) RecordNil). + fun (n : SAWCoreScaffolding.Nat) => RecordCons "scmp" (bvSCmp n) (RecordCons "signedCmpEq" (PEqWord n) (RecordCons "sle" (SAWCoreVectorsAsCoqVectors.bvsle n) (RecordCons "slt" (SAWCoreVectorsAsCoqVectors.bvslt n) RecordNil))). Definition PSignedCmpSeqBool : forall (n : Num), PSignedCmp (seq n SAWCoreScaffolding.Bool) := fun (n : Num) => CryptolPrimitivesForSAWCore.Num_rect (fun (n1 : Num) => PSignedCmp (seq n1 SAWCoreScaffolding.Bool)) (fun (n1 : SAWCoreScaffolding.Nat) => PSignedCmpWord n1) (SAWCoreScaffolding.error (PSignedCmp (SAWCorePrelude.Stream SAWCoreScaffolding.Bool)) "invalid SignedCmp instance"%string) n. Definition PSignedCmpUnit : PSignedCmp (unit : Type) := - RecordCons "scmp" unitCmp (RecordCons "signedCmpEq" PEqUnit RecordNil). + RecordCons "scmp" unitCmp (RecordCons "signedCmpEq" PEqUnit (RecordCons "sle" unitLe (RecordCons "slt" unitLt RecordNil))). Definition PSignedCmpPair : forall (a : Type), forall (b : Type), PSignedCmp a -> PSignedCmp b -> PSignedCmp (prod a b) := - fun (a : Type) (b : Type) (pa : RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (pb : RecordTypeCons "scmp" (b -> b -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (b -> b -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) => RecordCons "scmp" (pairCmp a b (RecordProj pa "scmp") (RecordProj pb "scmp")) (RecordCons "signedCmpEq" (PEqPair a b (RecordProj pa "signedCmpEq") (RecordProj pb "signedCmpEq")) RecordNil). + fun (a : Type) (b : Type) (pa : RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "sle" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "slt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (pb : RecordTypeCons "scmp" (b -> b -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (b -> b -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "sle" (b -> b -> SAWCoreScaffolding.Bool) (RecordTypeCons "slt" (b -> b -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => let var__0 := RecordProj pa "scmp" in + RecordCons "scmp" (pairCmp a b var__0 (RecordProj pb "scmp")) (RecordCons "signedCmpEq" (PEqPair a b (RecordProj pa "signedCmpEq") (RecordProj pb "signedCmpEq")) (RecordCons "sle" (pairLt a b var__0 (RecordProj pb "sle")) (RecordCons "slt" (pairLt a b var__0 (RecordProj pb "slt")) RecordNil))). Definition PZero : Type -> Type := fun (a : Type) => a. @@ -523,19 +548,19 @@ Definition ecFieldDiv : forall (a : Type), PField a -> a -> a -> a := fun (a : Type) (pf : RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) => RecordProj pf "fieldDiv". Definition ecCeiling : forall (a : Type), PRound a -> a -> SAWCoreScaffolding.Integer := - fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "ceiling". + fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "ceiling". Definition ecFloor : forall (a : Type), PRound a -> a -> SAWCoreScaffolding.Integer := - fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "floor". + fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "floor". Definition ecTruncate : forall (a : Type), PRound a -> a -> SAWCoreScaffolding.Integer := - fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "trunc". + fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "trunc". Definition ecRoundAway : forall (a : Type), PRound a -> a -> SAWCoreScaffolding.Integer := - fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "roundAway". + fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "roundAway". Definition ecRoundToEven : forall (a : Type), PRound a -> a -> SAWCoreScaffolding.Integer := - fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "roundToEven". + fun (a : Type) (pr : RecordTypeCons "ceiling" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "floor" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundAway" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "roundCmp" (RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (RecordTypeCons "roundField" (RecordTypeCons "fieldDiv" (a -> a -> a) (RecordTypeCons "fieldRing" (RecordTypeCons "add" (a -> a -> a) (RecordTypeCons "int" (SAWCoreScaffolding.Integer -> a) (RecordTypeCons "mul" (a -> a -> a) (RecordTypeCons "neg" (a -> a) (RecordTypeCons "ringZero" a (RecordTypeCons "sub" (a -> a -> a) RecordTypeNil)))))) (RecordTypeCons "recip" (a -> a) RecordTypeNil))) (RecordTypeCons "roundToEven" (a -> SAWCoreScaffolding.Integer) (RecordTypeCons "trunc" (a -> SAWCoreScaffolding.Integer) RecordTypeNil))))))) => RecordProj pr "roundToEven". Definition ecLg2 : forall (n : Num), seq n SAWCoreScaffolding.Bool -> seq n SAWCoreScaffolding.Bool := fun (n : Num) => CryptolPrimitivesForSAWCore.Num_rect (fun (n1 : Num) => seq n1 SAWCoreScaffolding.Bool -> seq n1 SAWCoreScaffolding.Bool) SAWCoreVectorsAsCoqVectors.bvLg2 (SAWCoreScaffolding.error (SAWCorePrelude.Stream SAWCoreScaffolding.Bool -> SAWCorePrelude.Stream SAWCoreScaffolding.Bool) "ecLg2: expected finite word"%string) n. @@ -556,19 +581,19 @@ Definition ecNotEq : forall (a : Type), PEq a -> a -> a -> SAWCoreScaffolding.Bo fun (a : Type) (pa : RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (x : a) (y : a) => SAWCoreScaffolding.not (ecEq a pa x y). Definition ecLt : forall (a : Type), PCmp a -> a -> a -> SAWCoreScaffolding.Bool := - fun (a : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (x : a) (y : a) => RecordProj pa "cmp" x y SAWCoreScaffolding.false. + fun (a : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => RecordProj pa "lt". Definition ecGt : forall (a : Type), PCmp a -> a -> a -> SAWCoreScaffolding.Bool := - fun (a : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (x : a) (y : a) => ecLt a pa y x. + fun (a : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (x : a) (y : a) => ecLt a pa y x. Definition ecLtEq : forall (a : Type), PCmp a -> a -> a -> SAWCoreScaffolding.Bool := - fun (a : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (x : a) (y : a) => SAWCoreScaffolding.not (ecLt a pa y x). + fun (a : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => RecordProj pa "le". Definition ecGtEq : forall (a : Type), PCmp a -> a -> a -> SAWCoreScaffolding.Bool := - fun (a : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (x : a) (y : a) => SAWCoreScaffolding.not (ecLt a pa x y). + fun (a : Type) (pa : RecordTypeCons "cmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "cmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "le" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "lt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) (x : a) (y : a) => ecLtEq a pa y x. Definition ecSLt : forall (a : Type), PSignedCmp a -> a -> a -> SAWCoreScaffolding.Bool := - fun (a : Type) (pa : RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) RecordTypeNil)) (x : a) (y : a) => RecordProj pa "scmp" x y SAWCoreScaffolding.false. + fun (a : Type) (pa : RecordTypeCons "scmp" (a -> a -> SAWCoreScaffolding.Bool -> SAWCoreScaffolding.Bool) (RecordTypeCons "signedCmpEq" (RecordTypeCons "eq" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil) (RecordTypeCons "sle" (a -> a -> SAWCoreScaffolding.Bool) (RecordTypeCons "slt" (a -> a -> SAWCoreScaffolding.Bool) RecordTypeNil)))) => RecordProj pa "slt". Definition ecAnd : forall (a : Type), PLogic a -> a -> a -> a := fun (a : Type) (pa : RecordTypeCons "and" (a -> a -> a) (RecordTypeCons "logicZero" a (RecordTypeCons "not" (a -> a) (RecordTypeCons "or" (a -> a -> a) (RecordTypeCons "xor" (a -> a -> a) RecordTypeNil))))) => RecordProj pa "and". @@ -694,7 +719,8 @@ Definition PEqFloat : forall (e : Num), forall (p : Num), PEq (TCFloat e p) := fun (e : Num) (p : Num) => RecordCons "eq" (fun (x : unit : Type) (y : unit : Type) => SAWCoreScaffolding.error SAWCoreScaffolding.Bool "Unimplemented: (==) Float"%string) RecordNil. Definition PCmpFloat : forall (e : Num), forall (p : Num), PCmp (TCFloat e p) := - fun (e : Num) (p : Num) => RecordCons "cmp" (fun (x : unit : Type) (y : unit : Type) (k : SAWCoreScaffolding.Bool) => SAWCoreScaffolding.error SAWCoreScaffolding.Bool "Unimplemented: Cmp Float"%string) (RecordCons "cmpEq" (PEqFloat e p) RecordNil). + fun (e : Num) (p : Num) => let var__0 := fun (x : unit : Type) (y : unit : Type) => SAWCoreScaffolding.error SAWCoreScaffolding.Bool "Unimplemented: Cmp Float"%string in + RecordCons "cmp" (fun (x : unit : Type) (y : unit : Type) (k : SAWCoreScaffolding.Bool) => SAWCoreScaffolding.error SAWCoreScaffolding.Bool "Unimplemented: Cmp Float"%string) (RecordCons "cmpEq" (PEqFloat e p) (RecordCons "le" var__0 (RecordCons "lt" var__0 RecordNil))). Definition PZeroFloat : forall (e : Num), forall (p : Num), PZero (TCFloat e p) := fun (e : Num) (p : Num) => SAWCoreScaffolding.error (TCFloat e p) "Unimplemented: Zero Float"%string. diff --git a/saw-core-coq/coq/generated/CryptolToCoq/SAWCorePrelude.v b/saw-core-coq/coq/generated/CryptolToCoq/SAWCorePrelude.v index 792b54e6e9..8ea02a2411 100644 --- a/saw-core-coq/coq/generated/CryptolToCoq/SAWCorePrelude.v +++ b/saw-core-coq/coq/generated/CryptolToCoq/SAWCorePrelude.v @@ -377,12 +377,12 @@ Definition and_triv2 : forall (x : SAWCoreScaffolding.Bool), SAWCoreScaffolding. fun (x : SAWCoreScaffolding.Bool) => let var__0 := SAWCoreScaffolding.not SAWCoreScaffolding.true in SAWCoreScaffolding.iteDep (fun (b : SAWCoreScaffolding.Bool) => SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool (SAWCoreScaffolding.and (SAWCoreScaffolding.not b) b) SAWCoreScaffolding.false) x (trans SAWCoreScaffolding.Bool (SAWCoreScaffolding.and var__0 SAWCoreScaffolding.true) var__0 SAWCoreScaffolding.false (and_True2 var__0) not_True) (and_False2 (SAWCoreScaffolding.not SAWCoreScaffolding.false)). -Definition FalseProp : Prop := - SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool SAWCoreScaffolding.true SAWCoreScaffolding.false. - Definition EqTrue : SAWCoreScaffolding.Bool -> Prop := fun (x : SAWCoreScaffolding.Bool) => SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool x SAWCoreScaffolding.true. +Definition TrueProp : Prop := + EqTrue SAWCoreScaffolding.true. + Definition TrueI : EqTrue SAWCoreScaffolding.true := SAWCoreScaffolding.Refl SAWCoreScaffolding.Bool SAWCoreScaffolding.true. @@ -534,6 +534,9 @@ Axiom head : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), SAWCoreVect Axiom tail : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), SAWCoreVectorsAsCoqVectors.Vec (SAWCoreScaffolding.Succ n) a -> SAWCoreVectorsAsCoqVectors.Vec n a . +Definition atWithDefault' : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), a -> SAWCoreVectorsAsCoqVectors.Vec n a -> SAWCoreScaffolding.Nat -> a := + fun (n_top : SAWCoreScaffolding.Nat) (a : Type) (d : a) => Nat__rec (fun (n : SAWCoreScaffolding.Nat) => SAWCoreVectorsAsCoqVectors.Vec n a -> SAWCoreScaffolding.Nat -> a) (fun (_1 : SAWCoreVectorsAsCoqVectors.Vec 0 a) (_2 : SAWCoreScaffolding.Nat) => d) (fun (n : SAWCoreScaffolding.Nat) (rec_f : SAWCoreVectorsAsCoqVectors.Vec n a -> SAWCoreScaffolding.Nat -> a) (v : SAWCoreVectorsAsCoqVectors.Vec (SAWCoreScaffolding.Succ n) a) (i : SAWCoreScaffolding.Nat) => Nat_cases a (head n a v) (fun (i_prev : SAWCoreScaffolding.Nat) (_1 : a) => rec_f (tail n a v) i_prev) i) n_top. + (* Prelude.atWithDefault was skipped *) Definition sawAt : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), forall {Inh_a : SAWCoreScaffolding.Inhabited a}, SAWCoreVectorsAsCoqVectors.Vec n a -> SAWCoreScaffolding.Nat -> a := @@ -1009,10 +1012,16 @@ Definition genBVVec : forall (n : SAWCoreScaffolding.Nat), forall (len : SAWCore Definition genBVVecFromVec : forall (m : SAWCoreScaffolding.Nat), forall (a : Type), SAWCoreVectorsAsCoqVectors.Vec m a -> a -> forall (n : SAWCoreScaffolding.Nat), forall (len : SAWCoreVectorsAsCoqVectors.Vec n SAWCoreScaffolding.Bool), BVVec n len a := fun (m : SAWCoreScaffolding.Nat) (a : Type) (v : SAWCoreVectorsAsCoqVectors.Vec m a) (def : a) (n : SAWCoreScaffolding.Nat) (len : SAWCoreVectorsAsCoqVectors.Vec n SAWCoreScaffolding.Bool) => genBVVec n len a (fun (i : SAWCoreVectorsAsCoqVectors.Vec n SAWCoreScaffolding.Bool) (_1 : SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool (SAWCoreVectorsAsCoqVectors.bvult n i len) SAWCoreScaffolding.true) => SAWCoreVectorsAsCoqVectors.atWithDefault m a def v (SAWCoreVectorsAsCoqVectors.bvToNat n i)). -Definition efq : forall (a : Type), SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool SAWCoreScaffolding.true SAWCoreScaffolding.false -> a := +Definition FalseProp : Prop := + SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool SAWCoreScaffolding.true SAWCoreScaffolding.false. + +Definition efq : forall (a : Type), FalseProp -> a := fun (a : Type) (contra : SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool SAWCoreScaffolding.true SAWCoreScaffolding.false) => let var__0 := if SAWCoreScaffolding.true then unit : Type else a in SAWCoreScaffolding.coerce (unit : Type) a (trans Type (unit : Type) var__0 a (sym Type var__0 (unit : Type) (ite_true Type (unit : Type) a)) (trans Type var__0 (if SAWCoreScaffolding.false then unit : Type else a) a (eq_cong SAWCoreScaffolding.Bool SAWCoreScaffolding.true SAWCoreScaffolding.false contra Type (fun (b : SAWCoreScaffolding.Bool) => if b then unit : Type else a)) (ite_false Type (unit : Type) a))) tt. +Definition efq1 : forall (a : Type), SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool SAWCoreScaffolding.true SAWCoreScaffolding.false -> a := + fun (a : Type) (contra : SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool SAWCoreScaffolding.true SAWCoreScaffolding.false) => SAWCoreScaffolding.Eq__rec Bit Bit1 (fun (b : Bit) (_1 : SAWCoreScaffolding.Eq Bit Bit1 b) => SAWCorePrelude.Bit_rect (fun (_2 : Bit) => Type) (unit : Type) a b) tt Bit0 (efq (SAWCoreScaffolding.Eq Bit Bit1 Bit0) contra). + Definition emptyBVVec : forall (n : SAWCoreScaffolding.Nat), forall (a : Type), BVVec n (SAWCoreVectorsAsCoqVectors.bvNat n 0) a := fun (n : SAWCoreScaffolding.Nat) (a : Type) => genBVVec n (SAWCoreVectorsAsCoqVectors.bvNat n 0) a (fun (i : SAWCoreVectorsAsCoqVectors.Vec n SAWCoreScaffolding.Bool) (pf : SAWCoreScaffolding.Eq SAWCoreScaffolding.Bool (SAWCoreVectorsAsCoqVectors.bvult n i (SAWCoreVectorsAsCoqVectors.bvNat n 0)) SAWCoreScaffolding.true) => let var__0 := SAWCoreVectorsAsCoqVectors.bvult n i (SAWCoreVectorsAsCoqVectors.bvNat n 0) in efq a (trans SAWCoreScaffolding.Bool SAWCoreScaffolding.true var__0 SAWCoreScaffolding.false (sym SAWCoreScaffolding.Bool var__0 SAWCoreScaffolding.true pf) (not_bvult_zero n i))). @@ -1130,8 +1139,6 @@ Definition foldIRT : forall (As : ListSort), forall (Ds : IRTSubsts As), forall (* Prelude.bindM was skipped *) -(* Prelude.existsM was skipped *) - (* Prelude.errorM was skipped *) Definition fmapM : forall (a : Type), forall (b : Type), (a -> b) -> CompM a -> CompM b := @@ -1147,6 +1154,21 @@ Definition fmapM2 : forall (a : Type), forall (b : Type), forall (c : Type), (a Definition fmapM3 : forall (a : Type), forall (b : Type), forall (c : Type), forall (d : Type), (a -> b -> c -> d) -> CompM a -> CompM b -> CompM c -> CompM d := fun (a : Type) (b : Type) (c : Type) (d : Type) (f : a -> b -> c -> d) (m1 : CompM a) (m2 : CompM b) (m3 : CompM c) => applyM c d (fmapM2 a b (c -> d) f m1 m2) m3. +Definition bindM2 : forall (a : Type), forall (b : Type), forall (c : Type), CompM a -> CompM b -> (a -> b -> CompM c) -> CompM c := + fun (a : Type) (b : Type) (c : Type) (m1 : CompM a) (m2 : CompM b) (f : a -> b -> CompM c) => @bindM CompM _ a c m1 (fun (x : a) => @bindM CompM _ b c m2 (f x)). + +Definition bindM3 : forall (a : Type), forall (b : Type), forall (c : Type), forall (d : Type), CompM a -> CompM b -> CompM c -> (a -> b -> c -> CompM d) -> CompM d := + fun (a : Type) (b : Type) (c : Type) (d : Type) (m1 : CompM a) (m2 : CompM b) (m3 : CompM c) (f : a -> b -> c -> CompM d) => @bindM CompM _ a d m1 (fun (x : a) => bindM2 b c d m2 m3 (f x)). + +Definition bindApplyM : forall (a : Type), forall (b : Type), (a -> CompM b) -> CompM a -> CompM b := + fun (a : Type) (b : Type) (f : a -> CompM b) (m : CompM a) => @bindM CompM _ a b m f. + +Definition bindApplyM2 : forall (a : Type), forall (b : Type), forall (c : Type), (a -> b -> CompM c) -> CompM a -> CompM b -> CompM c := + fun (a : Type) (b : Type) (c : Type) (f : a -> b -> CompM c) (m1 : CompM a) (m2 : CompM b) => @bindM CompM _ a c m1 (fun (x : a) => @bindM CompM _ b c m2 (f x)). + +Definition bindApplyM3 : forall (a : Type), forall (b : Type), forall (c : Type), forall (d : Type), (a -> b -> c -> CompM d) -> CompM a -> CompM b -> CompM c -> CompM d := + fun (a : Type) (b : Type) (c : Type) (d : Type) (f : a -> b -> c -> CompM d) (m1 : CompM a) (m2 : CompM b) (m3 : CompM c) => bindM3 a b c d m1 m2 m3 f. + Definition composeM : forall (a : Type), forall (b : Type), forall (c : Type), (a -> CompM b) -> (b -> CompM c) -> a -> CompM c := fun (a : Type) (b : Type) (c : Type) (f : a -> CompM b) (g : b -> CompM c) (x : a) => @bindM CompM _ b c (f x) g. @@ -1171,7 +1193,14 @@ Definition appendCastBVVecM : forall (n : SAWCoreScaffolding.Nat), forall (len1 let var__5 := BVVec n len3 a in @returnM CompM _ var__5 (SAWCoreScaffolding.coerce (BVVec n var__4 a) var__5 (eq_cong var__3 var__4 len3 pf Type (fun (l : var__3) => BVVec n l a)) (appendBVVec n len1 len2 a v1 v2))) (bvEqWithProof n var__0 len3). -(* Prelude.fixM was skipped *) +(* Prelude.existsM was skipped *) + +(* Prelude.orM was skipped *) + +(* Prelude.forallM was skipped *) + +Definition precondHint : forall (a : Type), SAWCoreScaffolding.Bool -> a -> a := + fun (_1 : Type) (_2 : SAWCoreScaffolding.Bool) (a : _1) => a. (* Prelude.LetRecType was skipped *) @@ -1183,14 +1212,19 @@ Definition appendCastBVVecM : forall (n : SAWCoreScaffolding.Nat), forall (len1 (* Prelude.lrtTupleType was skipped *) -(* Prelude.multiFixM was skipped *) - (* Prelude.letRecM was skipped *) Definition letRecM1 : forall (a : Type), forall (b : Type), forall (c : Type), ((a -> CompM b) -> a -> CompM b) -> ((a -> CompM b) -> CompM c) -> CompM c := fun (a : Type) (b : Type) (c : Type) (fn : (a -> CompM b) -> a -> CompM b) (body : (a -> CompM b) -> CompM c) => let var__0 := a -> CompM b in @CompM.letRecM (CompM.LRT_Cons (CompM.LRT_Fun a (fun (_1 : a) => CompM.LRT_Ret b)) CompM.LRT_Nil) c (fun (f : var__0) => pair (fn f) tt) (fun (f : var__0) => body f). +(* Prelude.fixM was skipped *) + +(* Prelude.multiFixM was skipped *) + +Definition multiArgFixM : forall (lrt : CompM.LetRecType), (CompM.lrtToType lrt -> CompM.lrtToType lrt) -> CompM.lrtToType lrt := + fun (lrt : CompM.LetRecType) (F : CompM.LetRecType_rect (fun (lrt1 : CompM.LetRecType) => Type) (fun (b : Type) => CompM b) (fun (a : Type) (_1 : a -> CompM.LetRecType) (b : a -> Type) => forall (x : a), b x) lrt -> CompM.LetRecType_rect (fun (lrt1 : CompM.LetRecType) => Type) (fun (b : Type) => CompM b) (fun (a : Type) (_2 : a -> CompM.LetRecType) (b : a -> Type) => forall (x : a), b x) lrt) => SAWCoreScaffolding.fst (@CompM.multiFixM (CompM.LRT_Cons lrt CompM.LRT_Nil) (fun (f : CompM.LetRecType_rect (fun (lrt1 : CompM.LetRecType) => Type) (fun (b : Type) => CompM b) (fun (a : Type) (_1 : a -> CompM.LetRecType) (b : a -> Type) => forall (x : a), b x) lrt) => pair (F f) tt)). + (* Prelude.test_fun0 was skipped *) (* Prelude.test_fun1 was skipped *) diff --git a/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs b/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs index c772b7d01d..d8e61537de 100644 --- a/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs +++ b/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs @@ -451,11 +451,11 @@ sawCorePreludeSpecialTreatmentMap configuration = , ("errorM", replace (Coq.App (Coq.ExplVar "errorM") [Coq.Var "CompM", Coq.Var "_"])) , ("catchM", skip) - , ("existsM", mapsTo compMModule "existsM") - , ("forallM", mapsTo compMModule "forallM") + , ("existsM", mapsToExpl compMModule "existsM") + , ("forallM", mapsToExpl compMModule "forallM") + , ("orM", mapsToExpl compMModule "orM") , ("fixM", replace (Coq.App (Coq.ExplVar "fixM") [Coq.Var "CompM", Coq.Var "_"])) - , ("existsM", mapsToExpl compMModule "existsM") , ("LetRecType", mapsTo compMModule "LetRecType") , ("LRT_Ret", mapsTo compMModule "LRT_Ret") , ("LRT_Fun", mapsTo compMModule "LRT_Fun") diff --git a/saw-core-what4/src/Verifier/SAW/Simulator/What4.hs b/saw-core-what4/src/Verifier/SAW/Simulator/What4.hs index 232d685758..f106645b9e 100644 --- a/saw-core-what4/src/Verifier/SAW/Simulator/What4.hs +++ b/saw-core-what4/src/Verifier/SAW/Simulator/What4.hs @@ -961,6 +961,8 @@ applyUnintApp sym app0 v = VCtorApp i ps xv -> foldM (applyUnintApp sym) app' =<< traverse force (ps++xv) where app' = suffixUnintApp ("_" ++ (Text.unpack (identBaseName (primName i)))) app0 VNat n -> return (suffixUnintApp ("_" ++ show n) app0) + VBVToNat w v' -> applyUnintApp sym app' v' + where app' = suffixUnintApp ("_" ++ show w) app0 TValue (suffixTValue -> Just s) -> return (suffixUnintApp s app0) VFun _ _ -> @@ -1399,6 +1401,7 @@ data ArgTerm -- ^ length, element type, list, index | ArgTermPairLeft ArgTerm | ArgTermPairRight ArgTerm + | ArgTermBVToNat Natural ArgTerm -- | Reassemble a saw-core term from an 'ArgTerm' and a list of parts. -- The length of the list should be equal to the number of @@ -1468,6 +1471,10 @@ reconstructArgTerm atrm sc ts = do (x1, ts1) <- parse at1 ts0 x <- scPairRight sc x1 return (x, ts1) + ArgTermBVToNat w at1 -> + do (x1, ts1) <- parse at1 ts0 + x <- scBvToNat sc w x1 + pure (x, ts1) parseList :: [ArgTerm] -> [Term] -> IO ([Term], [Term]) parseList [] ts0 = return ([], ts0) @@ -1519,6 +1526,15 @@ mkArgTerm sc ty val = do x <- termOfTValue sc tval pure (ArgTermConst x) + (_, VNat n) -> + do x <- scNat sc n + pure (ArgTermConst x) + + (_, VBVToNat w v) -> + do let w' = fromIntegral w -- FIXME: make w :: Natural to avoid fromIntegral + x <- mkArgTerm sc (VVecType w' VBoolType) v + pure (ArgTermBVToNat w' x) + _ -> fail $ "could not create uninterpreted function argument of type " ++ show ty termOfTValue :: SharedContext -> TValue (What4 sym) -> IO Term diff --git a/saw-core/prelude/Prelude.sawcore b/saw-core/prelude/Prelude.sawcore index 7b0676ff67..8eeaaf033b 100644 --- a/saw-core/prelude/Prelude.sawcore +++ b/saw-core/prelude/Prelude.sawcore @@ -1718,9 +1718,9 @@ efq a contra = -- Ex Falso Quodlibet at sort 1 efq1 : (a : sort 1) -> Eq Bool True False -> a; efq1 a contra = - Eq#rec Bit Bit1 - (\ (b:Bit) (_:Eq Bit Bit1 b) -> Bit#rec (\ (_:Bit) -> sort 1) #() a b) - () Bit0 (efq (Eq Bit Bit1 Bit0) contra); + Eq__rec Bit Bit1 + (\ (b:Bit) (_:Eq Bit Bit1 b) -> Bit#rec (\ (_:Bit) -> sort 1) #() a b) + () Bit0 (efq (Eq Bit Bit1 Bit0) contra); -- Generate an empty BVVec emptyBVVec : (n : Nat) -> (a : sort 0) -> BVVec n (bvNat n 0) a; @@ -2158,6 +2158,10 @@ orM a m1 m2 = existsM Bool a (\ (b:Bool) -> ite (CompM a) b m1 m2); -- those computations diverge from each other. primitive forallM : (a b:sort 0) -> (a -> CompM b) -> CompM b; +-- A hint to Mr Solver that a recursive function has the given precondition +precondHint : (a : sort 0) -> Bool -> a -> a; +precondHint _ _ a = a; + -- NOTE: for the simplicity and efficiency of MR solver, we define all -- fixed-point computations in CompM via a primitive multiFixM, defined below. -- Thus, even though fixM is really the primitive operation, we write this file @@ -2262,27 +2266,6 @@ lrtPi lrts b = (\ (lrt:LetRecType) (_:LetRecTypes) (rest:sort 0) -> lrtToType lrt -> rest) lrts; --- Apply a function the the body of a multi-arity lrtPi function -lrtPiMap : (a b : sort 0) -> (f : a -> b) -> (lrts : LetRecTypes) -> - lrtPi lrts a -> lrtPi lrts b; -lrtPiMap a b f lrts_top = - LetRecTypes#rec - (\ (lrts:LetRecTypes) -> lrtPi lrts a -> lrtPi lrts b) - (\ (x:a) -> f x) - (\ (lrt:LetRecType) (lrts:LetRecTypes) (rec:lrtPi lrts a -> lrtPi lrts b) - (f:lrtToType lrt -> lrtPi lrts a) (g:lrtToType lrt) -> - rec (f g)) - lrts_top; - --- Convert a multi-arity lrtPi that returns a pair to a pair of lrtPi functions --- that return the individual arguments -lrtPiPair : (a b:sort 0) -> (lrts : LetRecTypes) -> lrtPi lrts #(a,b) -> - #(lrtPi lrts a, lrtPi lrts b); -lrtPiPair a b lrts f = - (lrtPiMap #(a,b) a (\ (tup:#(a,b)) -> tup.(1)) lrts f, - lrtPiMap #(a,b) b (\ (tup:#(a,b)) -> tup.(2)) lrts f); - - -- Build the product type (lrtToType lrt1, ..., lrtToType lrtn) from the -- LetRecTypes list [lrt1, ..., lrtn] lrtTupleType : LetRecTypes -> sort 0; @@ -2387,6 +2370,32 @@ fixM a b f x = (\ (g: (y:a) -> CompM (b y)) -> (f g, ())) (\ (g: (y:a) -> CompM (b y)) -> g x); + +-- The following commented block allows multiFixM to be defined in terms of and +-- to reduce to letRecM, which is useful if we want to define all our automated +-- reasoning in terms of letRecM instead of multiFixM + +-- Apply a function the the body of a multi-arity lrtPi function +{- +lrtPiMap : (a b : sort 0) -> (f : a -> b) -> (lrts : LetRecTypes) -> + lrtPi lrts a -> lrtPi lrts b; +lrtPiMap a b f lrts_top = + LetRecTypes#rec + (\ (lrts:LetRecTypes) -> lrtPi lrts a -> lrtPi lrts b) + (\ (x:a) -> f x) + (\ (lrt:LetRecType) (lrts:LetRecTypes) (rec:lrtPi lrts a -> lrtPi lrts b) + (f:lrtToType lrt -> lrtPi lrts a) (g:lrtToType lrt) -> + rec (f g)) + lrts_top; + +-- Convert a multi-arity lrtPi that returns a pair to a pair of lrtPi functions +-- that return the individual arguments +lrtPiPair : (a b:sort 0) -> (lrts : LetRecTypes) -> lrtPi lrts #(a,b) -> + #(lrtPi lrts a, lrtPi lrts b); +lrtPiPair a b lrts f = + (lrtPiMap #(a,b) a (\ (tup:#(a,b)) -> tup.(1)) lrts f, + lrtPiMap #(a,b) b (\ (tup:#(a,b)) -> tup.(2)) lrts f); + -- Build a monadic function that takes in its arguments and then calls letRecM. -- That is, build a function -- @@ -2439,6 +2448,21 @@ multiFixM lrts_top F_top = rec (lrtPiPair (lrtToType lrt) (lrtTupleType lrts) lrts_top F).(2))) lrts_top F_top; +-} + +-- Construct a fixed-point for a tuple of mutually-recursive functions +-- +-- NOTE: Currently, Mr Solver actually works better with a primitive multiFixM, +-- so that's what we are going to do for now... +primitive +multiFixM : (lrts:LetRecTypes) -> lrtPi lrts (lrtTupleType lrts) -> + lrtTupleType lrts; + +-- Build a multi-argument fixed-point of type A1 -> ... -> An -> CompM B +multiArgFixM : (lrt:LetRecType) -> (lrtToType lrt -> lrtToType lrt) -> + lrtToType lrt; +multiArgFixM lrt F = + (multiFixM (LRT_Cons lrt LRT_Nil) (\ (f:lrtToType lrt) -> (F f, ()))).(1); -- Test computations diff --git a/saw-core/src/Verifier/SAW/Grammar.y b/saw-core/src/Verifier/SAW/Grammar.y index 9ecdf4fecf..a37fee8d01 100644 --- a/saw-core/src/Verifier/SAW/Grammar.y +++ b/saw-core/src/Verifier/SAW/Grammar.y @@ -80,6 +80,7 @@ import Verifier.SAW.Lexer 'injectCode' { PosPair _ (TKey "injectCode") } nat { PosPair _ (TNat _) } + bvlit { PosPair _ (TBitvector _) } '_' { PosPair _ (TIdent "_") } ident { PosPair _ (TIdent _) } identrec { PosPair _ (TRecursor _) } @@ -177,6 +178,7 @@ AppTerm : AtomTerm { $1 } AtomTerm :: { Term } AtomTerm : nat { NatLit (pos $1) (tokNat (val $1)) } + | bvlit { BVLit (pos $1) (tokBits (val $1)) } | string { StringLit (pos $1) (Text.pack (tokString (val $1))) } | Ident { Name $1 } | IdentRec { Recursor Nothing $1 } diff --git a/saw-core/src/Verifier/SAW/Lexer.x b/saw-core/src/Verifier/SAW/Lexer.x index 53bc175153..ffb2e9f77b 100644 --- a/saw-core/src/Verifier/SAW/Lexer.x +++ b/saw-core/src/Verifier/SAW/Lexer.x @@ -37,6 +37,8 @@ import Control.Monad.State.Strict import qualified Data.ByteString.Lazy as B import Data.ByteString.Lazy.UTF8 (toString) import Data.Word (Word8) +import Data.Bits +import Data.Char (digitToInt) import Numeric.Natural import Verifier.SAW.Position @@ -86,6 +88,8 @@ $white+; "-}" { \_ -> TCmntE } \" @string* \" { TString . read } @num { TNat . read } +"0x"@hex { TBitvector . readHexBV . drop 2 } +"0b"[0-1]+ { TBitvector . readBinBV . drop 2 } @key { TKey } @ident { TIdent } @ident "#rec" { TRecursor . dropRecSuffix } @@ -96,6 +100,7 @@ data Token = TIdent { tokIdent :: String } -- ^ Identifier | TRecursor { tokRecursor :: String } -- ^ Recursor | TNat { tokNat :: Natural } -- ^ Natural number literal + | TBitvector { tokBits :: [Bool] } -- ^ Bitvector literal | TString { tokString :: String } -- ^ String literal | TKey String -- ^ Keyword or predefined symbol | TEnd -- ^ End of file. @@ -108,12 +113,24 @@ data Token dropRecSuffix :: String -> String dropRecSuffix str = take (length str - 4) str +-- | Convert a hexadecimal string to a big endian list of bits +readHexBV :: String -> [Bool] +readHexBV = + concatMap (\c -> let i = digitToInt c in + [testBit i 3, testBit i 2, testBit i 1, testBit i 0]) + +-- | Convert a binary string to a big endian list of bits +readBinBV :: String -> [Bool] +readBinBV = map (\c -> c == '1') + ppToken :: Token -> String ppToken tkn = case tkn of TIdent s -> s TRecursor s -> s ++ "#rec" TNat n -> show n + TBitvector bits -> + "0b" ++ map (\b -> if b then '1' else '0') bits TString s -> show s TKey s -> s TEnd -> "END" diff --git a/saw-core/src/Verifier/SAW/OpenTerm.hs b/saw-core/src/Verifier/SAW/OpenTerm.hs index d271153071..57c1fd7ad0 100644 --- a/saw-core/src/Verifier/SAW/OpenTerm.hs +++ b/saw-core/src/Verifier/SAW/OpenTerm.hs @@ -27,13 +27,14 @@ module Verifier.SAW.OpenTerm ( unitOpenTerm, unitTypeOpenTerm, stringLitOpenTerm, stringTypeOpenTerm, trueOpenTerm, falseOpenTerm, boolOpenTerm, boolTypeOpenTerm, - arrayValueOpenTerm, bvLitOpenTerm, bvTypeOpenTerm, + arrayValueOpenTerm, vectorTypeOpenTerm, bvLitOpenTerm, bvTypeOpenTerm, pairOpenTerm, pairTypeOpenTerm, pairLeftOpenTerm, pairRightOpenTerm, tupleOpenTerm, tupleTypeOpenTerm, projTupleOpenTerm, tupleOpenTerm', tupleTypeOpenTerm', recordOpenTerm, recordTypeOpenTerm, projRecordOpenTerm, ctorOpenTerm, dataTypeOpenTerm, globalOpenTerm, extCnsOpenTerm, - applyOpenTerm, applyOpenTermMulti, applyPiOpenTerm, piArgOpenTerm, + applyOpenTerm, applyOpenTermMulti, applyGlobalOpenTerm, + applyPiOpenTerm, piArgOpenTerm, lambdaOpenTerm, lambdaOpenTermMulti, piOpenTerm, piOpenTermMulti, arrowOpenTerm, letOpenTerm, sawLetOpenTerm, -- * Monadic operations for building terms with binders @@ -179,6 +180,10 @@ bvLitOpenTerm :: [Bool] -> OpenTerm bvLitOpenTerm bits = arrayValueOpenTerm boolTypeOpenTerm $ map boolOpenTerm bits +-- | Create a SAW core term for a vector type +vectorTypeOpenTerm :: OpenTerm -> OpenTerm -> OpenTerm +vectorTypeOpenTerm n a = applyGlobalOpenTerm "Prelude.Vec" [n,a] + -- | Create a SAW core term for the type of a bitvector bvTypeOpenTerm :: Integral a => a -> OpenTerm bvTypeOpenTerm n = @@ -287,6 +292,10 @@ applyOpenTerm (OpenTerm f) (OpenTerm arg) = applyOpenTermMulti :: OpenTerm -> [OpenTerm] -> OpenTerm applyOpenTermMulti = foldl applyOpenTerm +-- | Apply a named global to 0 or more arguments +applyGlobalOpenTerm :: Ident -> [OpenTerm] -> OpenTerm +applyGlobalOpenTerm ident = applyOpenTermMulti (globalOpenTerm ident) + -- | Compute the output type of applying a function of a given type to an -- argument. That is, given @tp@ and @arg@, compute the type of applying any @f@ -- of type @tp@ to @arg@. diff --git a/saw-core/src/Verifier/SAW/Recognizer.hs b/saw-core/src/Verifier/SAW/Recognizer.hs index cea4a0acae..ba3d81ead7 100644 --- a/saw-core/src/Verifier/SAW/Recognizer.hs +++ b/saw-core/src/Verifier/SAW/Recognizer.hs @@ -382,7 +382,14 @@ asEq t = _ -> Nothing asEqTrue :: Recognizer Term Term -asEqTrue = isGlobalDef "Prelude.EqTrue" @> return +asEqTrue t = + case (isGlobalDef "Prelude.EqTrue" @> return) t of + Just x -> Just x + Nothing -> + do (a,x,y) <- asEq t + isGlobalDef "Prelude.Bool" a + isGlobalDef "Prelude.True" y + return x asArrayType :: Recognizer Term (Term :*: Term) asArrayType = (isGlobalDef "Prelude.Array" @> return) <@> return diff --git a/saw-core/src/Verifier/SAW/Rewriter.hs b/saw-core/src/Verifier/SAW/Rewriter.hs index 0b5e47ba9a..4d267f3435 100644 --- a/saw-core/src/Verifier/SAW/Rewriter.hs +++ b/saw-core/src/Verifier/SAW/Rewriter.hs @@ -181,6 +181,12 @@ asConstantNat t = _ -> Nothing -- | An enhanced matcher that can handle higher-order patterns. +-- +-- This matching procedure will attempt to find an instantiation +-- for the dangling variables appearing in @pattern@. +-- The resulting instantation will return terms that are in the same +-- variable-scoping context as @term@. In particular, if @term@ +-- is closed, then the terms in the instantiation will also be closed. scMatch :: SharedContext -> Term {- ^ pattern -} -> diff --git a/saw-core/src/Verifier/SAW/SharedTerm.hs b/saw-core/src/Verifier/SAW/SharedTerm.hs index ef6b1146c8..4a8cde027a 100644 --- a/saw-core/src/Verifier/SAW/SharedTerm.hs +++ b/saw-core/src/Verifier/SAW/SharedTerm.hs @@ -108,6 +108,8 @@ module Verifier.SAW.SharedTerm -- *** Functions and function application , scApply , scApplyAll + , scApplyBeta + , scApplyAllBeta , scGlobalApply , scFun , scFunAll @@ -1283,6 +1285,17 @@ betaNormalize sc t0 = scApplyAll :: SharedContext -> Term -> [Term] -> IO Term scApplyAll sc = foldlM (scApply sc) +-- | Apply a function to an argument, beta-reducing if the function is a lambda +scApplyBeta :: SharedContext -> Term -> Term -> IO Term +scApplyBeta sc (asLambda -> Just (_, _, body)) arg = + instantiateVar sc 0 arg body +scApplyBeta sc f arg = scApply sc f arg + +-- | Apply a function 'Term' to zero or more arguments, beta reducing any time +-- the function is a lambda +scApplyAllBeta :: SharedContext -> Term -> [Term] -> IO Term +scApplyAllBeta sc = foldlM (scApplyBeta sc) + -- | Returns the defined constant with the given 'Ident'. Fails if no -- such constant exists in the module. scLookupDef :: SharedContext -> Ident -> IO Term diff --git a/saw-core/src/Verifier/SAW/Term/Functor.hs b/saw-core/src/Verifier/SAW/Term/Functor.hs index fb7ae57fef..59c3e3276d 100644 --- a/saw-core/src/Verifier/SAW/Term/Functor.hs +++ b/saw-core/src/Verifier/SAW/Term/Functor.hs @@ -58,6 +58,7 @@ module Verifier.SAW.Term.Functor -- * Sets of free variables , BitSet, emptyBitSet, inBitSet, unionBitSets, intersectBitSets , decrBitSet, multiDecrBitSet, completeBitSet, singletonBitSet, bitSetElems + , smallestBitSetElem , looseVars, smallestFreeVar ) where @@ -485,7 +486,7 @@ bitSetElems = go 0 where go shft bs = case smallestBitSetElem bs of Nothing -> [] Just i -> - shft + i : go (shft + i + 1) (multiDecrBitSet (shft + i + 1) bs) + shft + i : go (shft + i + 1) (multiDecrBitSet (i + 1) bs) -- | Compute the free variables of a term given free variables for its immediate -- subterms diff --git a/saw-core/src/Verifier/SAW/Term/Pretty.hs b/saw-core/src/Verifier/SAW/Term/Pretty.hs index 981cfbe663..58cf0de96e 100644 --- a/saw-core/src/Verifier/SAW/Term/Pretty.hs +++ b/saw-core/src/Verifier/SAW/Term/Pretty.hs @@ -6,6 +6,7 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE PatternGuards #-} {- | Module : Verifier.SAW.Term.Pretty @@ -40,6 +41,7 @@ module Verifier.SAW.Term.Pretty , ppName ) where +import Data.Char (intToDigit) import Data.Maybe (isJust) import Control.Monad.Reader import Control.Monad.State.Strict as State @@ -62,6 +64,7 @@ import qualified Data.IntMap.Strict as IntMap import Verifier.SAW.Name import Verifier.SAW.Term.Functor import Verifier.SAW.Utils (panic) +import Verifier.SAW.Recognizer -------------------------------------------------------------------------------- -- * Doc annotations @@ -469,11 +472,28 @@ ppFlatTermF prec tf = RecordProj e fld -> ppProj fld <$> ppTerm' PrecArg e Sort s h -> return ((if h then pretty ("i"::String) else mempty) <> viaShow s) NatLit i -> ppNat <$> (ppOpts <$> ask) <*> return (toInteger i) + ArrayValue (asBoolType -> Just _) args + | Just bits <- mapM asBool $ V.toList args -> + if length bits `mod` 4 == 0 then + return $ pretty ("0x" ++ ppBitsToHex bits) + else + return $ pretty ("0b" ++ map (\b -> if b then '1' else '0') bits) ArrayValue _ args -> ppArrayValue <$> mapM (ppTerm' PrecTerm) (V.toList args) StringLit s -> return $ viaShow s ExtCns cns -> annotate ExtCnsStyle <$> ppBestName (ecName cns) +-- | Pretty-print a big endian list of bit values as a hexadecimal number +ppBitsToHex :: [Bool] -> String +ppBitsToHex (b8:b4:b2:b1:bits') = + intToDigit (8 * toInt b8 + 4 * toInt b4 + 2 * toInt b2 + toInt b1) : + ppBitsToHex bits' + where toInt True = 1 + toInt False = 0 +ppBitsToHex [] = "" +ppBitsToHex bits = + panic "ppBitsToHex" ["length of bit list is not a multiple of 4", show bits] + -- | Pretty-print a name, using the best unambiguous alias from the -- naming environment. ppBestName :: NameInfo -> PPM SawDoc diff --git a/saw-core/src/Verifier/SAW/Typechecker.hs b/saw-core/src/Verifier/SAW/Typechecker.hs index cf5bc73700..c2ccb0903a 100644 --- a/saw-core/src/Verifier/SAW/Typechecker.hs +++ b/saw-core/src/Verifier/SAW/Typechecker.hs @@ -278,6 +278,11 @@ typeInferCompleteTerm (Un.VecLit _ ts) = type_of_tp <- typeInfer tp typeInferComplete (ArrayValue (TypedTerm tp type_of_tp) $ V.fromList typed_ts) +typeInferCompleteTerm (Un.BVLit _ []) = throwTCError EmptyVectorLit +typeInferCompleteTerm (Un.BVLit _ bits) = + do tp <- liftTCM scBoolType + bit_tms <- mapM (liftTCM scBool) bits + typeInferComplete $ ArrayValue tp $ V.fromList bit_tms typeInferCompleteTerm (Un.BadTerm _) = -- Should be unreachable, since BadTerms represent parse errors, that should diff --git a/saw-core/src/Verifier/SAW/UntypedAST.hs b/saw-core/src/Verifier/SAW/UntypedAST.hs index b09ee9a1ff..d889210fe2 100644 --- a/saw-core/src/Verifier/SAW/UntypedAST.hs +++ b/saw-core/src/Verifier/SAW/UntypedAST.hs @@ -84,6 +84,8 @@ data Term | StringLit Pos Text -- | Vector literal. | VecLit Pos [Term] + -- | Bitvector literal. + | BVLit Pos [Bool] | BadTerm Pos deriving (Show, TH.Lift) @@ -128,6 +130,7 @@ instance Positioned Term where NatLit p _ -> p StringLit p _ -> p VecLit p _ -> p + BVLit p _ -> p BadTerm p -> p instance Positioned TermVar where diff --git a/saw-remote-api/src/SAWServer.hs b/saw-remote-api/src/SAWServer.hs index cbc72e3845..b6ac03c7f6 100644 --- a/saw-remote-api/src/SAWServer.hs +++ b/saw-remote-api/src/SAWServer.hs @@ -59,6 +59,7 @@ import qualified Verifier.SAW.Cryptol.Prelude as CryptolSAW import Verifier.SAW.CryptolEnv (initCryptolEnv, bindTypedTerm) import qualified Cryptol.Utils.Ident as Cryptol import Verifier.SAW.Cryptol.Monadify (defaultMonEnv) +import SAWScript.Prover.MRSolver (emptyMREnv) import qualified Argo --import qualified CryptolServer (validateServerState, ServerState(..)) @@ -94,6 +95,7 @@ data CrucibleSetupVal ty e -- | RecordValue [(String, CrucibleSetupVal e)] | FieldLValue (CrucibleSetupVal ty e) String | CastLValue (CrucibleSetupVal ty e) ty + | UnionLValue (CrucibleSetupVal ty e) String | ElementLValue (CrucibleSetupVal ty e) Int | GlobalInitializer String | GlobalLValue String @@ -218,6 +220,7 @@ initialState readFileFn = , rwDocs = mempty , rwCryptol = cenv , rwMonadify = defaultMonEnv + , rwMRSolverEnv = emptyMREnv , rwPPOpts = defaultPPOpts , rwJVMTrans = jvmTrans , rwPrimsAvail = mempty diff --git a/saw-remote-api/src/SAWServer/Data/SetupValue.hs b/saw-remote-api/src/SAWServer/Data/SetupValue.hs index 18fb6836fd..ff973feb34 100644 --- a/saw-remote-api/src/SAWServer/Data/SetupValue.hs +++ b/saw-remote-api/src/SAWServer/Data/SetupValue.hs @@ -16,6 +16,7 @@ data SetupValTag | TagTupleValue | TagFieldLValue | TagCastLValue + | TagUnionLValue | TagElemLValue | TagGlobalInit | TagGlobalLValue @@ -31,6 +32,7 @@ instance FromJSON SetupValTag where "tuple" -> pure TagTupleValue "field" -> pure TagFieldLValue "cast" -> pure TagCastLValue + "union" -> pure TagUnionLValue "element lvalue" -> pure TagElemLValue "global initializer" -> pure TagGlobalInit "global lvalue" -> pure TagGlobalLValue @@ -49,6 +51,7 @@ instance (FromJSON ty, FromJSON cryptolExpr) => FromJSON (CrucibleSetupVal ty cr TagTupleValue -> TupleValue <$> o .: "elements" TagFieldLValue -> FieldLValue <$> o .: "base" <*> o .: "field" TagCastLValue -> CastLValue <$> o .: "base" <*> o .: "type" + TagUnionLValue -> UnionLValue <$> o .: "base" <*> o .: "field" TagElemLValue -> ElementLValue <$> o .: "base" <*> o .: "index" TagGlobalInit -> GlobalInitializer <$> o .: "name" TagGlobalLValue -> GlobalLValue <$> o .: "name" diff --git a/saw-remote-api/src/SAWServer/JVMCrucibleSetup.hs b/saw-remote-api/src/SAWServer/JVMCrucibleSetup.hs index d2e571e943..8811524901 100644 --- a/saw-remote-api/src/SAWServer/JVMCrucibleSetup.hs +++ b/saw-remote-api/src/SAWServer/JVMCrucibleSetup.hs @@ -188,6 +188,8 @@ compileJVMContract fileReader bic cenv0 c = JVMSetupM $ fail "Field l-values unsupported in JVM API." getSetupVal _ (CastLValue _ _) = JVMSetupM $ fail "Cast l-values unsupported in JVM API." + getSetupVal _ (UnionLValue _ _) = + JVMSetupM $ fail "Union l-values unsupported in JVM API." getSetupVal _ (ElementLValue _ _) = JVMSetupM $ fail "Element l-values unsupported in JVM API." getSetupVal _ (GlobalInitializer _) = diff --git a/saw-remote-api/src/SAWServer/LLVMCrucibleSetup.hs b/saw-remote-api/src/SAWServer/LLVMCrucibleSetup.hs index ab25b88d8d..8eab0dd225 100644 --- a/saw-remote-api/src/SAWServer/LLVMCrucibleSetup.hs +++ b/saw-remote-api/src/SAWServer/LLVMCrucibleSetup.hs @@ -188,6 +188,9 @@ compileLLVMContract fileReader bic ghostEnv cenv0 c = getSetupVal env (CastLValue base ty) = do base' <- getSetupVal env base LLVMCrucibleSetupM $ return $ CMS.anySetupCast base' (llvmType ty) + getSetupVal env (UnionLValue base fld) = + do base' <- getSetupVal env base + LLVMCrucibleSetupM $ return $ CMS.anySetupUnion base' fld getSetupVal env (ElementLValue base idx) = do base' <- getSetupVal env base LLVMCrucibleSetupM $ return $ CMS.anySetupElem base' idx diff --git a/saw-script.cabal b/saw-script.cabal index d59f8546ab..47496b4791 100644 --- a/saw-script.cabal +++ b/saw-script.cabal @@ -48,6 +48,7 @@ library , haskeline , heapster-saw , hobbits >= 1.3.1 + , galois-dwarf >= 0.2.2 , IfElse , jvm-parser , lens @@ -158,6 +159,10 @@ library SAWScript.Prover.Util SAWScript.Prover.SBV SAWScript.Prover.MRSolver + SAWScript.Prover.MRSolver.Monad + SAWScript.Prover.MRSolver.SMT + SAWScript.Prover.MRSolver.Solver + SAWScript.Prover.MRSolver.Term SAWScript.Prover.RME SAWScript.Prover.ABC SAWScript.Prover.What4 diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index d9c2244a42..4edc37ef2d 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -106,7 +106,8 @@ import qualified Cryptol.Backend.Monad as C (runEval) import qualified Cryptol.Eval.Type as C (evalType) import qualified Cryptol.Eval.Value as C (fromVBit, fromVWord) import qualified Cryptol.Eval.Concrete as C (Concrete(..), bvVal) -import qualified Cryptol.Utils.Ident as C (mkIdent, packModName) +import qualified Cryptol.Utils.Ident as C (mkIdent, packModName, + textToModName, PrimIdent(..)) import qualified Cryptol.Utils.RecordMap as C (recordFromFields) import qualified SAWScript.SBVParser as SBV @@ -323,15 +324,17 @@ hoistIfsPrim t = do return t{ ttTerm = t' } +isConvertiblePrim :: TypedTerm -> TypedTerm -> TopLevel Bool +isConvertiblePrim x y = do + sc <- getSharedContext + io $ scConvertible sc False (ttTerm x) (ttTerm y) + checkConvertiblePrim :: TypedTerm -> TypedTerm -> TopLevel () checkConvertiblePrim x y = do - sc <- getSharedContext - str <- io $ do - c <- scConvertible sc False (ttTerm x) (ttTerm y) - pure (if c - then "Convertible" - else "Not convertible") - printOutLnTop Info str + c <- isConvertiblePrim x y + printOutLnTop Info (if c + then "Convertible" + else "Not convertible") readCore :: FilePath -> TopLevel TypedTerm @@ -1385,8 +1388,8 @@ tailPrim :: [a] -> TopLevel [a] tailPrim [] = fail "tail: empty list" tailPrim (_ : xs) = return xs -parseCore :: String -> TopLevel Term -parseCore input = +parseCoreMod :: String -> String -> TopLevel Term +parseCoreMod mnm_str input = do sc <- getSharedContext let base = "" path = "" @@ -1397,18 +1400,29 @@ parseCore input = do let msg = show err printOutLnTop Opts.Error msg fail msg - let mnm = Just $ mkModuleName ["Cryptol"] - err_or_t <- io $ runTCM (typeInferComplete uterm) sc mnm [] + let mnm = + mkModuleName $ Text.splitOn (Text.pack ".") $ Text.pack mnm_str + _ <- io $ scFindModule sc mnm -- Check that mnm exists + err_or_t <- io $ runTCM (typeInferComplete uterm) sc (Just mnm) [] case err_or_t of Left err -> fail (show err) Right (TC.TypedTerm x _) -> return x +parseCore :: String -> TopLevel Term +parseCore = parseCoreMod "Cryptol" + parse_core :: String -> TopLevel TypedTerm parse_core input = do t <- parseCore input sc <- getSharedContext io $ mkTypedTerm sc t +parse_core_mod :: String -> String -> TopLevel TypedTerm +parse_core_mod mnm input = do + t <- parseCoreMod mnm input + sc <- getSharedContext + io $ mkTypedTerm sc t + prove_core :: ProofScript () -> String -> TopLevel Theorem prove_core script input = do sc <- getSharedContext @@ -1443,6 +1457,14 @@ core_thm input = thm <- io (proofByTerm sc db t pos "core_thm") SV.returnProof thm +specialize_theorem :: Theorem -> [TypedTerm] -> TopLevel Theorem +specialize_theorem thm ts = + do sc <- getSharedContext + db <- roTheoremDB <$> getTopLevelRO + pos <- SV.getPosition + thm' <- io (specializeTheorem sc db pos "specialize_theorem" thm (map ttTerm ts)) + SV.returnProof thm' + get_opt :: Int -> TopLevel String get_opt n = do prog <- io $ System.Environment.getProgName @@ -1505,6 +1527,27 @@ cryptol_add_path path = let rw' = rw { rwCryptol = ce' } putTopLevelRW rw' +cryptol_add_prim :: String -> String -> TypedTerm -> TopLevel () +cryptol_add_prim mnm nm trm = + do rw <- getTopLevelRW + let env = rwCryptol rw + let prim_name = + C.PrimIdent (C.textToModName $ Text.pack mnm) (Text.pack nm) + let env' = + env { CEnv.ePrims = + Map.insert prim_name (ttTerm trm) (CEnv.ePrims env) } + putTopLevelRW (rw { rwCryptol = env' }) + +cryptol_add_prim_type :: String -> String -> TypedTerm -> TopLevel () +cryptol_add_prim_type mnm nm tp = + do rw <- getTopLevelRW + let env = rwCryptol rw + let prim_name = + C.PrimIdent (C.textToModName $ Text.pack mnm) (Text.pack nm) + let env' = env { CEnv.ePrimTypes = + Map.insert prim_name (ttTerm tp) (CEnv.ePrimTypes env) } + putTopLevelRW (rw { rwCryptol = env' }) + -- | Call 'Cryptol.importSchema' using a 'CEnv.CryptolEnv' importSchemaCEnv :: SharedContext -> CEnv.CryptolEnv -> Cryptol.Schema -> IO Term @@ -1532,19 +1575,25 @@ monadifyTypedTerm sc t = -- | Ensure that a 'TypedTerm' has been monadified ensureMonadicTerm :: SharedContext -> TypedTerm -> TopLevel TypedTerm -ensureMonadicTerm _ t - | TypedTermOther tp <- ttType t - , Prover.isCompFunType tp = return t +ensureMonadicTerm sc t + | TypedTermOther tp <- ttType t = + io (Prover.isCompFunType sc tp) >>= \case + True -> return t + False -> monadifyTypedTerm sc t ensureMonadicTerm sc t = monadifyTypedTerm sc t +-- | Run Mr Solver with the given debug level to prove that the first term +-- refines the second mrSolver :: SharedContext -> Int -> TypedTerm -> TypedTerm -> TopLevel Bool mrSolver sc dlvl t1 t2 = - do m1 <- ttTerm <$> ensureMonadicTerm sc t1 + do rw <- get + m1 <- ttTerm <$> ensureMonadicTerm sc t1 m2 <- ttTerm <$> ensureMonadicTerm sc t2 - res <- liftIO $ Prover.askMRSolver sc dlvl SBV.z3 Nothing m1 m2 + let env = rwMRSolverEnv rw + res <- liftIO $ Prover.askMRSolver sc dlvl env Nothing m1 m2 case res of - Just err -> io (putStrLn $ Prover.showMRFailure err) >> return False - Nothing -> return True + Left err -> io (putStrLn $ Prover.showMRFailure err) >> return False + Right env' -> put (rw { rwMRSolverEnv = env' }) >> return True setMonadification :: SharedContext -> String -> String -> TopLevel () setMonadification sc cry_str saw_str = diff --git a/src/SAWScript/Crucible/Common/MethodSpec.hs b/src/SAWScript/Crucible/Common/MethodSpec.hs index 680fff2c48..5d9249a824 100644 --- a/src/SAWScript/Crucible/Common/MethodSpec.hs +++ b/src/SAWScript/Crucible/Common/MethodSpec.hs @@ -112,6 +112,7 @@ type family HasSetupElem ext :: Bool type family HasSetupField ext :: Bool type family HasSetupGlobal ext :: Bool type family HasSetupCast ext :: Bool +type family HasSetupUnion ext :: Bool type family HasSetupGlobalInitializer ext :: Bool -- | From the manual: \"The SetupValue type corresponds to values that can occur @@ -127,6 +128,7 @@ data SetupValue ext where SetupElem :: B (HasSetupElem ext) -> SetupValue ext -> Int -> SetupValue ext SetupField :: B (HasSetupField ext) -> SetupValue ext -> String -> SetupValue ext SetupCast :: B (HasSetupCast ext) -> SetupValue ext -> CastType ext -> SetupValue ext + SetupUnion :: B (HasSetupUnion ext) -> SetupValue ext -> String -> SetupValue ext -- | A pointer to a global variable SetupGlobal :: B (HasSetupGlobal ext) -> String -> SetupValue ext @@ -144,6 +146,7 @@ type SetupValueHas (c :: Type -> Constraint) ext = , c (B (HasSetupElem ext)) , c (B (HasSetupField ext)) , c (B (HasSetupCast ext)) + , c (B (HasSetupUnion ext)) , c (B (HasSetupGlobal ext)) , c (B (HasSetupGlobalInitializer ext)) , c (CastType ext) @@ -170,6 +173,7 @@ ppSetupValue setupval = case setupval of SetupArray _ vs -> PP.brackets (commaList (map ppSetupValue vs)) SetupElem _ v i -> PP.parens (ppSetupValue v) PP.<> PP.pretty ("." ++ show i) SetupField _ v f -> PP.parens (ppSetupValue v) PP.<> PP.pretty ("." ++ f) + SetupUnion _ v u -> PP.parens (ppSetupValue v) PP.<> PP.pretty ("." ++ u) SetupCast _ v tp -> PP.parens (ppSetupValue v) PP.<> PP.pretty (" AS " ++ show tp) SetupGlobal _ nm -> PP.pretty ("global(" ++ nm ++ ")") SetupGlobalInitializer _ nm -> PP.pretty ("global_initializer(" ++ nm ++ ")") diff --git a/src/SAWScript/Crucible/JVM/MethodSpecIR.hs b/src/SAWScript/Crucible/JVM/MethodSpecIR.hs index 957452f383..136d66156f 100644 --- a/src/SAWScript/Crucible/JVM/MethodSpecIR.hs +++ b/src/SAWScript/Crucible/JVM/MethodSpecIR.hs @@ -59,6 +59,7 @@ type instance MS.HasSetupArray CJ.JVM = 'False type instance MS.HasSetupElem CJ.JVM = 'False type instance MS.HasSetupField CJ.JVM = 'False type instance MS.HasSetupCast CJ.JVM = 'False +type instance MS.HasSetupUnion CJ.JVM = 'False type instance MS.HasSetupGlobalInitializer CJ.JVM = 'False type instance MS.HasGhostState CJ.JVM = 'False diff --git a/src/SAWScript/Crucible/JVM/Override.hs b/src/SAWScript/Crucible/JVM/Override.hs index ce464d92aa..da0cff4cdf 100644 --- a/src/SAWScript/Crucible/JVM/Override.hs +++ b/src/SAWScript/Crucible/JVM/Override.hs @@ -964,6 +964,7 @@ instantiateSetupValue sc s v = MS.SetupElem empty _ _ -> absurd empty MS.SetupField empty _ _ -> absurd empty MS.SetupCast empty _ _ -> absurd empty + MS.SetupUnion empty _ _ -> absurd empty MS.SetupGlobalInitializer empty _ -> absurd empty where doTerm (TypedTerm schema t) = TypedTerm schema <$> scInstantiateExt sc s t diff --git a/src/SAWScript/Crucible/JVM/ResolveSetupValue.hs b/src/SAWScript/Crucible/JVM/ResolveSetupValue.hs index ddf904b0a5..33e4311225 100644 --- a/src/SAWScript/Crucible/JVM/ResolveSetupValue.hs +++ b/src/SAWScript/Crucible/JVM/ResolveSetupValue.hs @@ -145,6 +145,7 @@ typeOfSetupValue _cc env _nameEnv val = MS.SetupElem empty _ _ -> absurd empty MS.SetupField empty _ _ -> absurd empty MS.SetupCast empty _ _ -> absurd empty + MS.SetupUnion empty _ _ -> absurd empty MS.SetupGlobalInitializer empty _ -> absurd empty lookupAllocIndex :: Map AllocIndex a -> AllocIndex -> a @@ -175,6 +176,7 @@ resolveSetupVal cc env _tyenv _nameEnv val = MS.SetupElem empty _ _ -> absurd empty MS.SetupField empty _ _ -> absurd empty MS.SetupCast empty _ _ -> absurd empty + MS.SetupUnion empty _ _ -> absurd empty MS.SetupGlobalInitializer empty _ -> absurd empty where sym = cc^.jccSym diff --git a/src/SAWScript/Crucible/LLVM/Builtins.hs b/src/SAWScript/Crucible/LLVM/Builtins.hs index f31cfab88b..ba1ad5e1a1 100644 --- a/src/SAWScript/Crucible/LLVM/Builtins.hs +++ b/src/SAWScript/Crucible/LLVM/Builtins.hs @@ -697,7 +697,7 @@ checkSpecArgumentTypes cc mspec = mapM_ resolveArg [0..(nArgs-1)] resolveArg i = case Map.lookup i (mspec ^. MS.csArgBindings) of Just (mt, sv) -> do - mt' <- typeOfSetupValue cc tyenv nameEnv sv + mt' <- exceptToFail (typeOfSetupValue cc tyenv nameEnv sv) checkArgTy i mt mt' Nothing -> throwMethodSpec mspec $ unwords ["Argument", show i, "unspecified when verifying", show nm] @@ -721,7 +721,7 @@ checkSpecReturnType cc mspec = " has void return type" ] (Just sv, Just retTy) -> - do retTy' <- + do retTy' <- exceptToFail $ typeOfSetupValue cc (MS.csAllocations mspec) -- map allocation indices to allocations (mspec ^. MS.csPreState . MS.csVarTypeNames) -- map alloc indices to var names @@ -2206,7 +2206,7 @@ llvm_points_to_internal mbCheckType cond (getAllLLVM -> ptr) (getAllLLVM -> val) let path = [] lhsTy <- llvm_points_to_check_lhs_validity ptr loc path - valTy <- typeOfSetupValue cc env nameEnv val + valTy <- exceptToFail $ typeOfSetupValue cc env nameEnv val case mbCheckType of Nothing -> pure () Just CheckAgainstPointerType -> checkMemTypeCompatibility loc lhsTy valTy @@ -2243,9 +2243,9 @@ llvm_points_to_bitfield (getAllLLVM -> ptr) fieldName (getAllLLVM -> val) = let path = [ResolvedField fieldName] _ <- llvm_points_to_check_lhs_validity ptr loc path - bfIndex <- resolveSetupBitfieldIndexOrFail cc env nameEnv ptr fieldName + bfIndex <- exceptToFail $ resolveSetupBitfield cc env nameEnv ptr fieldName let lhsFieldTy = Crucible.IntType $ fromIntegral $ biFieldSize bfIndex - valTy <- typeOfSetupValue cc env nameEnv val + valTy <- exceptToFail $ typeOfSetupValue cc env nameEnv val -- Currently, we require the type of the RHS value to precisely match -- the type of the field within the bitfield. One could imagine -- having finer-grained control over this (e.g., @@ -2279,7 +2279,7 @@ llvm_points_to_check_lhs_validity ptr loc path = else Setup.csResolvedState %= markResolved ptr path let env = MS.csAllocations (st ^. Setup.csMethodSpec) nameEnv = MS.csTypeNames (st ^. Setup.csMethodSpec) - ptrTy <- typeOfSetupValue cc env nameEnv ptr + ptrTy <- exceptToFail $ typeOfSetupValue cc env nameEnv ptr case ptrTy of Crucible.PtrType symTy -> case Crucible.asMemType symTy of @@ -2326,7 +2326,7 @@ llvm_points_to_array_prefix (getAllLLVM -> ptr) arr sz = else Setup.csResolvedState %= markResolved ptr [] let env = MS.csAllocations (st ^. Setup.csMethodSpec) nameEnv = MS.csTypeNames (st ^. Setup.csMethodSpec) - ptrTy <- typeOfSetupValue cc env nameEnv ptr + ptrTy <- exceptToFail $ typeOfSetupValue cc env nameEnv ptr _ <- case ptrTy of Crucible.PtrType symTy -> case Crucible.asMemType symTy of @@ -2351,8 +2351,8 @@ llvm_equal (getAllLLVM -> val1) (getAllLLVM -> val2) = st <- get let env = MS.csAllocations (st ^. Setup.csMethodSpec) nameEnv = MS.csTypeNames (st ^. Setup.csMethodSpec) - ty1 <- typeOfSetupValue cc env nameEnv val1 - ty2 <- typeOfSetupValue cc env nameEnv val2 + ty1 <- exceptToFail $ typeOfSetupValue cc env nameEnv val1 + ty2 <- exceptToFail $ typeOfSetupValue cc env nameEnv val2 b <- liftIO $ checkRegisterCompatibility ty1 ty2 unless b $ throwCrucibleSetup loc $ unlines diff --git a/src/SAWScript/Crucible/LLVM/MethodSpecIR.hs b/src/SAWScript/Crucible/LLVM/MethodSpecIR.hs index f7a64dbf0c..66b3598ce5 100644 --- a/src/SAWScript/Crucible/LLVM/MethodSpecIR.hs +++ b/src/SAWScript/Crucible/LLVM/MethodSpecIR.hs @@ -98,6 +98,7 @@ module SAWScript.Crucible.LLVM.MethodSpecIR , anySetupStruct , anySetupElem , anySetupField + , anySetupUnion , anySetupNull , anySetupGlobal , anySetupGlobalInitializer @@ -173,6 +174,7 @@ type instance MS.HasSetupArray (LLVM _) = 'True type instance MS.HasSetupElem (LLVM _) = 'True type instance MS.HasSetupField (LLVM _) = 'True type instance MS.HasSetupCast (LLVM _) = 'True +type instance MS.HasSetupUnion (LLVM _) = 'True type instance MS.HasSetupGlobal (LLVM _) = 'True type instance MS.HasSetupGlobalInitializer (LLVM _) = 'True @@ -582,6 +584,9 @@ anySetupCast val ty = mkAllLLVM (MS.SetupCast () (getAllLLVM val) ty) anySetupField :: AllLLVM MS.SetupValue -> String -> AllLLVM MS.SetupValue anySetupField val field = mkAllLLVM (MS.SetupField () (getAllLLVM val) field) +anySetupUnion :: AllLLVM MS.SetupValue -> String -> AllLLVM MS.SetupValue +anySetupUnion val uname = mkAllLLVM (MS.SetupUnion () (getAllLLVM val) uname) + anySetupNull :: AllLLVM MS.SetupValue anySetupNull = mkAllLLVM (MS.SetupNull ()) diff --git a/src/SAWScript/Crucible/LLVM/Override.hs b/src/SAWScript/Crucible/LLVM/Override.hs index 694f8bccc1..f29bf1731a 100644 --- a/src/SAWScript/Crucible/LLVM/Override.hs +++ b/src/SAWScript/Crucible/LLVM/Override.hs @@ -64,8 +64,8 @@ import Control.Lens.Lens import Control.Lens.Setter import Control.Lens.TH import Control.Exception as X -import Control.Monad.IO.Class (liftIO) import Control.Monad +import Control.Monad.Except import Data.Either (partitionEithers) import Data.Foldable (for_, traverse_, toList) import Data.List @@ -192,11 +192,11 @@ mkStructuralMismatch :: mkStructuralMismatch _opts cc _sc spec llvmval setupval memTy = let tyEnv = MS.csAllocations spec nameEnv = MS.csTypeNames spec - maybeTy = typeOfSetupValue cc tyEnv nameEnv setupval + maybeMsgTy = either (const Nothing) Just $ runExcept (typeOfSetupValue cc tyEnv nameEnv setupval) in pure $ StructuralMismatch (PP.pretty llvmval) (MS.ppSetupValue setupval) - maybeTy + maybeMsgTy memTy -- | Instead of using 'ppPointsTo', which prints 'SetupValue', translate @@ -1023,6 +1023,7 @@ matchPointsTos opts sc cc spec prepost = go False [] SetupElem _ x _ -> setupVars x SetupField _ x _ -> setupVars x SetupCast _ x _ -> setupVars x + SetupUnion _ x _ -> setupVars x SetupTerm _ -> Set.empty SetupNull _ -> Set.empty SetupGlobal _ _ -> Set.empty @@ -1194,15 +1195,21 @@ matchArg opts sc cc cs prepost actual expectedTy expected = (Crucible.LLVMValInt blk off, Crucible.PtrType _, SetupElem () v i) -> do let tyenv = MS.csAllocations cs nameEnv = MS.csTypeNames cs - i' <- resolveSetupElemIndexOrFail cc tyenv nameEnv v i + delta <- exceptToFail $ resolveSetupElemOffset cc tyenv nameEnv v i off' <- liftIO $ W4.bvSub sym off - =<< W4.bvLit sym (W4.bvWidth off) (Crucible.bytesToBV (W4.bvWidth off) i') + =<< W4.bvLit sym (W4.bvWidth off) (Crucible.bytesToBV (W4.bvWidth off) delta) matchArg opts sc cc cs prepost (Crucible.LLVMValInt blk off') expectedTy v - (_, Crucible.PtrType _, SetupField () v n) -> + + (Crucible.LLVMValInt blk off, Crucible.PtrType _, SetupField () v n) -> do let tyenv = MS.csAllocations cs nameEnv = MS.csTypeNames cs - i <- resolveSetupFieldIndexOrFail cc tyenv nameEnv v n - matchArg opts sc cc cs prepost actual expectedTy (SetupElem () v i) + fld <- exceptToFail $ + do info <- resolveSetupValueInfo cc tyenv nameEnv v + recoverStructFieldInfo cc tyenv nameEnv v info n + let delta = fromIntegral $ Crucible.fiOffset fld + off' <- liftIO $ W4.bvSub sym off + =<< W4.bvLit sym (W4.bvWidth off) (Crucible.bytesToBV (W4.bvWidth off) delta) + matchArg opts sc cc cs prepost (Crucible.LLVMValInt blk off') expectedTy v (_, _, SetupGlobalInitializer () _) -> resolveAndMatch @@ -1478,7 +1485,7 @@ matchPointsToValue opts sc cc spec prepost loc maybe_cond ptr val = case val of ConcreteSizeValue val' -> - do memTy <- liftIO $ typeOfSetupValue cc tyenv nameEnv val' + do memTy <- exceptToFail $ typeOfSetupValue cc tyenv nameEnv val' -- In case the types are different (from llvm_points_to_untyped) -- then the load type should be determined by the rhs. storTy <- Crucible.toStorableType memTy @@ -1840,14 +1847,14 @@ invalidateMutableAllocs opts sc cc cs = _ -> pure Nothing -- set of (concrete base pointer, size) for each postcondition memory write - postPtrs <- Set.fromList <$> catMaybes <$> mapM + postPtrs <- Set.fromList <$> catMaybes <$> traverse (\case LLVMPointsTo _loc _cond ptr val -> case val of ConcreteSizeValue val' -> do (_, Crucible.LLVMPointer blk _) <- resolveSetupValue opts cc sc cs Crucible.PtrRepr ptr - sz <- (return . Crucible.storageTypeSize) - =<< Crucible.toStorableType - =<< typeOfSetupValue cc (MS.csAllocations cs) (MS.csTypeNames cs) val' + memTy <- exceptToFail $ + typeOfSetupValue cc (MS.csAllocations cs) (MS.csTypeNames cs) val' + sz <- Crucible.storageTypeSize <$> Crucible.toStorableType memTy return $ Just (W4.asNat blk, sz) SymbolicSizeValue{} -> return Nothing LLVMPointsToBitfield _loc ptr fieldName _val -> do @@ -2063,7 +2070,7 @@ storePointsToValue opts cc env tyenv nameEnv base_mem maybe_cond ptr val maybe_i let store_op = \mem -> case val of ConcreteSizeValue val' -> do - memTy <- typeOfSetupValue cc tyenv nameEnv val' + memTy <- exceptToFail $ typeOfSetupValue cc tyenv nameEnv val' storTy <- Crucible.toStorableType memTy case val' of SetupTerm tm @@ -2103,7 +2110,7 @@ storePointsToValue opts cc env tyenv nameEnv base_mem maybe_cond ptr val maybe_i let invalidate_op = \mem -> do sz <- case val of ConcreteSizeValue val' -> do - memTy <- typeOfSetupValue cc tyenv nameEnv val' + memTy <- exceptToFail $ typeOfSetupValue cc tyenv nameEnv val' storTy <- Crucible.toStorableType memTy W4.bvLit sym @@ -2353,6 +2360,7 @@ instantiateSetupValue sc s v = SetupElem{} -> return v SetupField{} -> return v SetupCast{} -> return v + SetupUnion{} -> return v SetupNull{} -> return v SetupGlobal{} -> return v SetupGlobalInitializer{} -> return v @@ -2375,7 +2383,7 @@ resolveSetupValueLLVM opts cc sc spec sval = mem <- readGlobal (Crucible.llvmMemVar (ccLLVMContext cc)) let tyenv = MS.csAllocations spec nameEnv = MS.csTypeNames spec - memTy <- liftIO $ typeOfSetupValue cc tyenv nameEnv sval + memTy <- exceptToFail $ typeOfSetupValue cc tyenv nameEnv sval sval' <- liftIO $ instantiateSetupValue sc s sval lval <- liftIO $ resolveSetupVal cc mem m tyenv nameEnv sval' `X.catch` handleException opts return (memTy, lval) diff --git a/src/SAWScript/Crucible/LLVM/ResolveSetupValue.hs b/src/SAWScript/Crucible/LLVM/ResolveSetupValue.hs index db16fb3b00..fb37eae978 100644 --- a/src/SAWScript/Crucible/LLVM/ResolveSetupValue.hs +++ b/src/SAWScript/Crucible/LLVM/ResolveSetupValue.hs @@ -20,15 +20,15 @@ module SAWScript.Crucible.LLVM.ResolveSetupValue , resolveSetupVal , resolveSetupValBitfield , typeOfSetupValue + , exceptToFail , resolveTypedTerm , resolveSAWPred , resolveSAWSymBV - , resolveSetupFieldIndex - , resolveSetupFieldIndexOrFail + , recoverStructFieldInfo + , resolveSetupValueInfo , BitfieldIndex(..) - , resolveSetupBitfieldIndex - , resolveSetupBitfieldIndexOrFail - , resolveSetupElemIndexOrFail + , resolveSetupBitfield + , resolveSetupElemOffset , equalValsPred , memArrayToSawCoreTerm , scPtrWidthBvNat @@ -37,11 +37,12 @@ module SAWScript.Crucible.LLVM.ResolveSetupValue import Control.Lens ((^.)) import Control.Monad -import qualified Control.Monad.Fail as Fail +import Control.Monad.Except import Control.Monad.State import qualified Data.BitVector.Sized as BV -import Data.Maybe (fromMaybe, listToMaybe, fromJust) +import Data.Maybe (fromMaybe, fromJust) +import qualified Data.Dwarf as Dwarf import Data.Map (Map) import qualified Data.Map as Map import qualified Data.Set as Set @@ -85,90 +86,219 @@ import SAWScript.Crucible.Common.MethodSpec (AllocIndex(..), SetupValu import SAWScript.Crucible.LLVM.MethodSpecIR import qualified SAWScript.Proof as SP ---import qualified SAWScript.LLVMBuiltins as LB type LLVMVal = Crucible.LLVMVal Sym type LLVMPtr wptr = Crucible.LLVMPtr Sym wptr --- | Use the LLVM metadata to determine the struct field index --- corresponding to the given field name. + + +exceptToFail :: MonadFail m => Except String a -> m a +exceptToFail m = either fail pure $ runExcept m + +-- | Attempt to look up LLVM debug metadata regarding the type of the +-- given setup value. This is a best-effort procedure, as the +-- necessary debug information may not be avaliable. Even if this +-- procedure succeeds, the returned information may be partial, in +-- the sense that it may contain `Unknown` nodes. resolveSetupValueInfo :: - LLVMCrucibleContext wptr {- ^ crucible context -} -> - Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> + LLVMCrucibleContext wptr {- ^ crucible context -} -> + Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> Map AllocIndex Crucible.Ident {- ^ allocation type names -} -> - SetupValue (LLVM arch) {- ^ pointer to struct -} -> - L.Info {- ^ field index -} + SetupValue (LLVM arch) {- ^ pointer value -} -> + Except String L.Info {- ^ debug type info of pointed-to type -} resolveSetupValueInfo cc env nameEnv v = case v of SetupGlobal _ name -> case lookup (L.Symbol name) globalTys of - Just (L.Alias alias) -> L.Pointer (L.guessAliasInfo mdMap alias) - _ -> L.Unknown + Just (L.Alias alias) -> pure (L.guessAliasInfo mdMap alias) + _ -> throwError $ "Debug info for global name '"++name++"' not found." - SetupVar i - | Just alias <- Map.lookup i nameEnv - -> L.Pointer (L.guessAliasInfo mdMap alias) + SetupVar i -> + case Map.lookup i nameEnv of + Just alias -> pure (L.guessAliasInfo mdMap alias) + Nothing -> + -- TODO? is this a panic situation? + throwError $ "Type information for local allocation value not found: " ++ show i - SetupCast () _ (L.Alias alias) - -> L.Pointer (L.guessAliasInfo mdMap alias) + SetupCast () _ (L.Alias alias) -> pure (L.guessAliasInfo mdMap alias) SetupField () a n -> - fromMaybe L.Unknown $ - do L.Pointer (L.Structure xs) <- return (resolveSetupValueInfo cc env nameEnv a) - listToMaybe [L.Pointer i | L.StructFieldInfo{L.sfiName = n', L.sfiInfo = i} <- xs, n == n' ] + do i <- resolveSetupValueInfo cc env nameEnv a + case findStruct i of + Nothing -> + throwError $ unlines $ + [ "Unable to resolve struct field name: '" ++ n ++ "'" + , "Could not resolve setup value debug information into a struct type." + , case i of + L.Unknown -> "Perhaps you need to compile with debug symbols enabled." + _ -> show i + ] + Just (snm, xs) -> + case [ i' | L.StructFieldInfo{L.sfiName = n', L.sfiInfo = i' } <- xs, n == n' ] of + [] -> throwError $ unlines $ + [ "Unable to resolve struct field name: '" ++ n ++ "'"] ++ + [ "Struct with name '" ++ str ++ "' found." | Just str <- [snm] ] ++ + [ "The following field names were found for this struct:" ] ++ + map ("- "++) [n' | L.StructFieldInfo{L.sfiName = n'} <- xs] + i':_ -> pure i' + + SetupUnion () a u -> + do i <- resolveSetupValueInfo cc env nameEnv a + case findUnion i of + Nothing -> + throwError $ unlines $ + [ "Unable to resolve union field name: '" ++ u ++ "'" + , "Could not resolve setup value debug information into a union type." + , case i of + L.Unknown -> "Perhaps you need to compile with debug symbols enabled." + _ -> show i + ] + Just (unm, xs) -> + case [ i' | L.UnionFieldInfo{L.ufiName = n', L.ufiInfo = i'} <- xs, u == n' ] of + [] -> throwError $ unlines $ + [ "Unable to resolve union field name: '" ++ u ++ "'"] ++ + [ "Union with name '" ++ str ++ "' found." | Just str <- [unm] ] ++ + [ "The following field names were found for this union:" ] ++ + map ("- "++) [n' | L.UnionFieldInfo{L.ufiName = n'} <- xs] + i':_ -> pure i' + + _ -> pure L.Unknown - _ -> L.Unknown where globalTys = [ (L.globalSym g, L.globalType g) | g <- L.modGlobals (ccLLVMModuleAST cc) ] mdMap = Crucible.llvmMetadataMap (ccTypeCtx cc) --- | Use the LLVM metadata to determine the struct field index --- corresponding to the given field name. -resolveSetupFieldIndex :: - LLVMCrucibleContext arch {- ^ crucible context -} -> - Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> - Map AllocIndex Crucible.Ident {- ^ allocation type names -} -> - SetupValue (LLVM arch) {- ^ pointer to struct -} -> - String {- ^ field name -} -> - Maybe Int {- ^ field index -} -resolveSetupFieldIndex cc env nameEnv v n = - case resolveSetupValueInfo cc env nameEnv v of - L.Pointer (L.Structure xs) -> +-- | Given DWARF type information that is expected to describe a +-- struct, find its name (if any) and information about its fields. +-- This procedure handles the common case where a typedef is used to +-- give a name to an anonymous struct. If a struct both has a direct +-- name and is included in a typedef, the direct name will be preferred. +findStruct :: L.Info -> Maybe (Maybe String, [L.StructFieldInfo]) +findStruct = loop Nothing + where loop _ (L.Typedef nm i) = loop (Just nm) i + loop nm (L.Structure nm' xs) = Just (nm' <> nm, xs) + loop _ _ = Nothing + +-- | Given DWARF type information that is expected to describe a +-- union, find its name (if any) and information about its fields. +-- This procedure handles the common case where a typedef is used to +-- give a name to an anonymous union. If a union both has a direct +-- name and is included in a typedef, the direct name will be preferred. +findUnion :: L.Info -> Maybe (Maybe String, [L.UnionFieldInfo]) +findUnion = loop Nothing + where loop _ (L.Typedef nm i) = loop (Just nm) i + loop nm (L.Union nm' xs) = Just (nm' <> nm, xs) + loop _ _ = Nothing + +-- | Given LLVM debug information about a setup value, attempt to +-- find the corresponding @FieldInfo@ structure for the named +-- field. +recoverStructFieldInfo :: + LLVMCrucibleContext arch {- ^ crucible context -} -> + Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> + Map AllocIndex Crucible.Ident {- ^ allocation type names -} -> + SetupValue (LLVM arch) {- ^ the value to examine -} -> + L.Info {- ^ extracted LLVM debug information about the type of the value -} -> + String {- ^ the name of the field -} -> + Except String Crucible.FieldInfo +recoverStructFieldInfo cc env nameEnv v info n = + case findStruct info of + Nothing -> + throwError $ unlines $ + [ "Unable to resolve struct field name: '" ++ show n ++ "'" + , "Could not resolve setup value debug information into a struct type." + , case info of + L.Unknown -> "Perhaps you need to compile with debug symbols enabled." + _ -> show info + ] + Just (snm,xs) -> case [o | L.StructFieldInfo{L.sfiName = n', L.sfiOffset = o} <- xs, n == n' ] of - [] -> Nothing + [] -> throwError $ unlines $ + [ "Unable to resolve struct field name: '" ++ n ++ "'"] ++ + [ "Struct with name '" ++ str ++ "' found." | Just str <- [snm] ] ++ + [ "The following field names were found for this struct:" ] ++ + map ("- "++) [n' | L.StructFieldInfo{L.sfiName = n'} <- xs] o:_ -> - do Crucible.PtrType symTy <- typeOfSetupValue cc env nameEnv v - Crucible.StructType si <- - let ?lc = lc - in either (\_ -> Nothing) Just $ Crucible.asMemType symTy - V.findIndex (\fi -> Crucible.bytesToBits (Crucible.fiOffset fi) == fromIntegral o) (Crucible.siFields si) - - _ -> Nothing + do vty <- typeOfSetupValue cc env nameEnv v + case do Crucible.PtrType symTy <- pure vty + Crucible.StructType si <- let ?lc = ccTypeCtx cc + in either (\_ -> Nothing) Just $ Crucible.asMemType symTy + V.find (\fi -> Crucible.bytesToBits (Crucible.fiOffset fi) == fromIntegral o) + (Crucible.siFields si) + of + Nothing -> + throwError $ unlines $ + [ "Found struct field name: '" ++ n ++ "'"] ++ + [ "in struct with name '" ++ str ++ "'." | Just str <- [snm] ] ++ + [ "However, the offset of this field found in the debug information could not" + , "be correlated with the computed LLVM type of the setup value:" + , show vty + ] + Just fld -> return fld + +-- | Attempt to turn type information from DWARF debug data back into +-- the corresponding LLVM type. This is a best-effort procedure, as +-- we may have to make educated guesses about names, and there might +-- not be enough data to succeed. +reverseDebugInfoType :: L.Info -> Maybe L.Type +reverseDebugInfoType = loop Nothing where - lc = ccTypeCtx cc + loop n i = case i of + L.Unknown -> + case n of + Just nm -> Just (L.Alias (L.Ident nm)) + Nothing -> Nothing + + L.Pointer i' -> L.PtrTo <$> loop Nothing i' + + L.Union n' _ -> + case n' <> n of + Just nm -> Just (L.Alias (L.Ident ("union."++ nm))) + Nothing -> Nothing + + L.Structure n' xs -> + case n' <> n of + Just nm -> Just (L.Alias (L.Ident ("struct." ++ nm))) + Nothing -> L.Struct <$> mapM (reverseDebugInfoType . L.sfiInfo) xs + + L.Typedef nm x -> loop (Just nm) x + + L.ArrInfo x -> L.Array 0 <$> loop Nothing x + + L.BaseType _nm bt -> reverseBaseTypeInfo bt + +-- | Attempt to turn DWARF basic type information back into +-- LLVM type syntax. This process is currently rather +-- ad-hoc, and may miss cases. +reverseBaseTypeInfo :: L.DIBasicType -> Maybe L.Type +reverseBaseTypeInfo dibt = + case Dwarf.DW_ATE (fromIntegral (L.dibtEncoding dibt)) of + Dwarf.DW_ATE_boolean -> Just $ L.PrimType $ L.Integer 1 + + Dwarf.DW_ATE_float -> + case L.dibtSize dibt of + 16 -> Just $ L.PrimType $ L.FloatType $ L.Half + 32 -> Just $ L.PrimType $ L.FloatType $ L.Float + 64 -> Just $ L.PrimType $ L.FloatType $ L.Double + 80 -> Just $ L.PrimType $ L.FloatType $ L.X86_fp80 + 128 -> Just $ L.PrimType $ L.FloatType $ L.Fp128 + _ -> Nothing + + Dwarf.DW_ATE_signed -> + Just $ L.PrimType $ L.Integer (fromIntegral (L.dibtSize dibt)) + + Dwarf.DW_ATE_signed_char -> + Just $ L.PrimType $ L.Integer 8 + + Dwarf.DW_ATE_unsigned -> + Just $ L.PrimType $ L.Integer (fromIntegral (L.dibtSize dibt)) + + Dwarf.DW_ATE_unsigned_char -> + Just $ L.PrimType $ L.Integer 8 + + _ -> Nothing + -resolveSetupFieldIndexOrFail :: - Fail.MonadFail m => - LLVMCrucibleContext arch {- ^ crucible context -} -> - Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> - Map AllocIndex Crucible.Ident {- ^ allocation type names -} -> - SetupValue (LLVM arch) {- ^ pointer to struct -} -> - String {- ^ field name -} -> - m Int {- ^ field index -} -resolveSetupFieldIndexOrFail cc env nameEnv v n = - case resolveSetupFieldIndex cc env nameEnv v n of - Just i -> pure i - Nothing -> - let msg = "Unable to resolve field name: " ++ show n - in - fail $ - -- Show the user what fields were available (if any) - case resolveSetupValueInfo cc env nameEnv v of - L.Pointer (L.Structure xs) -> unlines $ - [ msg - , "The following field names were found for this struct:" - ] ++ map ("- "++) [n' | L.StructFieldInfo{L.sfiName = n'} <- xs] - _ -> unlines [msg, "No field names were found for this struct"] -- | Information about a field within a bitfield in a struct. For example, -- given the following C struct: @@ -190,7 +320,7 @@ resolveSetupFieldIndexOrFail cc env nameEnv v n = -- 'BitfieldIndex' -- { 'biFieldSize' = 1 -- , 'biFieldOffset' = 0 --- , 'biBitfieldIndex' = 4 +-- , 'biBitfieldByteOffset' = 4 -- , 'biBitfieldType' = i8 -- } -- @@ -198,7 +328,7 @@ resolveSetupFieldIndexOrFail cc env nameEnv v n = -- 'BitfieldIndex' -- { 'biFieldSize' = 2 -- , 'biFieldOffset' = 1 --- , 'biBitfieldIndex' = 4 +-- , 'biBitfieldByteOffset' = 4 -- , 'biBitfieldType' = i8 -- } -- @@ -206,13 +336,13 @@ resolveSetupFieldIndexOrFail cc env nameEnv v n = -- 'BitfieldIndex' -- { 'biFieldSize' = 1 -- , 'biFieldOffset' = 3 --- , 'biBitfieldIndex' = 4 +-- , 'biBitfieldByteOffset' = 4 -- , 'biBitfieldType' = i8 -- } -- @ -- -- Note that the 'biFieldSize's and 'biFieldOffset's are specific to each --- individual field, while the 'biBitfieldIndex'es and 'biBitfieldType's are +-- individual field, while the 'biBitfieldByteOffest's and 'biBitfieldType's are -- all the same, as the latter two all describe the same bitfield. data BitfieldIndex = BitfieldIndex { biFieldSize :: Word64 @@ -220,147 +350,145 @@ data BitfieldIndex = BitfieldIndex , biFieldOffset :: Word64 -- ^ The offset (in bits) of the field from the start of the bitfield, -- counting from the least significant bit. - , biBitfieldIndex :: Int - -- ^ The struct field index corresponding to the overall bitfield, where - -- the index represents the number of bytes the bitfield is from the - -- start of the struct. + , biFieldByteOffset :: Crucible.Bytes + -- ^ The offset (in bytes) of the struct member in which this bitfield resides. , biBitfieldType :: Crucible.MemType -- ^ The 'Crucible.MemType' of the overall bitfield. } deriving Show --- | Returns @'Just' bi@ if SAW is able to find a field within a bitfield with --- the supplied name in the LLVM debug metadata. Returns 'Nothing' otherwise. -resolveSetupBitfieldIndex :: +-- | Given a pointer setup value and the name of a bitfield, attempt to +-- determine were in the struct that bitfield resides by examining +-- DWARF type metadata. +resolveSetupBitfield :: LLVMCrucibleContext arch {- ^ crucible context -} -> Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> Map AllocIndex Crucible.Ident {- ^ allocation type names -} -> SetupValue (LLVM arch) {- ^ pointer to struct -} -> String {- ^ field name -} -> - Maybe BitfieldIndex {- ^ information about bitfield -} -resolveSetupBitfieldIndex cc env nameEnv v n = - case resolveSetupValueInfo cc env nameEnv v of - L.Pointer (L.Structure xs) - | (fieldOffsetStartingFromStruct, bfInfo):_ <- - [ (fieldOffsetStartingFromStruct, bfInfo) - | L.StructFieldInfo - { L.sfiName = n' - , L.sfiOffset = fieldOffsetStartingFromStruct - , L.sfiBitfield = Just bfInfo - } <- xs - , n == n' + Except String BitfieldIndex {- ^ information about bitfield -} +resolveSetupBitfield cc env nameEnv v n = + do info <- resolveSetupValueInfo cc env nameEnv v + case findStruct info of + Nothing -> + throwError $ unlines $ + [ "Unable to resolve struct bitfield name: '" ++ show n ++ "'" + , "Could not resolve setup value debug information into a struct type." + , case info of + L.Unknown -> "Perhaps you need to compile with debug symbols enabled." + _ -> show info ] - -> do Crucible.PtrType symTy <- typeOfSetupValue cc env nameEnv v - Crucible.StructType si <- - let ?lc = lc - in either (\_ -> Nothing) Just $ Crucible.asMemType symTy - bfIndex <- - V.findIndex (\fi -> Crucible.bytesToBits (Crucible.fiOffset fi) - == fromIntegral (L.biBitfieldOffset bfInfo)) - (Crucible.siFields si) - let bfType = Crucible.fiType $ Crucible.siFields si V.! bfIndex - fieldOffsetStartingFromBitfield = - fieldOffsetStartingFromStruct - L.biBitfieldOffset bfInfo - pure $ BitfieldIndex { biFieldSize = L.biFieldSize bfInfo - , biFieldOffset = fieldOffsetStartingFromBitfield - , biBitfieldIndex = bfIndex - , biBitfieldType = bfType - } - - _ -> Nothing - where - lc = ccTypeCtx cc - --- | Like 'resolveSetupBitfieldIndex', but if SAW cannot find the supplied --- name, fail instead of returning 'Nothing'. -resolveSetupBitfieldIndexOrFail :: - Fail.MonadFail m => - LLVMCrucibleContext arch {- ^ crucible context -} -> - Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> + Just (snm, xs) -> + case [ (fieldOffsetStartingFromStruct, bfInfo) | L.StructFieldInfo + { L.sfiName = n' + , L.sfiOffset = fieldOffsetStartingFromStruct + , L.sfiBitfield = Just bfInfo + } <- xs, n == n' ] of + + [] -> throwError $ unlines $ + [ "Unable to resolve struct bitfield name: '" ++ n ++ "'"] ++ + [ "Struct with name '" ++ str ++ "' found." | Just str <- [snm] ] ++ + [ "The following bitfield names were found for this struct:" ] ++ + map ("- "++) [n' | L.StructFieldInfo{L.sfiName = n', L.sfiBitfield = Just{}} <- xs] + + ((fieldOffsetStartingFromStruct, bfInfo):_) -> + do memTy <- typeOfSetupValue cc env nameEnv v + case do Crucible.PtrType symTy <- pure memTy + Crucible.StructType si <- let ?lc = ccTypeCtx cc + in either (\_ -> Nothing) Just $ Crucible.asMemType symTy + fi <- V.find (\fi -> Crucible.bytesToBits (Crucible.fiOffset fi) + == fromIntegral (L.biBitfieldOffset bfInfo)) + (Crucible.siFields si) + let fieldOffsetStartingFromBitfield = + fieldOffsetStartingFromStruct - L.biBitfieldOffset bfInfo + pure $ BitfieldIndex { biFieldSize = L.biFieldSize bfInfo + , biFieldOffset = fieldOffsetStartingFromBitfield + , biBitfieldType = Crucible.fiType fi + , biFieldByteOffset = Crucible.fiOffset fi + } + of + Nothing -> + throwError $ unlines $ + [ "Found struct field name: '" ++ n ++ "'"] ++ + [ "in struct with name '" ++ str ++ "'." | Just str <- [snm] ] ++ + [ "However, the offset of this field found in the debug information could not" + , "be correlated with the computed LLVM type of the setup value, or the field" + , "is not a bitfield." + , show memTy + ] + + Just bfi -> return bfi + +-- | Attempt to compute the @MemType@ of a setup value. +typeOfSetupValue :: forall arch. + LLVMCrucibleContext arch {- ^ crucible context -} -> + Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> Map AllocIndex Crucible.Ident {- ^ allocation type names -} -> - SetupValue (LLVM arch) {- ^ pointer to struct -} -> - String {- ^ field name -} -> - m BitfieldIndex {- ^ field index -} -resolveSetupBitfieldIndexOrFail cc env nameEnv v n = - case resolveSetupBitfieldIndex cc env nameEnv v n of - Just i -> pure i - Nothing -> - let msg = "Unable to resolve field name: " ++ show n - in - fail $ - -- Show the user what fields were available (if any) - case resolveSetupValueInfo cc env nameEnv v of - L.Pointer (L.Structure xs) -> unlines $ - [ msg - , "The following bitfield names were found for this struct:" - ] ++ map ("- "++) [n' | L.StructFieldInfo{ L.sfiName = n' - , L.sfiBitfield = Just{} - } <- xs] - _ -> unlines [msg, "No field names were found for this struct"] - -typeOfSetupValue :: - Fail.MonadFail m => - LLVMCrucibleContext arch -> - Map AllocIndex LLVMAllocSpec -> - Map AllocIndex Crucible.Ident -> - SetupValue (LLVM arch) -> - m Crucible.MemType + SetupValue (LLVM arch) {- ^ value to compute the type of -} -> + Except String Crucible.MemType typeOfSetupValue cc env nameEnv val = - do let ?lc = ccTypeCtx cc - typeOfSetupValue' cc env nameEnv val - -typeOfSetupValue' :: forall m arch. - Fail.MonadFail m => - LLVMCrucibleContext arch -> - Map AllocIndex LLVMAllocSpec -> - Map AllocIndex Crucible.Ident -> - SetupValue (LLVM arch) -> - m Crucible.MemType -typeOfSetupValue' cc env nameEnv val = case val of SetupVar i -> case Map.lookup i env of - Nothing -> fail ("typeOfSetupValue: Unresolved prestate variable:" ++ show i) + Nothing -> throwError ("typeOfSetupValue: Unresolved prestate variable:" ++ show i) Just spec -> return (Crucible.PtrType (Crucible.MemType (spec ^. allocSpecType))) + SetupTerm tt -> case ttType tt of TypedTermSchema (Cryptol.Forall [] [] ty) -> case toLLVMType dl (Cryptol.evalValType mempty ty) of - Left err -> fail (toLLVMTypeErrToString err) + Left err -> throwError (toLLVMTypeErrToString err) Right memTy -> return memTy - tp -> fail $ unlines [ "typeOfSetupValue: expected monomorphic term" - , "instead got:" - , show (ppTypedTermType tp) - ] - SetupCast () v ltp -> - do memTy <- typeOfSetupValue cc env nameEnv v - case memTy of - Crucible.PtrType _symTy -> - case let ?lc = lc in Crucible.liftMemType (L.PtrTo ltp) of - Left err -> fail $ unlines [ "typeOfSetupValue: invalid type " ++ show ltp - , "Details:" - , err - ] - Right mt -> return mt + tp -> throwError $ unlines + [ "typeOfSetupValue: expected monomorphic term" + , "instead got:" + , show (ppTypedTermType tp) + ] - _ -> fail $ unwords $ - [ "typeOfSetupValue: tried to cast the type of a non-pointer value" - , "actual type of value: " ++ show memTy - ] SetupStruct () packed vs -> do memTys <- traverse (typeOfSetupValue cc env nameEnv) vs let si = Crucible.mkStructInfo dl packed memTys return (Crucible.StructType si) - SetupArray () [] -> fail "typeOfSetupValue: invalid empty llvm_array_value" + + SetupArray () [] -> throwError "typeOfSetupValue: invalid empty llvm_array_value" SetupArray () (v : vs) -> do memTy <- typeOfSetupValue cc env nameEnv v _memTys <- traverse (typeOfSetupValue cc env nameEnv) vs -- TODO: check that all memTys are compatible with memTy return (Crucible.ArrayType (fromIntegral (length (v:vs))) memTy) - SetupField () v n -> do - i <- resolveSetupFieldIndexOrFail cc env nameEnv v n - typeOfSetupValue' cc env nameEnv (SetupElem () v i) - SetupElem () v i -> + + SetupField () v n -> + do info <- resolveSetupValueInfo cc env nameEnv v + fld <- recoverStructFieldInfo cc env nameEnv v info n + pure $ Crucible.PtrType $ Crucible.MemType $ Crucible.fiType fld + + SetupUnion () v n -> + do info <- resolveSetupValueInfo cc env nameEnv (SetupUnion () v n) + case reverseDebugInfoType info of + Nothing -> throwError $ unlines + [ "Could not determine LLVM type from computed debug type information:" + , show info + ] + Just ltp -> typeOfSetupValue cc env nameEnv (SetupCast () v ltp) + + SetupCast () v ltp -> + do memTy <- typeOfSetupValue cc env nameEnv v + case memTy of + Crucible.PtrType _symTy -> + case let ?lc = lc in Crucible.liftMemType (L.PtrTo ltp) of + Left err -> throwError $ unlines + [ "typeOfSetupValue: invalid type " ++ show ltp + , "Details:" + , err + ] + Right mt -> pure mt + + _ -> throwError $ unwords $ + [ "typeOfSetupValue: tried to cast the type of a non-pointer value" + , "actual type of value: " ++ show memTy + ] + + SetupElem () v i -> do do memTy <- typeOfSetupValue cc env nameEnv v let msg = "typeOfSetupValue: llvm_elem requires pointer to struct or array, found " ++ show memTy case memTy of @@ -370,7 +498,7 @@ typeOfSetupValue' cc env nameEnv val = case memTy' of Crucible.ArrayType n memTy'' | fromIntegral i <= n -> return (Crucible.PtrType (Crucible.MemType memTy'')) - | otherwise -> fail $ unwords $ + | otherwise -> throwError $ unwords $ [ "typeOfSetupValue: array type index out of bounds" , "(index: " ++ show i ++ ")" , "(array length: " ++ show n ++ ")" @@ -378,16 +506,18 @@ typeOfSetupValue' cc env nameEnv val = Crucible.StructType si -> case Crucible.siFieldInfo si i of Just fi -> return (Crucible.PtrType (Crucible.MemType (Crucible.fiType fi))) - Nothing -> fail $ "typeOfSetupValue: struct type index out of bounds: " ++ show i - _ -> fail msg - Left err -> fail (unlines [msg, "Details:", err]) - _ -> fail msg + Nothing -> throwError $ "typeOfSetupValue: struct type index out of bounds: " ++ show i + _ -> throwError msg + Left err -> throwError (unlines [msg, "Details:", err]) + _ -> throwError msg + SetupNull () -> -- We arbitrarily set the type of NULL to void*, because a) it -- is memory-compatible with any type that NULL can be used at, -- and b) it prevents us from doing a type-safe dereference -- operation. return (Crucible.PtrType Crucible.VoidType) + -- A global and its initializer have the same type. SetupGlobal () name -> do let m = ccLLVMModuleAST cc @@ -395,37 +525,42 @@ typeOfSetupValue' cc env nameEnv val = [ (L.decName d, L.decFunType d) | d <- L.modDeclares m ] ++ [ (L.defName d, L.defFunType d) | d <- L.modDefines m ] case lookup (L.Symbol name) tys of - Nothing -> fail $ "typeOfSetupValue: unknown global " ++ show name + Nothing -> throwError $ "typeOfSetupValue: unknown global " ++ show name Just ty -> case let ?lc = lc in Crucible.liftType ty of - Left err -> fail $ unlines [ "typeOfSetupValue: invalid type " ++ show ty - , "Details:" - , err - ] + Left err -> throwError $ unlines + [ "typeOfSetupValue: invalid type " ++ show ty + , "Details:" + , err + ] Right symTy -> return (Crucible.PtrType symTy) + SetupGlobalInitializer () name -> do case Map.lookup (L.Symbol name) (Crucible.globalInitMap $ ccLLVMModuleTrans cc) of Just (g, _) -> case let ?lc = lc in Crucible.liftMemType (L.globalType g) of - Left err -> fail $ unlines [ "typeOfSetupValue: invalid type " ++ show (L.globalType g) - , "Details:" - , err - ] + Left err -> throwError $ unlines + [ "typeOfSetupValue: invalid type " ++ show (L.globalType g) + , "Details:" + , err + ] Right memTy -> return memTy - Nothing -> fail $ "resolveSetupVal: global not found: " ++ name + Nothing -> throwError $ "resolveSetupVal: global not found: " ++ name where lc = ccTypeCtx cc dl = Crucible.llvmDataLayout lc -resolveSetupElemIndexOrFail :: - Fail.MonadFail m => - LLVMCrucibleContext arch {- ^ crucible context -} -> - Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> +-- | Given a pointer setup value that points to an aggregate +-- type (struct or array), attempt to compute the byte offset of +-- the nth element of that aggregate structure. +resolveSetupElemOffset :: + LLVMCrucibleContext arch {- ^ crucible context -} -> + Map AllocIndex LLVMAllocSpec {- ^ allocation types -} -> Map AllocIndex Crucible.Ident {- ^ allocation type names -} -> - SetupValue (LLVM arch) {- ^ base pointer -} -> + SetupValue (LLVM arch) {- ^ base pointer -} -> Int {- ^ element index -} -> - m Crucible.Bytes {- ^ element offset -} -resolveSetupElemIndexOrFail cc env nameEnv v i = do + Except String Crucible.Bytes {- ^ element offset -} +resolveSetupElemOffset cc env nameEnv v i = do do memTy <- typeOfSetupValue cc env nameEnv v let msg = "resolveSetupVal: llvm_elem requires pointer to struct or array, found " ++ show memTy case memTy of @@ -438,10 +573,10 @@ resolveSetupElemIndexOrFail cc env nameEnv v i = do Crucible.StructType si -> case Crucible.siFieldOffset si i of Just d -> return d - Nothing -> fail $ "resolveSetupVal: struct type index out of bounds: " ++ show (i, memTy') - _ -> fail msg - Left err -> fail $ unlines [msg, "Details:", err] - _ -> fail msg + Nothing -> throwError $ "resolveSetupVal: struct type index out of bounds: " ++ show (i, memTy') + _ -> throwError msg + Left err -> throwError $ unlines [msg, "Details:", err] + _ -> throwError msg where lc = ccTypeCtx cc dl = Crucible.llvmDataLayout lc @@ -455,7 +590,7 @@ newtype W4EvalTactic = W4EvalTactic { doW4Eval :: Bool } deriving (Eq, Ord, Show) -- | Translate a SetupValue into a Crucible LLVM value, resolving --- references +-- references. resolveSetupVal :: forall arch. (?w4EvalTactic :: W4EvalTactic, Crucible.HasPtrWidth (Crucible.ArchWidth arch)) => LLVMCrucibleContext arch -> @@ -479,6 +614,9 @@ resolveSetupVal cc mem env tyenv nameEnv val = -- NB, SetupCast values should always be pointers. Pointer casts have no -- effect on the actual computed LLVMVal. SetupCast () v _lty -> resolveSetupVal cc mem env tyenv nameEnv v + -- NB, SetupUnion values should always be pointers. Pointer casts have no + -- effect on the actual computed LLVMVal. + SetupUnion () v _n -> resolveSetupVal cc mem env tyenv nameEnv v SetupStruct () packed vs -> do vals <- mapM (resolveSetupVal cc mem env tyenv nameEnv) vs let tps = map Crucible.llvmValStorableType vals @@ -493,10 +631,19 @@ resolveSetupVal cc mem env tyenv nameEnv val = let tp = Crucible.llvmValStorableType (V.head vals) return $ Crucible.LLVMValArray tp vals SetupField () v n -> do - i <- resolveSetupFieldIndexOrFail cc tyenv nameEnv v n - resolveSetupVal cc mem env tyenv nameEnv (SetupElem () v i) + do fld <- exceptToFail $ + do info <- resolveSetupValueInfo cc tyenv nameEnv v + recoverStructFieldInfo cc tyenv nameEnv v info n + ptr <- resolveSetupVal cc mem env tyenv nameEnv v + case ptr of + Crucible.LLVMValInt blk off -> + do delta <- W4.bvLit sym (W4.bvWidth off) (Crucible.bytesToBV (W4.bvWidth off) (Crucible.fiOffset fld)) + off' <- W4.bvAdd sym off delta + return (Crucible.LLVMValInt blk off') + _ -> fail "resolveSetupVal: llvm_field requires pointer value" + SetupElem () v i -> - do delta <- resolveSetupElemIndexOrFail cc tyenv nameEnv v i + do delta <- exceptToFail (resolveSetupElemOffset cc tyenv nameEnv v i) ptr <- resolveSetupVal cc mem env tyenv nameEnv v case ptr of Crucible.LLVMValInt blk off -> @@ -547,21 +694,21 @@ resolveSetupValBitfield :: IO (BitfieldIndex, LLVMVal) resolveSetupValBitfield cc mem env tyenv nameEnv val fieldName = do let sym = cc^.ccSym - lval <- resolveSetupVal cc mem env tyenv nameEnv val - bfIndex <- resolveSetupBitfieldIndexOrFail cc tyenv nameEnv val fieldName - delta <- resolveSetupElemIndexOrFail cc tyenv nameEnv val (biBitfieldIndex bfIndex) + lval <- resolveSetupVal cc mem env tyenv nameEnv val + bfIndex <- exceptToFail (resolveSetupBitfield cc tyenv nameEnv val fieldName) + let delta = biFieldByteOffset bfIndex offsetLval <- case lval of - Crucible.LLVMValInt blk off -> - do deltaBV <- W4.bvLit sym (W4.bvWidth off) (Crucible.bytesToBV (W4.bvWidth off) delta) - off' <- W4.bvAdd sym off deltaBV - return (Crucible.LLVMValInt blk off') - _ -> fail "resolveSetupValBitfield: expected a pointer value" - pure (bfIndex, offsetLval) + Crucible.LLVMValInt blk off -> + do deltaBV <- W4.bvLit sym (W4.bvWidth off) (Crucible.bytesToBV (W4.bvWidth off) delta) + off' <- W4.bvAdd sym off deltaBV + return (Crucible.LLVMValInt blk off') + _ -> fail "resolveSetupValBitfield: expected a pointer value" + return (bfIndex, offsetLval) resolveTypedTerm :: (?w4EvalTactic :: W4EvalTactic, Crucible.HasPtrWidth (Crucible.ArchWidth arch)) => LLVMCrucibleContext arch -> - TypedTerm -> + TypedTerm -> IO LLVMVal resolveTypedTerm cc tm = case ttType tm of diff --git a/src/SAWScript/Crucible/LLVM/X86.hs b/src/SAWScript/Crucible/LLVM/X86.hs index 0d0a02c3be..e95786b426 100644 --- a/src/SAWScript/Crucible/LLVM/X86.hs +++ b/src/SAWScript/Crucible/LLVM/X86.hs @@ -22,9 +22,11 @@ Stability : provisional {-# Language ConstraintKinds #-} {-# Language GeneralizedNewtypeDeriving #-} {-# Language TemplateHaskell #-} +{-# Language ViewPatterns #-} module SAWScript.Crucible.LLVM.X86 ( llvm_verify_x86 + , llvm_verify_fixpoint_x86 , defaultStackBaseAlign ) where @@ -46,6 +48,7 @@ import qualified Data.Set as Set import Data.Text (Text) import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Time.Clock (getCurrentTime, diffUTCTime) +import qualified Data.List as List import qualified Data.Map as Map import Data.Map (Map) import Data.Maybe @@ -53,12 +56,15 @@ import Data.Maybe import qualified Text.LLVM.AST as LLVM import Data.Parameterized.Some +import qualified Data.Parameterized.Map as MapF import Data.Parameterized.NatRepr -import Data.Parameterized.Context hiding (view) +import Data.Parameterized.Context hiding (view, zipWithM) import Verifier.SAW.CryptolEnv import Verifier.SAW.FiniteValue import Verifier.SAW.Name (toShortName) +import Verifier.SAW.Prelude +import Verifier.SAW.Recognizer import Verifier.SAW.SharedTerm import Verifier.SAW.TypedTerm @@ -113,6 +119,7 @@ import qualified Lang.Crucible.LLVM.Extension as C.LLVM import qualified Lang.Crucible.LLVM.Intrinsics as C.LLVM import qualified Lang.Crucible.LLVM.MemModel as C.LLVM import qualified Lang.Crucible.LLVM.MemType as C.LLVM +import qualified Lang.Crucible.LLVM.SimpleLoopFixpoint as Crucible.LLVM.Fixpoint import qualified Lang.Crucible.LLVM.Translation as C.LLVM import qualified Lang.Crucible.LLVM.TypeContext as C.LLVM @@ -297,7 +304,36 @@ llvm_verify_x86 :: LLVMCrucibleSetupM () {- ^ Specification to verify against -} -> ProofScript () {- ^ Tactic used to use when discharging goals -} -> TopLevel (SomeLLVM MS.ProvedSpec) -llvm_verify_x86 (Some (llvmModule :: LLVMModule x)) path nm globsyms checkSat setup tactic +llvm_verify_x86 llvmModule path nm globsyms checkSat = + llvm_verify_x86_common llvmModule path nm globsyms checkSat Nothing + +-- | Verify that an x86_64 function (following the System V AMD64 ABI) conforms +-- to an LLVM specification. This allows for compositional verification of LLVM +-- functions that call x86_64 functions (but not the other way around). +llvm_verify_fixpoint_x86 :: + Some LLVMModule {- ^ Module to associate with method spec -} -> + FilePath {- ^ Path to ELF file -} -> + String {- ^ Function's symbol in ELF file -} -> + [(String, Integer)] {- ^ Global variable symbol names and sizes (in bytes) -} -> + Bool {- ^ Whether to enable path satisfiability checking -} -> + TypedTerm {- ^ Function specifying the loop -} -> + LLVMCrucibleSetupM () {- ^ Specification to verify against -} -> + ProofScript () {- ^ Tactic used to use when discharging goals -} -> + TopLevel (SomeLLVM MS.ProvedSpec) +llvm_verify_fixpoint_x86 llvmModule path nm globsyms checkSat f = + llvm_verify_x86_common llvmModule path nm globsyms checkSat (Just f) + +llvm_verify_x86_common :: + Some LLVMModule {- ^ Module to associate with method spec -} -> + FilePath {- ^ Path to ELF file -} -> + String {- ^ Function's symbol in ELF file -} -> + [(String, Integer)] {- ^ Global variable symbol names and sizes (in bytes) -} -> + Bool {- ^ Whether to enable path satisfiability checking -} -> + Maybe TypedTerm -> + LLVMCrucibleSetupM () {- ^ Specification to verify against -} -> + ProofScript () {- ^ Tactic used to use when discharging goals -} -> + TopLevel (SomeLLVM MS.ProvedSpec) +llvm_verify_x86_common (Some (llvmModule :: LLVMModule x)) path nm globsyms checkSat maybeFixpointFunc setup tactic | Just Refl <- testEquality (C.LLVM.X86Repr $ knownNat @64) . C.LLVM.llvmArch $ modTrans llvmModule ^. C.LLVM.transContext = do start <- io getCurrentTime @@ -460,7 +496,14 @@ llvm_verify_x86 (Some (llvmModule :: LLVMModule x)) path nm globsyms checkSat se else pure [] - let execFeatures = psatf + simpleLoopFixpointFeature <- + case maybeFixpointFunc of + Nothing -> return [] + Just func -> + do f <- liftIO (setupSimpleLoopFixpointFeature sym sc sawst cfg mvar func) + return [f] + + let execFeatures = simpleLoopFixpointFeature ++ psatf liftIO $ C.executeCrucible execFeatures initial >>= \case C.FinishedResult{} -> pure () @@ -484,6 +527,79 @@ llvm_verify_x86 (Some (llvmModule :: LLVMModule x)) path nm globsyms checkSat se | otherwise = fail "LLVM module must be 64-bit" + + +setupSimpleLoopFixpointFeature :: + ( sym ~ W4.B.ExprBuilder n st fs + , C.IsSymInterface sym + , ?memOpts::C.LLVM.MemOptions + , C.LLVM.HasLLVMAnn sym + ) => + sym -> + SharedContext -> + SAWCoreState n -> + C.CFG ext blocks init ret -> + C.GlobalVar C.LLVM.Mem -> + TypedTerm -> + IO (C.ExecutionFeature p sym ext rtp) + +setupSimpleLoopFixpointFeature sym sc sawst cfg mvar func = + Crucible.LLVM.Fixpoint.simpleLoopFixpoint sym cfg mvar fixpoint_func + + where + fixpoint_func fixpoint_substitution condition = + do let fixpoint_substitution_as_list = reverse $ MapF.toList fixpoint_substitution + let body_exprs = map (mapSome $ Crucible.LLVM.Fixpoint.bodyValue) (MapF.elems fixpoint_substitution) + let uninterpreted_constants = foldMap + (viewSome $ Set.map (mapSome $ W4.varExpr sym) . W4.exprUninterpConstants sym) + (Some condition : body_exprs) + let filtered_uninterpreted_constants = Set.toList $ Set.filter + (\(Some variable) -> + not (List.isPrefixOf "creg_join_var" $ show $ W4.printSymExpr variable) + && not (List.isPrefixOf "cmem_join_var" $ show $ W4.printSymExpr variable) + && not (List.isPrefixOf "cundefined" $ show $ W4.printSymExpr variable) + && not (List.isPrefixOf "calign_amount" $ show $ W4.printSymExpr variable)) + uninterpreted_constants + body_tms <- mapM (viewSome $ toSC sym sawst) filtered_uninterpreted_constants + implicit_parameters <- mapM (scExtCns sc) $ Set.toList $ foldMap getAllExtSet body_tms + + arguments <- forM fixpoint_substitution_as_list $ \(MapF.Pair _ fixpoint_entry) -> + toSC sym sawst $ Crucible.LLVM.Fixpoint.headerValue fixpoint_entry + applied_func <- scApplyAll sc (ttTerm func) $ implicit_parameters ++ arguments + applied_func_selectors <- forM [1 .. (length fixpoint_substitution_as_list)] $ \i -> + scTupleSelector sc applied_func i (length fixpoint_substitution_as_list) + result_substitution <- MapF.fromList <$> zipWithM + (\(MapF.Pair variable _) applied_func_selector -> + MapF.Pair variable <$> bindSAWTerm sym sawst (W4.exprType variable) applied_func_selector) + fixpoint_substitution_as_list + applied_func_selectors + + explicit_parameters <- forM fixpoint_substitution_as_list $ \(MapF.Pair variable _) -> + toSC sym sawst variable + inner_func <- case asConstant (ttTerm func) of + Just (_, Just (asApplyAll -> (isGlobalDef "Prelude.fix" -> Just (), [_, inner_func]))) -> + return inner_func + _ -> fail $ "not Prelude.fix: " ++ showTerm (ttTerm func) + func_body <- betaNormalize sc + =<< scApplyAll sc inner_func ((ttTerm func) : (implicit_parameters ++ explicit_parameters)) + + step_arguments <- forM fixpoint_substitution_as_list $ \(MapF.Pair _ fixpoint_entry) -> + toSC sym sawst $ Crucible.LLVM.Fixpoint.bodyValue fixpoint_entry + tail_applied_func <- scApplyAll sc (ttTerm func) $ implicit_parameters ++ step_arguments + explicit_parameters_tuple <- scTuple sc explicit_parameters + let lhs = Prelude.last step_arguments + w <- scNat sc 64 + rhs <- scBvMul sc w (head implicit_parameters) =<< scBvNat sc w =<< scNat sc 128 + loop_condition <- scBvULt sc w lhs rhs + output_tuple_type <- scTupleType sc =<< mapM (scTypeOf sc) explicit_parameters + loop_body <- scIte sc output_tuple_type loop_condition tail_applied_func explicit_parameters_tuple + + induction_step_condition <- scEq sc loop_body func_body + result_condition <- bindSAWTerm sym sawst W4.BaseBoolRepr induction_step_condition + + return (result_substitution, result_condition) + + -------------------------------------------------------------------------------- -- ** Computing the CFG @@ -863,8 +979,8 @@ setArgs env tyenv nameEnv args cc <- use x86CrucibleContext mem <- use x86Mem let - setRegSetupValue rs (reg, sval) = typeOfSetupValue cc tyenv nameEnv sval >>= \ty -> - case ty of + setRegSetupValue rs (reg, sval) = + exceptToFail (typeOfSetupValue cc tyenv nameEnv sval) >>= \case C.LLVM.PtrType _ -> do val <- C.LLVM.unpackMemValue sym (C.LLVM.LLVMPointerRepr $ knownNat @64) =<< resolveSetupVal cc mem env tyenv nameEnv sval diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index 7df26f6d31..93dc3e3118 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -67,6 +67,7 @@ import SAWScript.Value import SAWScript.Proof (newTheoremDB) import SAWScript.Prover.Rewrite(basic_ss) import SAWScript.Prover.Exporter +import SAWScript.Prover.MRSolver (emptyMREnv) import Verifier.SAW.Conversion --import Verifier.SAW.PrettySExp import Verifier.SAW.Prim (rethrowEvalError) @@ -474,6 +475,7 @@ buildTopLevelEnv proxy opts = , rwDocs = primDocEnv primsAvail , rwCryptol = ce0 , rwMonadify = Monadify.defaultMonEnv + , rwMRSolverEnv = emptyMREnv , rwProofs = [] , rwPPOpts = SAWScript.Value.defaultPPOpts , rwJVMTrans = jvmTrans @@ -1071,10 +1073,15 @@ primitives = Map.fromList , "object that can be passed to 'read_sbv'." ] + , prim "is_convertible" "Term -> Term -> TopLevel Bool" + (pureVal isConvertiblePrim) + Current + [ "Returns true iff the two terms are convertible." ] + , prim "check_convertible" "Term -> Term -> TopLevel ()" (pureVal checkConvertiblePrim) Current - [ "Check if two terms are convertible." ] + [ "Check if two terms are convertible and print the result." ] , prim "replace" "Term -> Term -> Term -> TopLevel Term" (pureVal replacePrim) @@ -1871,6 +1878,20 @@ primitives = Map.fromList , "Cryptol source files." ] + , prim "cryptol_add_prim" "String -> String -> Term -> TopLevel ()" + (pureVal cryptol_add_prim) + Experimental + [ "cryptol_add_prim mod nm trm sets the translation of Cryptol primitive" + , "nm in module mod to trm" + ] + + , prim "cryptol_add_prim_type" "String -> String -> Term -> TopLevel ()" + (pureVal cryptol_add_prim_type) + Experimental + [ "cryptol_add_prim_type mod nm tp sets the translation of Cryptol" + , "primitive type nm in module mod to tp" + ] + -- Java stuff , prim "java_bool" "JavaType" @@ -2225,6 +2246,13 @@ primitives = Map.fromList [ "Parse a Term from a String in SAWCore syntax." ] + , prim "parse_core_mod" "String -> String -> Term" + (funVal2 parse_core_mod) + Current + [ "Parse a Term from the second supplied String in SAWCore syntax," + , "relative to the module specified by the first String" + ] + , prim "prove_core" "ProofScript () -> String -> TopLevel Theorem" (pureVal prove_core) Current @@ -2247,6 +2275,13 @@ primitives = Map.fromList Current [ "Create a theorem from the type of the given core expression." ] + , prim "specialize_theorem" "Theorem -> [Term] -> TopLevel Theorem" + (pureVal specialize_theorem) + Experimental + [ "Specialize a theorem by instantiating universal quantifiers" + , "with the given list of terms." + ] + , prim "get_opt" "Int -> String" (funVal1 get_opt) Current @@ -2693,6 +2728,15 @@ primitives = Map.fromList Experimental [ "Legacy alternative name for `llvm_verify_x86`." ] + , prim "llvm_verify_fixpoint_x86" + "LLVMModule -> String -> String -> [(String, Int)] -> Bool -> Term -> LLVMSetup () -> ProofScript () -> TopLevel LLVMSpec" + (pureVal llvm_verify_fixpoint_x86) + Experimental + [ "An experimental variant of 'llvm_verify_x86'. This variant can prove some properties" + , "involving simple loops with the help of a user-provided term that describes how" + , "the live variables in the loop evolve as the loop computes." + ] + , prim "enable_x86_what4_hash_consing" "TopLevel ()" (pureVal enable_x86_what4_hash_consing) Experimental @@ -2815,12 +2859,23 @@ primitives = Map.fromList Current [ "Legacy alternative name for `llvm_elem`." ] + , prim "llvm_union" + "SetupValue -> String -> SetupValue" + (pureVal CIR.anySetupUnion) + Current + [ "Turn a SetupValue representing a union pointer into" + , "a pointer to one of the branches of the union by field name." + , "Requires debug symbols to resolve union field names." + ] + , prim "llvm_field" "SetupValue -> String -> SetupValue" (pureVal CIR.anySetupField) Current [ "Turn a SetupValue representing a struct pointer into" - , "a pointer to an element of the struct by field name." ] + , "a pointer to an element of the struct by field name." + , "Requires debug symbols to resolve struct field names." + ] , prim "crucible_field" "SetupValue -> String -> SetupValue" (pureVal CIR.anySetupField) diff --git a/src/SAWScript/Proof.hs b/src/SAWScript/Proof.hs index 906c92ddee..75dc6d81b2 100644 --- a/src/SAWScript/Proof.hs +++ b/src/SAWScript/Proof.hs @@ -9,6 +9,7 @@ Stability : provisional {-# LANGUAGE BlockArguments #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ParallelListComp #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ViewPatterns #-} @@ -22,6 +23,7 @@ module SAWScript.Proof , betaReduceProp , falseProp , termToProp + , termToMaybeProp , propToTerm , propToRewriteRule , propSize @@ -52,6 +54,7 @@ module SAWScript.Proof , proofByTerm , constructTheorem , validateTheorem + , specializeTheorem , Evidence(..) , checkEvidence @@ -112,7 +115,7 @@ import Verifier.SAW.TypedAST import Verifier.SAW.TypedTerm import Verifier.SAW.FiniteValue (FirstOrderValue) import Verifier.SAW.Term.Pretty (SawDoc) -import Verifier.SAW.SCTypeCheck (scTypeCheckError) +import qualified Verifier.SAW.SCTypeCheck as TC import Verifier.SAW.Simulator.Concrete (evalSharedTerm) import Verifier.SAW.Simulator.Value (asFirstOrderTypeValue, Value(..), TValue(..)) @@ -141,10 +144,9 @@ unProp (Prop tm) = tm -- is a sort. termToProp :: SharedContext -> Term -> IO Prop termToProp sc tm = - do mmap <- scGetModuleMap sc - ty <- scTypeOf sc tm - case evalSharedTerm mmap mempty mempty ty of - TValue (VSort s) | s == propSort -> return (Prop tm) + do ty <- scWhnf sc =<< scTypeOf sc tm + case asSort ty of + Just s | s == propSort -> return (Prop tm) _ -> case asLambda tm of Just _ -> @@ -155,6 +157,15 @@ termToProp sc tm = Nothing -> fail $ unlines [ "termToProp: Term is not a proposition", showTerm tm, showTerm ty ] +-- | Turn a saw-core term into a proposition under the type-as-propositions +-- regime. The given term must be a type, which means that its own type +-- is a sort. If it is not, return @Nothing@. +termToMaybeProp :: SharedContext -> Term -> IO (Maybe Prop) +termToMaybeProp sc tm = + do ty <- scWhnf sc =<< scTypeOf sc tm + case asSort ty of + Just s | s == propSort -> return (Just (Prop tm)) + _ -> return Nothing -- | Turn a boolean-valued saw-core term into a proposition by asserting -- that it is equal to the true boolean value. Generalize the proposition @@ -231,7 +242,7 @@ evalProp sc unints (Prop p) = body' <- case asEqTrue body of Just t -> pure t - Nothing -> fail "goal_eval: expected EqTrue" + Nothing -> fail ("goal_eval: expected EqTrue\n" ++ scPrettyTerm defaultPPOpts p) ecs <- traverse (\(nm, ty) -> scFreshEC sc nm ty) args vars <- traverse (scExtCns sc) ecs @@ -341,6 +352,7 @@ reachableTheorems db roots = | otherwise = panic "reachableTheorems" ["Could not find theorem with identifier", show (indexValue curr)] + -- | Check that the purported theorem is valid. -- -- This checks that the given theorem object does not correspond @@ -415,10 +427,12 @@ data Evidence | SplitEvidence Evidence Evidence -- | This type of evidence is produced when a previously-proved theorem is - -- applied via backward reasoning to prove a goal. Some of the hypotheses - -- of the theorem may be discharged via the included list of evidence, and - -- then the proposition must match the conclusion of the theorem. - | ApplyEvidence Theorem [Evidence] + -- applied via backward reasoning to prove a goal. Pi-quantified variables + -- of the theorem may be specialized either by giving an explicit @Term@ to + -- instantiate the variable, or by giving @Evidence@ for @Prop@ hypotheses. + -- After specializing the given @Theorem@ the result must match the + -- current goal. + | ApplyEvidence Theorem [Either Term Evidence] -- | This type of evidence is used to prove an implication. The included -- proposition must match the hypothesis of the goal, and the included @@ -572,6 +586,29 @@ constructTheorem sc db p e loc ploc rsn elapsed = , _thmSummary = sy } + +-- | Given a theorem with quantified variables, build a new theorem that +-- specializes the leading quantifiers with the given terms. +-- This will fail if the given terms to not match the quantifier structure +-- of the given theorem. +specializeTheorem :: SharedContext -> TheoremDB -> Pos -> Text -> Theorem -> [Term] -> IO Theorem +specializeTheorem _sc _db _loc _rsn thm [] = return thm +specializeTheorem sc db loc rsn thm ts0 = + do let p0 = unProp (_thmProp thm) + res <- TC.runTCM (loop p0 ts0) sc Nothing [] + case res of + Left err -> fail (unlines (["specialize_theorem: failed to specialize"] ++ TC.prettyTCError err)) + Right p' -> + constructTheorem sc db (Prop p') (ApplyEvidence thm (map Left ts0)) loc Nothing rsn 0 + + where + loop p [] = return p + loop p (t:ts) = + do prop <- liftIO (scSort sc propSort) + t' <- TC.typeInferComplete t + p' <- TC.applyPiTyped (TC.NotFuncTypeInApp (TC.TypedTerm p prop) t') p t' + loop p' ts + -- | Admit the given theorem without evidence. -- The provided message allows the user to -- explain why this proposition is being admitted. @@ -688,14 +725,19 @@ psStats :: ProofState -> SolverStats psStats = _psStats -- | Verify that the given evidence in fact supports the given proposition. --- Returns the identifers of all the theorems depened on while checking evidence. +-- Returns the identifers of all the theorems depended on while checking evidence. checkEvidence :: SharedContext -> TheoremDB -> Evidence -> Prop -> IO (Set TheoremNonce, TheoremSummary) checkEvidence sc db = \e p -> do hyps <- Map.keysSet <$> readIORef (theoremMap db) check hyps e p where checkApply _hyps (Prop p) [] = return (mempty, mempty, p) - checkApply hyps (Prop p) (e:es) + + -- Check a theorem applied to "Evidence". + -- The given prop must be an implication + -- (i.e., nondependent Pi quantifying over a Prop) + -- and the given evidence must match the expected prop. + checkApply hyps (Prop p) (Right e:es) | Just (_lnm, tp, body) <- asPi p , looseVars body == emptyBitSet = do (d1,sy1) <- check hyps e =<< termToProp sc tp @@ -706,6 +748,18 @@ checkEvidence sc db = \e p -> do hyps <- Map.keysSet <$> readIORef (theoremMap d , showTerm p ] + -- Check a theorem applied to a term. This explicity instantiates + -- a Pi binder with the given term. + checkApply hyps (Prop p) (Left tm:es) = + do propTerm <- scSort sc propSort + let m = do tm' <- TC.typeInferComplete tm + let err = TC.NotFuncTypeInApp (TC.TypedTerm p propTerm) tm' + TC.applyPiTyped err p tm' + res <- TC.runTCM m sc Nothing [] + case res of + Left msg -> fail (unlines (TC.prettyTCError msg)) + Right p' -> checkApply hyps (Prop p') es + checkTheorem :: Set TheoremNonce -> Theorem -> IO () checkTheorem hyps (LocalAssumption p loc n) = unless (Set.member n hyps) $ fail $ unlines @@ -722,7 +776,7 @@ checkEvidence sc db = \e p -> do hyps <- Map.keysSet <$> readIORef (theoremMap d IO (Set TheoremNonce, TheoremSummary) check hyps e p@(Prop ptm) = case e of ProofTerm tm -> - do ty <- scTypeCheckError sc tm + do ty <- TC.scTypeCheckError sc tm ok <- scConvertible sc True ptm ty unless ok $ fail $ unlines [ "Proof term does not prove the required proposition" @@ -1003,34 +1057,44 @@ propToSATQuery sc unintSet prop = Just fot -> filterFirstOrderVars mmap (Map.insert e fot fovars) absvars es processTerm mmap vars xs tm = - case asPi tm of - Just (lnm, tp, body) - | Just x <- asEqTrue tp - , looseVars body == emptyBitSet -> - do processTerm mmap vars (x:xs) body - - -- TODO? Allow universal hypotheses... - - | otherwise -> - case evalFOT mmap tp of - Nothing -> fail ("propToSATQuery: expected first order type: " ++ showTerm tp) - Just fot -> - do ec <- scFreshEC sc lnm tp - etm <- scExtCns sc ec - body' <- instantiateVar sc 0 etm body - processTerm mmap (Map.insert ec fot vars) xs body' - - Nothing -> - case asEqTrue tm of - Nothing -> fail $ "propToSATQuery: expected EqTrue, actual " ++ showTerm tm - Just tmBool -> - do tmNeg <- scNot sc tmBool - return (vars, reverse (tmNeg:xs)) + do -- TODO: I would like to WHNF here, but that evalutes too aggressively + -- because scWhnf evaluates strictly through the `Eq` datatype former. + -- This breaks some proof examples by unfolding things that need to + -- be uninterpreted. + -- tm' <- scWhnf sc tm + let tm' = tm + + case asPi tm' of + Just (lnm, tp, body) -> + do -- same issue with WHNF + -- tp' <- scWhnf sc tp + let tp' = tp + case asEqTrue tp' of + Just x | looseVars body == emptyBitSet -> + processTerm mmap vars (x:xs) body + + -- TODO? Allow universal hypotheses... + + _ -> + case evalFOT mmap tp' of + Nothing -> fail ("propToSATQuery: expected first order type: " ++ showTerm tp') + Just fot -> + do ec <- scFreshEC sc lnm tp' + etm <- scExtCns sc ec + body' <- instantiateVar sc 0 etm body + processTerm mmap (Map.insert ec fot vars) xs body' + + Nothing -> + case asEqTrue tm' of + Nothing -> fail $ "propToSATQuery: expected EqTrue, actual " ++ showTerm tm' + Just tmBool -> + do tmNeg <- scNot sc tmBool + return (vars, reverse (tmNeg:xs)) -- | Given a goal to prove, attempt to apply the given proposition, producing -- new subgoals for any necessary hypotheses of the proposition. Returns -- @Nothing@ if the given proposition does not apply to the goal. -goalApply :: SharedContext -> Prop-> ProofGoal -> IO (Maybe [ProofGoal]) +goalApply :: SharedContext -> Prop -> ProofGoal -> IO (Maybe [Either Term Prop]) goalApply sc rule goal = applyFirst (asPiLists (unProp rule)) where @@ -1042,17 +1106,22 @@ goalApply sc rule goal = applyFirst (asPiLists (unProp rule)) Just inst -> do let inst' = [ Map.lookup i inst | i <- take (length ruleArgs) [0..] ] dummy <- scUnitType sc - let mkNewGoals (Nothing : mts) ((_, prop) : args) = + let mkNewGoals (Nothing : mts) ((nm, prop) : args) = do c0 <- instantiateVarList sc 0 (map (fromMaybe dummy) mts) prop - cs <- mkNewGoals mts args - return (Prop c0 : cs) - mkNewGoals (Just _ : mts) (_ : args) = - mkNewGoals mts args + mp <- termToMaybeProp sc c0 + case mp of + Nothing -> + fail ("goal_apply: could not find instantiation for " ++ show nm) + Just p -> + do cs <- mkNewGoals mts args + return (Right p : cs) + mkNewGoals (Just tm : mts) (_ : args) = + do cs <- mkNewGoals mts args + return (Left tm : cs) mkNewGoals _ _ = return [] + newgoalterms <- mkNewGoals inst' (reverse ruleArgs) - -- TODO, change the "ty" field to list the hypotheses? - let newgoals = reverse [ goal { goalProp = t } | t <- newgoalterms ] - return (Just newgoals) + return (Just (reverse newgoalterms)) asPiLists :: Term -> [([(Text, Term)], Term)] asPiLists t = @@ -1112,8 +1181,20 @@ tacticApply :: (F.MonadFail m, MonadIO m) => SharedContext -> Theorem -> Tactic tacticApply sc thm = Tactic \goal -> liftIO (goalApply sc (thmProp thm) goal) >>= \case Nothing -> fail "apply tactic failed: no match" - Just newgoals -> - return ((), mempty, newgoals, pure . ApplyEvidence thm) + Just newterms -> + let newgoals = + [ goal{ goalProp = p, goalType = goalType goal ++ ".subgoal" ++ show i } + | Right p <- newterms + | i <- [0::Integer ..] + ] in + return ((), mempty, newgoals, \es -> ApplyEvidence thm <$> processEvidence newterms es) + + where + processEvidence :: [Either Term Prop] -> [Evidence] -> IO [Either Term Evidence] + processEvidence (Left tm : xs) es = (Left tm :) <$> processEvidence xs es + processEvidence (Right _ : xs) (e:es) = (Right e :) <$> processEvidence xs es + processEvidence [] [] = pure [] + processEvidence _ _ = fail "apply tactic failed: evidence mismatch" -- | Attempt to simplify a goal by splitting it along conjunctions. If successful, -- two subgoals will be produced, representing the two conjuncts to be proved. @@ -1133,7 +1214,7 @@ tacticTrivial sc = Tactic \goal -> Left err -> fail err Right pf -> do let gp = unProp (goalProp goal) - ty <- liftIO $ scTypeCheckError sc pf + ty <- liftIO $ TC.scTypeCheckError sc pf ok <- liftIO $ scConvertible sc True gp ty unless ok $ fail $ unlines [ "The trivial tactic cannot prove this equality" @@ -1144,7 +1225,7 @@ tacticTrivial sc = Tactic \goal -> tacticExact :: (F.MonadFail m, MonadIO m) => SharedContext -> Term -> Tactic m () tacticExact sc tm = Tactic \goal -> do let gp = unProp (goalProp goal) - ty <- liftIO $ scTypeCheckError sc tm + ty <- liftIO $ TC.scTypeCheckError sc tm ok <- liftIO $ scConvertible sc True gp ty unless ok $ fail $ unlines [ "Proof term does not prove the required proposition" diff --git a/src/SAWScript/Prover/MRSolver.hs b/src/SAWScript/Prover/MRSolver.hs index d29bfa0617..759116dedf 100644 --- a/src/SAWScript/Prover/MRSolver.hs +++ b/src/SAWScript/Prover/MRSolver.hs @@ -1,1548 +1,17 @@ -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeSynonymInstances #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TupleSections #-} - {- | Module : SAWScript.Prover.MRSolver -Copyright : Galois, Inc. 2021 +Description : The SAW monadic-recursive solver (Mr. Solver) +Copyright : Galois, Inc. 2022 License : BSD3 Maintainer : westbrook@galois.com Stability : experimental Portability : non-portable (language extensions) - -This module implements a monadic-recursive solver, for proving that one monadic -term refines another. The algorithm works on the "monadic normal form" of -computations, which uses the following laws to simplify binds in computations, -where either is the sum elimination function defined in the SAW core prelude: - -returnM x >>= k = k x -errorM str >>= k = errorM -(m >>= k1) >>= k2 = m >>= \x -> k1 x >>= k2 -(existsM f) >>= k = existsM (\x -> f x >>= k) -(forallM f) >>= k = forallM (\x -> f x >>= k) -(orM m1 m2) >>= k = orM (m1 >>= k) (m2 >>= k) -(if b then m1 else m2) >>= k = if b then m1 >>= k else m2 >>1 k -(either f1 f2 e) >>= k = either (\x -> f1 x >= k) (\x -> f2 x >= k) e -(letrecM funs body) >>= k = letrecM funs (\F1 ... Fn -> body F1 ... Fn >>= k) - -The resulting computations of one of the following forms: - -returnM e | errorM str | existsM f | forallM f | orM m1 m2 | -if b then m1 else m2 | either f1 f2 e | F e1 ... en | F e1 ... en >>= k | -letrecM lrts B (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> m) - -The form F e1 ... en refers to a recursively-defined function or a function -variable that has been locally bound by a letrecM. Either way, monadic -normalization does not attempt to normalize these functions. - -The algorithm maintains a context of three sorts of variables: letrec-bound -variables, existential variables, and universal variables. Universal variables -are represented as free SAW core variables, while the other two forms of -variable are represented as SAW core 'ExtCns's terms, which are essentially -axioms that have been generated internally. These 'ExtCns's are Skolemized, -meaning that they take in as arguments all universal variables that were in -scope when they were created. The context also maintains a partial substitution -for the existential variables, as they become instantiated with values, and it -additionally remembers the bodies / unfoldings of the letrec-bound variables. - -The goal of the solver at any point is of the form C |- m1 |= m2, meaning that -we are trying to prove m1 refines m2 in context C. This proceed by cases: - -C |- returnM e1 |= returnM e2: prove C |- e1 = e2 - -C |- errorM str1 |= errorM str2: vacuously true - -C |- if b then m1' else m1'' |= m2: prove C,b=true |- m1' |= m2 and -C,b=false |- m1'' |= m2, skipping either case where C,b=X is unsatisfiable; - -C |- m1 |= if b then m2' else m2'': similar to the above - -C |- either T U (CompM V) f1 f2 e |= m: prove C,x:T,e=inl x |- f1 x |= m and -C,y:U,e=inl y |- f2 y |= m, again skippping any case with unsatisfiable context; - -C |- m |= either T U (CompM V) f1 f2 e: similar to previous - -C |- m |= forallM f: make a new universal variable x and recurse - -C |- existsM f |= m: make a new universal variable x and recurse (existential -elimination uses universal variables and vice-versa) - -C |- m |= existsM f: make a new existential variable x and recurse - -C |- forall f |= m: make a new existential variable x and recurse - -C |- m |= orM m1 m2: try to prove C |- m |= m1, and if that fails, backtrack and -prove C |- m |= m2 - -C |- orM m1 m2 |= m: prove both C |- m1 |= m and C |- m2 |= m - -C |- letrec (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body) |= m: create -letrec-bound variables F1 through Fn in the context bound to their unfoldings f1 -through fn, respectively, and recurse on body |= m - -C |- m |= letrec (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body): similar to -previous case - -C |- F e1 ... en >>= k |= F e1' ... en' >>= k': prove C |- ei = ei' for each i -and then prove k x |= k' x for new universal variable x - -C |- F e1 ... en >>= k |= F' e1' ... em' >>= k': - -* If we have an assumption that forall x1 ... xj, F a1 ... an |= F' a1' .. am', - prove ei = ai and ei' = ai' and then that C |- k x |= k' x for fresh uvar x - -* If we have an assumption that forall x1, ..., xn, F e1'' ... en'' |= m' for - some ei'' and m', match the ei'' against the ei by instantiating the xj with - fresh evars, and if this succeeds then recursively prove C |- m' >>= k |= RHS - -(We don't do this one right now) -* If we have an assumption that forall x1', ..., xn', m |= F e1'' ... en'' for - some ei'' and m', match the ei'' against the ei by instantiating the xj with - fresh evars, and if this succeeds then recursively prove C |- LHS |= m' >>= k' - -* If either side is a definition whose unfolding does not contain letrecM, fixM, - or any related operations, unfold it - -* If F and F' have the same return type, add an assumption forall uvars in scope - that F e1 ... en |= F' e1' ... em' and unfold both sides, recursively proving - that F_body e1 ... en |= F_body' e1' ... em'. Then also prove k x |= k' x for - fresh uvar x. - -* Otherwise we don't know to "split" one of the sides into a bind whose - components relate to the two components on the other side, so just fail -} module SAWScript.Prover.MRSolver - (askMRSolver, MRFailure(..), showMRFailure, isCompFunType - , SBV.SMTConfig - , SBV.z3, SBV.cvc4, SBV.yices, SBV.mathSAT, SBV.boolector - ) where - -import Data.List (find, findIndex) -import qualified Data.Text as T -import Data.IORef -import System.IO (hPutStrLn, stderr) -import Control.Monad.Reader -import Control.Monad.State -import Control.Monad.Except -import Control.Monad.Trans.Maybe - -import qualified Data.IntMap as IntMap -import Data.Map (Map) -import qualified Data.Map as Map - -import Prettyprinter - -import Verifier.SAW.Term.Functor -import Verifier.SAW.Term.CtxTerm (MonadTerm(..)) -import Verifier.SAW.Term.Pretty -import Verifier.SAW.SCTypeCheck -import Verifier.SAW.SharedTerm -import Verifier.SAW.Recognizer -import Verifier.SAW.Cryptol.Monadify - -import SAWScript.Proof (termToProp) -import qualified SAWScript.Prover.SBV as SBV - - ----------------------------------------------------------------------- --- * Utility Functions for Transforming 'Term's ----------------------------------------------------------------------- - --- | Transform the immediate subterms of a term using the supplied function -traverseSubterms :: MonadTerm m => (Term -> m Term) -> Term -> m Term -traverseSubterms f (unwrapTermF -> tf) = traverse f tf >>= mkTermF - --- | Build a recursive memoized function for tranforming 'Term's. Take in a --- function @f@ that intuitively performs one step of the transformation and --- allow it to recursively call the memoized function being defined by passing --- it as the first argument to @f@. -memoFixTermFun :: MonadIO m => ((Term -> m a) -> Term -> m a) -> Term -> m a -memoFixTermFun f term_top = - do table_ref <- liftIO $ newIORef IntMap.empty - let go t@(STApp { stAppIndex = ix }) = - liftIO (readIORef table_ref) >>= \table -> - case IntMap.lookup ix table of - Just ret -> return ret - Nothing -> - do ret <- f go t - liftIO $ modifyIORef' table_ref (IntMap.insert ix ret) - return ret - go t = f go t - go term_top - --- | Recursively test if a 'Term' contains @letRecM@ -_containsLetRecM :: Term -> Bool -_containsLetRecM (asGlobalDef -> Just "Prelude.letRecM") = True -_containsLetRecM (unwrapTermF -> tf) = any _containsLetRecM tf - - ----------------------------------------------------------------------- --- * MR Solver Term Representation ----------------------------------------------------------------------- - --- | A variable used by the MR solver -newtype MRVar = MRVar { unMRVar :: ExtCns Term } deriving (Eq, Show, Ord) - --- | Get the type of an 'MRVar' -mrVarType :: MRVar -> Term -mrVarType = ecType . unMRVar - --- | Names of functions to be used in computations, which are either names bound --- by letrec to for recursive calls to fixed-points, existential variables, or --- global named constants -data FunName - = LetRecName MRVar | EVarFunName MRVar | GlobalName GlobalDef - deriving (Eq, Ord, Show) - --- | Get the type of a 'FunName' -funNameType :: FunName -> Term -funNameType (LetRecName var) = mrVarType var -funNameType (EVarFunName var) = mrVarType var -funNameType (GlobalName gd) = globalDefType gd - --- | A term specifically known to be of type @sort i@ for some @i@ -newtype Type = Type Term deriving Show - --- | A Haskell representation of a @CompM@ in "monadic normal form" -data NormComp - = ReturnM Term -- ^ A term @returnM a x@ - | ErrorM Term -- ^ A term @errorM a str@ - | Ite Term Comp Comp -- ^ If-then-else computation - | Either CompFun CompFun Term -- ^ A sum elimination - | OrM Comp Comp -- ^ an @orM@ computation - | ExistsM Type CompFun -- ^ an @existsM@ computation - | ForallM Type CompFun -- ^ a @forallM@ computation - | FunBind FunName [Term] CompFun - -- ^ Bind a monadic function with @N@ arguments in an @a -> CompM b@ term - deriving Show - --- | A computation function of type @a -> CompM b@ for some @a@ and @b@ -data CompFun - -- | An arbitrary term - = CompFunTerm Term - -- | A special case for the term @\ (x:a) -> returnM a x@ - | CompFunReturn - -- | The monadic composition @f >=> g@ - | CompFunComp CompFun CompFun - deriving Show - --- | Compose two 'CompFun's, simplifying if one is a 'CompFunReturn' -compFunComp :: CompFun -> CompFun -> CompFun -compFunComp CompFunReturn f = f -compFunComp f CompFunReturn = f -compFunComp f g = CompFunComp f g - --- | If a 'CompFun' contains an explicit lambda-abstraction, then return the --- textual name bound by that lambda -compFunVarName :: CompFun -> Maybe LocalName -compFunVarName (CompFunTerm (asLambda -> Just (nm, _, _))) = Just nm -compFunVarName (CompFunComp f _) = compFunVarName f -compFunVarName _ = Nothing - --- | If a 'CompFun' contains an explicit lambda-abstraction, then return the --- input type for it -compFunInputType :: CompFun -> Maybe Type -compFunInputType (CompFunTerm (asLambda -> Just (_, tp, _))) = Just $ Type tp -compFunInputType (CompFunComp f _) = compFunInputType f -compFunInputType _ = Nothing - --- | A computation of type @CompM a@ for some @a@ -data Comp = CompTerm Term | CompBind Comp CompFun | CompReturn Term - deriving Show - - ----------------------------------------------------------------------- --- * Pretty-Printing MR Solver Terms ----------------------------------------------------------------------- - --- | The monad for pretty-printing in a context of SAW core variables -type PPInCtxM = Reader [LocalName] - --- | Pretty-print an object in a SAW core context and render to a 'String' -showInCtx :: PrettyInCtx a => [LocalName] -> a -> String -showInCtx ctx a = - renderSawDoc defaultPPOpts $ runReader (prettyInCtx a) ctx - --- | A generic function for pretty-printing an object in a SAW core context of --- locally-bound names -class PrettyInCtx a where - prettyInCtx :: a -> PPInCtxM SawDoc - -instance PrettyInCtx Term where - prettyInCtx t = flip (ppTermInCtx defaultPPOpts) t <$> ask - --- | Combine a list of pretty-printed documents that represent an application -prettyAppList :: [PPInCtxM SawDoc] -> PPInCtxM SawDoc -prettyAppList = fmap (group . hang 2 . vsep) . sequence - -instance PrettyInCtx Type where - prettyInCtx (Type t) = prettyInCtx t - -instance PrettyInCtx MRVar where - prettyInCtx (MRVar ec) = return $ ppName $ ecName ec - -instance PrettyInCtx FunName where - prettyInCtx (LetRecName var) = prettyInCtx var - prettyInCtx (EVarFunName var) = prettyInCtx var - prettyInCtx (GlobalName i) = return $ viaShow i - -instance PrettyInCtx Comp where - prettyInCtx (CompTerm t) = prettyInCtx t - prettyInCtx (CompBind c f) = - prettyAppList [prettyInCtx c, return ">>=", prettyInCtx f] - prettyInCtx (CompReturn t) = - prettyAppList [ return "returnM", return "_", parens <$> prettyInCtx t] - -instance PrettyInCtx CompFun where - prettyInCtx (CompFunTerm t) = prettyInCtx t - prettyInCtx CompFunReturn = return "returnM" - prettyInCtx (CompFunComp f g) = - prettyAppList [prettyInCtx f, return ">=>", prettyInCtx g] - -instance PrettyInCtx NormComp where - prettyInCtx (ReturnM t) = - prettyAppList [return "returnM", return "_", parens <$> prettyInCtx t] - prettyInCtx (ErrorM str) = - prettyAppList [return "errorM", return "_", parens <$> prettyInCtx str] - prettyInCtx (Ite cond t1 t2) = - prettyAppList [return "ite", return "_", prettyInCtx cond, - parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] - prettyInCtx (Either f g eith) = - prettyAppList [return "either", return "_", return "_", return "_", - prettyInCtx f, prettyInCtx g, prettyInCtx eith] - prettyInCtx (OrM t1 t2) = - prettyAppList [return "orM", return "_", - parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] - prettyInCtx (ExistsM tp f) = - prettyAppList [return "existsM", prettyInCtx tp, return "_", - parens <$> prettyInCtx f] - prettyInCtx (ForallM tp f) = - prettyAppList [return "forallM", prettyInCtx tp, return "_", - parens <$> prettyInCtx f] - prettyInCtx (FunBind f args CompFunReturn) = - prettyAppList (prettyInCtx f : map prettyInCtx args) - prettyInCtx (FunBind f [] k) = - prettyAppList [prettyInCtx f, return ">>=", prettyInCtx k] - prettyInCtx (FunBind f args k) = - prettyAppList - [parens <$> prettyAppList (prettyInCtx f : map prettyInCtx args), - return ">>=", prettyInCtx k] - - ----------------------------------------------------------------------- --- * Lifting MR Solver Terms ----------------------------------------------------------------------- - --- | A term-like object is one that supports lifting and substitution -class TermLike a where - liftTermLike :: MonadTerm m => DeBruijnIndex -> DeBruijnIndex -> a -> m a - substTermLike :: MonadTerm m => DeBruijnIndex -> [Term] -> a -> m a - -instance (TermLike a, TermLike b) => TermLike (a,b) where - liftTermLike n i (a,b) = (,) <$> liftTermLike n i a <*> liftTermLike n i b - substTermLike n s (a,b) = (,) <$> substTermLike n s a <*> substTermLike n s b - -instance TermLike a => TermLike [a] where - liftTermLike n i l = mapM (liftTermLike n i) l - substTermLike n s l = mapM (substTermLike n s) l - -instance TermLike Term where - liftTermLike = liftTerm - substTermLike = substTerm - -instance TermLike Type where - liftTermLike n i (Type tp) = Type <$> liftTerm n i tp - substTermLike n s (Type tp) = Type <$> substTerm n s tp - -instance TermLike NormComp where - liftTermLike n i (ReturnM t) = ReturnM <$> liftTermLike n i t - liftTermLike n i (ErrorM str) = ErrorM <$> liftTermLike n i str - liftTermLike n i (Ite cond t1 t2) = - Ite <$> liftTermLike n i cond <*> liftTermLike n i t1 <*> liftTermLike n i t2 - liftTermLike n i (Either f g eith) = - Either <$> liftTermLike n i f <*> liftTermLike n i g <*> liftTermLike n i eith - liftTermLike n i (OrM t1 t2) = OrM <$> liftTermLike n i t1 <*> liftTermLike n i t2 - liftTermLike n i (ExistsM tp f) = - ExistsM <$> liftTermLike n i tp <*> liftTermLike n i f - liftTermLike n i (ForallM tp f) = - ForallM <$> liftTermLike n i tp <*> liftTermLike n i f - liftTermLike n i (FunBind nm args f) = - FunBind nm <$> mapM (liftTermLike n i) args <*> liftTermLike n i f - - substTermLike n s (ReturnM t) = ReturnM <$> substTermLike n s t - substTermLike n s (ErrorM str) = ErrorM <$> substTermLike n s str - substTermLike n s (Ite cond t1 t2) = - Ite <$> substTermLike n s cond <*> substTermLike n s t1 - <*> substTermLike n s t2 - substTermLike n s (Either f g eith) = - Either <$> substTermLike n s f <*> substTermLike n s g - <*> substTermLike n s eith - substTermLike n s (OrM t1 t2) = - OrM <$> substTermLike n s t1 <*> substTermLike n s t2 - substTermLike n s (ExistsM tp f) = - ExistsM <$> substTermLike n s tp <*> substTermLike n s f - substTermLike n s (ForallM tp f) = - ForallM <$> substTermLike n s tp <*> substTermLike n s f - substTermLike n s (FunBind nm args f) = - FunBind nm <$> mapM (substTermLike n s) args <*> substTermLike n s f - -instance TermLike CompFun where - liftTermLike n i (CompFunTerm t) = CompFunTerm <$> liftTermLike n i t - liftTermLike _ _ CompFunReturn = return CompFunReturn - liftTermLike n i (CompFunComp f g) = - CompFunComp <$> liftTermLike n i f <*> liftTermLike n i g - - substTermLike n s (CompFunTerm t) = CompFunTerm <$> substTermLike n s t - substTermLike _ _ CompFunReturn = return CompFunReturn - substTermLike n s (CompFunComp f g) = - CompFunComp <$> substTermLike n s f <*> substTermLike n s g - -instance TermLike Comp where - liftTermLike n i (CompTerm t) = CompTerm <$> liftTermLike n i t - liftTermLike n i (CompBind m f) = - CompBind <$> liftTermLike n i m <*> liftTermLike n i f - liftTermLike n i (CompReturn t) = CompReturn <$> liftTermLike n i t - substTermLike n s (CompTerm t) = CompTerm <$> substTermLike n s t - substTermLike n s (CompBind m f) = - CompBind <$> substTermLike n s m <*> substTermLike n s f - substTermLike n s (CompReturn t) = CompReturn <$> substTermLike n s t - - ----------------------------------------------------------------------- --- * MR Solver Errors ----------------------------------------------------------------------- - --- | The context in which a failure occurred -data FailCtx - = FailCtxRefines NormComp NormComp - | FailCtxMNF Term - deriving Show - --- | That's MR. Failure to you -data MRFailure - = TermsNotEq Term Term - | TypesNotEq Type Type - | CompsDoNotRefine NormComp NormComp - | ReturnNotError Term - | FunsNotEq FunName FunName - | CannotLookupFunDef FunName - | RecursiveUnfold FunName - | MalformedLetRecTypes Term - | MalformedDefsFun Term - | MalformedComp Term - | NotCompFunType Term - -- | A local variable binding - | MRFailureLocalVar LocalName MRFailure - -- | Information about the context of the failure - | MRFailureCtx FailCtx MRFailure - -- | Records a disjunctive branch we took, where both cases failed - | MRFailureDisj MRFailure MRFailure - deriving Show - --- | Pretty-print an object prefixed with a 'String' that describes it -ppWithPrefix :: PrettyInCtx a => String -> a -> PPInCtxM SawDoc -ppWithPrefix str a = (pretty str <>) <$> nest 2 <$> (line <>) <$> prettyInCtx a - --- | Pretty-print two objects, prefixed with a 'String' and with a separator -ppWithPrefixSep :: PrettyInCtx a => String -> a -> String -> a -> - PPInCtxM SawDoc -ppWithPrefixSep d1 t2 d3 t4 = - prettyInCtx t2 >>= \d2 -> prettyInCtx t4 >>= \d4 -> - return $ group (pretty d1 <> nest 2 (line <> d2) <> line <> - pretty d3 <> nest 2 (line <> d4)) - --- | Apply 'vsep' to a list of pretty-printing computations -vsepM :: [PPInCtxM SawDoc] -> PPInCtxM SawDoc -vsepM = fmap vsep . sequence - -instance PrettyInCtx FailCtx where - prettyInCtx (FailCtxRefines m1 m2) = - group <$> nest 2 <$> - ppWithPrefixSep "When proving refinement:" m1 "|=" m2 - prettyInCtx (FailCtxMNF t) = - group <$> nest 2 <$> vsepM [return "When normalizing computation:", - prettyInCtx t] - -instance PrettyInCtx MRFailure where - prettyInCtx (TermsNotEq t1 t2) = - ppWithPrefixSep "Could not prove terms equal:" t1 "and" t2 - prettyInCtx (TypesNotEq tp1 tp2) = - ppWithPrefixSep "Types not equal:" tp1 "and" tp2 - prettyInCtx (CompsDoNotRefine m1 m2) = - ppWithPrefixSep "Could not prove refinement: " m1 "|=" m2 - prettyInCtx (ReturnNotError t) = - ppWithPrefix "errorM computation not equal to:" (ReturnM t) - prettyInCtx (FunsNotEq nm1 nm2) = - vsepM [return "Named functions not equal:", - prettyInCtx nm1, prettyInCtx nm2] - prettyInCtx (CannotLookupFunDef nm) = - ppWithPrefix "Could not find definition for function:" nm - prettyInCtx (RecursiveUnfold nm) = - ppWithPrefix "Recursive unfolding of function inside its own body:" nm - prettyInCtx (MalformedLetRecTypes t) = - ppWithPrefix "Not a ground LetRecTypes list:" t - prettyInCtx (MalformedDefsFun t) = - ppWithPrefix "Cannot handle letRecM recursive definitions term:" t - prettyInCtx (MalformedComp t) = - ppWithPrefix "Could not handle computation:" t - prettyInCtx (NotCompFunType tp) = - ppWithPrefix "Not a computation or computational function type:" tp - prettyInCtx (MRFailureLocalVar x err) = - local (x:) $ prettyInCtx err - prettyInCtx (MRFailureCtx ctx err) = - do pp1 <- prettyInCtx ctx - pp2 <- prettyInCtx err - return (pp1 <> line <> pp2) - prettyInCtx (MRFailureDisj err1 err2) = - ppWithPrefixSep "Tried two comparisons:" err1 "Backtracking..." err2 - --- | Render a 'MRFailure' to a 'String' -showMRFailure :: MRFailure -> String -showMRFailure = showInCtx [] - - ----------------------------------------------------------------------- --- * MR Monad ----------------------------------------------------------------------- - --- | Classification info for what sort of variable an 'MRVar' is -data MRVarInfo - -- | An existential variable, that might be instantiated - = EVarInfo (Maybe Term) - -- | A letrec-bound function, with its body - | FunVarInfo Term - --- | A map from 'MRVar's to their info -type MRVarMap = Map MRVar MRVarInfo - --- | Test if a 'Term' is an application of an 'ExtCns' to some arguments -asExtCnsApp :: Recognizer Term (ExtCns Term, [Term]) -asExtCnsApp (asApplyAll -> (asExtCns -> Just ec, args)) = - return (ec, args) -asExtCnsApp _ = Nothing - --- | Recognize an evar applied to 0 or more arguments relative to a 'MRVarMap' --- along with its instantiation, if any -asEVarApp :: MRVarMap -> Recognizer Term (MRVar, [Term], Maybe Term) -asEVarApp var_map (asExtCnsApp -> Just (ec, args)) - | Just (EVarInfo maybe_inst) <- Map.lookup (MRVar ec) var_map = - Just (MRVar ec, args, maybe_inst) -asEVarApp _ _ = Nothing - --- | An assumption that a named function refines some specificaiton. This has --- the form --- --- > forall x1, ..., xn. F e1 ... ek |= m --- --- for some universal context @x1:T1, .., xn:Tn@, some list of argument --- expressions @ei@ over the universal @xj@ variables, and some right-hand side --- computation expression @m@. -data FunAssump = FunAssump { - -- | The uvars that were in scope when this assmption was created, in order - -- from outermost to innermost; that is, the uvars as "seen from outside their - -- scope", which is the reverse of the order of 'mrUVars', below - fassumpCtx :: [(LocalName,Term)], - -- | The argument expressions @e1, ..., en@ over the 'fassumpCtx' uvars - fassumpArgs :: [Term], - -- | The right-hand side upper bound @m@ over the 'fassumpCtx' uvars - fassumpRHS :: NormComp } - --- | State maintained by MR. Solver -data MRState = MRState { - -- | Global shared context for building terms, etc. - mrSC :: SharedContext, - -- | Global SMT configuration for the duration of the MR. Solver call - mrSMTConfig :: SBV.SMTConfig, - -- | SMT timeout for SMT calls made by Mr. Solver - mrSMTTimeout :: Maybe Integer, - -- | The context of universal variables, which are free SAW core variables, in - -- order from innermost to outermost, i.e., where element @0@ corresponds to - -- deBruijn index @0@ - mrUVars :: [(LocalName,Type)], - -- | The existential and letrec-bound variables - mrVars :: MRVarMap, - -- | The current assumptions of function refinement - mrFunAssumps :: Map FunName FunAssump, - -- | The current assumptions, which are conjoined into a single Boolean term - mrAssumptions :: Term, - -- | The debug level, which controls debug printing - mrDebugLevel :: Int -} - --- | Build a default, empty state from SMT configuration parameters and a set of --- function refinement assumptions -mkMRState :: SharedContext -> Map FunName FunAssump -> SBV.SMTConfig -> - Maybe Integer -> Int -> IO MRState -mkMRState sc fun_assumps smt_config timeout dlvl = - scBool sc True >>= \true_tm -> - return $ MRState { mrSC = sc, mrSMTConfig = smt_config, - mrSMTTimeout = timeout, mrUVars = [], mrVars = Map.empty, - mrFunAssumps = fun_assumps, mrAssumptions = true_tm, - mrDebugLevel = dlvl } - --- | Mr. Monad, the monad used by MR. Solver, which is the state-exception monad -newtype MRM a = MRM { unMRM :: StateT MRState (ExceptT MRFailure IO) a } - deriving (Functor, Applicative, Monad, MonadIO, - MonadState MRState, MonadError MRFailure) - -instance MonadTerm MRM where - mkTermF = liftSC1 scTermF - liftTerm = liftSC3 incVars - whnfTerm = liftSC1 scWhnf - substTerm = liftSC3 instantiateVarList - --- | Run an 'MRM' computation and return a result or an error -runMRM :: MRState -> MRM a -> IO (Either MRFailure a) -runMRM init_st m = runExceptT $ flip evalStateT init_st $ unMRM m - --- | Apply a function to any failure thrown by an 'MRM' computation -mapFailure :: (MRFailure -> MRFailure) -> MRM a -> MRM a -mapFailure f m = catchError m (throwError . f) - --- | Try two different 'MRM' computations, combining their failures if needed. --- Note that the 'MRState' will reset if the first computation fails. -mrOr :: MRM a -> MRM a -> MRM a -mrOr m1 m2 = - catchError m1 $ \err1 -> - catchError m2 $ \err2 -> - throwError $ MRFailureDisj err1 err2 - --- | Run an 'MRM' computation in an extended failure context -withFailureCtx :: FailCtx -> MRM a -> MRM a -withFailureCtx ctx = mapFailure (MRFailureCtx ctx) - -{- --- | Catch any errors thrown by a computation and coerce them to a 'Left' -catchErrorEither :: MonadError e m => m a -> m (Either e a) -catchErrorEither m = catchError (Right <$> m) (return . Left) --} - --- FIXME: replace these individual lifting functions with a more general --- typeclass like LiftTCM - -{- --- | Lift a nullary SharedTerm computation into 'MRM' -liftSC0 :: (SharedContext -> IO a) -> MRM a -liftSC0 f = (mrSC <$> get) >>= \sc -> liftIO (f sc) --} - --- | Lift a unary SharedTerm computation into 'MRM' -liftSC1 :: (SharedContext -> a -> IO b) -> a -> MRM b -liftSC1 f a = (mrSC <$> get) >>= \sc -> liftIO (f sc a) - --- | Lift a binary SharedTerm computation into 'MRM' -liftSC2 :: (SharedContext -> a -> b -> IO c) -> a -> b -> MRM c -liftSC2 f a b = (mrSC <$> get) >>= \sc -> liftIO (f sc a b) - --- | Lift a ternary SharedTerm computation into 'MRM' -liftSC3 :: (SharedContext -> a -> b -> c -> IO d) -> a -> b -> c -> MRM d -liftSC3 f a b c = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c) - --- | Lift a quaternary SharedTerm computation into 'MRM' -liftSC4 :: (SharedContext -> a -> b -> c -> d -> IO e) -> a -> b -> c -> d -> - MRM e -liftSC4 f a b c d = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c d) - --- | Apply a 'Term' to a list of arguments and beta-reduce in Mr. Monad -mrApplyAll :: Term -> [Term] -> MRM Term -mrApplyAll f args = liftSC2 scApplyAll f args >>= liftSC1 betaNormalize - --- | Get the current context of uvars as a list of variable names and their --- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in --- the order as seen "from the outside" -mrUVarCtx :: MRM [(LocalName,Term)] -mrUVarCtx = reverse <$> map (\(nm,Type tp) -> (nm,tp)) <$> mrUVars <$> get - --- | Get the type of a 'Term' in the current uvar context -mrTypeOf :: Term -> MRM Term -mrTypeOf t = mrUVarCtx >>= \ctx -> liftSC2 scTypeOf' (map snd ctx) t - --- | Check if two 'Term's are convertible in the 'MRM' monad -mrConvertible :: Term -> Term -> MRM Bool -mrConvertible = liftSC4 scConvertibleEval scTypeCheckWHNF True - --- | Take a 'FunName' @f@ for a monadic function of type @vars -> CompM a@ and --- compute the type @CompM [args/vars]a@ of @f@ applied to @args@. Return the --- type @[args/vars]a@ that @CompM@ is applied to. -mrFunOutType :: FunName -> [Term] -> MRM Term -mrFunOutType ((asPiList . funNameType) -> (vars, asCompM -> Just tp)) args - | length vars == length args = - substTermLike 0 args tp -mrFunOutType _ _ = - -- NOTE: this is an error because we should only ever call mrFunOutType with a - -- well-formed application at a CompM type - error "mrFunOutType" - --- | Turn a 'LocalName' into one not in a list, adding a suffix if necessary -uniquifyName :: LocalName -> [LocalName] -> LocalName -uniquifyName nm nms | notElem nm nms = nm -uniquifyName nm nms = - case find (flip notElem nms) $ - map (T.append nm . T.pack . show) [(0::Int) ..] of - Just nm' -> nm' - Nothing -> error "uniquifyName" - --- | Run a MR Solver computation in a context extended with a universal --- variable, which is passed as a 'Term' to the sub-computation -withUVar :: LocalName -> Type -> (Term -> MRM a) -> MRM a -withUVar nm tp m = - do st <- get - let nm' = uniquifyName nm (map fst $ mrUVars st) - put (st { mrUVars = (nm',tp) : mrUVars st }) - ret <- mapFailure (MRFailureLocalVar nm') (liftSC1 scLocalVar 0 >>= m) - modify (\st' -> st' { mrUVars = mrUVars st }) - return ret - --- | Run a MR Solver computation in a context extended with a universal variable --- and pass it the lifting (in the sense of 'incVars') of an MR Solver term -withUVarLift :: TermLike tm => LocalName -> Type -> tm -> - (Term -> tm -> MRM a) -> MRM a -withUVarLift nm tp t m = - withUVar nm tp (\x -> liftTermLike 0 1 t >>= m x) - --- | Run a MR Solver computation in a context extended with a list of universal --- variables, passing 'Term's for those variables to the supplied computation. --- The variables are bound "outside in", meaning the first variable in the list --- is bound outermost, and so will have the highest deBruijn index. -withUVars :: [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a -withUVars = helper [] where - -- The extra input list gives the variables that have already been bound, in - -- order from most to least recently bound - helper :: [Term] -> [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a - helper vars [] m = m $ reverse vars - helper vars ((nm,tp):ctx) m = - substTerm 0 vars tp >>= \tp' -> - withUVar nm (Type tp') $ \var -> helper (var:vars) ctx m - --- | Build 'Term's for all the uvars currently in scope, ordered from least to --- most recently bound -getAllUVarTerms :: MRM [Term] -getAllUVarTerms = - (length <$> mrUVars <$> get) >>= \len -> - mapM (liftSC1 scLocalVar) [len-1, len-2 .. 0] - --- | Lambda-abstract all the current uvars out of a 'Term', with the least --- recently bound variable being abstracted first -lambdaUVarsM :: Term -> MRM Term -lambdaUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scLambdaList ctx t - --- | Pi-abstract all the current uvars out of a 'Term', with the least recently --- bound variable being abstracted first -piUVarsM :: Term -> MRM Term -piUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scPiList ctx t - --- | Convert an 'MRVar' to a 'Term', applying it to all the uvars in scope -mrVarTerm :: MRVar -> MRM Term -mrVarTerm (MRVar ec) = - do var_tm <- liftSC1 scExtCns ec - vars <- getAllUVarTerms - liftSC2 scApplyAll var_tm vars - --- | Get the 'VarInfo' associated with a 'MRVar' -mrVarInfo :: MRVar -> MRM (Maybe MRVarInfo) -mrVarInfo var = Map.lookup var <$> mrVars <$> get - --- | Convert an 'ExtCns' to a 'FunName' -extCnsToFunName :: ExtCns Term -> MRM FunName -extCnsToFunName ec = let var = MRVar ec in mrVarInfo var >>= \case - Just (EVarInfo _) -> return $ EVarFunName var - Just (FunVarInfo _) -> return $ LetRecName var - Nothing - | Just glob <- asTypedGlobalDef (Unshared $ FTermF $ ExtCns ec) -> - return $ GlobalName glob - _ -> error "extCnsToFunName: unreachable" - --- | Get the body of a function @f@ if it has one -mrFunNameBody :: FunName -> MRM (Maybe Term) -mrFunNameBody (LetRecName var) = - mrVarInfo var >>= \case - Just (FunVarInfo body) -> return $ Just body - _ -> error "mrFunBody: unknown letrec var" -mrFunNameBody (GlobalName glob) = return $ globalDefBody glob -mrFunNameBody (EVarFunName _) = return Nothing - --- | Get the body of a function @f@ applied to some arguments, if possible -mrFunBody :: FunName -> [Term] -> MRM (Maybe Term) -mrFunBody f args = mrFunNameBody f >>= \case - Just body -> Just <$> mrApplyAll body args - Nothing -> return Nothing - --- | Get the body of a function @f@ applied to some arguments, as per --- 'mrFunBody', and also return whether its body recursively calls itself, as --- per 'mrCallsFun' -mrFunBodyRecInfo :: FunName -> [Term] -> MRM (Maybe (Term, Bool)) -mrFunBodyRecInfo f args = - mrFunBody f args >>= \case - Just f_body -> Just <$> (f_body,) <$> mrCallsFun f f_body - Nothing -> return Nothing - --- | Test if a 'Term' contains, after possibly unfolding some functions, a call --- to a given function @f@ again -mrCallsFun :: FunName -> Term -> MRM Bool -mrCallsFun f = memoFixTermFun $ \recurse t -> case t of - (asExtCns -> Just ec) -> - do g <- extCnsToFunName ec - maybe_body <- mrFunNameBody g - case maybe_body of - _ | f == g -> return True - Just body -> recurse body - Nothing -> return False - (asTypedGlobalDef -> Just gdef) -> - case globalDefBody gdef of - _ | f == GlobalName gdef -> return True - Just body -> recurse body - Nothing -> return False - (unwrapTermF -> tf) -> - foldM (\b t' -> if b then return b else recurse t') False tf - --- | Make a fresh 'MRVar' of a given type, which must be closed, i.e., have no --- free uvars -mrFreshVar :: LocalName -> Term -> MRM MRVar -mrFreshVar nm tp = MRVar <$> liftSC2 scFreshEC nm tp - --- | Set the info associated with an 'MRVar', assuming it has not been set -mrSetVarInfo :: MRVar -> MRVarInfo -> MRM () -mrSetVarInfo var info = - modify $ \st -> - st { mrVars = - Map.alter (\case - Just _ -> error "mrSetVarInfo" - Nothing -> Just info) - var (mrVars st) } - --- | Make a fresh existential variable of the given type, abstracting out all --- the current uvars and returning the new evar applied to all current uvars -mrFreshEVar :: LocalName -> Type -> MRM Term -mrFreshEVar nm (Type tp) = - do tp' <- piUVarsM tp - var <- mrFreshVar nm tp' - mrSetVarInfo var (EVarInfo Nothing) - mrVarTerm var - --- | Return a fresh sequence of existential variables for a context of variable --- names and types, assuming each variable is free in the types that occur after --- it in the list. Return the new evars all applied to the current uvars. -mrFreshEVars :: [(LocalName,Term)] -> MRM [Term] -mrFreshEVars = helper [] where - -- Return fresh evars for the suffix of a context of variable names and types, - -- where the supplied Terms are evars that have already been generated for the - -- earlier part of the context, and so must be substituted into the remaining - -- types in the context - helper :: [Term] -> [(LocalName,Term)] -> MRM [Term] - helper evars [] = return evars - helper evars ((nm,tp):ctx) = - do evar <- substTerm 0 evars tp >>= mrFreshEVar nm . Type - helper (evar:evars) ctx - --- | Set the value of an evar to a closed term -mrSetEVarClosed :: MRVar -> Term -> MRM () -mrSetEVarClosed var val = - do val_tp <- mrTypeOf val - -- FIXME: catch subtyping errors and report them as being evar failures - liftSC3 scCheckSubtype Nothing (TypedTerm val val_tp) (mrVarType var) - modify $ \st -> - st { mrVars = - Map.alter - (\case - Just (EVarInfo Nothing) -> Just $ EVarInfo (Just val) - Just (EVarInfo (Just _)) -> - error "Setting existential variable: variable already set!" - _ -> error "Setting existential variable: not an evar!") - var (mrVars st) } - - --- | Try to set the value of the application @X e1 .. en@ of evar @X@ to an --- expression @e@ by trying to set @X@ to @\ x1 ... xn -> e@. This only works if --- each free uvar @xi@ in @e@ is one of the arguments @ej@ to @X@ (though it --- need not be the case that @i=j@). Return whether this succeeded. -mrTrySetAppliedEVar :: MRVar -> [Term] -> Term -> MRM Bool -mrTrySetAppliedEVar evar args t = - -- Get the complete list of argument variables of the type of evar - let (evar_vars, _) = asPiList (mrVarType evar) in - -- Get all the free variables of t - let free_vars = bitSetElems (looseVars t) in - -- For each free var of t, find an arg equal to it - case mapM (\i -> findIndex (\case - (asLocalVar -> Just j) -> i == j - _ -> False) args) free_vars of - Just fv_arg_ixs - -- Check to make sure we have the right number of args - | length args == length evar_vars -> do - -- Build a list of the input vars x1 ... xn as terms, noting that the - -- first variable is the least recently bound and so has the highest - -- deBruijn index - let arg_ixs = [length args - 1, length args - 2 .. 0] - arg_vars <- mapM (liftSC1 scLocalVar) arg_ixs - - -- For free variable of t, we substitute the corresponding variable - -- xi, substituting error terms for the variables that are not free - -- (since we have nothing else to substitute for them) - let var_map = zip free_vars fv_arg_ixs - let subst = flip map [0 .. length args - 1] $ \i -> - maybe (error "mrTrySetAppliedEVar: unexpected free variable") - (arg_vars !!) (lookup i var_map) - body <- substTerm 0 subst t - - -- Now instantiate evar to \x1 ... xn -> body - evar_inst <- liftSC2 scLambdaList evar_vars body - mrSetEVarClosed evar evar_inst - return True - - _ -> return False - - --- | Replace all evars in a 'Term' with their instantiations when they have one -mrSubstEVars :: Term -> MRM Term -mrSubstEVars = memoFixTermFun $ \recurse t -> - do var_map <- mrVars <$> get - case t of - -- If t is an instantiated evar, recurse on its instantiation - (asEVarApp var_map -> Just (_, args, Just t')) -> - mrApplyAll t' args >>= recurse - -- If t is anything else, recurse on its immediate subterms - _ -> traverseSubterms recurse t - --- | Replace all evars in a 'Term' with their instantiations, returning --- 'Nothing' if we hit an uninstantiated evar -mrSubstEVarsStrict :: Term -> MRM (Maybe Term) -mrSubstEVarsStrict top_t = - runMaybeT $ flip memoFixTermFun top_t $ \recurse t -> - do var_map <- mrVars <$> get - case t of - -- If t is an instantiated evar, recurse on its instantiation - (asEVarApp var_map -> Just (_, args, Just t')) -> - lift (mrApplyAll t' args) >>= recurse - -- If t is an uninstantiated evar, return Nothing - (asEVarApp var_map -> Just (_, _, Nothing)) -> - mzero - -- If t is anything else, recurse on its immediate subterms - _ -> traverseSubterms recurse t - --- | Makes 'mrSubstEVarsStrict' be marked as used -_mrSubstEVarsStrict :: Term -> MRM (Maybe Term) -_mrSubstEVarsStrict = mrSubstEVarsStrict - --- | Look up the 'FunAssump' for a 'FunName', if there is one -mrGetFunAssump :: FunName -> MRM (Maybe FunAssump) -mrGetFunAssump nm = Map.lookup nm <$> mrFunAssumps <$> get - --- | Run a computation under the additional assumption that a named function --- applied to a list of arguments refines a given right-hand side, all of which --- are 'Term's that can have the current uvars free -withFunAssump :: FunName -> [Term] -> NormComp -> MRM a -> MRM a -withFunAssump fname args rhs m = - do mrDebugPPPrefixSep 1 "withFunAssump" (FunBind - fname args CompFunReturn) "|=" rhs - ctx <- mrUVarCtx - assumps <- mrFunAssumps <$> get - let assumps' = Map.insert fname (FunAssump ctx args rhs) assumps - modify (\s -> s { mrFunAssumps = assumps' }) - ret <- m - modify (\s -> s { mrFunAssumps = assumps }) - return ret - --- | Generate fresh evars for the context of a 'FunAssump' and substitute them --- into its arguments and right-hand side -instantiateFunAssump :: FunAssump -> MRM ([Term], NormComp) -instantiateFunAssump fassump = - do evars <- mrFreshEVars $ fassumpCtx fassump - args <- substTermLike 0 evars $ fassumpArgs fassump - rhs <- substTermLike 0 evars $ fassumpRHS fassump - return (args, rhs) - --- | Add an assumption of type @Bool@ to the current path condition while --- executing a sub-computation -withAssumption :: Term -> MRM a -> MRM a -withAssumption phi m = - do assumps <- mrAssumptions <$> get - assumps' <- liftSC2 scAnd phi assumps - modify (\s -> s { mrAssumptions = assumps' }) - ret <- m - modify (\s -> s { mrAssumptions = assumps }) - return ret - --- | Print a 'String' if the debug level is at least the supplied 'Int' -debugPrint :: Int -> String -> MRM () -debugPrint i str = - (mrDebugLevel <$> get) >>= \lvl -> - if lvl >= i then liftIO (hPutStrLn stderr str) else return () - --- | Print a document if the debug level is at least the supplied 'Int' -debugPretty :: Int -> SawDoc -> MRM () -debugPretty i pp = debugPrint i $ renderSawDoc defaultPPOpts pp - --- | Pretty-print an object in the current context if the current debug level is --- at least the supplied 'Int' -_debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM () -_debugPrettyInCtx i a = - (mrUVars <$> get) >>= \ctx -> debugPrint i (showInCtx (map fst ctx) a) - --- | Pretty-print an object relative to the current context -_mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc -_mrPPInCtx a = - runReader (prettyInCtx a) <$> map fst <$> mrUVars <$> get - --- | Pretty-print the result of 'ppWithPrefixSep' relative to the current uvar --- context to 'stderr' if the debug level is at least the 'Int' provided -mrDebugPPPrefixSep :: PrettyInCtx a => Int -> String -> a -> String -> a -> - MRM () -mrDebugPPPrefixSep i pre a1 sp a2 = - (mrUVars <$> get) >>= \ctx -> - debugPretty i $ - flip runReader (map fst ctx) (group <$> nest 2 <$> - ppWithPrefixSep pre a1 sp a2) - - ----------------------------------------------------------------------- --- * Calling Out to SMT ----------------------------------------------------------------------- - --- | Test if a closed Boolean term is "provable", i.e., its negation is --- unsatisfiable, using an SMT solver. By "closed" we mean that it contains no --- uvars or 'MRVar's. -mrProvableRaw :: Term -> MRM Bool -mrProvableRaw prop_term = - do smt_conf <- mrSMTConfig <$> get - timeout <- mrSMTTimeout <$> get - prop <- liftSC1 termToProp prop_term - (smt_res, _) <- liftSC4 SBV.proveUnintSBVIO smt_conf mempty timeout prop - case smt_res of - Just _ -> return False - Nothing -> return True - --- | Test if a Boolean term over the current uvars is provable given the current --- assumptions -mrProvable :: Term -> MRM Bool -mrProvable bool_tm = - do assumps <- mrAssumptions <$> get - prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue - forall_prop <- piUVarsM prop - mrProvableRaw forall_prop - --- | Build a Boolean 'Term' stating that two 'Term's are equal. This is like --- 'scEq' except that it works on open terms. -mrEq :: Term -> Term -> MRM Term -mrEq t1 t2 = mrTypeOf t1 >>= \tp -> mrEq' tp t1 t2 - --- | Build a Boolean 'Term' stating that the second and third 'Term' arguments --- are equal, where the first 'Term' gives their type (which we assume is the --- same for both). This is like 'scEq' except that it works on open terms. -mrEq' :: Term -> Term -> Term -> MRM Term -mrEq' (asDataType -> Just (pn, [])) t1 t2 - | primName pn == "Prelude.Nat" = liftSC2 scEqualNat t1 t2 -mrEq' (asBoolType -> Just _) t1 t2 = liftSC2 scBoolEq t1 t2 -mrEq' (asIntegerType -> Just _) t1 t2 = liftSC2 scIntEq t1 t2 -mrEq' (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = - liftSC3 scBvEq n t1 t2 -mrEq' _ _ _ = error "mrEq': unsupported type" - --- | A "simple" strategy for proving equality between two terms, which we assume --- are of the same type. This strategy first checks if either side is an --- uninstantiated evar, in which case it set that evar to the other side. If --- not, it builds an equality proposition by applying the supplied function to --- both sides, and passes this proposition to an SMT solver. -mrProveEqSimple :: (Term -> Term -> MRM Term) -> MRVarMap -> Term -> Term -> - MRM () - --- If t1 is an instantiated evar, substitute and recurse -mrProveEqSimple eqf var_map (asEVarApp var_map -> Just (_, args, Just f)) t2 = - mrApplyAll f args >>= \t1' -> mrProveEqSimple eqf var_map t1' t2 - --- If t1 is an uninstantiated evar, instantiate it with t2 -mrProveEqSimple _ var_map t1@(asEVarApp var_map -> - Just (evar, args, Nothing)) t2 = - do t2' <- mrSubstEVars t2 - success <- mrTrySetAppliedEVar evar args t2' - if success then return () else throwError (TermsNotEq t1 t2) - --- If t2 is an instantiated evar, substitute and recurse -mrProveEqSimple eqf var_map t1 (asEVarApp var_map -> Just (_, args, Just f)) = - mrApplyAll f args >>= \t2' -> mrProveEqSimple eqf var_map t1 t2' - --- If t2 is an uninstantiated evar, instantiate it with t1 -mrProveEqSimple _ var_map t1 t2@(asEVarApp var_map -> - Just (evar, args, Nothing)) = - do t1' <- mrSubstEVars t1 - success <- mrTrySetAppliedEVar evar args t1' - if success then return () else throwError (TermsNotEq t1 t2) - --- Otherwise, try to prove both sides are equal. The use of mrSubstEVars instead --- of mrSubstEVarsStrict means that we allow evars in the terms we send to the --- SMT solver, but we treat them as uvars. -mrProveEqSimple eqf _ t1 t2 = - do t1' <- mrSubstEVars t1 - t2' <- mrSubstEVars t2 - prop <- eqf t1' t2' - success <- mrProvable prop - if success then return () else - throwError (TermsNotEq t1 t2) - - --- | Prove that two terms are equal, instantiating evars if necessary, or --- throwing an error if this is not possible -mrProveEq :: Term -> Term -> MRM () -mrProveEq t1_top t2_top = - (do mrDebugPPPrefixSep 1 "mrProveEq" t1_top "==" t2_top - tp <- mrTypeOf t1_top - varmap <- mrVars <$> get - proveEq varmap tp t1_top t2_top) - where - proveEq :: Map MRVar MRVarInfo -> Term -> Term -> Term -> MRM () - proveEq var_map (asDataType -> Just (pn, [])) t1 t2 - | primName pn == "Prelude.Nat" = - mrProveEqSimple (liftSC2 scEqualNat) var_map t1 t2 - proveEq var_map (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = - -- FIXME: make a better solver for bitvector equalities - mrProveEqSimple (liftSC3 scBvEq n) var_map t1 t2 - proveEq var_map (asBoolType -> Just _) t1 t2 = - mrProveEqSimple (liftSC2 scBoolEq) var_map t1 t2 - proveEq var_map (asIntegerType -> Just _) t1 t2 = - mrProveEqSimple (liftSC2 scIntEq) var_map t1 t2 - proveEq _ _ t1 t2 = - -- As a fallback, for types we can't handle, just check convertibility - mrConvertible t1 t2 >>= \case - True -> return () - False -> throwError (TermsNotEq t1 t2) - - ----------------------------------------------------------------------- --- * Normalizing and Matching on Terms ----------------------------------------------------------------------- - --- | Match a type as being of the form @CompM a@ for some @a@ -asCompM :: Term -> Maybe Term -asCompM (asApp -> Just (isGlobalDef "Prelude.CompM" -> Just (), tp)) = - return tp -asCompM _ = fail "not a CompM type!" - --- | Test if a type is a monadic function type of 0 or more arguments -isCompFunType :: Term -> Bool -isCompFunType (asPiList -> (_, asCompM -> Just _)) = True -isCompFunType _ = False - --- | Pattern-match on a @LetRecTypes@ list in normal form and return a list of --- the types it specifies, each in normal form and with uvars abstracted out -asLRTList :: Term -> MRM [Term] -asLRTList (asCtor -> Just (primName -> "Prelude.LRT_Nil", [])) = - return [] -asLRTList (asCtor -> Just (primName -> "Prelude.LRT_Cons", [lrt, lrts])) = - do tp <- liftSC2 scGlobalApply "Prelude.lrtToType" [lrt] - tp_norm_closed <- liftSC1 scWhnf tp >>= piUVarsM - (tp_norm_closed :) <$> asLRTList lrts -asLRTList t = throwError (MalformedLetRecTypes t) - --- | Match a right-nested series of pairs. This is similar to 'asTupleValue' --- except that it expects a unit value to always be at the end. -asNestedPairs :: Recognizer Term [Term] -asNestedPairs (asPairValue -> Just (x, asNestedPairs -> Just xs)) = Just (x:xs) -asNestedPairs (asFTermF -> Just UnitValue) = Just [] -asNestedPairs _ = Nothing - --- | Normalize a 'Term' of monadic type to monadic normal form -normCompTerm :: Term -> MRM NormComp -normCompTerm = normComp . CompTerm - --- | Normalize a computation to monadic normal form, assuming any 'Term's it --- contains have already been normalized with respect to beta and projections --- (but constants need not be unfolded) -normComp :: Comp -> MRM NormComp -normComp (CompReturn t) = return $ ReturnM t -normComp (CompBind m f) = - do norm <- normComp m - normBind norm f -normComp (CompTerm t) = - withFailureCtx (FailCtxMNF t) $ - case asApplyAll t of - (isGlobalDef "Prelude.returnM" -> Just (), [_, x]) -> - return $ ReturnM x - (isGlobalDef "Prelude.bindM" -> Just (), [_, _, m, f]) -> - do norm <- normComp (CompTerm m) - normBind norm (CompFunTerm f) - (isGlobalDef "Prelude.errorM" -> Just (), [_, str]) -> - return (ErrorM str) - (isGlobalDef "Prelude.ite" -> Just (), [_, cond, then_tm, else_tm]) -> - return $ Ite cond (CompTerm then_tm) (CompTerm else_tm) - (isGlobalDef "Prelude.either" -> Just (), [_, _, _, f, g, eith]) -> - return $ Either (CompFunTerm f) (CompFunTerm g) eith - (isGlobalDef "Prelude.orM" -> Just (), [_, m1, m2]) -> - return $ OrM (CompTerm m1) (CompTerm m2) - (isGlobalDef "Prelude.existsM" -> Just (), [tp, _, body_tm]) -> - return $ ExistsM (Type tp) (CompFunTerm body_tm) - (isGlobalDef "Prelude.forallM" -> Just (), [tp, _, body_tm]) -> - return $ ForallM (Type tp) (CompFunTerm body_tm) - (isGlobalDef "Prelude.letRecM" -> Just (), [lrts, _, defs_f, body_f]) -> - do - -- First, make fresh function constants for all the bound functions, - -- using the names bound by body_f and just "F" if those run out - let fun_var_names = - map fst (fst $ asLambdaList body_f) ++ repeat "F" - fun_tps <- asLRTList lrts - funs <- zipWithM mrFreshVar fun_var_names fun_tps - fun_tms <- mapM mrVarTerm funs - - -- Next, apply the definition function defs_f to our function vars, - -- yielding the definitions of the individual letrec-bound functions in - -- terms of the new function constants - defs_tm <- mrApplyAll defs_f fun_tms - defs <- case asNestedPairs defs_tm of - Just defs -> return defs - Nothing -> throwError (MalformedDefsFun defs_f) - - -- Remember the body associated with each fresh function constant - zipWithM_ (\f body -> - lambdaUVarsM body >>= \cl_body -> - mrSetVarInfo f (FunVarInfo cl_body)) funs defs - - -- Finally, apply the body function to our function vars and recursively - -- normalize the resulting computation - body_tm <- mrApplyAll body_f fun_tms - normComp (CompTerm body_tm) - - -- Only unfold constants that are not recursive functions, i.e., whose - -- bodies do not contain letrecs - {- FIXME: this should be handled by mrRefines; we want it to be handled there - so that we use refinement assumptions before unfolding constants, to give - the user control over refinement proofs - ((asConstant -> Just (_, body)), args) - | not (containsLetRecM body) -> - mrApplyAll body args >>= normCompTerm - -} - - -- For an ExtCns, we have to check what sort of variable it is - -- FIXME: substitute for evars if they have been instantiated - ((asExtCns -> Just ec), args) -> - do fun_name <- extCnsToFunName ec - return $ FunBind fun_name args CompFunReturn - - ((asTypedGlobalDef -> Just gdef), args) -> - return $ FunBind (GlobalName gdef) args CompFunReturn - - _ -> throwError (MalformedComp t) - - --- | Bind a computation in whnf with a function, and normalize -normBind :: NormComp -> CompFun -> MRM NormComp -normBind (ReturnM t) k = applyNormCompFun k t -normBind (ErrorM msg) _ = return (ErrorM msg) -normBind (Ite cond comp1 comp2) k = - return $ Ite cond (CompBind comp1 k) (CompBind comp2 k) -normBind (Either f g t) k = - return $ Either (compFunComp f k) (compFunComp g k) t -normBind (OrM comp1 comp2) k = - return $ OrM (CompBind comp1 k) (CompBind comp2 k) -normBind (ExistsM tp f) k = return $ ExistsM tp (compFunComp f k) -normBind (ForallM tp f) k = return $ ForallM tp (compFunComp f k) -normBind (FunBind f args k1) k2 = - return $ FunBind f args (compFunComp k1 k2) - --- | Bind a 'Term' for a computation with a function and normalize -normBindTerm :: Term -> CompFun -> MRM NormComp -normBindTerm t f = normCompTerm t >>= \m -> normBind m f - --- | Apply a computation function to a term argument to get a computation -applyCompFun :: CompFun -> Term -> MRM Comp -applyCompFun (CompFunComp f g) t = - -- (f >=> g) t == f t >>= g - do comp <- applyCompFun f t - return $ CompBind comp g -applyCompFun CompFunReturn t = - return $ CompReturn t -applyCompFun (CompFunTerm f) t = CompTerm <$> mrApplyAll f [t] - --- | Apply a 'CompFun' to a term and normalize the resulting computation -applyNormCompFun :: CompFun -> Term -> MRM NormComp -applyNormCompFun f arg = applyCompFun f arg >>= normComp - --- | Apply a 'Comp - -{- FIXME: do these go away? --- | Lookup the definition of a function or throw a 'CannotLookupFunDef' if this is --- not allowed, either because it is a global function we are treating as opaque --- or because it is a locally-bound function variable -mrLookupFunDef :: FunName -> MRM Term -mrLookupFunDef f@(GlobalName _) = throwError (CannotLookupFunDef f) -mrLookupFunDef f@(LocalName var) = - mrVarInfo var >>= \case - Just (FunVarInfo body) -> return body - Just _ -> throwError (CannotLookupFunDef f) - Nothing -> error "mrLookupFunDef: unknown variable!" - --- | Unfold a call to function @f@ in term @f args >>= g@ -mrUnfoldFunBind :: FunName -> [Term] -> Mark -> CompFun -> MRM Comp -mrUnfoldFunBind f _ mark _ | inMark f mark = throwError (RecursiveUnfold f) -mrUnfoldFunBind f args mark g = - do f_def <- mrLookupFunDef f - CompBind <$> - (CompMark <$> (CompTerm <$> liftSC2 scApplyAll f_def args) - <*> (return $ singleMark f `mappend` mark)) - <*> return g --} - -{- -FIXME HERE NOW: maybe each FunName should stipulate whether it is recursive or -not, so that mrRefines can unfold the non-recursive ones early but wait on -handling the recursive ones --} - ----------------------------------------------------------------------- --- * Mr Solver Himself (He Identifies as Male) ----------------------------------------------------------------------- - --- | An object that can be converted to a normalized computation -class ToNormComp a where - toNormComp :: a -> MRM NormComp - -instance ToNormComp NormComp where - toNormComp = return -instance ToNormComp Comp where - toNormComp = normComp -instance ToNormComp Term where - toNormComp = normComp . CompTerm - --- | Prove that the left-hand computation refines the right-hand one. See the --- rules described at the beginning of this module. -mrRefines :: (ToNormComp a, ToNormComp b) => a -> b -> MRM () -mrRefines t1 t2 = - do m1 <- toNormComp t1 - m2 <- toNormComp t2 - mrDebugPPPrefixSep 1 "mrRefines" m1 "|=" m2 - withFailureCtx (FailCtxRefines m1 m2) $ mrRefines' m1 m2 - --- | The main implementation of 'mrRefines' -mrRefines' :: NormComp -> NormComp -> MRM () -mrRefines' (ReturnM e1) (ReturnM e2) = mrProveEq e1 e2 -mrRefines' (ErrorM _) (ErrorM _) = return () -mrRefines' (ReturnM e) (ErrorM _) = throwError (ReturnNotError e) -mrRefines' (ErrorM _) (ReturnM e) = throwError (ReturnNotError e) -mrRefines' (Ite cond1 m1 m1') m2_all@(Ite cond2 m2 m2') = - liftSC1 scNot cond1 >>= \not_cond1 -> - (mrEq cond1 cond2 >>= mrProvable) >>= \case - True -> - -- If we can prove cond1 == cond2, then we just need to prove m1 |= m2 and - -- m1' |= m2'; further, we need only add assumptions about cond1, because it - -- is provably equal to cond2 - withAssumption cond1 (mrRefines m1 m2) >> - withAssumption not_cond1 (mrRefines m1' m2') - False -> - -- Otherwise, prove each branch of the LHS refines the whole RHS - withAssumption cond1 (mrRefines m1 m2_all) >> - withAssumption not_cond1 (mrRefines m1' m2_all) -mrRefines' (Ite cond1 m1 m1') m2 = - do not_cond1 <- liftSC1 scNot cond1 - withAssumption cond1 (mrRefines m1 m2) - withAssumption not_cond1 (mrRefines m1' m2) -mrRefines' m1 (Ite cond2 m2 m2') = - do not_cond2 <- liftSC1 scNot cond2 - withAssumption cond2 (mrRefines m1 m2) - withAssumption not_cond2 (mrRefines m1 m2') --- FIXME: handle sum elimination --- mrRefines (Either f1 g1 e1) (Either f2 g2 e2) = -mrRefines' m1 (ForallM tp f2) = - let nm = maybe "x" id (compFunVarName f2) in - withUVarLift nm tp (m1,f2) $ \x (m1',f2') -> - applyNormCompFun f2' x >>= \m2' -> - mrRefines m1' m2' -mrRefines' (ExistsM tp f1) m2 = - let nm = maybe "x" id (compFunVarName f1) in - withUVarLift nm tp (f1,m2) $ \x (f1',m2') -> - applyNormCompFun f1' x >>= \m1' -> - mrRefines m1' m2' -mrRefines' m1 (OrM m2 m2') = - mrOr (mrRefines m1 m2) (mrRefines m1 m2') -mrRefines' (OrM m1 m1') m2 = - mrRefines m1 m2 >> mrRefines m1' m2 - --- FIXME: the following cases don't work unless we either allow evars to be set --- to NormComps or we can turn NormComps back into terms -mrRefines' m1@(FunBind (EVarFunName _) _ _) m2 = - throwError (CompsDoNotRefine m1 m2) -mrRefines' m1 m2@(FunBind (EVarFunName _) _ _) = - throwError (CompsDoNotRefine m1 m2) -{- -mrRefines' (FunBind (EVarFunName evar) args CompFunReturn) m2 = - mrGetEVar evar >>= \case - Just f -> - (mrApplyAll f args >>= normCompTerm) >>= \m1' -> - mrRefines m1' m2 - Nothing -> mrTrySetAppliedEVar evar args m2 --} - -mrRefines' (FunBind (LetRecName f) args1 k1) (FunBind (LetRecName f') args2 k2) - | f == f' && length args1 == length args2 = - zipWithM_ mrProveEq args1 args2 >> - mrRefinesFun k1 k2 - -mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = - mrFunOutType f1 args1 >>= \tp1 -> - mrFunOutType f2 args2 >>= \tp2 -> - mrConvertible tp1 tp2 >>= \tps_eq -> - mrFunBodyRecInfo f1 args1 >>= \maybe_f1_body -> - mrFunBodyRecInfo f2 args2 >>= \maybe_f2_body -> - mrGetFunAssump f1 >>= \case - - -- If we have an assumption that f1 args' refines some rhs, then prove that - -- args1 = args' and then that rhs refines m2 - Just fassump -> - do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrProveEq assump_args args1 - m1' <- normBind assump_rhs k1 - mrRefines m1' m2 - - -- If f1 unfolds and is not recursive in itself, unfold it and recurse - _ | Just (f1_body, False) <- maybe_f1_body -> - normBindTerm f1_body k1 >>= \m1' -> mrRefines m1' m2 - - -- If f2 unfolds and is not recursive in itself, unfold it and recurse - _ | Just (f2_body, False) <- maybe_f2_body -> - normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' - - -- If we do not already have an assumption that f1 refines some specification, - -- and both f1 and f2 are recursive but have the same return type, then try to - -- coinductively prove that f1 args1 |= f2 args2 under the assumption that f1 - -- args1 |= f2 args2, and then try to prove that k1 |= k2 - Nothing - | tps_eq - , Just (f1_body, _) <- maybe_f1_body - , Just (f2_body, _) <- maybe_f2_body -> - do withFunAssump f1 args1 (FunBind f2 args2 CompFunReturn) $ - mrRefines f1_body f2_body - mrRefinesFun k1 k2 - - -- If we cannot line up f1 and f2, then making progress here would require us - -- to somehow split either m1 or m2 into some bind m' >>= k' such that m' is - -- related to the function call on the other side and k' is related to the - -- continuation on the other side, but we don't know how to do that, so give - -- up - Nothing -> - throwError (CompsDoNotRefine m1 m2) - -{- FIXME: handle FunBind on just one side -mrRefines' m1@(FunBind f@(GlobalName _) args k1) m2 = - mrGetFunAssump f >>= \case - Just fassump -> - -- If we have an assumption that f args' refines some rhs, then prove that - -- args = args' and then that rhs refines m2 - do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrProveEq assump_args args - m1' <- normBind assump_rhs k1 - mrRefines m1' m2 - Nothing -> - -- We don't want to do inter-procedural proofs, so if we don't know anything - -- about f already then give up - throwError (CompsDoNotRefine m1 m2) --} - - -mrRefines' m1@(FunBind f1 args1 k1) m2 = - mrGetFunAssump f1 >>= \case - - -- If we have an assumption that f1 args' refines some rhs, then prove that - -- args1 = args' and then that rhs refines m2 - Just fassump -> - do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrProveEq assump_args args1 - m1' <- normBind assump_rhs k1 - mrRefines m1' m2 - - -- Otherwise, see if we can unfold f1 - Nothing -> - mrFunBodyRecInfo f1 args1 >>= \case - - -- If f1 unfolds and is not recursive in itself, unfold it and recurse - Just (f1_body, False) -> - normBindTerm f1_body k1 >>= \m1' -> mrRefines m1' m2 - - -- Otherwise we would have to somehow split m2 into some computation of the - -- form m2' >>= k2 where f1 args1 |= m2' and k1 |= k2, but we don't know how - -- to do this splitting, so give up - _ -> - throwError (CompsDoNotRefine m1 m2) - - -mrRefines' m1 m2@(FunBind f2 args2 k2) = - mrFunBodyRecInfo f2 args2 >>= \case - - -- If f2 unfolds and is not recursive in itself, unfold it and recurse - Just (f2_body, False) -> - normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' - - -- If f2 unfolds but is recursive, and k2 is the trivial continuation, meaning - -- m2 is just f2 args2, use the law of coinduction to prove m1 |= f2 args2 by - -- proving m1 |= f2_body under the assumption that m1 |= f2 args2 - {- FIXME: implement something like this - Just (f2_body, True) - | CompFunReturn <- k2 -> - withFunAssumpR m1 f2 args2 $ - -} - - -- Otherwise we would have to somehow split m1 into some computation of the - -- form m1' >>= k1 where m1' |= f2 args2 and k1 |= k2, but we don't know how - -- to do this splitting, so give up - _ -> - throwError (CompsDoNotRefine m1 m2) - - --- NOTE: the rules that introduce existential variables need to go last, so that --- they can quantify over as many universals as possible -mrRefines' m1 (ExistsM tp f2) = - do let nm = maybe "x" id (compFunVarName f2) - evar <- mrFreshEVar nm tp - m2' <- applyNormCompFun f2 evar - mrRefines m1 m2' -mrRefines' (ForallM tp f1) m2 = - do let nm = maybe "x" id (compFunVarName f1) - evar <- mrFreshEVar nm tp - m1' <- applyNormCompFun f1 evar - mrRefines m1' m2 - --- If none of the above cases match, then fail -mrRefines' m1 m2 = throwError (CompsDoNotRefine m1 m2) - - --- | Prove that one function refines another for all inputs -mrRefinesFun :: CompFun -> CompFun -> MRM () -mrRefinesFun CompFunReturn CompFunReturn = return () -mrRefinesFun f1 f2 - | Just nm <- compFunVarName f1 `mplus` compFunVarName f2 - , Just tp <- compFunInputType f1 `mplus` compFunInputType f2 = - withUVarLift nm tp (f1,f2) $ \x (f1', f2') -> - do m1' <- applyNormCompFun f1' x - m2' <- applyNormCompFun f2' x - mrRefines m1' m2' -mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" - - ----------------------------------------------------------------------- --- * External Entrypoints ----------------------------------------------------------------------- - --- | Test two monadic, recursive terms for equivalence -askMRSolver :: - SharedContext -> - Int {- ^ The debug level -} -> - SBV.SMTConfig {- ^ SBV configuration -} -> - Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> - Term -> Term -> IO (Maybe MRFailure) + (askMRSolver, MRFailure(..), showMRFailure, isCompFunType, + MREnv(..), emptyMREnv) where -askMRSolver sc dlvl smt_conf timeout t1 t2 = - do tp1 <- scTypeOf sc t1 - tp2 <- scTypeOf sc t2 - init_st <- mkMRState sc Map.empty smt_conf timeout dlvl - case asPiList tp1 of - (uvar_ctx, asCompM -> Just _) -> - fmap (either Just (const Nothing)) $ runMRM init_st $ - withUVars uvar_ctx $ \vars -> - do tps_are_eq <- mrConvertible tp1 tp2 - if tps_are_eq then return () else - throwError (TypesNotEq (Type tp1) (Type tp2)) - mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 - m1 <- mrApplyAll t1 vars >>= normCompTerm - m2 <- mrApplyAll t2 vars >>= normCompTerm - mrRefines m1 m2 - _ -> return $ Just $ NotCompFunType tp1 +import SAWScript.Prover.MRSolver.Term +import SAWScript.Prover.MRSolver.Monad +import SAWScript.Prover.MRSolver.Solver diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs new file mode 100644 index 0000000000..71e79735ba --- /dev/null +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -0,0 +1,967 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} + +{- | +Module : SAWScript.Prover.MRSolver.Monad +Copyright : Galois, Inc. 2022 +License : BSD3 +Maintainer : westbrook@galois.com +Stability : experimental +Portability : non-portable (language extensions) + +This module defines the monad used by Mr. Solver ('MRM') as well as the core +monadic combinators for operating on terms. +-} + +module SAWScript.Prover.MRSolver.Monad where + +import Data.List (find, findIndex) +import qualified Data.Text as T +import System.IO (hPutStrLn, stderr) +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Except +import Control.Monad.Trans.Maybe +import GHC.Generics + +import Data.Map (Map) +import qualified Data.Map as Map + +import Data.HashMap.Lazy (HashMap) +import qualified Data.HashMap.Lazy as HashMap + +import Prettyprinter + +import Verifier.SAW.Term.Functor +import Verifier.SAW.Term.CtxTerm (MonadTerm(..)) +import Verifier.SAW.Term.Pretty +import Verifier.SAW.SCTypeCheck +import Verifier.SAW.SharedTerm +import Verifier.SAW.Recognizer +import Verifier.SAW.Cryptol.Monadify + +import SAWScript.Prover.MRSolver.Term + + +---------------------------------------------------------------------- +-- * MR Solver Errors +---------------------------------------------------------------------- + +-- | The context in which a failure occurred +data FailCtx + = FailCtxRefines NormComp NormComp + | FailCtxMNF Term + deriving Show + +-- | That's MR. Failure to you +data MRFailure + = TermsNotEq Term Term + | TypesNotEq Type Type + | CompsDoNotRefine NormComp NormComp + | ReturnNotError Term + | FunsNotEq FunName FunName + | CannotLookupFunDef FunName + | RecursiveUnfold FunName + | MalformedLetRecTypes Term + | MalformedDefsFun Term + | MalformedComp Term + | NotCompFunType Term + | PrecondNotProvable FunName FunName Term + -- | A local variable binding + | MRFailureLocalVar LocalName MRFailure + -- | Information about the context of the failure + | MRFailureCtx FailCtx MRFailure + -- | Records a disjunctive branch we took, where both cases failed + | MRFailureDisj MRFailure MRFailure + deriving Show + +-- | Pretty-print an object prefixed with a 'String' that describes it +ppWithPrefix :: PrettyInCtx a => String -> a -> PPInCtxM SawDoc +ppWithPrefix str a = (pretty str <>) <$> nest 2 <$> (line <>) <$> prettyInCtx a + +-- | Pretty-print two objects, prefixed with a 'String' and with a separator +ppWithPrefixSep :: (PrettyInCtx a, PrettyInCtx b) => + String -> a -> String -> b -> PPInCtxM SawDoc +ppWithPrefixSep d1 t2 d3 t4 = + prettyInCtx t2 >>= \d2 -> prettyInCtx t4 >>= \d4 -> + return $ group (pretty d1 <> nest 2 (line <> d2) <> line <> + pretty d3 <> nest 2 (line <> d4)) + +-- | Apply 'vsep' to a list of pretty-printing computations +vsepM :: [PPInCtxM SawDoc] -> PPInCtxM SawDoc +vsepM = fmap vsep . sequence + +instance PrettyInCtx FailCtx where + prettyInCtx (FailCtxRefines m1 m2) = + group <$> nest 2 <$> + ppWithPrefixSep "When proving refinement:" m1 "|=" m2 + prettyInCtx (FailCtxMNF t) = + group <$> nest 2 <$> vsepM [return "When normalizing computation:", + prettyInCtx t] + +instance PrettyInCtx MRFailure where + prettyInCtx (TermsNotEq t1 t2) = + ppWithPrefixSep "Could not prove terms equal:" t1 "and" t2 + prettyInCtx (TypesNotEq tp1 tp2) = + ppWithPrefixSep "Types not equal:" tp1 "and" tp2 + prettyInCtx (CompsDoNotRefine m1 m2) = + ppWithPrefixSep "Could not prove refinement: " m1 "|=" m2 + prettyInCtx (ReturnNotError t) = + ppWithPrefix "errorM computation not equal to:" (ReturnM t) + prettyInCtx (FunsNotEq nm1 nm2) = + vsepM [return "Named functions not equal:", + prettyInCtx nm1, prettyInCtx nm2] + prettyInCtx (CannotLookupFunDef nm) = + ppWithPrefix "Could not find definition for function:" nm + prettyInCtx (RecursiveUnfold nm) = + ppWithPrefix "Recursive unfolding of function inside its own body:" nm + prettyInCtx (MalformedLetRecTypes t) = + ppWithPrefix "Not a ground LetRecTypes list:" t + prettyInCtx (MalformedDefsFun t) = + ppWithPrefix "Cannot handle letRecM recursive definitions term:" t + prettyInCtx (MalformedComp t) = + ppWithPrefix "Could not handle computation:" t + prettyInCtx (NotCompFunType tp) = + ppWithPrefix "Not a computation or computational function type:" tp + prettyInCtx (PrecondNotProvable f g pre) = + prettyAppList [return "Could not prove precondition for functions", + prettyInCtx f, return "and", prettyInCtx g, + return ":", prettyInCtx pre] + prettyInCtx (MRFailureLocalVar x err) = + local (x:) $ prettyInCtx err + prettyInCtx (MRFailureCtx ctx err) = + do pp1 <- prettyInCtx ctx + pp2 <- prettyInCtx err + return (pp1 <> line <> pp2) + prettyInCtx (MRFailureDisj err1 err2) = + ppWithPrefixSep "Tried two comparisons:" err1 "Backtracking..." err2 + +-- | Render a 'MRFailure' to a 'String' +showMRFailure :: MRFailure -> String +showMRFailure = showInCtx [] + + +---------------------------------------------------------------------- +-- * MR Monad +---------------------------------------------------------------------- + +-- | Classification info for what sort of variable an 'MRVar' is +data MRVarInfo + -- | An existential variable, that might be instantiated + = EVarInfo (Maybe Term) + -- | A letrec-bound function, with its body + | FunVarInfo Term + +-- | A map from 'MRVar's to their info +type MRVarMap = Map MRVar MRVarInfo + +-- | Test if a 'Term' is an application of an 'ExtCns' to some arguments +asExtCnsApp :: Recognizer Term (ExtCns Term, [Term]) +asExtCnsApp (asApplyAll -> (asExtCns -> Just ec, args)) = + return (ec, args) +asExtCnsApp _ = Nothing + +-- | Recognize an evar applied to 0 or more arguments relative to a 'MRVarMap' +-- along with its instantiation, if any +asEVarApp :: MRVarMap -> Recognizer Term (MRVar, [Term], Maybe Term) +asEVarApp var_map (asExtCnsApp -> Just (ec, args)) + | Just (EVarInfo maybe_inst) <- Map.lookup (MRVar ec) var_map = + Just (MRVar ec, args, maybe_inst) +asEVarApp _ _ = Nothing + +-- | A co-inductive hypothesis of the form: +-- +-- > forall x1, ..., xn. F y1 ... ym |= G z1 ... zl +-- +-- for some universal context @x1:T1, ..., xn:Tn@ and some lists of argument +-- expressions @y1, ..., ym@ and @z1, ..., zl@ over the universal context. +data CoIndHyp = CoIndHyp { + -- | The uvars that were in scope when this assmption was created, in order + -- from outermost to innermost; that is, the uvars as "seen from outside their + -- scope", which is the reverse of the order of 'mrUVars', below + coIndHypCtx :: [(LocalName,Term)], + -- | The LHS function name + coIndHypLHSFun :: FunName, + -- | The RHS function name + coIndHypRHSFun :: FunName, + -- | The LHS argument expressions @y1, ..., ym@ over the 'coIndHypCtx' uvars + coIndHypLHS :: [Term], + -- | The RHS argument expressions @y1, ..., ym@ over the 'coIndHypCtx' uvars + coIndHypRHS :: [Term], + -- | The precondition for the left-hand arguments, as a closed function from + -- the left-hand arguments to @Bool@ + coIndHypPrecondLHS :: Maybe Term, + -- | The precondition for the right-hand arguments, as a closed function from + -- the left-hand arguments to @Bool@ + coIndHypPrecondRHS :: Maybe Term +} deriving Show + +-- | Extract the @i@th argument on either the left- or right-hand side of a +-- coinductive hypothesis +coIndHypArg :: CoIndHyp -> Either Int Int -> Term +coIndHypArg hyp (Left i) = (coIndHypLHS hyp) !! i +coIndHypArg hyp (Right i) = (coIndHypRHS hyp) !! i + +-- | A map from pairs of function names to co-inductive hypotheses over those +-- names +type CoIndHyps = Map (FunName, FunName) CoIndHyp + +instance PrettyInCtx CoIndHyp where + prettyInCtx (CoIndHyp ctx f1 f2 args1 args2 pre1 pre2) = + local (const $ map fst $ reverse ctx) $ + prettyAppList [return (ppCtx ctx <> "."), + (case pre1 of + Just f -> prettyTermApp f args1 + Nothing -> return "True"), return "=>", + (case pre2 of + Just f -> prettyTermApp f args2 + Nothing -> return "True"), return "=>", + prettyInCtx (FunBind f1 args1 CompFunReturn), + return "|=", + prettyInCtx (FunBind f2 args2 CompFunReturn)] + +-- | An assumption that something is equal to one of the constructors of a +-- datatype, e.g. equal to @Left@ of some 'Term' or @Right@ of some 'Term' +data DataTypeAssump + = IsLeft Term | IsRight Term | IsNum Term | IsInf + deriving (Generic, Show, TermLike) + +instance PrettyInCtx DataTypeAssump where + prettyInCtx (IsLeft x) = prettyInCtx x >>= ppWithPrefix "Left _ _" + prettyInCtx (IsRight x) = prettyInCtx x >>= ppWithPrefix "Right _ _" + prettyInCtx (IsNum x) = prettyInCtx x >>= ppWithPrefix "TCNum" + prettyInCtx IsInf = return "TCInf" + +-- | Recognize a term as a @Left@ or @Right@ +asEither :: Recognizer Term (Either Term Term) +asEither (asCtor -> Just (c, [_, _, x])) + | primName c == "Prelude.Left" = return $ Left x + | primName c == "Prelude.Right" = return $ Right x +asEither _ = Nothing + +-- | Recognize a term as a @TCNum n@ or @TCInf@ +asNum :: Recognizer Term (Either Term ()) +asNum (asCtor -> Just (c, [n])) + | primName c == "Cryptol.TCNum" = return $ Left n +asNum (asCtor -> Just (c, [])) + | primName c == "Cryptol.TCInf" = return $ Right () +asNum _ = Nothing + +-- | Recognize a term as being of the form @isFinite n@ +asIsFinite :: Recognizer Term Term +asIsFinite (asApp -> Just (isGlobalDef "CryptolM.isFinite" -> Just (), n)) = + Just n +asIsFinite _ = Nothing + +-- | A map from 'Term's to 'DataTypeAssump's over that term +type DataTypeAssumps = HashMap Term DataTypeAssump + +-- | Parameters and locals for MR. Solver +data MRInfo = MRInfo { + -- | Global shared context for building terms, etc. + mriSC :: SharedContext, + -- | SMT timeout for SMT calls made by Mr. Solver + mriSMTTimeout :: Maybe Integer, + -- | The current context of universal variables, which are free SAW core + -- variables, in order from innermost to outermost, i.e., where element @0@ + -- corresponds to deBruijn index @0@ + mriUVars :: [(LocalName,Type)], + -- | The top-level Mr Solver environment + mriEnv :: MREnv, + -- | The current set of co-inductive hypotheses + mriCoIndHyps :: CoIndHyps, + -- | The current assumptions, which are conjoined into a single Boolean term; + -- note that these have the current UVars free + mriAssumptions :: Term, + -- | The current set of 'DataTypeAssump's + mriDataTypeAssumps :: DataTypeAssumps, + -- | The debug level, which controls debug printing + mriDebugLevel :: Int +} + +-- | State maintained by MR. Solver +data MRState = MRState { + -- | The existential and letrec-bound variables + mrsVars :: MRVarMap +} + +-- | The exception type for MR. Solver, which is either a 'MRFailure' or a +-- widening request +data MRExn = MRExnFailure MRFailure + | MRExnWiden FunName FunName [Either Int Int] + deriving Show + +-- | Mr. Monad, the monad used by MR. Solver, which has 'MRInfo' as as a +-- shared environment, 'MRState' as state, and 'MRFailure' as an exception +-- type, all over an 'IO' monad +newtype MRM a = MRM { unMRM :: ReaderT MRInfo (StateT MRState + (ExceptT MRExn IO)) a } + deriving newtype (Functor, Applicative, Monad, MonadIO, + MonadReader MRInfo, MonadState MRState, + MonadError MRExn) + +instance MonadTerm MRM where + mkTermF = liftSC1 scTermF + liftTerm = liftSC3 incVars + whnfTerm = liftSC1 scWhnf + substTerm = liftSC3 instantiateVarList + +-- | Get the current value of 'mriSC' +mrSC :: MRM SharedContext +mrSC = mriSC <$> ask + +-- | Get the current value of 'mriSMTTimeout' +mrSMTTimeout :: MRM (Maybe Integer) +mrSMTTimeout = mriSMTTimeout <$> ask + +-- | Get the current value of 'mriUVars' +mrUVars :: MRM [(LocalName,Type)] +mrUVars = mriUVars <$> ask + +-- | Get the current function assumptions +mrFunAssumps :: MRM FunAssumps +mrFunAssumps = mreFunAssumps <$> mriEnv <$> ask + +-- | Get the current value of 'mriCoIndHyps' +mrCoIndHyps :: MRM CoIndHyps +mrCoIndHyps = mriCoIndHyps <$> ask + +-- | Get the current value of 'mriAssumptions' +mrAssumptions :: MRM Term +mrAssumptions = mriAssumptions <$> ask + +-- | Get the current value of 'mriDataTypeAssumps' +mrDataTypeAssumps :: MRM DataTypeAssumps +mrDataTypeAssumps = mriDataTypeAssumps <$> ask + +-- | Get the current value of 'mriDebugLevel' +mrDebugLevel :: MRM Int +mrDebugLevel = mriDebugLevel <$> ask + +-- | Get the current value of 'mrsVars' +mrVars :: MRM MRVarMap +mrVars = mrsVars <$> get + +-- | Run an 'MRM' computation and return a result or an error +runMRM :: SharedContext -> Maybe Integer -> Int -> MREnv -> + MRM a -> IO (Either MRFailure a) +runMRM sc timeout debug env m = + do true_tm <- scBool sc True + let init_info = MRInfo { mriSC = sc, mriSMTTimeout = timeout, + mriDebugLevel = debug, mriEnv = env, + mriUVars = [], mriCoIndHyps = Map.empty, + mriAssumptions = true_tm, + mriDataTypeAssumps = HashMap.empty } + let init_st = MRState { mrsVars = Map.empty } + res <- runExceptT $ flip evalStateT init_st $ + flip runReaderT init_info $ unMRM m + case res of + Right a -> return $ Right a + Left (MRExnFailure failure) -> return $ Left failure + Left exn -> fail ("runMRM: unexpected internal exception: " ++ show exn) + +-- | Throw an 'MRFailure' +throwMRFailure :: MRFailure -> MRM a +throwMRFailure = throwError . MRExnFailure + +-- | Apply a function to any failure thrown by an 'MRM' computation +mapMRFailure :: (MRFailure -> MRFailure) -> MRM a -> MRM a +mapMRFailure f m = catchError m $ \case + MRExnFailure failure -> throwError $ MRExnFailure $ f failure + e -> throwError e + +-- | Catch any 'MRFailure' raised by a computation +catchFailure :: MRM a -> (MRFailure -> MRM a) -> MRM a +catchFailure m f = + m `catchError` \case + MRExnFailure failure -> f failure + e -> throwError e + +-- | Try two different 'MRM' computations, combining their failures if needed. +-- Note that the 'MRState' will reset if the first computation fails. +mrOr :: MRM a -> MRM a -> MRM a +mrOr m1 m2 = + catchFailure m1 $ \err1 -> + catchFailure m2 $ \err2 -> + throwMRFailure $ MRFailureDisj err1 err2 + +-- | Run an 'MRM' computation in an extended failure context +withFailureCtx :: FailCtx -> MRM a -> MRM a +withFailureCtx ctx = mapMRFailure (MRFailureCtx ctx) + +{- +-- | Catch any errors thrown by a computation and coerce them to a 'Left' +catchErrorEither :: MonadError e m => m a -> m (Either e a) +catchErrorEither m = catchError (Right <$> m) (return . Left) +-} + +-- FIXME: replace these individual lifting functions with a more general +-- typeclass like LiftTCM + +-- | Lift a nullary SharedTerm computation into 'MRM' +liftSC0 :: (SharedContext -> IO a) -> MRM a +liftSC0 f = mrSC >>= \sc -> liftIO (f sc) + +-- | Lift a unary SharedTerm computation into 'MRM' +liftSC1 :: (SharedContext -> a -> IO b) -> a -> MRM b +liftSC1 f a = mrSC >>= \sc -> liftIO (f sc a) + +-- | Lift a binary SharedTerm computation into 'MRM' +liftSC2 :: (SharedContext -> a -> b -> IO c) -> a -> b -> MRM c +liftSC2 f a b = mrSC >>= \sc -> liftIO (f sc a b) + +-- | Lift a ternary SharedTerm computation into 'MRM' +liftSC3 :: (SharedContext -> a -> b -> c -> IO d) -> a -> b -> c -> MRM d +liftSC3 f a b c = mrSC >>= \sc -> liftIO (f sc a b c) + +-- | Lift a quaternary SharedTerm computation into 'MRM' +liftSC4 :: (SharedContext -> a -> b -> c -> d -> IO e) -> a -> b -> c -> d -> + MRM e +liftSC4 f a b c d = mrSC >>= \sc -> liftIO (f sc a b c d) + +-- | Lift a quinary SharedTerm computation into 'MRM' +liftSC5 :: (SharedContext -> a -> b -> c -> d -> e -> IO f) -> + a -> b -> c -> d -> e -> MRM f +liftSC5 f a b c d e = mrSC >>= \sc -> liftIO (f sc a b c d e) + + +---------------------------------------------------------------------- +-- * Monadic Operations on Terms +---------------------------------------------------------------------- + +-- | Apply a 'TermProj' to perform a projection on a 'Term' +doTermProj :: Term -> TermProj -> MRM Term +doTermProj t TermProjLeft = liftSC1 scPairLeft t +doTermProj t TermProjRight = liftSC1 scPairRight t +doTermProj t (TermProjRecord fld) = liftSC2 scRecordSelect t fld + +-- | Apply a 'TermProj' to a type to get the output type of the projection, +-- assuming that the type is already normalized +doTypeProj :: Term -> TermProj -> MRM Term +doTypeProj (asPairType -> Just (tp1, _)) TermProjLeft = return tp1 +doTypeProj (asPairType -> Just (_, tp2)) TermProjRight = return tp2 +doTypeProj (asRecordType -> Just tp_map) (TermProjRecord fld) + | Just tp <- Map.lookup fld tp_map + = return tp +doTypeProj _ _ = + -- FIXME: better error message? This is an error and not an MRFailure because + -- we should only be projecting types for terms that we have already seen... + error "doTypeProj" + +-- | Get and normalize the type of a 'FunName' +funNameType :: FunName -> MRM Term +funNameType (LetRecName var) = liftSC1 scWhnf $ mrVarType var +funNameType (EVarFunName var) = liftSC1 scWhnf $ mrVarType var +funNameType (GlobalName gd projs) = + liftSC1 scWhnf (globalDefType gd) >>= \gd_tp -> + foldM doTypeProj gd_tp projs + +-- | Apply a 'Term' to a list of arguments and beta-reduce in Mr. Monad +mrApplyAll :: Term -> [Term] -> MRM Term +mrApplyAll f args = liftSC2 scApplyAllBeta f args + +-- | Get the current context of uvars as a list of variable names and their +-- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in +-- the order as seen "from the outside" +mrUVarCtx :: MRM [(LocalName,Term)] +mrUVarCtx = reverse <$> mrUVarCtxRev + +-- | Get the current context of uvars as a list of variable names and their +-- types as SAW core 'Term's, with the most recently bound uvar first, i.e., in +-- the order as seen "from the inside" +mrUVarCtxRev :: MRM [(LocalName,Term)] +mrUVarCtxRev = map (\(nm,Type tp) -> (nm,tp)) <$> mrUVars + +-- | Get the type of a 'Term' in the current uvar context +mrTypeOf :: Term -> MRM Term +mrTypeOf t = + -- NOTE: scTypeOf' wants the type context in the most recently bound var + -- first, i.e., in the mrUVarCtxRev order + mrDebugPPPrefix 2 "mrTypeOf:" t >> + mrUVarCtxRev >>= \ctx -> liftSC2 scTypeOf' (map snd ctx) t + +-- | Check if two 'Term's are convertible in the 'MRM' monad +mrConvertible :: Term -> Term -> MRM Bool +mrConvertible = liftSC4 scConvertibleEval scTypeCheckWHNF True + +-- | Take a 'FunName' @f@ for a monadic function of type @vars -> CompM a@ and +-- compute the type @CompM [args/vars]a@ of @f@ applied to @args@. Return the +-- type @[args/vars]a@ that @CompM@ is applied to. +mrFunOutType :: FunName -> [Term] -> MRM Term +mrFunOutType fname args = + funNameType fname >>= \case + (asPiList -> (vars, asCompM -> Just tp)) + | length vars == length args -> substTermLike 0 (reverse args) tp + ftype@(asPiList -> (vars, _)) -> + do pp_ftype <- mrPPInCtx ftype + pp_fname <- mrPPInCtx fname + debugPrint 0 "mrFunOutType: function applied to the wrong number of args" + debugPrint 0 ("Expected: " ++ show (length vars) ++ + ", found: " ++ show (length args)) + debugPretty 0 ("For function: " <> pp_fname <> " with type: " <> pp_ftype) + error "mrFunOutType" + +-- | Turn a 'LocalName' into one not in a list, adding a suffix if necessary +uniquifyName :: LocalName -> [LocalName] -> LocalName +uniquifyName nm nms | notElem nm nms = nm +uniquifyName nm nms = + case find (flip notElem nms) $ + map (T.append nm . T.pack . show) [(0::Int) ..] of + Just nm' -> nm' + Nothing -> error "uniquifyName" + +-- | Turn a list of 'LocalName's into one names not in a list, adding suffixes +-- if necessary +uniquifyNames :: [LocalName] -> [LocalName] -> [LocalName] +uniquifyNames [] _ = [] +uniquifyNames (nm:nms) nms_other = + let nm' = uniquifyName nm nms_other in + nm' : uniquifyNames nms (nm' : nms_other) + +-- | Run a MR Solver computation in a context extended with a universal +-- variable, which is passed as a 'Term' to the sub-computation. Note that any +-- assumptions made in the sub-computation will be lost when it completes. +withUVar :: LocalName -> Type -> (Term -> MRM a) -> MRM a +withUVar nm (Type tp) m = withUVars [(nm,tp)] (\[v] -> m v) + +-- | Run a MR Solver computation in a context extended with a universal variable +-- and pass it the lifting (in the sense of 'incVars') of an MR Solver term +withUVarLift :: TermLike tm => LocalName -> Type -> tm -> + (Term -> tm -> MRM a) -> MRM a +withUVarLift nm tp t m = + withUVar nm tp (\x -> liftTermLike 0 1 t >>= m x) + +-- | Run a MR Solver computation in a context extended with a list of universal +-- variables, passing 'Term's for those variables to the supplied computation. +-- The variables are bound "outside in", meaning the first variable in the list +-- is bound outermost, and so will have the highest deBruijn index. +withUVars :: [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a +withUVars [] f = f [] +withUVars ctx f = + do nms <- uniquifyNames (map fst ctx) <$> map fst <$> mrUVars + let ctx_u = zip nms $ map (Type . snd) ctx + assumps' <- mrAssumptions >>= liftTerm 0 (length ctx) + dataTypeAssumps' <- mrDataTypeAssumps >>= mapM (liftTermLike 0 (length ctx)) + vars <- reverse <$> mapM (liftSC1 scLocalVar) [0 .. length ctx - 1] + local (\info -> info { mriUVars = reverse ctx_u ++ mriUVars info, + mriAssumptions = assumps', + mriDataTypeAssumps = dataTypeAssumps' }) $ + foldr (\nm m -> mapMRFailure (MRFailureLocalVar nm) m) (f vars) nms + +-- | Run a MR Solver in a top-level context, i.e., with no uvars or assumptions +withNoUVars :: MRM a -> MRM a +withNoUVars m = + do true_tm <- liftSC1 scBool True + local (\info -> info { mriUVars = [], mriAssumptions = true_tm, + mriDataTypeAssumps = HashMap.empty }) m + +-- | Run a MR Solver in a context of only the specified UVars, no others +withOnlyUVars :: [(LocalName,Term)] -> MRM a -> MRM a +withOnlyUVars vars m = withNoUVars $ withUVars vars $ const m + +-- | Build 'Term's for all the uvars currently in scope, ordered from least to +-- most recently bound +getAllUVarTerms :: MRM [Term] +getAllUVarTerms = + (length <$> mrUVars) >>= \len -> + mapM (liftSC1 scLocalVar) [len-1, len-2 .. 0] + +-- | Lambda-abstract all the current uvars out of a 'Term', with the least +-- recently bound variable being abstracted first +lambdaUVarsM :: Term -> MRM Term +lambdaUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scLambdaList ctx t + +-- | Pi-abstract all the current uvars out of a 'Term', with the least recently +-- bound variable being abstracted first +piUVarsM :: Term -> MRM Term +piUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scPiList ctx t + +-- | Instantiate all uvars in a term using the supplied function +instantiateUVarsM :: TermLike a => (LocalName -> Term -> MRM Term) -> a -> MRM a +instantiateUVarsM f a = + do ctx <- mrUVarCtx + -- Remember: the uvar context is outermost to innermost, so we bind + -- variables from left to right, substituting earlier ones into the types + -- of later ones, but all substitutions are in reverse order, since + -- substTerm and friends like innermost bindings first + let helper :: [Term] -> [(LocalName,Term)] -> MRM [Term] + helper tms [] = return tms + helper tms ((nm,tp):vars) = + do tp' <- substTerm 0 tms tp + tm <- f nm tp' + helper (tm:tms) vars + ecs <- helper [] ctx + substTermLike 0 ecs a + +-- | Convert an 'MRVar' to a 'Term', applying it to all the uvars in scope +mrVarTerm :: MRVar -> MRM Term +mrVarTerm (MRVar ec) = + do var_tm <- liftSC1 scExtCns ec + vars <- getAllUVarTerms + liftSC2 scApplyAll var_tm vars + +-- | Create a dummy proof term of the specified type, which can be open but +-- should be of @Prop@ sort, by creating an 'ExtCns' axiom. This is sound as +-- long as we only use the resulting term in computation branches where we know +-- the proposition holds. +mrDummyProof :: Term -> MRM Term +mrDummyProof tp = piUVarsM tp >>= mrFreshVar "pf" >>= mrVarTerm + +-- | Get the 'VarInfo' associated with a 'MRVar' +mrVarInfo :: MRVar -> MRM (Maybe MRVarInfo) +mrVarInfo var = Map.lookup var <$> mrVars + +-- | Convert an 'ExtCns' to a 'FunName' +extCnsToFunName :: ExtCns Term -> MRM FunName +extCnsToFunName ec = let var = MRVar ec in mrVarInfo var >>= \case + Just (EVarInfo _) -> return $ EVarFunName var + Just (FunVarInfo _) -> return $ LetRecName var + Nothing + | Just glob <- asTypedGlobalDef (Unshared $ FTermF $ ExtCns ec) -> + return $ GlobalName glob [] + _ -> error "extCnsToFunName: unreachable" + +-- | Get the body of a function @f@ if it has one +mrFunNameBody :: FunName -> MRM (Maybe Term) +mrFunNameBody (LetRecName var) = + mrVarInfo var >>= \case + Just (FunVarInfo body) -> return $ Just body + _ -> error "mrFunBody: unknown letrec var" +mrFunNameBody (GlobalName glob projs) + | Just body <- globalDefBody glob + = Just <$> foldM doTermProj body projs +mrFunNameBody (GlobalName _ _) = return Nothing +mrFunNameBody (EVarFunName _) = return Nothing + +-- | Get the body of a function @f@ applied to some arguments, if possible +mrFunBody :: FunName -> [Term] -> MRM (Maybe Term) +mrFunBody f args = mrFunNameBody f >>= \case + Just body -> Just <$> mrApplyAll body args + Nothing -> return Nothing + +-- | Get the body of a function @f@ applied to some arguments, as per +-- 'mrFunBody', and also return whether its body recursively calls itself, as +-- per 'mrCallsFun' +mrFunBodyRecInfo :: FunName -> [Term] -> MRM (Maybe (Term, Bool)) +mrFunBodyRecInfo f args = + mrFunBody f args >>= \case + Just f_body -> Just <$> (f_body,) <$> mrCallsFun f f_body + Nothing -> return Nothing + +-- | Test if a 'Term' contains, after possibly unfolding some functions, a call +-- to a given function @f@ again +mrCallsFun :: FunName -> Term -> MRM Bool +mrCallsFun f = memoFixTermFun $ \recurse t -> case t of + (asExtCns -> Just ec) -> + do g <- extCnsToFunName ec + maybe_body <- mrFunNameBody g + case maybe_body of + _ | f == g -> return True + Just body -> recurse body + Nothing -> return False + (asTypedGlobalProj -> Just (gdef, projs)) -> + case globalDefBody gdef of + _ | f == GlobalName gdef projs -> return True + Just body -> recurse body + Nothing -> return False + (unwrapTermF -> tf) -> + foldM (\b t' -> if b then return b else recurse t') False tf + +-- | Make a fresh 'MRVar' of a given type, which must be closed, i.e., have no +-- free uvars +mrFreshVar :: LocalName -> Term -> MRM MRVar +mrFreshVar nm tp = MRVar <$> liftSC2 scFreshEC nm tp + +-- | Set the info associated with an 'MRVar', assuming it has not been set +mrSetVarInfo :: MRVar -> MRVarInfo -> MRM () +mrSetVarInfo var info = + modify $ \st -> + st { mrsVars = + Map.alter (\case + Just _ -> error "mrSetVarInfo" + Nothing -> Just info) + var (mrsVars st) } + +-- | Make a fresh existential variable of the given type, abstracting out all +-- the current uvars and returning the new evar applied to all current uvars +mrFreshEVar :: LocalName -> Type -> MRM Term +mrFreshEVar nm (Type tp) = + do tp' <- piUVarsM tp + var <- mrFreshVar nm tp' + mrSetVarInfo var (EVarInfo Nothing) + mrVarTerm var + +-- | Return a fresh sequence of existential variables for a context of variable +-- names and types, assuming each variable is free in the types that occur after +-- it in the list. Return the new evars all applied to the current uvars. +mrFreshEVars :: [(LocalName,Term)] -> MRM [Term] +mrFreshEVars = helper [] where + -- Return fresh evars for the suffix of a context of variable names and types, + -- where the supplied Terms are evars that have already been generated for the + -- earlier part of the context, and so must be substituted into the remaining + -- types in the context + helper :: [Term] -> [(LocalName,Term)] -> MRM [Term] + helper evars [] = return evars + helper evars ((nm,tp):ctx) = + do evar <- substTerm 0 evars tp >>= mrFreshEVar nm . Type + helper (evar:evars) ctx + +-- | Set the value of an evar to a closed term +mrSetEVarClosed :: MRVar -> Term -> MRM () +mrSetEVarClosed var val = + do val_tp <- mrTypeOf val + -- NOTE: need to instantiate any evars in the type of var, to ensure the + -- following subtyping check will succeed + var_tp <- mrSubstEVars $ mrVarType var + -- FIXME: catch subtyping errors and report them as being evar failures + eith_err <- + liftSC2 (runTCM $ checkSubtype (TypedTerm val val_tp) var_tp) Nothing [] + case eith_err of + Left _ -> + error ("mrSetEVarClosed: incorrect instantiation for evar " ++ + showMRVar var ++ + "\nexpected type:\n" ++ showTerm var_tp ++ + "\nactual type:\n" ++ showTerm val_tp) + Right _ -> return () + modify $ \st -> + st { mrsVars = + Map.alter + (\case + Just (EVarInfo Nothing) -> Just $ EVarInfo (Just val) + Just (EVarInfo (Just _)) -> + error "Setting existential variable: variable already set!" + _ -> error "Setting existential variable: not an evar!") + var (mrsVars st) } + + +-- | Try to set the value of the application @X e1 .. en@ of evar @X@ to an +-- expression @e@ by trying to set @X@ to @\ x1 ... xn -> e@. This only works if +-- each free uvar @xi@ in @e@ is one of the arguments @ej@ to @X@ (though it +-- need not be the case that @i=j@). Return whether this succeeded. +mrTrySetAppliedEVar :: MRVar -> [Term] -> Term -> MRM Bool +mrTrySetAppliedEVar evar args t = + -- Get the complete list of argument variables of the type of evar + let (evar_vars, _) = asPiList (mrVarType evar) in + -- Get all the free variables of t + let free_vars = bitSetElems (looseVars t) in + -- For each free var of t, find an arg equal to it + case mapM (\i -> findIndex (\case + (asLocalVar -> Just j) -> i == j + _ -> False) args) free_vars of + Just fv_arg_ixs + -- Check to make sure we have the right number of args + | length args == length evar_vars -> do + -- Build a list of the input vars x1 ... xn as terms, noting that the + -- first variable is the least recently bound and so has the highest + -- deBruijn index + let arg_ixs = reverse [0 .. length args - 1] + arg_vars <- mapM (liftSC1 scLocalVar) arg_ixs + + -- For each free variable of t, we substitute the corresponding + -- variable xi, substituting error terms for the variables that are + -- not free (since we have nothing else to substitute for them) + let var_map = zip free_vars fv_arg_ixs + let subst_vars = if free_vars == [] then [] else + [0 .. maximum free_vars] + let subst = flip map subst_vars $ \i -> + maybe (error + ("mrTrySetAppliedEVar: unexpected free variable " + ++ show i ++ " in term\n" ++ showTerm t)) + (arg_vars !!) (lookup i var_map) + body <- substTerm 0 subst t + + -- Now instantiate evar to \x1 ... xn -> body + evar_inst <- liftSC2 scLambdaList evar_vars body + mrSetEVarClosed evar evar_inst + return True + + _ -> return False + + +-- | Replace all evars in a 'Term' with their instantiations when they have one +mrSubstEVars :: Term -> MRM Term +mrSubstEVars = memoFixTermFun $ \recurse t -> + do var_map <- mrVars + case t of + -- If t is an instantiated evar, recurse on its instantiation + (asEVarApp var_map -> Just (_, args, Just t')) -> + mrApplyAll t' args >>= recurse + -- If t is anything else, recurse on its immediate subterms + _ -> traverseSubterms recurse t + +-- | Replace all evars in a 'Term' with their instantiations, returning +-- 'Nothing' if we hit an uninstantiated evar +mrSubstEVarsStrict :: Term -> MRM (Maybe Term) +mrSubstEVarsStrict top_t = + runMaybeT $ flip memoFixTermFun top_t $ \recurse t -> + do var_map <- lift mrVars + case t of + -- If t is an instantiated evar, recurse on its instantiation + (asEVarApp var_map -> Just (_, args, Just t')) -> + lift (mrApplyAll t' args) >>= recurse + -- If t is an uninstantiated evar, return Nothing + (asEVarApp var_map -> Just (_, _, Nothing)) -> + mzero + -- If t is anything else, recurse on its immediate subterms + _ -> traverseSubterms recurse t + +-- | Makes 'mrSubstEVarsStrict' be marked as used +_mrSubstEVarsStrict :: Term -> MRM (Maybe Term) +_mrSubstEVarsStrict = mrSubstEVarsStrict + +-- | Get the 'CoIndHyp' for a pair of 'FunName's, if there is one +mrGetCoIndHyp :: FunName -> FunName -> MRM (Maybe CoIndHyp) +mrGetCoIndHyp nm1 nm2 = Map.lookup (nm1, nm2) <$> mrCoIndHyps + +-- | Run a compuation under an additional co-inductive assumption +withCoIndHyp :: CoIndHyp -> MRM a -> MRM a +withCoIndHyp hyp m = + do debugPretty 2 ("withCoIndHyp" <+> ppInEmptyCtx hyp) + hyps' <- Map.insert (coIndHypLHSFun hyp, + coIndHypRHSFun hyp) hyp <$> mrCoIndHyps + local (\info -> info { mriCoIndHyps = hyps' }) m + +-- | Generate fresh evars for the context of a 'CoIndHyp' and +-- substitute them into its arguments and right-hand side +instantiateCoIndHyp :: CoIndHyp -> MRM ([Term], [Term]) +instantiateCoIndHyp (CoIndHyp {..}) = + do evars <- mrFreshEVars coIndHypCtx + lhs <- substTermLike 0 evars coIndHypLHS + rhs <- substTermLike 0 evars coIndHypRHS + return (lhs, rhs) + +-- | Apply the preconditions of a 'CoIndHyp' to their respective arguments, +-- yielding @Bool@ conditions, using the constant @True@ value when a +-- precondition is absent +applyCoIndHypPreconds :: CoIndHyp -> MRM (Term, Term) +applyCoIndHypPreconds hyp = + let apply_precond :: Maybe Term -> [Term] -> MRM Term + apply_precond (Just (asLambdaList -> (vars, phi))) args + | length vars == length args + -- NOTE: applying to a list of arguments == substituting the reverse + -- of that list, because the first argument corresponds to the + -- greatest deBruijn index + = substTerm 0 (reverse args) phi + apply_precond (Just _) _ = + error "applyCoIndHypPreconds: wrong number of arguments for precondition!" + apply_precond Nothing _ = liftSC1 scBool True in + do pre1 <- apply_precond (coIndHypPrecondLHS hyp) (coIndHypLHS hyp) + pre2 <- apply_precond (coIndHypPrecondRHS hyp) (coIndHypRHS hyp) + return (pre1, pre2) + +-- | Look up the 'FunAssump' for a 'FunName', if there is one +mrGetFunAssump :: FunName -> MRM (Maybe FunAssump) +mrGetFunAssump nm = Map.lookup nm <$> mrFunAssumps + +-- | Run a computation under the additional assumption that a named function +-- applied to a list of arguments refines a given right-hand side, all of which +-- are 'Term's that can have the current uvars free +withFunAssump :: FunName -> [Term] -> NormComp -> MRM a -> MRM a +withFunAssump fname args rhs m = + do mrDebugPPPrefixSep 1 "withFunAssump" (FunBind + fname args CompFunReturn) "|=" rhs + ctx <- mrUVarCtx + assumps <- mrFunAssumps + let assumps' = Map.insert fname (FunAssump ctx args rhs) assumps + local (\info -> + let env' = (mriEnv info) { mreFunAssumps = assumps' } in + info { mriEnv = env' }) m + +-- | Generate fresh evars for the context of a 'FunAssump' and substitute them +-- into its arguments and right-hand side +instantiateFunAssump :: FunAssump -> MRM ([Term], NormComp) +instantiateFunAssump fassump = + do evars <- mrFreshEVars $ fassumpCtx fassump + args <- substTermLike 0 evars $ fassumpArgs fassump + rhs <- substTermLike 0 evars $ fassumpRHS fassump + return (args, rhs) + +-- | Get the precondition hint associated with a function name, by unfolding the +-- name and checking if its body has the form +-- +-- > \ x1 ... xn -> precondHint a phi m +-- +-- If so, return @\ x1 ... xn -> phi@ as a term with the @xi@ variables free. +-- Otherwise, return 'Nothing'. +mrGetPrecond :: FunName -> MRM (Maybe Term) +mrGetPrecond nm = + mrFunNameBody nm >>= \case + Just (asLambdaList -> + (args, + asApplyAll -> (isGlobalDef "Prelude.precondHint" -> Just (), + [_, phi, _]))) -> + Just <$> liftSC2 scLambdaList args phi + _ -> return Nothing + +-- | Add an assumption of type @Bool@ to the current path condition while +-- executing a sub-computation +withAssumption :: Term -> MRM a -> MRM a +withAssumption phi m = + do mrDebugPPPrefix 1 "withAssumption" phi + assumps <- mrAssumptions + assumps' <- liftSC2 scAnd phi assumps + local (\info -> info { mriAssumptions = assumps' }) m + +-- | Remove any existing assumptions and replace them with a Boolean term +withOnlyAssumption :: Term -> MRM a -> MRM a +withOnlyAssumption phi m = + do mrDebugPPPrefix 1 "withOnlyAssumption" phi + local (\info -> info { mriAssumptions = phi }) m + +-- | Add a 'DataTypeAssump' to the current context while executing a +-- sub-computations +withDataTypeAssump :: Term -> DataTypeAssump -> MRM a -> MRM a +withDataTypeAssump x assump m = + do mrDebugPPPrefixSep 1 "withDataTypeAssump" x "==" assump + dataTypeAssumps' <- HashMap.insert x assump <$> mrDataTypeAssumps + local (\info -> info { mriDataTypeAssumps = dataTypeAssumps' }) m + +-- | Get the 'DataTypeAssump' associated to the given term, if one exists +mrGetDataTypeAssump :: Term -> MRM (Maybe DataTypeAssump) +mrGetDataTypeAssump x = HashMap.lookup x <$> mrDataTypeAssumps + +-- | Print a 'String' if the debug level is at least the supplied 'Int' +debugPrint :: Int -> String -> MRM () +debugPrint i str = + mrDebugLevel >>= \lvl -> + if lvl >= i then liftIO (hPutStrLn stderr str) else return () + +-- | Print a document if the debug level is at least the supplied 'Int' +debugPretty :: Int -> SawDoc -> MRM () +debugPretty i pp = debugPrint i $ renderSawDoc defaultPPOpts pp + +-- | Pretty-print an object in the current context if the current debug level is +-- at least the supplied 'Int' +debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM () +debugPrettyInCtx i a = + mrUVars >>= \ctx -> debugPrint i (showInCtx (map fst ctx) a) + +-- | Pretty-print an object relative to the current context +mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc +mrPPInCtx a = + runReader (prettyInCtx a) <$> map fst <$> mrUVars + +-- | Pretty-print the result of 'ppWithPrefix' relative to the current uvar +-- context to 'stderr' if the debug level is at least the 'Int' provided +mrDebugPPPrefix :: PrettyInCtx a => Int -> String -> a -> MRM () +mrDebugPPPrefix i pre a = + mrUVars >>= \ctx -> + debugPretty i $ + flip runReader (map fst ctx) (group <$> nest 2 <$> ppWithPrefix pre a) + +-- | Pretty-print the result of 'ppWithPrefixSep' relative to the current uvar +-- context to 'stderr' if the debug level is at least the 'Int' provided +mrDebugPPPrefixSep :: (PrettyInCtx a, PrettyInCtx b) => + Int -> String -> a -> String -> b -> MRM () +mrDebugPPPrefixSep i pre a1 sp a2 = + mrUVars >>= \ctx -> + debugPretty i $ + flip runReader (map fst ctx) (group <$> nest 2 <$> + ppWithPrefixSep pre a1 sp a2) diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs new file mode 100644 index 0000000000..efb196f0bc --- /dev/null +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -0,0 +1,393 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE ImplicitParams #-} + +{- | +Module : SAWScript.Prover.MRSolver.SMT +Copyright : Galois, Inc. 2022 +License : BSD3 +Maintainer : westbrook@galois.com +Stability : experimental +Portability : non-portable (language extensions) + +This module implements the interface between Mr. Solver and an SMT solver, +namely 'mrProvable' and 'mrProveEq'. +-} + +module SAWScript.Prover.MRSolver.SMT where + +import qualified Data.Vector as V +import Control.Monad.Except + +import Data.Map (Map) +import qualified Data.Map as Map +import qualified Data.Set as Set + +import Verifier.SAW.Term.Functor +import Verifier.SAW.Term.Pretty +import Verifier.SAW.SharedTerm +import Verifier.SAW.Recognizer +import Verifier.SAW.OpenTerm + +import qualified Verifier.SAW.Prim as Prim +import Verifier.SAW.Simulator.TermModel +import Verifier.SAW.Simulator.Prims +import Verifier.SAW.Simulator.MonadLazy + +import SAWScript.Proof (termToProp, propToTerm, prettyProp) +import What4.Solver +import SAWScript.Prover.What4 + +import SAWScript.Prover.MRSolver.Term +import SAWScript.Prover.MRSolver.Monad + + +---------------------------------------------------------------------- +-- * Various SMT-specific Functions on Terms +---------------------------------------------------------------------- + +-- | Test if a 'Term' is a 'BVVec' type +asBVVecType :: Recognizer Term (Term, Term, Term) +asBVVecType (asApplyAll -> + (isGlobalDef "Prelude.Vec" -> Just _, + [(asApplyAll -> + (isGlobalDef "Prelude.bvToNat" -> Just _, [n, len])), a])) = + Just (n, len, a) +asBVVecType _ = Nothing + +-- | Apply @genBVVec@ to arguments @n@, @len@, and @a@, along with a function of +-- type @Vec n Bool -> a@ +genBVVecTerm :: SharedContext -> Term -> Term -> Term -> Term -> IO Term +genBVVecTerm sc n_tm len_tm a_tm f_tm = + let n = closedOpenTerm n_tm + len = closedOpenTerm len_tm + a = closedOpenTerm a_tm + f = closedOpenTerm f_tm in + completeOpenTerm sc $ + applyOpenTermMulti (globalOpenTerm "Prelude.genBVVec") + [n, len, a, + lambdaOpenTerm "i" (vectorTypeOpenTerm n boolTypeOpenTerm) $ \i -> + lambdaOpenTerm "_" (applyGlobalOpenTerm "Prelude.is_bvult" [n, i, len]) $ \_ -> + applyOpenTerm f i] + +-- | Match a term of the form @genBVVec n len a (\ i _ -> e)@, i.e., where @e@ +-- does not have the proof variable (the underscore) free +asGenBVVecTerm :: Recognizer Term (Term, Term, Term, Term) +asGenBVVecTerm (asApplyAll -> + (isGlobalDef "Prelude.genBVVec" -> Just _, + [n, len, a, + (asLambdaList -> ([_,_], e))])) + | not $ inBitSet 0 $ looseVars e + = Just (n, len, a, e) +asGenBVVecTerm _ = Nothing + +type TmPrim = Prim TermModel + +-- | Convert a Boolean value to a 'Term'; like 'readBackValue' but that function +-- requires a 'SimulatorConfig' which we cannot easily generate here... +boolValToTerm :: SharedContext -> Value TermModel -> IO Term +boolValToTerm _ (VBool (Left tm)) = return tm +boolValToTerm sc (VBool (Right b)) = scBool sc b +boolValToTerm _ (VExtra (VExtraTerm _tp tm)) = return tm +boolValToTerm _ v = error ("boolValToTerm: unexpected value: " ++ show v) + +-- | An implementation of a primitive function that expects a @genBVVec@ term +primGenBVVec :: SharedContext -> (Term -> TmPrim) -> TmPrim +primGenBVVec sc f = + PrimFilterFun "genBVVecPrim" + (\case + VExtra (VExtraTerm _ (asGenBVVecTerm -> Just (n, _, _, e))) -> + -- Generate the function \i -> [i/1,error/0]e + lift $ + do i_tp <- scBoolType sc >>= scVecType sc n + let err_tm = error "primGenBVVec: unexpected variable occurrence" + i_tm <- scLocalVar sc 0 + body <- instantiateVarList sc 0 [err_tm,i_tm] e + scLambda sc "i" i_tp body + _ -> mzero) + f + +-- | An implementation of a primitive function that expects a bitvector term +primBVTermFun :: SharedContext -> (Term -> TmPrim) -> TmPrim +primBVTermFun sc = + PrimFilterFun "primBVTermFun" $ + \case + VExtra (VExtraTerm _ w_tm) -> return w_tm + VWord (Left (_,w_tm)) -> return w_tm + VWord (Right bv) -> + lift $ scBvConst sc (fromIntegral (Prim.width bv)) (Prim.unsigned bv) + VVector vs -> + lift $ + do tms <- traverse (boolValToTerm sc <=< force) (V.toList vs) + tp <- scBoolType sc + scVectorReduced sc tp tms + v -> lift (putStrLn ("primBVTermFun: unhandled value: " ++ show v)) >> mzero + +-- | Implementations of primitives for normalizing Mr Solver terms +smtNormPrims :: SharedContext -> Map Ident TmPrim +smtNormPrims sc = Map.fromList + [ + ("Prelude.genBVVec", + Prim (do tp <- scTypeOfGlobal sc "Prelude.genBVVec" + VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> + scGlobalDef sc "Prelude.genBVVec")), + + ("Prelude.atBVVec", + PrimFun $ \_n -> PrimFun $ \_len -> tvalFun $ \a -> + primGenBVVec sc $ \f -> primBVTermFun sc $ \ix -> PrimFun $ \_pf -> + Prim (VExtra <$> VExtraTerm a <$> scApply sc f ix) + ), + ("Prelude.CompM", + PrimFilterFun "CompM" (\case + TValue tv -> return tv + _ -> mzero) $ \tv -> + Prim (do let ?recordEC = \_ec -> return () + let cfg = error "FIXME: smtNormPrims: need the simulator config" + tv_trm <- readBackTValue sc cfg tv + TValue <$> VTyTerm (mkSort 0) <$> + scGlobalApply sc "Prelude.CompM" [tv_trm])) + ] + +-- | Normalize a 'Term' using some Mr Solver specific primitives +mrNormTerm :: Term -> MRM Term +mrNormTerm t = + debugPrint 2 "Normalizing term:" >> + debugPrettyInCtx 2 t >> + liftSC0 return >>= \sc -> + liftSC0 scGetModuleMap >>= \modmap -> + liftSC5 normalizeSharedTerm modmap (smtNormPrims sc) Map.empty Set.empty t + +-- | Normalize an open term by wrapping it in lambdas, normalizing, and then +-- removing those lambdas +mrNormOpenTerm :: Term -> MRM Term +mrNormOpenTerm body = + do ctx <- mrUVarCtx + fun_term <- liftSC2 scLambdaList ctx body + normed_fun <- mrNormTerm fun_term + return (peel_lambdas (length ctx) normed_fun) + where + peel_lambdas :: Int -> Term -> Term + peel_lambdas 0 t = t + peel_lambdas i (asLambda -> Just (_, _, t)) = peel_lambdas (i-1) t + peel_lambdas _ _ = error "mrNormOpenTerm: unexpected non-lambda term!" + + +---------------------------------------------------------------------- +-- * Checking Provability with SMT +---------------------------------------------------------------------- + +-- | Test if a closed Boolean term is "provable", i.e., its negation is +-- unsatisfiable, using an SMT solver. By "closed" we mean that it contains no +-- uvars or 'MRVar's. +-- +-- FIXME: use the timeout! +mrProvableRaw :: Term -> MRM Bool +mrProvableRaw prop_term = + do sc <- mrSC + prop <- liftSC1 termToProp prop_term + unints <- Set.map ecVarIndex <$> getAllExtSet <$> liftSC1 propToTerm prop + debugPrint 2 ("Calling SMT solver with proposition: " ++ + prettyProp defaultPPOpts prop) + sym <- liftIO $ setupWhat4_sym True + (smt_res, _) <- + liftIO $ proveWhat4_solver z3Adapter sym unints sc prop (return ()) + case smt_res of + Just _ -> + debugPrint 2 "SMT solver response: not provable" >> return False + Nothing -> + debugPrint 2 "SMT solver response: provable" >> return True + +-- | Test if a Boolean term over the current uvars is provable given the current +-- assumptions +mrProvable :: Term -> MRM Bool +mrProvable (asBool -> Just b) = return b +mrProvable bool_tm = + do assumps <- mrAssumptions + prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue + prop_inst <- instantiateUVarsM instUVar prop + mrNormTerm prop_inst >>= mrProvableRaw + where -- | Given a UVar name and type, generate a 'Term' to be passed to + -- SMT, with special cases for BVVec and pair types + instUVar :: LocalName -> Term -> MRM Term + instUVar nm tp = liftSC1 scWhnf tp >>= \case + -- For variables of type BVVec, create a @Vec n Bool -> a@ function + -- as an ExtCns and apply genBVVec to it + (asBVVecType -> Just (n, len, a)) -> do + ec_tp <- + liftSC1 completeOpenTerm $ + arrowOpenTerm "_" (applyOpenTermMulti (globalOpenTerm "Prelude.Vec") + [closedOpenTerm n, boolTypeOpenTerm]) + (closedOpenTerm a) + ec <- instUVar nm ec_tp + liftSC4 genBVVecTerm n len a ec + -- For pairs, recurse on both sides and combine the result as a pair + (asPairType -> Just (tp1, tp2)) -> do + e1 <- instUVar nm tp1 + e2 <- instUVar nm tp2 + liftSC2 scPairValue e1 e2 + -- Otherwise, create a global variable with the given name and type + tp' -> liftSC2 scFreshEC nm tp' >>= liftSC1 scExtCns + + +---------------------------------------------------------------------- +-- * Checking Equality with SMT +---------------------------------------------------------------------- + +-- | Build a Boolean 'Term' stating that two 'Term's are equal. This is like +-- 'scEq' except that it works on open terms. +mrEq :: Term -> Term -> MRM Term +mrEq t1 t2 = mrTypeOf t1 >>= \tp -> mrEq' tp t1 t2 + +-- | Build a Boolean 'Term' stating that the second and third 'Term' arguments +-- are equal, where the first 'Term' gives their type (which we assume is the +-- same for both). This is like 'scEq' except that it works on open terms. +mrEq' :: Term -> Term -> Term -> MRM Term +mrEq' (asDataType -> Just (pn, [])) t1 t2 + | primName pn == "Prelude.Nat" = liftSC2 scEqualNat t1 t2 +mrEq' (asBoolType -> Just _) t1 t2 = liftSC2 scBoolEq t1 t2 +mrEq' (asIntegerType -> Just _) t1 t2 = liftSC2 scIntEq t1 t2 +mrEq' (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = + liftSC3 scBvEq n t1 t2 +mrEq' _ _ _ = error "mrEq': unsupported type" + +-- | A 'Term' in an extended context of universal variables, which are listed +-- "outside in", meaning the highest deBruijn index comes first +data TermInCtx = TermInCtx [(LocalName,Term)] Term + +-- | Conjoin two 'TermInCtx's, assuming they both have Boolean type +andTermInCtx :: TermInCtx -> TermInCtx -> MRM TermInCtx +andTermInCtx (TermInCtx ctx1 t1) (TermInCtx ctx2 t2) = + do + -- Insert the variables in ctx2 into the context of t1 starting at index 0, + -- by lifting its variables starting at 0 by length ctx2 + t1' <- liftTermLike 0 (length ctx2) t1 + -- Insert the variables in ctx1 into the context of t1 starting at index + -- length ctx2, by lifting its variables starting at length ctx2 by length + -- ctx1 + t2' <- liftTermLike (length ctx2) (length ctx1) t2 + TermInCtx (ctx1++ctx2) <$> liftSC2 scAnd t1' t2' + +-- | Extend the context of a 'TermInCtx' with additional universal variables +-- bound "outside" the 'TermInCtx' +extTermInCtx :: [(LocalName,Term)] -> TermInCtx -> TermInCtx +extTermInCtx ctx (TermInCtx ctx' t) = TermInCtx (ctx++ctx') t + +-- | Run an 'MRM' computation in the context of a 'TermInCtx', passing in the +-- 'Term' +withTermInCtx :: TermInCtx -> (Term -> MRM a) -> MRM a +withTermInCtx (TermInCtx [] tm) f = f tm +withTermInCtx (TermInCtx ((nm,tp):ctx) tm) f = + withUVar nm (Type tp) $ const $ withTermInCtx (TermInCtx ctx tm) f + +-- | A "simple" strategy for proving equality between two terms, which we assume +-- are of the same type, which builds an equality proposition by applying the +-- supplied function to both sides and passes this proposition to an SMT solver. +mrProveEqSimple :: (Term -> Term -> MRM Term) -> Term -> Term -> + MRM TermInCtx +-- NOTE: The use of mrSubstEVars instead of mrSubstEVarsStrict means that we +-- allow evars in the terms we send to the SMT solver, but we treat them as +-- uvars. +mrProveEqSimple eqf t1 t2 = + do t1' <- mrSubstEVars t1 + t2' <- mrSubstEVars t2 + TermInCtx [] <$> eqf t1' t2' + +-- | Prove that two terms are equal, instantiating evars if necessary, +-- returning true on success +mrProveEq :: Term -> Term -> MRM Bool +mrProveEq t1 t2 = + do mrDebugPPPrefixSep 1 "mrProveEq" t1 "==" t2 + tp <- mrTypeOf t1 >>= mrSubstEVars + varmap <- mrVars + cond_in_ctx <- mrProveEqH varmap tp t1 t2 + res <- withTermInCtx cond_in_ctx mrProvable + debugPrint 1 $ "mrProveEq: " ++ if res then "Success" else "Failure" + return res + +-- | Prove that two terms are equal, instantiating evars if necessary, or +-- throwing an error if this is not possible +mrAssertProveEq :: Term -> Term -> MRM () +mrAssertProveEq t1 t2 = + do success <- mrProveEq t1 t2 + if success then return () else + throwMRFailure (TermsNotEq t1 t2) + +-- | The main workhorse for 'mrProveEq'. Build a Boolean term expressing that +-- the third and fourth arguments, whose type is given by the second. +mrProveEqH :: Map MRVar MRVarInfo -> Term -> Term -> Term -> MRM TermInCtx + +{- +mrProveEqH _ _ t1 t2 + | trace ("mrProveEqH:\n" ++ showTerm t1 ++ "\n==\n" ++ showTerm t2) False = undefined +-} + +-- If t1 is an instantiated evar, substitute and recurse +mrProveEqH var_map tp (asEVarApp var_map -> Just (_, args, Just f)) t2 = + mrApplyAll f args >>= \t1' -> mrProveEqH var_map tp t1' t2 + +-- If t1 is an uninstantiated evar, instantiate it with t2 +mrProveEqH var_map _tp (asEVarApp var_map -> Just (evar, args, Nothing)) t2 = + do t2' <- mrSubstEVars t2 + success <- mrTrySetAppliedEVar evar args t2' + TermInCtx [] <$> liftSC1 scBool success + +-- If t2 is an instantiated evar, substitute and recurse +mrProveEqH var_map tp t1 (asEVarApp var_map -> Just (_, args, Just f)) = + mrApplyAll f args >>= \t2' -> mrProveEqH var_map tp t1 t2' + +-- If t2 is an uninstantiated evar, instantiate it with t1 +mrProveEqH var_map _tp t1 (asEVarApp var_map -> Just (evar, args, Nothing)) = + do t1' <- mrSubstEVars t1 + success <- mrTrySetAppliedEVar evar args t1' + TermInCtx [] <$> liftSC1 scBool success + +-- For unit types, always return true +mrProveEqH _ (asTupleType -> Just []) _ _ = + TermInCtx [] <$> liftSC1 scBool True + +-- For the nat, bitvector, Boolean, and integer types, call mrProveEqSimple +mrProveEqH _ (asDataType -> Just (pn, [])) t1 t2 + | primName pn == "Prelude.Nat" = + mrProveEqSimple (liftSC2 scEqualNat) t1 t2 +mrProveEqH _ (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = + -- FIXME: make a better solver for bitvector equalities + mrProveEqSimple (liftSC3 scBvEq n) t1 t2 +mrProveEqH _ (asBoolType -> Just _) t1 t2 = + mrProveEqSimple (liftSC2 scBoolEq) t1 t2 +mrProveEqH _ (asIntegerType -> Just _) t1 t2 = + mrProveEqSimple (liftSC2 scIntEq) t1 t2 + +-- For pair types, prove both the left and right projections are equal +mrProveEqH var_map (asPairType -> Just (tpL, tpR)) t1 t2 = + do t1L <- liftSC1 scPairLeft t1 + t2L <- liftSC1 scPairLeft t2 + t1R <- liftSC1 scPairRight t1 + t2R <- liftSC1 scPairRight t2 + condL <- mrProveEqH var_map tpL t1L t2L + condR <- mrProveEqH var_map tpR t1R t2R + andTermInCtx condL condR + +-- For non-bitvector vector types, prove all projections are equal by +-- quantifying over a universal index variable and proving equality at that +-- index +mrProveEqH _ (asBVVecType -> Just (n, len, tp)) t1 t2 = + liftSC0 scBoolType >>= \bool_tp -> + liftSC2 scVecType n bool_tp >>= \ix_tp -> + withUVarLift "eq_ix" (Type ix_tp) (n,(len,(tp,(t1,t2)))) $ + \ix' (n',(len',(tp',(t1',t2')))) -> + liftSC2 scGlobalApply "Prelude.is_bvult" [n', ix', len'] >>= \pf_tp -> + withUVarLift "eq_pf" (Type pf_tp) (n',(len',(tp',(ix',(t1',t2'))))) $ + \pf'' (n'',(len'',(tp'',(ix'',(t1'',t2''))))) -> + do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', + t1'', ix'', pf''] + t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', + t2'', ix'', pf''] + var_map <- mrVars + extTermInCtx [("eq_ix",ix_tp),("eq_pf",pf_tp)] <$> + mrProveEqH var_map tp'' t1_prj t2_prj + +-- As a fallback, for types we can't handle, just check convertibility +mrProveEqH _ _ t1 t2 = + do success <- mrConvertible t1 t2 + TermInCtx [] <$> liftSC1 scBool success diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs new file mode 100644 index 0000000000..f6b711a1e7 --- /dev/null +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -0,0 +1,826 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ViewPatterns #-} + +{- | +Module : SAWScript.Prover.MRSolver.Solver +Copyright : Galois, Inc. 2022 +License : BSD3 +Maintainer : westbrook@galois.com +Stability : experimental +Portability : non-portable (language extensions) + +This module implements a monadic-recursive solver, for proving that one monadic +term refines another. The algorithm works on the "monadic normal form" of +computations, which uses the following laws to simplify binds in computations, +where either is the sum elimination function defined in the SAW core prelude: + +returnM x >>= k = k x +errorM str >>= k = errorM +(m >>= k1) >>= k2 = m >>= \x -> k1 x >>= k2 +(existsM f) >>= k = existsM (\x -> f x >>= k) +(forallM f) >>= k = forallM (\x -> f x >>= k) +(orM m1 m2) >>= k = orM (m1 >>= k) (m2 >>= k) +(if b then m1 else m2) >>= k = if b then m1 >>= k else m2 >>1 k +(either f1 f2 e) >>= k = either (\x -> f1 x >= k) (\x -> f2 x >= k) e +(letrecM funs body) >>= k = letrecM funs (\F1 ... Fn -> body F1 ... Fn >>= k) + +The resulting computations of one of the following forms: + +returnM e | errorM str | existsM f | forallM f | orM m1 m2 | +if b then m1 else m2 | either f1 f2 e | F e1 ... en | F e1 ... en >>= k | +letrecM lrts B (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> m) + +The form F e1 ... en refers to a recursively-defined function or a function +variable that has been locally bound by a letrecM. Either way, monadic +normalization does not attempt to normalize these functions. + +The algorithm maintains a context of three sorts of variables: letrec-bound +variables, existential variables, and universal variables. Universal variables +are represented as free SAW core variables, while the other two forms of +variable are represented as SAW core 'ExtCns's terms, which are essentially +axioms that have been generated internally. These 'ExtCns's are Skolemized, +meaning that they take in as arguments all universal variables that were in +scope when they were created. The context also maintains a partial substitution +for the existential variables, as they become instantiated with values, and it +additionally remembers the bodies / unfoldings of the letrec-bound variables. + +The goal of the solver at any point is of the form C |- m1 |= m2, meaning that +we are trying to prove m1 refines m2 in context C. This proceed by cases: + +C |- returnM e1 |= returnM e2: prove C |- e1 = e2 + +C |- errorM str1 |= errorM str2: vacuously true + +C |- if b then m1' else m1'' |= m2: prove C,b=true |- m1' |= m2 and +C,b=false |- m1'' |= m2, skipping either case where C,b=X is unsatisfiable; + +C |- m1 |= if b then m2' else m2'': similar to the above + +C |- either T U (CompM V) f1 f2 e |= m: prove C,x:T,e=inl x |- f1 x |= m and +C,y:U,e=inl y |- f2 y |= m, again skippping any case with unsatisfiable context; + +C |- m |= either T U (CompM V) f1 f2 e: similar to previous + +C |- m |= forallM f: make a new universal variable x and recurse + +C |- existsM f |= m: make a new universal variable x and recurse (existential +elimination uses universal variables and vice-versa) + +C |- m |= existsM f: make a new existential variable x and recurse + +C |- forall f |= m: make a new existential variable x and recurse + +C |- m |= orM m1 m2: try to prove C |- m |= m1, and if that fails, backtrack and +prove C |- m |= m2 + +C |- orM m1 m2 |= m: prove both C |- m1 |= m and C |- m2 |= m + +C |- letrec (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body) |= m: create +letrec-bound variables F1 through Fn in the context bound to their unfoldings f1 +through fn, respectively, and recurse on body |= m + +C |- m |= letrec (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body): similar to +previous case + +C |- F e1 ... en >>= k |= F e1' ... en' >>= k': prove C |- ei = ei' for each i +and then prove k x |= k' x for new universal variable x + +C |- F e1 ... en >>= k |= F' e1' ... em' >>= k': + +* If we have an assumption that forall x1 ... xj, F a1 ... an |= F' a1' .. am', + prove ei = ai and ei' = ai' and then that C |- k x |= k' x for fresh uvar x + +* If we have an assumption that forall x1, ..., xn, F e1'' ... en'' |= m' for + some ei'' and m', match the ei'' against the ei by instantiating the xj with + fresh evars, and if this succeeds then recursively prove C |- m' >>= k |= RHS + +(We don't do this one right now) +* If we have an assumption that forall x1', ..., xn', m |= F e1'' ... en'' for + some ei'' and m', match the ei'' against the ei by instantiating the xj with + fresh evars, and if this succeeds then recursively prove C |- LHS |= m' >>= k' + +* If either side is a definition whose unfolding does not contain letrecM, fixM, + or any related operations, unfold it + +* If F and F' have the same return type, add an assumption forall uvars in scope + that F e1 ... en |= F' e1' ... em' and unfold both sides, recursively proving + that F_body e1 ... en |= F_body' e1' ... em'. Then also prove k x |= k' x for + fresh uvar x. + +* Otherwise we don't know to "split" one of the sides into a bind whose + components relate to the two components on the other side, so just fail +-} + +module SAWScript.Prover.MRSolver.Solver where + +import Data.Maybe +import Data.Either +import Data.List (findIndices, intercalate) +import Control.Monad.Except + +import Prettyprinter + +import Verifier.SAW.Term.Functor +import Verifier.SAW.SharedTerm +import Verifier.SAW.Recognizer + +import SAWScript.Prover.MRSolver.Term +import SAWScript.Prover.MRSolver.Monad +import SAWScript.Prover.MRSolver.SMT + + +---------------------------------------------------------------------- +-- * Normalizing and Matching on Terms +---------------------------------------------------------------------- + +-- | Pattern-match on a @LetRecTypes@ list in normal form and return a list of +-- the types it specifies, each in normal form and with uvars abstracted out +asLRTList :: Term -> MRM [Term] +asLRTList (asCtor -> Just (primName -> "Prelude.LRT_Nil", [])) = + return [] +asLRTList (asCtor -> Just (primName -> "Prelude.LRT_Cons", [lrt, lrts])) = + do tp <- liftSC2 scGlobalApply "Prelude.lrtToType" [lrt] + tp_norm_closed <- liftSC1 scWhnf tp >>= piUVarsM + (tp_norm_closed :) <$> asLRTList lrts +asLRTList t = throwMRFailure (MalformedLetRecTypes t) + +-- | Match a right-nested series of pairs. This is similar to 'asTupleValue' +-- except that it expects a unit value to always be at the end. +asNestedPairs :: Recognizer Term [Term] +asNestedPairs (asPairValue -> Just (x, asNestedPairs -> Just xs)) = Just (x:xs) +asNestedPairs (asFTermF -> Just UnitValue) = Just [] +asNestedPairs _ = Nothing + +-- | Bind fresh function variables for a @letRecM@ or @multiFixM@ with the given +-- @LetRecTypes@ and definitions for the function bodies as a lambda +mrFreshLetRecVars :: Term -> Term -> MRM [Term] +mrFreshLetRecVars lrts defs_f = + do + -- First, make fresh function constants for all the bound functions, using + -- the names bound by defs_f and just "F" if those run out + let fun_var_names = + map fst (fst $ asLambdaList defs_f) ++ repeat "F" + fun_tps <- asLRTList lrts + funs <- zipWithM mrFreshVar fun_var_names fun_tps + fun_tms <- mapM mrVarTerm funs + + -- Next, apply the definition function defs_f to our function vars, yielding + -- the definitions of the individual letrec-bound functions in terms of the + -- new function constants + defs_tm <- mrApplyAll defs_f fun_tms + defs <- case asNestedPairs defs_tm of + Just defs -> return defs + Nothing -> throwMRFailure (MalformedDefsFun defs_f) + + -- Remember the body associated with each fresh function constant + zipWithM_ (\f body -> + lambdaUVarsM body >>= \cl_body -> + mrSetVarInfo f (FunVarInfo cl_body)) funs defs + + -- Finally, return the terms for the fresh function variables + return fun_tms + + +-- | Normalize a 'Term' of monadic type to monadic normal form +normCompTerm :: Term -> MRM NormComp +normCompTerm = normComp . CompTerm + +-- | Normalize a computation to monadic normal form, assuming any 'Term's it +-- contains have already been normalized with respect to beta and projections +-- (but constants need not be unfolded) +normComp :: Comp -> MRM NormComp +normComp (CompReturn t) = return $ ReturnM t +normComp (CompBind m f) = + do norm <- normComp m + normBind norm f +normComp (CompTerm t) = + withFailureCtx (FailCtxMNF t) $ + case asApplyAll t of + (f@(asLambda -> Just _), args) -> + mrApplyAll f args >>= normCompTerm + (isGlobalDef "Prelude.returnM" -> Just (), [_, x]) -> + return $ ReturnM x + (isGlobalDef "Prelude.bindM" -> Just (), [_, _, m, f]) -> + do norm <- normComp (CompTerm m) + normBind norm (CompFunTerm f) + (isGlobalDef "Prelude.errorM" -> Just (), [_, str]) -> + return (ErrorM str) + (isGlobalDef "Prelude.ite" -> Just (), [_, cond, then_tm, else_tm]) -> + return $ Ite cond (CompTerm then_tm) (CompTerm else_tm) + (isGlobalDef "Prelude.either" -> Just (), [ltp, rtp, _, f, g, eith]) -> + return $ Either (Type ltp) (Type rtp) (CompFunTerm f) (CompFunTerm g) eith + (isGlobalDef "Prelude.maybe" -> Just (), [tp, _, m, f, mayb]) -> + return $ MaybeElim (Type tp) (CompTerm m) (CompFunTerm f) mayb + (isGlobalDef "Prelude.orM" -> Just (), [_, m1, m2]) -> + return $ OrM (CompTerm m1) (CompTerm m2) + (isGlobalDef "Prelude.existsM" -> Just (), [tp, _, body_tm]) -> + return $ ExistsM (Type tp) (CompFunTerm body_tm) + (isGlobalDef "Prelude.forallM" -> Just (), [tp, _, body_tm]) -> + return $ ForallM (Type tp) (CompFunTerm body_tm) + (isGlobalDef "Prelude.precondHint" -> Just (), [_, _, body_tm]) -> + normCompTerm body_tm + (isGlobalDef "Prelude.letRecM" -> Just (), [lrts, _, defs_f, body_f]) -> + do + -- Bind fresh function vars for the letrec-bound functions + fun_tms <- mrFreshLetRecVars lrts defs_f + -- Apply the body function to our function vars and recursively + -- normalize the resulting computation + body_tm <- mrApplyAll body_f fun_tms + normComp (CompTerm body_tm) + + -- Only unfold constants that are not recursive functions, i.e., whose + -- bodies do not contain letrecs + {- FIXME: this should be handled by mrRefines; we want it to be handled there + so that we use refinement assumptions before unfolding constants, to give + the user control over refinement proofs + ((asConstant -> Just (_, body)), args) + | not (containsLetRecM body) -> + mrApplyAll body args >>= normCompTerm + -} + + -- Recognize and unfold a multiArgFixM + (f@(isGlobalDef "Prelude.multiArgFixM" -> Just ()), args) + | Just (_, Just body) <- asConstant f -> + mrApplyAll body args >>= normCompTerm + + -- Recognize (multiFixM lrts (\ f1 ... fn -> (body1, ..., bodyn))).i args + (asTupleSelector -> + Just (asApplyAll -> (isGlobalDef "Prelude.multiFixM" -> Just (), + [lrts, defs_f]), + i), args) -> + do + -- Bind fresh function variables for the functions f1 ... fn + fun_tms <- mrFreshLetRecVars lrts defs_f + -- Apply fi to the top-level arguments, keeping in mind that tuple + -- selectors are one-based, not zero-based, so we subtract 1 from i + body_tm <- + if i > 0 && i <= length fun_tms then + mrApplyAll (fun_tms !! (i-1)) args + else throwMRFailure (MalformedComp t) + normComp (CompTerm body_tm) + + + -- For an ExtCns, we have to check what sort of variable it is + -- FIXME: substitute for evars if they have been instantiated + ((asExtCns -> Just ec), args) -> + do fun_name <- extCnsToFunName ec + return $ FunBind fun_name args CompFunReturn + + ((asGlobalFunName -> Just f), args) -> + return $ FunBind f args CompFunReturn + + _ -> throwMRFailure (MalformedComp t) + + +-- | Bind a computation in whnf with a function, and normalize +normBind :: NormComp -> CompFun -> MRM NormComp +normBind (ReturnM t) k = applyNormCompFun k t +normBind (ErrorM msg) _ = return (ErrorM msg) +normBind (Ite cond comp1 comp2) k = + return $ Ite cond (CompBind comp1 k) (CompBind comp2 k) +normBind (Either ltp rtp f g t) k = + return $ Either ltp rtp (compFunComp f k) (compFunComp g k) t +normBind (MaybeElim tp m f t) k = + return $ MaybeElim tp (CompBind m k) (compFunComp f k) t +normBind (OrM comp1 comp2) k = + return $ OrM (CompBind comp1 k) (CompBind comp2 k) +normBind (ExistsM tp f) k = return $ ExistsM tp (compFunComp f k) +normBind (ForallM tp f) k = return $ ForallM tp (compFunComp f k) +normBind (FunBind f args k1) k2 = + return $ FunBind f args (compFunComp k1 k2) + +-- | Bind a 'Term' for a computation with a function and normalize +normBindTerm :: Term -> CompFun -> MRM NormComp +normBindTerm t f = normCompTerm t >>= \m -> normBind m f + +-- | Apply a computation function to a term argument to get a computation +applyCompFun :: CompFun -> Term -> MRM Comp +applyCompFun (CompFunComp f g) t = + -- (f >=> g) t == f t >>= g + do comp <- applyCompFun f t + return $ CompBind comp g +applyCompFun CompFunReturn t = + return $ CompReturn t +applyCompFun (CompFunTerm f) t = CompTerm <$> mrApplyAll f [t] + +-- | Apply a 'CompFun' to a term and normalize the resulting computation +applyNormCompFun :: CompFun -> Term -> MRM NormComp +applyNormCompFun f arg = applyCompFun f arg >>= normComp + +-- | Apply a 'Comp + +{- FIXME: do these go away? +-- | Lookup the definition of a function or throw a 'CannotLookupFunDef' if this is +-- not allowed, either because it is a global function we are treating as opaque +-- or because it is a locally-bound function variable +mrLookupFunDef :: FunName -> MRM Term +mrLookupFunDef f@(GlobalName _) = throwMRFailure (CannotLookupFunDef f) +mrLookupFunDef f@(LocalName var) = + mrVarInfo var >>= \case + Just (FunVarInfo body) -> return body + Just _ -> throwMRFailure (CannotLookupFunDef f) + Nothing -> error "mrLookupFunDef: unknown variable!" + +-- | Unfold a call to function @f@ in term @f args >>= g@ +mrUnfoldFunBind :: FunName -> [Term] -> Mark -> CompFun -> MRM Comp +mrUnfoldFunBind f _ mark _ | inMark f mark = throwMRFailure (RecursiveUnfold f) +mrUnfoldFunBind f args mark g = + do f_def <- mrLookupFunDef f + CompBind <$> + (CompMark <$> (CompTerm <$> liftSC2 scApplyAll f_def args) + <*> (return $ singleMark f `mappend` mark)) + <*> return g +-} + +{- +FIXME HERE NOW: maybe each FunName should stipulate whether it is recursive or +not, so that mrRefines can unfold the non-recursive ones early but wait on +handling the recursive ones +-} + + +---------------------------------------------------------------------- +-- * Handling Coinductive Hypotheses +---------------------------------------------------------------------- + +-- | Prove the precondition of a coinductive hypothesis +proveCoIndHypPreCond :: CoIndHyp -> MRM () +proveCoIndHypPreCond hyp = + do (pre1, pre2) <- applyCoIndHypPreconds hyp + pre <- liftSC2 scAnd pre1 pre2 + success <- mrProvable pre + if success then return () else + throwMRFailure $ + PrecondNotProvable (coIndHypLHSFun hyp) (coIndHypRHSFun hyp) pre + +-- | Co-inductively prove the refinement +-- +-- > forall x1, ..., xn. preF y1 ... ym -> preG z1 ... zl -> +-- > F y1 ... ym |= G z1 ... zl@ +-- +-- where @F@ and @G@ are the given 'FunName's, @y1, ..., ym@ and @z1, ..., zl@ +-- are the given argument lists, @x1, ..., xn@ is the current context of uvars, +-- and @preF@ and @preG@ are the preconditions associated with @F@ and @G@, +-- respectively. This proof is performed by coinductively assuming the +-- refinement holds and proving the refinement with the definitions of @F@ and +-- @G@ unfolded to their bodies. Note that this refinement is performed with +-- /only/ the preconditions @preF@ and @preG@ as assumptions; all other +-- assumptions are thrown away. If while running the refinement computation a +-- 'CoIndHypMismatchWidened' error is reached with the given names, the state is +-- restored and the computation is re-run with the widened hypothesis. +mrRefinesCoInd :: FunName -> [Term] -> FunName -> [Term] -> MRM () +mrRefinesCoInd f1 args1 f2 args2 = + do ctx <- mrUVarCtx + preF1 <- mrGetPrecond f1 + preF2 <- mrGetPrecond f2 + let hyp = CoIndHyp ctx f1 f2 args1 args2 preF1 preF2 + proveCoIndHypPreCond hyp + proveCoIndHyp hyp + +-- | Prove the refinement represented by a 'CoIndHyp' coinductively. This is the +-- main loop implementing 'mrRefinesCoInd'. See that function for documentation. +proveCoIndHyp :: CoIndHyp -> MRM () +proveCoIndHyp hyp = + do let f1 = coIndHypLHSFun hyp + f2 = coIndHypRHSFun hyp + args1 = coIndHypLHS hyp + args2 = coIndHypRHS hyp + debugPretty 1 ("proveCoIndHyp" <+> ppInEmptyCtx hyp) + lhs <- fromMaybe (error "proveCoIndHyp") <$> mrFunBody f1 args1 + rhs <- fromMaybe (error "proveCoIndHyp") <$> mrFunBody f2 args2 + (pre1, pre2) <- applyCoIndHypPreconds hyp + pre <- liftSC2 scAnd pre1 pre2 + (withOnlyUVars (coIndHypCtx hyp) $ withOnlyAssumption pre $ + withCoIndHyp hyp $ mrRefines lhs rhs) `catchError` \case + MRExnWiden nm1' nm2' new_vars + | f1 == nm1' && f2 == nm2' -> + -- NOTE: the state automatically gets reset here because we defined + -- MRM with ExceptT at a lower level than StateT + do mrDebugPPPrefixSep 1 "Widening recursive assumption for" nm1' "|=" nm2' + debugPrint 2 ("Widening indices: " ++ + intercalate ", " (map show new_vars)) + hyp' <- generalizeCoIndHyp hyp new_vars + proveCoIndHyp hyp' + e -> throwError e + + +-- | Test that a coinductive hypothesis for the given function names matches the +-- given arguments, otherwise throw an exception saying that widening is needed +matchCoIndHyp :: CoIndHyp -> [Term] -> [Term] -> MRM () +matchCoIndHyp hyp args1 args2 = + do (args1', args2') <- instantiateCoIndHyp hyp + eqs1 <- zipWithM mrProveEq args1' args1 + eqs2 <- zipWithM mrProveEq args2' args2 + if and (eqs1 ++ eqs2) then return () else + throwError $ MRExnWiden (coIndHypLHSFun hyp) (coIndHypRHSFun hyp) + (map Left (findIndices not eqs1) ++ map Right (findIndices not eqs2)) + proveCoIndHypPreCond hyp + + +-- | Generalize some of the arguments of a coinductive hypothesis +generalizeCoIndHyp :: CoIndHyp -> [Either Int Int] -> MRM CoIndHyp +generalizeCoIndHyp hyp [] = return hyp +generalizeCoIndHyp hyp all_specs@(arg_spec:arg_specs) = + withOnlyUVars (coIndHypCtx hyp) $ do + mrDebugPPPrefixSep 2 "generalizeCoIndHyp" hyp "with arg specs" (show all_specs) + -- Get the arg and type associated with arg_spec + let arg = coIndHypArg hyp arg_spec + arg_tp <- mrTypeOf arg + ctx <- mrUVarCtx + debugPretty 2 ("Current context: " <> ppCtx ctx) + -- Sort out the other args that equal arg + eq_uneq_specs <- forM arg_specs $ \spec' -> + do let arg' = coIndHypArg hyp spec' + tp' <- mrTypeOf arg' + mrDebugPPPrefixSep 2 "generalizeCoIndHyp: the type of" arg' "is" tp' + tps_eq <- mrConvertible arg_tp tp' + args_eq <- if tps_eq then mrProveEq arg arg' else return False + return $ if args_eq then Left spec' else Right spec' + let (eq_specs, uneq_specs) = partitionEithers eq_uneq_specs + -- Add a new variable of type arg_tp, set all eq_specs plus our original + -- arg_spec to it, and recurse + hyp' <- generalizeCoIndHypArgs hyp arg_tp (arg_spec:eq_specs) + generalizeCoIndHyp hyp' uneq_specs + +-- | Add a new variable of the given type to the context of a coinductive +-- hypothesis and set the specified arguments to that new variable +generalizeCoIndHypArgs :: CoIndHyp -> Term -> [Either Int Int] -> MRM CoIndHyp +generalizeCoIndHypArgs (CoIndHyp ctx f1 f2 args1 args2 pre1 pre2) tp specs = + do let set_arg i args = + take i args ++ (Unshared $ LocalVar 0) : drop (i+1) args + let (specs1, specs2) = partitionEithers specs + -- NOTE: need to lift the arguments because we are adding a variable + args1' <- liftTermLike 0 1 args1 + args2' <- liftTermLike 0 1 args2 + let args1'' = foldr set_arg args1' specs1 + args2'' = foldr set_arg args2' specs2 + return $ CoIndHyp (ctx ++ [("z",tp)]) f1 f2 args1'' args2'' pre1 pre2 + + +---------------------------------------------------------------------- +-- * Mr Solver Himself (He Identifies as Male) +---------------------------------------------------------------------- + +-- | An object that can be converted to a normalized computation +class ToNormComp a where + toNormComp :: a -> MRM NormComp + +instance ToNormComp NormComp where + toNormComp = return +instance ToNormComp Comp where + toNormComp = normComp +instance ToNormComp Term where + toNormComp = normComp . CompTerm + +-- | Prove that the left-hand computation refines the right-hand one. See the +-- rules described at the beginning of this module. +mrRefines :: (ToNormComp a, ToNormComp b) => a -> b -> MRM () +mrRefines t1 t2 = + do m1 <- toNormComp t1 + m2 <- toNormComp t2 + mrDebugPPPrefixSep 1 "mrRefines" m1 "|=" m2 + withFailureCtx (FailCtxRefines m1 m2) $ mrRefines' m1 m2 + +-- | The main implementation of 'mrRefines' +mrRefines' :: NormComp -> NormComp -> MRM () + +mrRefines' (ReturnM e1) (ReturnM e2) = mrAssertProveEq e1 e2 +mrRefines' (ErrorM _) (ErrorM _) = return () +mrRefines' (ReturnM e) (ErrorM _) = throwMRFailure (ReturnNotError e) +mrRefines' (ErrorM _) (ReturnM e) = throwMRFailure (ReturnNotError e) + +-- A maybe eliminator on an equality type on the left +mrRefines' (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m1 f1 _) m2 = + do cond <- mrEq' tp e1 e2 + not_cond <- liftSC1 scNot cond + cond_pf <- liftSC1 scEqTrue cond >>= mrDummyProof + m1' <- applyNormCompFun f1 cond_pf + cond_holds <- mrProvable cond + if cond_holds then mrRefines m1' m2 else + withAssumption cond (mrRefines m1' m2) >> + withAssumption not_cond (mrRefines m1 m2) + +-- A maybe eliminator on an equality type on the right +mrRefines' m1 (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m2 f2 _) = + do cond <- mrEq' tp e1 e2 + not_cond <- liftSC1 scNot cond + cond_pf <- liftSC1 scEqTrue cond >>= mrDummyProof + m2' <- applyNormCompFun f2 cond_pf + cond_holds <- mrProvable cond + if cond_holds then mrRefines m1 m2' else + withAssumption cond (mrRefines m1 m2') >> + withAssumption not_cond (mrRefines m1 m2) + +-- A maybe eliminator on an isFinite type on the left +mrRefines' (MaybeElim (Type (asIsFinite -> Just n1)) m1 f1 _) m2 = + do n1_norm <- mrNormOpenTerm n1 + maybe_assump <- mrGetDataTypeAssump n1_norm + fin_pf <- + liftSC2 scGlobalApply "CryptolM.isFinite" [n1_norm] >>= mrDummyProof + case (maybe_assump, asNum n1_norm) of + (_, Just (Left _)) -> applyNormCompFun f1 fin_pf >>= flip mrRefines m2 + (_, Just (Right _)) -> mrRefines m1 m2 + (Just (IsNum _), _) -> applyNormCompFun f1 fin_pf >>= flip mrRefines m2 + (Just IsInf, _) -> mrRefines m1 m2 + _ -> + withDataTypeAssump n1_norm IsInf (mrRefines m1 m2) >> + liftSC0 scNatType >>= \nat_tp -> + (withUVarLift "n" (Type nat_tp) (n1_norm, f1, m2) $ \ n (n1', f1', m2') -> + withDataTypeAssump n1' (IsNum n) + (applyNormCompFun f1' n >>= flip mrRefines m2')) + +-- A maybe eliminator on an isFinite type on the right +mrRefines' m1 (MaybeElim (Type (asIsFinite -> Just n2)) m2 f2 _) = + do n2_norm <- mrNormOpenTerm n2 + maybe_assump <- mrGetDataTypeAssump n2_norm + fin_pf <- + liftSC2 scGlobalApply "CryptolM.isFinite" [n2_norm] >>= mrDummyProof + case (maybe_assump, asNum n2_norm) of + (_, Just (Left _)) -> applyNormCompFun f2 fin_pf >>= mrRefines m1 + (_, Just (Right _)) -> mrRefines m1 m2 + (Just (IsNum _), _) -> applyNormCompFun f2 fin_pf >>= mrRefines m1 + (Just IsInf, _) -> mrRefines m1 m2 + _ -> + withDataTypeAssump n2_norm IsInf (mrRefines m1 m2) >> + liftSC0 scNatType >>= \nat_tp -> + (withUVarLift "n" (Type nat_tp) (n2_norm, f2, m1) $ \ n (n2', f2', m1') -> + withDataTypeAssump n2' (IsNum n) + (applyNormCompFun f2' n >>= mrRefines m1')) + +mrRefines' (Ite cond1 m1 m1') m2 = + liftSC1 scNot cond1 >>= \not_cond1 -> + mrProvable cond1 >>= \cond1_true_pv-> + mrProvable not_cond1 >>= \cond1_false_pv -> + case (cond1_true_pv, cond1_false_pv) of + (True, _) -> mrRefines m1 m2 + (_, True) -> mrRefines m1' m2 + _ -> withAssumption cond1 (mrRefines m1 m2) >> + withAssumption not_cond1 (mrRefines m1' m2) +mrRefines' m1 (Ite cond2 m2 m2') = + liftSC1 scNot cond2 >>= \not_cond2 -> + mrProvable cond2 >>= \cond2_true_pv-> + mrProvable not_cond2 >>= \cond2_false_pv -> + case (cond2_true_pv, cond2_false_pv) of + (True, _) -> mrRefines m1 m2 + (_, True) -> mrRefines m1 m2' + _ -> withAssumption cond2 (mrRefines m1 m2) >> + withAssumption not_cond2 (mrRefines m1 m2') + +mrRefines' (Either ltp1 rtp1 f1 g1 t1) m2 = + mrNormOpenTerm t1 >>= \t1' -> + mrGetDataTypeAssump t1' >>= \mb_assump -> + case (mb_assump, asEither t1') of + (_, Just (Left x)) -> applyNormCompFun f1 x >>= flip mrRefines m2 + (_, Just (Right x)) -> applyNormCompFun g1 x >>= flip mrRefines m2 + (Just (IsLeft x), _) -> applyNormCompFun f1 x >>= flip mrRefines m2 + (Just (IsRight x), _) -> applyNormCompFun g1 x >>= flip mrRefines m2 + _ -> let lnm = maybe "x_left" id (compFunVarName f1) + rnm = maybe "x_right" id (compFunVarName g1) + in withUVarLift lnm ltp1 (f1, t1', m2) (\x (f1', t1'', m2') -> + applyNormCompFun f1' x >>= withDataTypeAssump t1'' (IsLeft x) + . flip mrRefines m2') >> + withUVarLift rnm rtp1 (g1, t1', m2) (\x (g1', t1'', m2') -> + applyNormCompFun g1' x >>= withDataTypeAssump t1'' (IsRight x) + . flip mrRefines m2') +mrRefines' m1 (Either ltp2 rtp2 f2 g2 t2) = + mrNormOpenTerm t2 >>= \t2' -> + mrGetDataTypeAssump t2' >>= \mb_assump -> + case (mb_assump, asEither t2') of + (_, Just (Left x)) -> applyNormCompFun f2 x >>= mrRefines m1 + (_, Just (Right x)) -> applyNormCompFun g2 x >>= mrRefines m1 + (Just (IsLeft x), _) -> applyNormCompFun f2 x >>= mrRefines m1 + (Just (IsRight x), _) -> applyNormCompFun g2 x >>= mrRefines m1 + _ -> let lnm = maybe "x_left" id (compFunVarName f2) + rnm = maybe "x_right" id (compFunVarName g2) + in withUVarLift lnm ltp2 (f2, t2', m1) (\x (f2', t2'', m1') -> + applyNormCompFun f2' x >>= withDataTypeAssump t2'' (IsLeft x) + . mrRefines m1') >> + withUVarLift rnm rtp2 (g2, t2', m1) (\x (g2', t2'', m1') -> + applyNormCompFun g2' x >>= withDataTypeAssump t2'' (IsRight x) + . mrRefines m1') + +mrRefines' m1 (ForallM tp f2) = + let nm = maybe "x" id (compFunVarName f2) in + withUVarLift nm tp (m1,f2) $ \x (m1',f2') -> + applyNormCompFun f2' x >>= \m2' -> + mrRefines m1' m2' +mrRefines' (ExistsM tp f1) m2 = + let nm = maybe "x" id (compFunVarName f1) in + withUVarLift nm tp (f1,m2) $ \x (f1',m2') -> + applyNormCompFun f1' x >>= \m1' -> + mrRefines m1' m2' + +mrRefines' m1 (OrM m2 m2') = + mrOr (mrRefines m1 m2) (mrRefines m1 m2') +mrRefines' (OrM m1 m1') m2 = + mrRefines m1 m2 >> mrRefines m1' m2 + +-- FIXME: the following cases don't work unless we either allow evars to be set +-- to NormComps or we can turn NormComps back into terms +mrRefines' m1@(FunBind (EVarFunName _) _ _) m2 = + throwMRFailure (CompsDoNotRefine m1 m2) +mrRefines' m1 m2@(FunBind (EVarFunName _) _ _) = + throwMRFailure (CompsDoNotRefine m1 m2) +{- +mrRefines' (FunBind (EVarFunName evar) args CompFunReturn) m2 = + mrGetEVar evar >>= \case + Just f -> + (mrApplyAll f args >>= normCompTerm) >>= \m1' -> + mrRefines m1' m2 + Nothing -> mrTrySetAppliedEVar evar args m2 +-} + +mrRefines' (FunBind (LetRecName f) args1 k1) (FunBind (LetRecName f') args2 k2) + | f == f' && length args1 == length args2 = + zipWithM_ mrAssertProveEq args1 args2 >> + mrRefinesFun k1 k2 + +mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = + mrFunOutType f1 args1 >>= \tp1 -> + mrFunOutType f2 args2 >>= \tp2 -> + mrConvertible tp1 tp2 >>= \tps_eq -> + mrFunBodyRecInfo f1 args1 >>= \maybe_f1_body -> + mrFunBodyRecInfo f2 args2 >>= \maybe_f2_body -> + mrGetCoIndHyp f1 f2 >>= \maybe_coIndHyp -> + mrGetFunAssump f1 >>= \maybe_fassump -> + case (maybe_coIndHyp, maybe_fassump) of + + -- If we have a co-inductive assumption that f1 args1' |= f2 args2': + -- * If it is convertible to our goal, continue and prove that k1 |= k2 + -- * If it can be widened with our goal, restart the current proof branch + -- with the widened hypothesis (done by throwing a + -- 'CoIndHypMismatchWidened' error for 'proveCoIndHyp' to catch) + -- * Otherwise, throw a 'CoIndHypMismatchFailure' error. + (Just hyp, _) -> + matchCoIndHyp hyp args1 args2 >> + mrRefinesFun k1 k2 + + -- If we have an assumption that f1 args' refines some rhs, then prove that + -- args1 = args' and then that rhs refines m2 + (_, Just fassump) -> + do (assump_args, assump_rhs) <- instantiateFunAssump fassump + zipWithM_ mrAssertProveEq assump_args args1 + m1' <- normBind assump_rhs k1 + mrRefines m1' m2 + + -- If f1 unfolds and is not recursive in itself, unfold it and recurse + _ | Just (f1_body, False) <- maybe_f1_body -> + normBindTerm f1_body k1 >>= \m1' -> mrRefines m1' m2 + + -- If f2 unfolds and is not recursive in itself, unfold it and recurse + _ | Just (f2_body, False) <- maybe_f2_body -> + normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' + + -- If we don't have a co-inducitve hypothesis for f1 and f2, don't have an + -- assumption that f1 refines some specification, and both f1 and f2 are + -- recursive and have the same return type, then try to coinductively prove + -- that f1 args1 |= f2 args2 under the assumption that f1 args1 |= f2 args2, + -- and then try to prove that k1 |= k2 + _ | tps_eq + , Just _ <- maybe_f1_body + , Just _ <- maybe_f2_body -> + mrRefinesCoInd f1 args1 f2 args2 >> mrRefinesFun k1 k2 + + -- If we cannot line up f1 and f2, then making progress here would require us + -- to somehow split either m1 or m2 into some bind m' >>= k' such that m' is + -- related to the function call on the other side and k' is related to the + -- continuation on the other side, but we don't know how to do that, so give + -- up + _ -> + mrDebugPPPrefixSep 1 "mrRefines: bind types not equal:" tp1 "/=" tp2 >> + throwMRFailure (CompsDoNotRefine m1 m2) + +{- FIXME: handle FunBind on just one side +mrRefines' m1@(FunBind f@(GlobalName _) args k1) m2 = + mrGetFunAssump f >>= \case + Just fassump -> + -- If we have an assumption that f args' refines some rhs, then prove that + -- args = args' and then that rhs refines m2 + do (assump_args, assump_rhs) <- instantiateFunAssump fassump + zipWithM_ mrAssertProveEq assump_args args + m1' <- normBind assump_rhs k1 + mrRefines m1' m2 + Nothing -> + -- We don't want to do inter-procedural proofs, so if we don't know anything + -- about f already then give up + throwMRFailure (CompsDoNotRefine m1 m2) +-} + + +mrRefines' m1@(FunBind f1 args1 k1) m2 = + mrGetFunAssump f1 >>= \case + + -- If we have an assumption that f1 args' refines some rhs, then prove that + -- args1 = args' and then that rhs refines m2 + Just fassump -> + do (assump_args, assump_rhs) <- instantiateFunAssump fassump + zipWithM_ mrAssertProveEq assump_args args1 + m1' <- normBind assump_rhs k1 + mrRefines m1' m2 + + -- Otherwise, see if we can unfold f1 + Nothing -> + mrFunBodyRecInfo f1 args1 >>= \case + + -- If f1 unfolds and is not recursive in itself, unfold it and recurse + Just (f1_body, False) -> + normBindTerm f1_body k1 >>= \m1' -> mrRefines m1' m2 + + -- Otherwise we would have to somehow split m2 into some computation of the + -- form m2' >>= k2 where f1 args1 |= m2' and k1 |= k2, but we don't know how + -- to do this splitting, so give up + _ -> + throwMRFailure (CompsDoNotRefine m1 m2) + + +mrRefines' m1 m2@(FunBind f2 args2 k2) = + mrFunBodyRecInfo f2 args2 >>= \case + + -- If f2 unfolds and is not recursive in itself, unfold it and recurse + Just (f2_body, False) -> + normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' + + -- If f2 unfolds but is recursive, and k2 is the trivial continuation, meaning + -- m2 is just f2 args2, use the law of coinduction to prove m1 |= f2 args2 by + -- proving m1 |= f2_body under the assumption that m1 |= f2 args2 + {- FIXME: implement something like this + Just (f2_body, True) + | CompFunReturn <- k2 -> + withFunAssumpR m1 f2 args2 $ + -} + + -- Otherwise we would have to somehow split m1 into some computation of the + -- form m1' >>= k1 where m1' |= f2 args2 and k1 |= k2, but we don't know how + -- to do this splitting, so give up + _ -> + throwMRFailure (CompsDoNotRefine m1 m2) + + +-- NOTE: the rules that introduce existential variables need to go last, so that +-- they can quantify over as many universals as possible +mrRefines' m1 (ExistsM tp f2) = + do let nm = maybe "x" id (compFunVarName f2) + evar <- mrFreshEVar nm tp + m2' <- applyNormCompFun f2 evar + mrRefines m1 m2' +mrRefines' (ForallM tp f1) m2 = + do let nm = maybe "x" id (compFunVarName f1) + evar <- mrFreshEVar nm tp + m1' <- applyNormCompFun f1 evar + mrRefines m1' m2 + +-- If none of the above cases match, then fail +mrRefines' m1 m2 = throwMRFailure (CompsDoNotRefine m1 m2) + + +-- | Prove that one function refines another for all inputs +mrRefinesFun :: CompFun -> CompFun -> MRM () +mrRefinesFun CompFunReturn CompFunReturn = return () +mrRefinesFun f1 f2 + | Just nm <- compFunVarName f1 `mplus` compFunVarName f2 + , Just tp <- compFunInputType f1 `mplus` compFunInputType f2 = + withUVarLift nm tp (f1,f2) $ \x (f1', f2') -> + do m1' <- applyNormCompFun f1' x + m2' <- applyNormCompFun f2' x + mrRefines m1' m2' +mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" + + +---------------------------------------------------------------------- +-- * External Entrypoints +---------------------------------------------------------------------- + +-- | Test two monadic, recursive terms for refinement. On success, if the +-- left-hand term is a named function, add the refinement to the 'MREnv' +-- environment. +askMRSolver :: + SharedContext -> + Int {- ^ The debug level -} -> + MREnv {- ^ The Mr Solver environment -} -> + Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + Term -> Term -> IO (Either MRFailure MREnv) + +askMRSolver sc dlvl env timeout t1 t2 = + do tp1 <- scTypeOf sc t1 >>= scWhnf sc + tp2 <- scTypeOf sc t2 >>= scWhnf sc + case asPiList tp1 of + (uvar_ctx, asCompM -> Just _) -> + runMRM sc timeout dlvl env $ + withUVars uvar_ctx $ \vars -> + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 + m1 <- mrApplyAll t1 vars >>= normCompTerm + m2 <- mrApplyAll t2 vars >>= normCompTerm + mrRefines m1 m2 + -- If t1 is a named function, add forall xs. f1 xs |= m2 to the env + case asGlobalFunName t1 of + Just f1 -> + let fassump = FunAssump { fassumpCtx = uvar_ctx, + fassumpArgs = vars, + fassumpRHS = m2 } in + return $ mrEnvAddFunAssump f1 fassump env + Nothing -> return env + _ -> return $ Left $ NotCompFunType tp1 diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs new file mode 100644 index 0000000000..cd7a10c86d --- /dev/null +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -0,0 +1,441 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-} + +{- | +Module : SAWScript.Prover.MRSolver.Term +Copyright : Galois, Inc. 2022 +License : BSD3 +Maintainer : westbrook@galois.com +Stability : experimental +Portability : non-portable (language extensions) + +This module defines the representation of terms used in Mr. Solver and various +utility functions for operating on terms and term representations. The main +datatype is 'NormComp', which represents the result of one step of monadic +normalization - see @Solver.hs@ for the description of this normalization. +-} + +module SAWScript.Prover.MRSolver.Term where + +import Data.String +import Data.IORef +import Control.Monad.Reader +import qualified Data.IntMap as IntMap +import GHC.Generics + +import Prettyprinter + +import Data.Map (Map) +import qualified Data.Map as Map + +import Verifier.SAW.Term.Functor +import Verifier.SAW.Term.CtxTerm (MonadTerm(..)) +import Verifier.SAW.Term.Pretty +import Verifier.SAW.SharedTerm +import Verifier.SAW.Recognizer hiding ((:*:)) +import Verifier.SAW.Cryptol.Monadify + + +---------------------------------------------------------------------- +-- * MR Solver Term Representation +---------------------------------------------------------------------- + +-- | A variable used by the MR solver +newtype MRVar = MRVar { unMRVar :: ExtCns Term } deriving (Eq, Show, Ord) + +-- | Get the type of an 'MRVar' +mrVarType :: MRVar -> Term +mrVarType = ecType . unMRVar + +-- | Print the string name of an 'MRVar' +showMRVar :: MRVar -> String +showMRVar = show . ppName . ecName . unMRVar + +-- | A tuple or record projection of a 'Term' +data TermProj = TermProjLeft | TermProjRight | TermProjRecord FieldName + deriving (Eq, Ord, Show) + +-- | Recognize a 'Term' as 0 or more projections +asProjAll :: Term -> (Term, [TermProj]) +asProjAll (asRecordSelector -> Just ((asProjAll -> (t, projs)), fld)) = + (t, TermProjRecord fld:projs) +asProjAll (asPairSelector -> Just ((asProjAll -> (t, projs)), isRight)) + | isRight = (t, TermProjRight:projs) + | not isRight = (t, TermProjLeft:projs) +asProjAll t = (t, []) + +-- | Names of functions to be used in computations, which are either names bound +-- by letrec to for recursive calls to fixed-points, existential variables, or +-- (possibly projections of) of global named constants +data FunName + = LetRecName MRVar | EVarFunName MRVar | GlobalName GlobalDef [TermProj] + deriving (Eq, Ord, Show) + +-- | Recognize a 'Term' as (possibly a projection of) a global name +asTypedGlobalProj :: Recognizer Term (GlobalDef, [TermProj]) +asTypedGlobalProj (asProjAll -> ((asTypedGlobalDef -> Just glob), projs)) = + Just (glob, projs) +asTypedGlobalProj _ = Nothing + +-- | Recognize a 'Term' as (possibly a projection of) a global name +asGlobalFunName :: Recognizer Term FunName +asGlobalFunName (asTypedGlobalProj -> Just (glob, projs)) = + Just $ GlobalName glob projs +asGlobalFunName _ = Nothing + +-- | Convert a 'FunName' to an unshared term, for printing +funNameTerm :: FunName -> Term +funNameTerm (LetRecName var) = Unshared $ FTermF $ ExtCns $ unMRVar var +funNameTerm (EVarFunName var) = Unshared $ FTermF $ ExtCns $ unMRVar var +funNameTerm (GlobalName gdef []) = globalDefTerm gdef +funNameTerm (GlobalName gdef (TermProjLeft:projs)) = + Unshared $ FTermF $ PairLeft $ funNameTerm (GlobalName gdef projs) +funNameTerm (GlobalName gdef (TermProjRight:projs)) = + Unshared $ FTermF $ PairRight $ funNameTerm (GlobalName gdef projs) +funNameTerm (GlobalName gdef (TermProjRecord fname:projs)) = + Unshared $ FTermF $ RecordProj (funNameTerm (GlobalName gdef projs)) fname + +-- | A term specifically known to be of type @sort i@ for some @i@ +newtype Type = Type Term deriving (Generic, Show) + +-- | A Haskell representation of a @CompM@ in "monadic normal form" +data NormComp + = ReturnM Term -- ^ A term @returnM a x@ + | ErrorM Term -- ^ A term @errorM a str@ + | Ite Term Comp Comp -- ^ If-then-else computation + | Either Type Type CompFun CompFun Term -- ^ A sum elimination + | MaybeElim Type Comp CompFun Term -- ^ A maybe elimination + | OrM Comp Comp -- ^ an @orM@ computation + | ExistsM Type CompFun -- ^ an @existsM@ computation + | ForallM Type CompFun -- ^ a @forallM@ computation + | FunBind FunName [Term] CompFun + -- ^ Bind a monadic function with @N@ arguments in an @a -> CompM b@ term + deriving (Generic, Show) + +-- | A computation function of type @a -> CompM b@ for some @a@ and @b@ +data CompFun + -- | An arbitrary term + = CompFunTerm Term + -- | A special case for the term @\ (x:a) -> returnM a x@ + | CompFunReturn + -- | The monadic composition @f >=> g@ + | CompFunComp CompFun CompFun + deriving (Generic, Show) + +-- | Compose two 'CompFun's, simplifying if one is a 'CompFunReturn' +compFunComp :: CompFun -> CompFun -> CompFun +compFunComp CompFunReturn f = f +compFunComp f CompFunReturn = f +compFunComp f g = CompFunComp f g + +-- | If a 'CompFun' contains an explicit lambda-abstraction, then return the +-- textual name bound by that lambda +compFunVarName :: CompFun -> Maybe LocalName +compFunVarName (CompFunTerm (asLambda -> Just (nm, _, _))) = Just nm +compFunVarName (CompFunComp f _) = compFunVarName f +compFunVarName _ = Nothing + +-- | If a 'CompFun' contains an explicit lambda-abstraction, then return the +-- input type for it +compFunInputType :: CompFun -> Maybe Type +compFunInputType (CompFunTerm (asLambda -> Just (_, tp, _))) = Just $ Type tp +compFunInputType (CompFunComp f _) = compFunInputType f +compFunInputType _ = Nothing + +-- | A computation of type @CompM a@ for some @a@ +data Comp = CompTerm Term | CompBind Comp CompFun | CompReturn Term + deriving (Generic, Show) + +-- | Match a type as being of the form @CompM a@ for some @a@ +asCompM :: Term -> Maybe Term +asCompM (asApp -> Just (isGlobalDef "Prelude.CompM" -> Just (), tp)) = + return tp +asCompM _ = fail "not a CompM type!" + +-- | Test if a type normalizes to a monadic function type of 0 or more arguments +isCompFunType :: SharedContext -> Term -> IO Bool +isCompFunType sc t = scWhnf sc t >>= \case + (asPiList -> (_, asCompM -> Just _)) -> return True + _ -> return False + + +---------------------------------------------------------------------- +-- * Mr Solver Environments +---------------------------------------------------------------------- + +-- | An assumption that a named function refines some specification. This has +-- the form +-- +-- > forall x1, ..., xn. F e1 ... ek |= m +-- +-- for some universal context @x1:T1, .., xn:Tn@, some list of argument +-- expressions @ei@ over the universal @xj@ variables, and some right-hand side +-- computation expression @m@. +data FunAssump = FunAssump { + -- | The uvars that were in scope when this assmption was created, in order + -- from outermost to innermost; that is, the uvars as "seen from outside their + -- scope", which is the reverse of the order of 'mrUVars', below + fassumpCtx :: [(LocalName,Term)], + -- | The argument expressions @e1, ..., en@ over the 'fassumpCtx' uvars + fassumpArgs :: [Term], + -- | The right-hand side upper bound @m@ over the 'fassumpCtx' uvars + fassumpRHS :: NormComp +} + +-- | A map from function names to function refinement assumptions over that +-- name +-- +-- FIXME: this should probably be an 'IntMap' on the 'VarIndex' of globals +type FunAssumps = Map FunName FunAssump + +-- | A global MR Solver environment +data MREnv = MREnv { + -- | The set of function refinements to be assumed by to Mr. Solver (which + -- have hopefully been proved previously...) + mreFunAssumps :: FunAssumps + } + +-- | The empty 'MREnv' +emptyMREnv :: MREnv +emptyMREnv = MREnv { mreFunAssumps = Map.empty } + +-- | Add a 'FunAssump' to a Mr Solver environment +mrEnvAddFunAssump :: FunName -> FunAssump -> MREnv -> MREnv +mrEnvAddFunAssump f fassump env = + env { mreFunAssumps = Map.insert f fassump (mreFunAssumps env) } + + +---------------------------------------------------------------------- +-- * Utility Functions for Transforming 'Term's +---------------------------------------------------------------------- + +-- | Transform the immediate subterms of a term using the supplied function +traverseSubterms :: MonadTerm m => (Term -> m Term) -> Term -> m Term +traverseSubterms f (unwrapTermF -> tf) = traverse f tf >>= mkTermF + +-- | Build a recursive memoized function for tranforming 'Term's. Take in a +-- function @f@ that intuitively performs one step of the transformation and +-- allow it to recursively call the memoized function being defined by passing +-- it as the first argument to @f@. +memoFixTermFun :: MonadIO m => ((Term -> m a) -> Term -> m a) -> Term -> m a +memoFixTermFun f term_top = + do table_ref <- liftIO $ newIORef IntMap.empty + let go t@(STApp { stAppIndex = ix }) = + liftIO (readIORef table_ref) >>= \table -> + case IntMap.lookup ix table of + Just ret -> return ret + Nothing -> + do ret <- f go t + liftIO $ modifyIORef' table_ref (IntMap.insert ix ret) + return ret + go t = f go t + go term_top + + +---------------------------------------------------------------------- +-- * Lifting MR Solver Terms +---------------------------------------------------------------------- + +-- | A term-like object is one that supports lifting and substitution. This +-- class can be derived using @DeriveAnyClass@. +class TermLike a where + liftTermLike :: MonadTerm m => DeBruijnIndex -> DeBruijnIndex -> a -> m a + substTermLike :: MonadTerm m => DeBruijnIndex -> [Term] -> a -> m a + + -- Default instances for @DeriveAnyClass@ + default liftTermLike :: (Generic a, GTermLike (Rep a), MonadTerm m) => + DeBruijnIndex -> DeBruijnIndex -> a -> m a + liftTermLike n i = fmap to . gLiftTermLike n i . from + default substTermLike :: (Generic a, GTermLike (Rep a), MonadTerm m) => + DeBruijnIndex -> [Term] -> a -> m a + substTermLike n i = fmap to . gSubstTermLike n i . from + +-- | A generic version of 'TermLike' for @DeriveAnyClass@, based on: +-- https://hackage.haskell.org/package/base-4.16.0.0/docs/GHC-Generics.html#g:12 +class GTermLike f where + gLiftTermLike :: MonadTerm m => DeBruijnIndex -> DeBruijnIndex -> f p -> m (f p) + gSubstTermLike :: MonadTerm m => DeBruijnIndex -> [Term] -> f p -> m (f p) + +-- | 'TermLike' on empty types +instance GTermLike V1 where + gLiftTermLike _ _ = \case {} + gSubstTermLike _ _ = \case {} + +-- | 'TermLike' on unary types +instance GTermLike U1 where + gLiftTermLike _ _ U1 = return U1 + gSubstTermLike _ _ U1 = return U1 + +-- | 'TermLike' on sums +instance (GTermLike f, GTermLike g) => GTermLike (f :+: g) where + gLiftTermLike n i (L1 a) = L1 <$> gLiftTermLike n i a + gLiftTermLike n i (R1 b) = R1 <$> gLiftTermLike n i b + gSubstTermLike n s (L1 a) = L1 <$> gSubstTermLike n s a + gSubstTermLike n s (R1 b) = R1 <$> gSubstTermLike n s b + +-- | 'TermLike' on products +instance (GTermLike f, GTermLike g) => GTermLike (f :*: g) where + gLiftTermLike n i (a :*: b) = (:*:) <$> gLiftTermLike n i a <*> gLiftTermLike n i b + gSubstTermLike n s (a :*: b) = (:*:) <$> gSubstTermLike n s a <*> gSubstTermLike n s b + +-- | 'TermLike' on fields +instance TermLike a => GTermLike (K1 i a) where + gLiftTermLike n i (K1 a) = K1 <$> liftTermLike n i a + gSubstTermLike n i (K1 a) = K1 <$> substTermLike n i a + +-- | 'GTermLike' ignores meta-information +instance GTermLike a => GTermLike (M1 i c a) where + gLiftTermLike n i (M1 a) = M1 <$> gLiftTermLike n i a + gSubstTermLike n i (M1 a) = M1 <$> gSubstTermLike n i a + +deriving instance _ => TermLike (a,b) +deriving instance _ => TermLike (a,b,c) +deriving instance _ => TermLike (a,b,c,d) +deriving instance _ => TermLike (a,b,c,d,e) +deriving instance _ => TermLike (a,b,c,d,e,f) +deriving instance _ => TermLike (a,b,c,d,e,f,g) +deriving instance _ => TermLike [a] + +instance TermLike Term where + liftTermLike = liftTerm + substTermLike = substTerm + +instance TermLike FunName where + liftTermLike _ _ = return + substTermLike _ _ = return + +deriving instance TermLike Type +deriving instance TermLike NormComp +deriving instance TermLike CompFun +deriving instance TermLike Comp + + +---------------------------------------------------------------------- +-- * Pretty-Printing MR Solver Terms +---------------------------------------------------------------------- + +-- | The monad for pretty-printing in a context of SAW core variables +type PPInCtxM = Reader [LocalName] + +-- | Pretty-print an object in a SAW core context and render to a 'String' +showInCtx :: PrettyInCtx a => [LocalName] -> a -> String +showInCtx ctx a = + renderSawDoc defaultPPOpts $ runReader (prettyInCtx a) ctx + +-- | Pretty-print an object in the empty SAW core context +ppInEmptyCtx :: PrettyInCtx a => a -> SawDoc +ppInEmptyCtx a = runReader (prettyInCtx a) [] + +-- | A generic function for pretty-printing an object in a SAW core context of +-- locally-bound names +class PrettyInCtx a where + prettyInCtx :: a -> PPInCtxM SawDoc + +instance PrettyInCtx Term where + prettyInCtx t = flip (ppTermInCtx defaultPPOpts) t <$> ask + +-- | Combine a list of pretty-printed documents like applications are combined +prettyAppList :: [PPInCtxM SawDoc] -> PPInCtxM SawDoc +prettyAppList = fmap (group . hang 2 . vsep) . sequence + +-- | Pretty-print the application of a 'Term' +prettyTermApp :: Term -> [Term] -> PPInCtxM SawDoc +prettyTermApp f_top args = + prettyInCtx $ foldl (\f arg -> Unshared $ App f arg) f_top args + +-- | FIXME: move this helper function somewhere better... +ppCtx :: [(LocalName,Term)] -> SawDoc +ppCtx = helper [] where + helper :: [LocalName] -> [(LocalName,Term)] -> SawDoc + helper _ [] = "" + helper ns ((n,tp):ctx) = + let ns' = n:ns in + ppTermInCtx defaultPPOpts ns' (Unshared $ LocalVar 0) <> ":" <> + ppTermInCtx defaultPPOpts ns tp <> ", " <> helper ns' ctx + +instance PrettyInCtx String where + prettyInCtx str = return $ fromString str + +instance PrettyInCtx SawDoc where + prettyInCtx pp = return pp + +instance PrettyInCtx Type where + prettyInCtx (Type t) = prettyInCtx t + +instance PrettyInCtx MRVar where + prettyInCtx (MRVar ec) = return $ ppName $ ecName ec + +instance PrettyInCtx [Term] where + prettyInCtx xs = list <$> mapM prettyInCtx xs + +instance PrettyInCtx TermProj where + prettyInCtx TermProjLeft = return (pretty '.' <> "1") + prettyInCtx TermProjRight = return (pretty '.' <> "2") + prettyInCtx (TermProjRecord fld) = return (pretty '.' <> pretty fld) + +instance PrettyInCtx FunName where + prettyInCtx (LetRecName var) = prettyInCtx var + prettyInCtx (EVarFunName var) = prettyInCtx var + prettyInCtx (GlobalName g projs) = + foldM (\pp proj -> (pp <>) <$> prettyInCtx proj) (ppName $ + globalDefName g) projs + +instance PrettyInCtx Comp where + prettyInCtx (CompTerm t) = prettyInCtx t + prettyInCtx (CompBind c f) = + prettyAppList [prettyInCtx c, return ">>=", prettyInCtx f] + prettyInCtx (CompReturn t) = + prettyAppList [ return "returnM", return "_", parens <$> prettyInCtx t] + +instance PrettyInCtx CompFun where + prettyInCtx (CompFunTerm t) = prettyInCtx t + prettyInCtx CompFunReturn = return "returnM" + prettyInCtx (CompFunComp f g) = + prettyAppList [prettyInCtx f, return ">=>", prettyInCtx g] + +instance PrettyInCtx NormComp where + prettyInCtx (ReturnM t) = + prettyAppList [return "returnM", return "_", parens <$> prettyInCtx t] + prettyInCtx (ErrorM str) = + prettyAppList [return "errorM", return "_", parens <$> prettyInCtx str] + prettyInCtx (Ite cond t1 t2) = + prettyAppList [return "ite", return "_", parens <$> prettyInCtx cond, + parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] + prettyInCtx (Either ltp rtp f g eith) = + prettyAppList [return "either", + parens <$> prettyInCtx ltp, parens <$> prettyInCtx rtp, + return (parens "CompM _"), + parens <$> prettyInCtx f, parens <$> prettyInCtx g, + parens <$> prettyInCtx eith] + prettyInCtx (MaybeElim tp m f mayb) = + prettyAppList [return "maybe", parens <$> prettyInCtx tp, + return (parens "CompM _"), parens <$> prettyInCtx m, + parens <$> prettyInCtx f, parens <$> prettyInCtx mayb] + prettyInCtx (OrM t1 t2) = + prettyAppList [return "orM", return "_", + parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] + prettyInCtx (ExistsM tp f) = + prettyAppList [return "existsM", prettyInCtx tp, return "_", + parens <$> prettyInCtx f] + prettyInCtx (ForallM tp f) = + prettyAppList [return "forallM", prettyInCtx tp, return "_", + parens <$> prettyInCtx f] + prettyInCtx (FunBind f args CompFunReturn) = + prettyTermApp (funNameTerm f) args + prettyInCtx (FunBind f [] k) = + prettyAppList [prettyInCtx f, return ">>=", prettyInCtx k] + prettyInCtx (FunBind f args k) = + prettyAppList [parens <$> prettyTermApp (funNameTerm f) args, + return ">>=", prettyInCtx k] diff --git a/src/SAWScript/Value.hs b/src/SAWScript/Value.hs index ca1eac5f59..9da3910473 100644 --- a/src/SAWScript/Value.hs +++ b/src/SAWScript/Value.hs @@ -76,6 +76,7 @@ import SAWScript.JavaPretty (prettyClass) import SAWScript.Options (Options(printOutFn),printOutLn,Verbosity) import SAWScript.Proof import SAWScript.Prover.SolverStats +import SAWScript.Prover.MRSolver.Term as MRSolver import SAWScript.Crucible.LLVM.Skeleton import SAWScript.X86 (X86Unsupported(..), X86Error(..)) @@ -431,6 +432,7 @@ data TopLevelRW = , rwDocs :: Map SS.Name String , rwCryptol :: CEnv.CryptolEnv , rwMonadify :: Monadify.MonadifyEnv + , rwMRSolverEnv :: MRSolver.MREnv , rwProofs :: [Value] {- ^ Values, generated anywhere, that represent proofs. -} , rwPPOpts :: PPOpts -- , rwCrucibleLLVMCtx :: Crucible.LLVMContext diff --git a/src/SAWScript/VerificationSummary.hs b/src/SAWScript/VerificationSummary.hs index 60799cdd1a..35e8d10613 100644 --- a/src/SAWScript/VerificationSummary.hs +++ b/src/SAWScript/VerificationSummary.hs @@ -172,9 +172,10 @@ prettyVerificationSummary vs@(VerificationSummary jspecs lspecs thms) = prettyTheorems ts = sectionWithItems "Theorems Proved or Assumed" (item . prettyTheorem) ts prettyTheorem t = - vsep [ if Set.null (solverStatsSolvers (thmStats t)) - then "Axiom:" - else "Theorem:" + vsep [ case thmSummary t of + ProvedTheorem{} -> "Theorem:" + TestedTheorem n -> "Theorem (randomly tested on" <+> viaShow n <+> "samples):" + AdmittedTheorem{} -> "Axiom:" , code (indent 2 (ppProp PP.defaultPPOpts (thmProp t))) , "" ] diff --git a/src/SAWScript/X86.hs b/src/SAWScript/X86.hs index 578d2cbf35..ce90faef4b 100644 --- a/src/SAWScript/X86.hs +++ b/src/SAWScript/X86.hs @@ -185,7 +185,7 @@ data Fun = Fun { funName :: ByteString, funSpec :: FunSpec } -------------------------------------------------------------------------------- -type CallHandler = Sym -> Macaw.LookupFunctionHandle Sym X86_64 +type CallHandler = Sym -> Macaw.LookupFunctionHandle (MacawSimulatorState Sym) Sym X86_64 -- | Run a top-level proof. -- Should be used when making a standalone proof script. diff --git a/src/SAWScript/X86Spec.hs b/src/SAWScript/X86Spec.hs index f160e6bc00..d11e3ef2ae 100644 --- a/src/SAWScript/X86Spec.hs +++ b/src/SAWScript/X86Spec.hs @@ -123,7 +123,10 @@ import Lang.Crucible.Types import Verifier.SAW.SharedTerm (Term,scApplyAll,scVector,scBitvector,scAt,scNat) import Data.Macaw.Memory(RegionIndex) -import Data.Macaw.Symbolic(GlobalMap(..), ToCrucibleType, LookupFunctionHandle(..), MacawCrucibleRegTypes) +import Data.Macaw.Symbolic + ( GlobalMap(..), ToCrucibleType, LookupFunctionHandle(..) + , MacawCrucibleRegTypes, MacawSimulatorState + ) import Data.Macaw.Symbolic.Backend ( crucArchRegTypes ) import Data.Macaw.X86.X86Reg import Data.Macaw.X86.Symbolic @@ -1203,7 +1206,7 @@ _debugDumpGoals opts = sh (ProofGoal _hyps g) = print (view labeledPredMsg g) -type Overrides = Map (Natural,Integer) (Sym -> LookupFunctionHandle Sym X86_64) +type Overrides = Map (Natural,Integer) (Sym -> LookupFunctionHandle (MacawSimulatorState Sym) Sym X86_64) -- | Use a specification to verify a function. -- Returns the initial state for the function, and a post-condition.