diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs index 76de3b0318..e1b1a06a4d 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs @@ -90,6 +90,7 @@ import Verifier.SAW.Recognizer -- import Verifier.SAW.Position import Verifier.SAW.Cryptol.PreludeM +import GHC.Stack import Debug.Trace @@ -345,6 +346,12 @@ typeclassMonMap = ("Cryptol.PIntegral", "Cryptol.PIntegral"), ("Cryptol.PLiteral", "Cryptol.PLiteral")] +-- | The list of functions that are monadified as themselves in types +typeLevelOpMonList :: [Ident] +typeLevelOpMonList = ["Cryptol.tcAdd", "Cryptol.tcSub", "Cryptol.tcMul", + "Cryptol.tcDiv", "Cryptol.tcMod", "Cryptol.tcExp", + "Cryptol.tcMin", "Cryptol.tcMax"] + -- | A context of local variables used for monadifying types, which includes the -- variable names, their original types (before monadification), and, if their -- types corespond to 'MonKind's, a local 'MonType' that quantifies over them. @@ -364,25 +371,20 @@ ppTermInTypeCtx ctx t = typeCtxPureCtx :: MonadifyTypeCtx -> [(LocalName,Term)] typeCtxPureCtx = map (\(x,tp,_) -> (x,tp)) --- | Make a monadification type that is to be considered a base type -mkTermBaseType :: MonadifyTypeCtx -> MonKind -> Term -> MonType -mkTermBaseType ctx k t = - MTyBase k $ openOpenTerm (typeCtxPureCtx ctx) t - -- | Monadify a type and convert it to its corresponding argument type -monadifyTypeArgType :: MonadifyTypeCtx -> Term -> OpenTerm +monadifyTypeArgType :: HasCallStack => MonadifyTypeCtx -> Term -> OpenTerm monadifyTypeArgType ctx t = toArgType $ monadifyType ctx t -- | Apply a monadified type to a type or term argument in the sense of -- 'applyPiOpenTerm', meaning give the type of applying @f@ of a type to a -- particular argument @arg@ -applyMonType :: MonType -> Either MonType ArgMonTerm -> MonType +applyMonType :: HasCallStack => MonType -> Either MonType ArgMonTerm -> MonType applyMonType (MTyArrow _ tp_ret) (Right _) = tp_ret applyMonType (MTyForall _ _ f) (Left mtp) = f mtp applyMonType _ _ = error "applyMonType: application at incorrect type" -- | Convert a SAW core 'Term' to a monadification type -monadifyType :: MonadifyTypeCtx -> Term -> MonType +monadifyType :: HasCallStack => MonadifyTypeCtx -> Term -> MonType {- monadifyType ctx t | trace ("\nmonadifyType:\n" ++ ppTermInTypeCtx ctx t) False = undefined @@ -417,15 +419,12 @@ monadifyType ctx (asDataType -> Just (pn, args)) -- and/or Nums MTyBase k_out $ dataTypeOpenTerm (primName pn) (map toArgType margs) monadifyType ctx (asVectorType -> Just (len, tp)) = - let lenOT = openOpenTerm (typeCtxPureCtx ctx) len in + let lenOT = monadifyTypeNat ctx len in MTySeq (ctorOpenTerm "Cryptol.TCNum" [lenOT]) $ monadifyType ctx tp -monadifyType ctx tp@(asApplyAll -> ((asGlobalDef -> Just seq_id), [n, a])) +monadifyType ctx (asApplyAll -> ((asGlobalDef -> Just seq_id), [n, a])) | seq_id == "Cryptol.seq" = - case monTypeNum (monadifyType ctx n) of - Just n_trm -> MTySeq n_trm (monadifyType ctx a) - Nothing -> - error ("Monadify type: not a number: " ++ ppTermInTypeCtx ctx n - ++ " in type: " ++ ppTermInTypeCtx ctx tp) + let nOT = monadifyTypeArgType ctx n in + MTySeq nOT $ monadifyType ctx a monadifyType ctx (asApp -> Just ((asGlobalDef -> Just f), arg)) | Just f_trans <- lookup f typeclassMonMap = MTyBase (MKType $ mkSort 1) $ @@ -442,9 +441,16 @@ monadifyType ctx (asApplyAll -> (f, args)) MTyBase k_out (applyOpenTermMulti (globalDefOpenTerm glob) $ map toArgType margs) -} -monadifyType ctx tp@(asCtor -> Just (pn, _)) - | primName pn == "Cryptol.TCNum" || primName pn == "Cryptol.TCInf" = - MTyNum $ openOpenTerm (typeCtxPureCtx ctx) tp +monadifyType _ (asCtor -> Just (pn, [])) + | primName pn == "Cryptol.TCInf" + = MTyNum $ ctorOpenTerm "Cryptol.TCInf" [] +monadifyType ctx (asCtor -> Just (pn, [n])) + | primName pn == "Cryptol.TCNum" + = MTyNum $ ctorOpenTerm "Cryptol.TCNum" [monadifyTypeNat ctx n] +monadifyType ctx (asApplyAll -> ((asGlobalDef -> Just f), args)) + | f `elem` typeLevelOpMonList = + MTyNum $ + applyOpenTermMulti (globalOpenTerm f) $ map (monadifyTypeArgType ctx) args monadifyType ctx (asLocalVar -> Just i) | i < length ctx , (_,_,Just tp) <- ctx!!i = tp @@ -452,6 +458,16 @@ monadifyType ctx tp = error ("monadifyType: not a valid type for monadification: " ++ ppTermInTypeCtx ctx tp) +-- | Monadify a type-level natural number +monadifyTypeNat :: HasCallStack => MonadifyTypeCtx -> Term -> OpenTerm +monadifyTypeNat _ (asNat -> Just n) = natOpenTerm n +monadifyTypeNat ctx (asLocalVar -> Just i) + | i < length ctx + , (_,_,Just tp) <- ctx!!i = toArgType tp +monadifyTypeNat ctx tp = + error ("monadifyTypeNat: not a valid natural number for monadification: " + ++ ppTermInTypeCtx ctx tp) + ---------------------------------------------------------------------- -- * Monadified Terms @@ -591,13 +607,21 @@ failArgMonTerm :: MonType -> String -> ArgMonTerm failArgMonTerm tp str = fromArgTerm tp (failOpenTerm str) -- | Apply a monadified term to a type or term argument -applyMonTerm :: MonTerm -> Either MonType ArgMonTerm -> MonTerm +applyMonTerm :: HasCallStack => MonTerm -> Either MonType ArgMonTerm -> MonTerm applyMonTerm (ArgMonTerm (FunMonTerm _ _ _ f)) (Right arg) = f arg applyMonTerm (ArgMonTerm (ForallMonTerm _ _ f)) (Left mtp) = f mtp -applyMonTerm _ _ = error "applyMonTerm: application at incorrect type" +applyMonTerm (ArgMonTerm (FunMonTerm _ _ _ _)) (Left _) = + error "applyMonTerm: application of term-level function to type-level argument" +applyMonTerm (ArgMonTerm (ForallMonTerm _ _ _)) (Right _) = + error "applyMonTerm: application of type-level function to term-level argument" +applyMonTerm (ArgMonTerm (BaseMonTerm _ _)) _ = + error "applyMonTerm: application of non-function base term" +applyMonTerm (CompMonTerm _ _) _ = + error "applyMonTerm: application of computational term" -- | Apply a monadified term to 0 or more arguments -applyMonTermMulti :: MonTerm -> [Either MonType ArgMonTerm] -> MonTerm +applyMonTermMulti :: HasCallStack => MonTerm -> [Either MonType ArgMonTerm] -> + MonTerm applyMonTermMulti = foldl applyMonTerm -- | Build a 'MonTerm' from a global of a given argument type @@ -814,13 +838,13 @@ assertIsFinite _ = ---------------------------------------------------------------------- -- | Monadify a type in the context of the 'MonadifyM' monad -monadifyTypeM :: Term -> MonadifyM MonType +monadifyTypeM :: HasCallStack => Term -> MonadifyM MonType monadifyTypeM tp = do ctx <- monStCtx <$> ask return $ monadifyType (ctxToTypeCtx ctx) tp -- | Monadify a term to a monadified term of argument type -monadifyArg :: Maybe MonType -> Term -> MonadifyM ArgMonTerm +monadifyArg :: HasCallStack => Maybe MonType -> Term -> MonadifyM ArgMonTerm {- monadifyArg _ t | trace ("Monadifying term of argument type: " ++ showTerm t) False @@ -832,7 +856,7 @@ monadifyArg mtp t = monadifyTerm' mtp t >>= argifyMonTerm -- | Monadify a term to argument type and convert back to a term -monadifyArgTerm :: Maybe MonType -> Term -> MonadifyM OpenTerm +monadifyArgTerm :: HasCallStack => Maybe MonType -> Term -> MonadifyM OpenTerm monadifyArgTerm mtp t = toArgTerm <$> monadifyArg mtp t -- | Monadify a term @@ -852,7 +876,7 @@ monadifyTerm mtp t = -- (i.e.,, lambdas, pairs, and records), but is optional for elimination forms -- (i.e., applications, projections, and also in this case variables). Note that -- this means monadification will fail on terms with beta or tuple redexes. -monadifyTerm' :: Maybe MonType -> Term -> MonadifyM MonTerm +monadifyTerm' :: HasCallStack => Maybe MonType -> Term -> MonadifyM MonTerm monadifyTerm' (Just mtp) t@(asLambda -> Just _) = ask >>= \(MonadifyROState { monStEnv = env, monStCtx = ctx }) -> return $ monadifyLambdas env ctx mtp t @@ -938,7 +962,7 @@ monadifyTerm' _ t = -- | Monadify the application of a monadified term to a list of terms, using the -- type of the already monadified to monadify the arguments -monadifyApply :: MonTerm -> [Term] -> MonadifyM MonTerm +monadifyApply :: HasCallStack => MonTerm -> [Term] -> MonadifyM MonTerm monadifyApply f (t : ts) | MTyArrow tp_in _ <- getMonType f = do mtrm <- monadifyArg (Just tp_in) t @@ -953,7 +977,8 @@ monadifyApply f [] = return f -- | FIXME: documentation; get our type down to a base type before going into -- the MonadifyM monad -monadifyLambdas :: MonadifyEnv -> MonadifyCtx -> MonType -> Term -> MonTerm +monadifyLambdas :: HasCallStack => MonadifyEnv -> MonadifyCtx -> + MonType -> Term -> MonTerm monadifyLambdas env ctx (MTyForall _ k tp_f) (asLambda -> Just (x, x_tp, body)) = -- FIXME: check that monadifyKind x_tp == k @@ -968,7 +993,8 @@ monadifyLambdas env ctx tp t = monadifyEtaExpand env ctx tp tp t [] -- | FIXME: documentation -monadifyEtaExpand :: MonadifyEnv -> MonadifyCtx -> MonType -> MonType -> Term -> +monadifyEtaExpand :: HasCallStack => MonadifyEnv -> MonadifyCtx -> + MonType -> MonType -> Term -> [Either MonType ArgMonTerm] -> MonTerm monadifyEtaExpand env ctx top_mtp (MTyForall x k tp_f) t args = ArgMonTerm $ ForallMonTerm x k $ \mtp -> diff --git a/heapster-saw/examples/arrays_mr_solver.saw b/heapster-saw/examples/arrays_mr_solver.saw index 5dfbb1fa9b..3100b82983 100644 --- a/heapster-saw/examples/arrays_mr_solver.saw +++ b/heapster-saw/examples/arrays_mr_solver.saw @@ -11,5 +11,5 @@ include "specPrims.saw"; import "arrays.cry"; zero_array <- parse_core_mod "arrays" "zero_array"; -// mr_solver_prove zero_array {{ zero_array_loop_spec }}; +mr_solver_test zero_array {{ zero_array_loop_spec }}; mr_solver_prove zero_array {{ zero_array_spec }}; diff --git a/heapster-saw/examples/sha512.bc b/heapster-saw/examples/sha512.bc index 711867222c..a8b922d8df 100644 Binary files a/heapster-saw/examples/sha512.bc and b/heapster-saw/examples/sha512.bc differ diff --git a/heapster-saw/examples/sha512.c b/heapster-saw/examples/sha512.c index a467b2ffaa..b5a9690c00 100644 --- a/heapster-saw/examples/sha512.c +++ b/heapster-saw/examples/sha512.c @@ -235,7 +235,7 @@ static void processBlock(uint64_t *a, uint64_t *b, uint64_t *c, uint64_t *d, const uint8_t *in) { uint64_t s0, s1, T1; uint64_t X[16]; - int i; + uint64_t i; T1 = X[0] = CRYPTO_load_u64_be(in); round_00_15(0, a, b, c, d, e, f, g, h, &T1); diff --git a/heapster-saw/examples/sha512.cry b/heapster-saw/examples/sha512.cry index c118a7c12b..75591adc34 100644 --- a/heapster-saw/examples/sha512.cry +++ b/heapster-saw/examples/sha512.cry @@ -1,6 +1,8 @@ module SHA512 where +import SpecPrims + // ============================================================================ // Definitions from cryptol-specs/Primitive/Keyless/Hash/SHA512.cry, with some // type annotations added to SIGMA_0, SIGMA_1, sigma_0, and sigma_1 to get @@ -83,3 +85,70 @@ round_16_80_spec i j a b c d e f g h X T1 = X' = update X (j && 15) T1' (a', b', c', d', e', f', g', h', T1'') = round_00_15_spec (i + j) a b c d e f g h T1' + +processBlock_spec : [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> + [16][w] -> + ([w], [w], [w], [w], [w], [w], [w], [w], [16][w]) +processBlock_spec a b c d e f g h in = + processBlock_loop_spec 16 aF bF cF dF eF fF gF hF X T1 in + where (a0,b0,c0,d0,e0,f0,g0,h0,_) = round_00_15_spec 0 a b c d e f g h (in @ ( 0 : [w])) + (h1,a1,b1,c1,d1,e1,f1,g1,_) = round_00_15_spec 1 h0 a0 b0 c0 d0 e0 f0 g0 (in @ ( 1 : [w])) + (g2,h2,a2,b2,c2,d2,e2,f2,_) = round_00_15_spec 2 g1 h1 a1 b1 c1 d1 e1 f1 (in @ ( 2 : [w])) + (f3,g3,h3,a3,b3,c3,d3,e3,_) = round_00_15_spec 3 f2 g2 h2 a2 b2 c2 d2 e2 (in @ ( 3 : [w])) + (e4,f4,g4,h4,a4,b4,c4,d4,_) = round_00_15_spec 4 e3 f3 g3 h3 a3 b3 c3 d3 (in @ ( 4 : [w])) + (d5,e5,f5,g5,h5,a5,b5,c5,_) = round_00_15_spec 5 d4 e4 f4 g4 h4 a4 b4 c4 (in @ ( 5 : [w])) + (c6,d6,e6,f6,g6,h6,a6,b6,_) = round_00_15_spec 6 c5 d5 e5 f5 g5 h5 a5 b5 (in @ ( 6 : [w])) + (b7,c7,d7,e7,f7,g7,h7,a7,_) = round_00_15_spec 7 b6 c6 d6 e6 f6 g6 h6 a6 (in @ ( 7 : [w])) + (a8,b8,c8,d8,e8,f8,g8,h8,_) = round_00_15_spec 8 a7 b7 c7 d7 e7 f7 g7 h7 (in @ ( 8 : [w])) + (h9,a9,b9,c9,d9,e9,f9,g9,_) = round_00_15_spec 9 h8 a8 b8 c8 d8 e8 f8 g8 (in @ ( 9 : [w])) + (gA,hA,aA,bA,cA,dA,eA,fA,_) = round_00_15_spec 10 g9 h9 a9 b9 c9 d9 e9 f9 (in @ (10 : [w])) + (fB,gB,hB,aB,bB,cB,dB,eB,_) = round_00_15_spec 11 fA gA hA aA bA cA dA eA (in @ (11 : [w])) + (eC,fC,gC,hC,aC,bC,cC,dC,_) = round_00_15_spec 12 eB fB gB hB aB bB cB dB (in @ (12 : [w])) + (dD,eD,fD,gD,hD,aD,bD,cD,_) = round_00_15_spec 13 dC eC fC gC hC aC bC cC (in @ (13 : [w])) + (cE,dE,eE,fE,gE,hE,aE,bE,_) = round_00_15_spec 14 cD dD eD fD gD hD aD bD (in @ (14 : [w])) + (bF,cF,dF,eF,fF,gF,hF,aF,T1) = round_00_15_spec 15 bE cE dE eE fE gE hE aE (in @ (15 : [w])) + X = [in @ ( 0 : [w]), in @ ( 1 : [w]), in @ ( 2 : [w]), in @ ( 3 : [w]), + in @ ( 4 : [w]), in @ ( 5 : [w]), in @ ( 6 : [w]), in @ ( 7 : [w]), + in @ ( 8 : [w]), in @ ( 9 : [w]), in @ (10 : [w]), in @ (11 : [w]), + in @ (12 : [w]), in @ (13 : [w]), in @ (14 : [w]), in @ (15 : [w])] + +processBlock_loop_spec : [w] -> + [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> + [16][w] -> [w] -> [16][w] -> + ([w], [w], [w], [w], [w], [w], [w], [w], [16][w]) +processBlock_loop_spec i a b c d e f g h X T1 in = + if i < 80 then processBlock_loop_spec (i+16) aF bF cF dF eF fF gF hF XF T1F in + else (a,b,c,d,e,f,g,h,in) + where (a0,b0,c0,d0,e0,f0,g0,h0,X0,_,_,T10) = round_16_80_spec i 0 a b c d e f g h X T1 + (h1,a1,b1,c1,d1,e1,f1,g1,X1,_,_,T11) = round_16_80_spec i 1 h0 a0 b0 c0 d0 e0 f0 g0 X0 T10 + (g2,h2,a2,b2,c2,d2,e2,f2,X2,_,_,T12) = round_16_80_spec i 2 g1 h1 a1 b1 c1 d1 e1 f1 X1 T11 + (f3,g3,h3,a3,b3,c3,d3,e3,X3,_,_,T13) = round_16_80_spec i 3 f2 g2 h2 a2 b2 c2 d2 e2 X2 T12 + (e4,f4,g4,h4,a4,b4,c4,d4,X4,_,_,T14) = round_16_80_spec i 4 e3 f3 g3 h3 a3 b3 c3 d3 X3 T13 + (d5,e5,f5,g5,h5,a5,b5,c5,X5,_,_,T15) = round_16_80_spec i 5 d4 e4 f4 g4 h4 a4 b4 c4 X4 T14 + (c6,d6,e6,f6,g6,h6,a6,b6,X6,_,_,T16) = round_16_80_spec i 6 c5 d5 e5 f5 g5 h5 a5 b5 X5 T15 + (b7,c7,d7,e7,f7,g7,h7,a7,X7,_,_,T17) = round_16_80_spec i 7 b6 c6 d6 e6 f6 g6 h6 a6 X6 T16 + (a8,b8,c8,d8,e8,f8,g8,h8,X8,_,_,T18) = round_16_80_spec i 8 a7 b7 c7 d7 e7 f7 g7 h7 X7 T17 + (h9,a9,b9,c9,d9,e9,f9,g9,X9,_,_,T19) = round_16_80_spec i 9 h8 a8 b8 c8 d8 e8 f8 g8 X8 T18 + (gA,hA,aA,bA,cA,dA,eA,fA,XA,_,_,T1A) = round_16_80_spec i 10 g9 h9 a9 b9 c9 d9 e9 f9 X9 T19 + (fB,gB,hB,aB,bB,cB,dB,eB,XB,_,_,T1B) = round_16_80_spec i 11 fA gA hA aA bA cA dA eA XA T1A + (eC,fC,gC,hC,aC,bC,cC,dC,XC,_,_,T1C) = round_16_80_spec i 12 eB fB gB hB aB bB cB dB XB T1B + (dD,eD,fD,gD,hD,aD,bD,cD,XD,_,_,T1D) = round_16_80_spec i 13 dC eC fC gC hC aC bC cC XC T1C + (cE,dE,eE,fE,gE,hE,aE,bE,XE,_,_,T1E) = round_16_80_spec i 14 cD dD eD fD gD hD aD bD XD T1D + (bF,cF,dF,eF,fF,gF,hF,aF,XF,_,_,T1F) = round_16_80_spec i 15 bE cE dE eE fE gE hE aE XE T1E + +processBlocks_spec : {n} Literal n [64] => [8][w] -> [16*n][w] -> + ([8][w], [16*n][w]) +processBlocks_spec state in = processBlocks_loop_spec 0 `n state in + +processBlocks_loop_spec : {n} Literal n [64] => [w] -> [w] -> [8][w] -> + [16*n][w] -> ([8][w], [16*n][w]) +processBlocks_loop_spec i j state in = invariantHint (i + j == `n) ( + if j != 0 then processBlocks_loop_spec (i+1) (j-1) state' in + else (state, in)) + where (a,b,c,d,e,f,g,h) = (state @ ( 0 : [w]), state @ ( 1 : [w]), + state @ ( 2 : [w]), state @ ( 3 : [w]), + state @ ( 4 : [w]), state @ ( 5 : [w]), + state @ ( 6 : [w]), state @ ( 7 : [w])) + in_i = split in @ i + (a',b',c',d',e',f',g',h',_) = processBlock_spec a b c d e f g h in_i + state' = [a', b', c', d', e', f', g', h'] diff --git a/heapster-saw/examples/sha512.saw b/heapster-saw/examples/sha512.saw index f5b470e4a5..6624a9f6fc 100644 --- a/heapster-saw/examples/sha512.saw +++ b/heapster-saw/examples/sha512.saw @@ -7,11 +7,17 @@ heapster_define_perm env "int64" " " "llvmptr 64" "exists x:bv 64.eq(llvmword(x) heapster_define_perm env "int32" " " "llvmptr 32" "exists x:bv 32.eq(llvmword(x))"; heapster_define_perm env "int8" " " "llvmptr 8" "exists x:bv 8.eq(llvmword(x))"; +// FIXME: We always have rw=W, but without the rw arguments below Heapster +// doesn't realize the perm is not copyable (it needs to unfold named perms). +heapster_define_perm env "int64_ptr" "rw:rwmodality" "llvmptr 64" "ptr((rw,0) |-> int64<>)"; +heapster_define_perm env "true_ptr" "rw:rwmodality" "llvmptr 64" "ptr((rw,0) |-> true)"; + heapster_assume_fun env "CRYPTO_load_u64_be" "(). arg0:ptr((R,0) |-> int64<>) -o \ \ arg0:ptr((R,0) |-> int64<>), ret:int64<>" "\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool * Vec 64 Bool) (x, x)"; +/* heapster_typecheck_fun env "return_state" "(). arg0:array(W,0,<8,*8,fieldsh(int64<>)) -o \ \ arg0:array(W,0,<8,*8,fieldsh(int64<>))"; @@ -24,5 +30,50 @@ heapster_typecheck_fun env "sha512_block_data_order" \ arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ \ arg2:true, ret:true"; +*/ + +heapster_typecheck_fun env "round_00_15" + "(). arg0:int64<>, \ + \ arg1:int64_ptr, arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, \ + \ arg5:int64_ptr, arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, \ + \ arg9:int64_ptr -o \ + \ arg1:int64_ptr, arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, \ + \ arg5:int64_ptr, arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, \ + \ arg9:int64_ptr, ret:true"; + +heapster_typecheck_fun env "round_16_80" + "(). arg0:int64<>, arg1:int64<>, \ + \ arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ + \ arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, arg9:int64_ptr, \ + \ arg10:array(W,0,<16,*8,fieldsh(int64<>)), \ + \ arg11:true_ptr, arg12:true_ptr, arg13:int64_ptr -o \ + \ arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ + \ arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, arg9:int64_ptr, \ + \ arg10:array(W,0,<16,*8,fieldsh(int64<>)), \ + \ arg11:int64_ptr, arg12:int64_ptr, arg13:int64_ptr, ret:true"; + +heapster_typecheck_fun env "return_X" + "(). arg0:array(W,0,<16,*8,fieldsh(int64<>)) -o \ + \ arg0:array(W,0,<16,*8,fieldsh(int64<>))"; + +heapster_set_translation_checks env false; +heapster_typecheck_fun env "processBlock" + "(). arg0:int64_ptr, arg1:int64_ptr, arg2:int64_ptr, \ + \ arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ + \ arg6:int64_ptr, arg7:int64_ptr, \ + \ arg8:array(R,0,<16,*8,fieldsh(int64<>)) -o \ + \ arg0:int64_ptr, arg1:int64_ptr, arg2:int64_ptr, \ + \ arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ + \ arg6:int64_ptr, arg7:int64_ptr, \ + \ arg8:array(R,0,<16,*8,fieldsh(int64<>)), ret:true"; + +heapster_set_translation_checks env false; +heapster_typecheck_fun env "processBlocks" + "(num:bv 64). arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ + \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ + \ arg2:eq(llvmword(num)) -o \ + \ arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ + \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ + \ arg2:true, ret:true"; heapster_export_coq env "sha512_gen.v"; diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index 928e7ab40f..bd7ea87192 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -1,84 +1,4 @@ -enable_experimental; -env <- heapster_init_env "SHA512" "sha512.bc"; - -// Heapster - -heapster_define_perm env "int64" " " "llvmptr 64" "exists x:bv 64.eq(llvmword(x))"; -heapster_define_perm env "int32" " " "llvmptr 32" "exists x:bv 32.eq(llvmword(x))"; -heapster_define_perm env "int8" " " "llvmptr 8" "exists x:bv 8.eq(llvmword(x))"; - -// FIXME: We always have rw=W, but without the rw arguments below Heapster -// doesn't realize the perm is not copyable (it needs to unfold named perms). -heapster_define_perm env "int64_ptr" "rw:rwmodality" "llvmptr 64" "ptr((rw,0) |-> int64<>)"; -heapster_define_perm env "true_ptr" "rw:rwmodality" "llvmptr 64" "ptr((rw,0) |-> true)"; - -heapster_assume_fun env "CRYPTO_load_u64_be" - "(). arg0:ptr((R,0) |-> int64<>) -o \ - \ arg0:ptr((R,0) |-> int64<>), ret:int64<>" - "\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool * Vec 64 Bool) (x, x)"; - -heapster_typecheck_fun env "round_00_15" - "(). arg0:int64<>, \ - \ arg1:int64_ptr, arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, \ - \ arg5:int64_ptr, arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, \ - \ arg9:int64_ptr -o \ - \ arg1:int64_ptr, arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, \ - \ arg5:int64_ptr, arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, \ - \ arg9:int64_ptr, ret:true"; - -heapster_typecheck_fun env "round_16_80" - "(). arg0:int64<>, arg1:int64<>, \ - \ arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ - \ arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, arg9:int64_ptr, \ - \ arg10:array(W,0,<16,*8,fieldsh(int64<>)), \ - \ arg11:true_ptr, arg12:true_ptr, arg13:int64_ptr -o \ - \ arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ - \ arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, arg9:int64_ptr, \ - \ arg10:array(W,0,<16,*8,fieldsh(int64<>)), \ - \ arg11:int64_ptr, arg12:int64_ptr, arg13:int64_ptr, ret:true"; - -heapster_typecheck_fun env "return_X" - "(). arg0:array(W,0,<16,*8,fieldsh(int64<>)) -o \ - \ arg0:array(W,0,<16,*8,fieldsh(int64<>))"; - -heapster_set_translation_checks env false; -heapster_typecheck_fun env "processBlock" - "(). arg0:int64_ptr, arg1:int64_ptr, arg2:int64_ptr, \ - \ arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ - \ arg6:int64_ptr, arg7:int64_ptr, \ - \ arg8:array(R,0,<16,*8,fieldsh(int64<>)) -o \ - \ arg0:int64_ptr, arg1:int64_ptr, arg2:int64_ptr, \ - \ arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ - \ arg6:int64_ptr, arg7:int64_ptr, \ - \ arg8:array(R,0,<16,*8,fieldsh(int64<>)), ret:true"; - -// FIXME: This translation contains errors -heapster_set_translation_checks env false; -heapster_typecheck_fun env "processBlocks" - "(num:bv 64). arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ - \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ - \ arg2:eq(llvmword(num)) -o \ - \ arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ - \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ - \ arg2:true, ret:true"; - -heapster_export_coq env "sha512_mr_solver_gen.v"; - -// Mr. Solver - -let eq_bool b1 b2 = - if b1 then - if b2 then true else false - else - if b2 then false else true; - -let fail = do { print "Test failed"; exit 1; }; -let run_test name test expected = - do { if expected then print (str_concat "Test: " name) else - print (str_concat (str_concat "Test: " name) " (expecting failure)"); - actual <- test; - if eq_bool actual expected then print "Success\n" else - do { print "Test failed\n"; exit 1; }; }; +include "sha512.saw"; round_00_15 <- parse_core_mod "SHA512" "round_00_15"; round_16_80 <- parse_core_mod "SHA512" "round_16_80"; @@ -100,10 +20,14 @@ monadify_term {{ sigma_0 }}; monadify_term {{ sigma_1 }}; monadify_term {{ Ch }}; monadify_term {{ Maj }}; - -// FIXME: Why does monadification fail without this line while running -// "round_16_80 |= round_16_80_spec"? monadify_term {{ round_00_15_spec }}; +monadify_term {{ round_16_80_spec }}; +monadify_term {{ processBlock_loop_spec }}; +monadify_term {{ processBlock_spec }}; +monadify_term {{ processBlocks_loop_spec }}; +monadify_term {{ processBlocks_spec }}; mr_solver_prove round_00_15 {{ round_00_15_spec }}; mr_solver_prove round_16_80 {{ round_16_80_spec }}; +mr_solver_prove processBlock {{ processBlock_spec }}; +// mr_solver_prove processBlocks {{ processBlocks_spec }}; diff --git a/heapster-saw/examples/sha512_proofs.v b/heapster-saw/examples/sha512_proofs.v index 143fa9204c..ad4fd7101b 100644 --- a/heapster-saw/examples/sha512_proofs.v +++ b/heapster-saw/examples/sha512_proofs.v @@ -13,15 +13,3 @@ Import SAWCorePrelude. Require Import Examples.sha512_gen. Import SHA512. - -Definition sha512_block_data_order_precond num := isBvslt 64 (intToBv 64 0) num. - -Lemma no_errors_sha512_block_data_order : - refinesFun sha512_block_data_order - (fun num _ _ => assumingM (sha512_block_data_order_precond num) noErrorsSpec). -Proof. - unfold sha512_block_data_order, sha512_block_data_order__tuple_fun. - (* time "sha512_block_data_order (1)" prove_refinement_match_letRecM_l. *) - (* 1-2: intros; apply noErrorsSpec. *) - (* time "sha512_block_data_order (2)" prove_refinement. *) -Admitted. diff --git a/saw-core/src/Verifier/SAW/SharedTerm.hs b/saw-core/src/Verifier/SAW/SharedTerm.hs index 4a8cde027a..d0cfae71e7 100644 --- a/saw-core/src/Verifier/SAW/SharedTerm.hs +++ b/saw-core/src/Verifier/SAW/SharedTerm.hs @@ -212,6 +212,7 @@ module Verifier.SAW.SharedTerm , scBvToNat , scBvAt , scBvConst + , scBvLit , scFinVal , scBvForall , scUpdBvFun @@ -2019,14 +2020,23 @@ scBvToNat sc n x = do n' <- scNat sc n scGlobalApply sc "Prelude.bvToNat" [n',x] --- | Create a term computing a bitvector of the given length representing the --- given 'Integer' value (if possible). +-- | Create a @bvNat@ term computing a bitvector of the given length +-- representing the given 'Integer' value (if possible). scBvConst :: SharedContext -> Natural -> Integer -> IO Term scBvConst sc w v = assert (w <= fromIntegral (maxBound :: Int)) $ do x <- scNat sc w y <- scNat sc $ fromInteger $ v .&. (1 `shiftL` fromIntegral w - 1) scGlobalApply sc "Prelude.bvNat" [x, y] +-- | Create a vector literal term computing a bitvector of the given length +-- representing the given 'Integer' value (if possible). +scBvLit :: SharedContext -> Natural -> Integer -> IO Term +scBvLit sc w v = assert (w <= fromIntegral (maxBound :: Int)) $ do + do bool_tp <- scBoolType sc + bits <- mapM (scBool sc . testBit v) + [(fromIntegral w - 1), (fromIntegral w - 2) .. 0] + scVector sc bool_tp bits + -- TODO: This doesn't appear to be used anywhere, and "FinVal" doesn't appear -- in Prelude.sawcore... can this be deleted? -- | FinVal :: (x r :: Nat) -> Fin (Succ (addNat r x)); diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index b2dafb2d43..36044e1bb7 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -1681,13 +1681,14 @@ ensureMonadicTerm sc t False -> monadifyTypedTerm sc t ensureMonadicTerm sc t = monadifyTypedTerm sc t --- | A wrapper for 'Prover.askMRSolver' from @MRSolver.hs@ which if the first --- argument is @Just str@, prints out @str@ followed by an abridged version --- of the refinement being asked -askMRSolver :: Maybe SawDoc -> SharedContext -> TypedTerm -> TypedTerm -> - TopLevel (NominalDiffTime, - Either Prover.MRFailure Prover.MRSolverResult) -askMRSolver printStr sc t1 t2 = +-- | A wrapper for either 'Prover.askMRSolver' or 'Prover.assumeMRSolver' from +-- @MRSolver.hs@: if the first argument is @Just str@, prints out @str@ +-- followed by an abridged version of the refinement being asked, then calls +-- the given function, returning how long it took to execute +mrSolver :: (SharedContext -> Prover.MREnv -> Maybe Integer -> Term -> Term -> IO a) -> + Maybe SawDoc -> SharedContext -> TypedTerm -> TypedTerm -> + TopLevel (NominalDiffTime, a) +mrSolver f printStr sc t1 t2 = do env <- rwMRSolverEnv <$> get m1 <- collapseEta <$> ttTerm <$> ensureMonadicTerm sc t1 m2 <- collapseEta <$> ttTerm <$> ensureMonadicTerm sc t2 @@ -1697,7 +1698,7 @@ askMRSolver printStr sc t1 t2 = "[MRSolver] " <> str <> ": " <> ppTmHead m1 <> " |= " <> ppTmHead m2 time1 <- liftIO getCurrentTime - res <- io $ Prover.askMRSolver sc env Nothing m1 m2 + res <- io $ f sc env Nothing m1 m2 time2 <- liftIO getCurrentTime return (diffUTCTime time2 time1, res) where -- Turn a term of the form @\x1 ... xn -> f x1 ... xn@ into @f@ @@ -1726,7 +1727,7 @@ mrSolverProve :: Bool -> SharedContext -> TypedTerm -> TypedTerm -> TopLevel () mrSolverProve addToEnv sc t1 t2 = do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get let printStr = if addToEnv then "Proving" else "Testing" - (diff, res) <- askMRSolver (Just printStr) sc t1 t2 + (diff, res) <- mrSolver Prover.askMRSolver (Just printStr) sc t1 t2 case res of Left err | dlvl == 0 -> io (putStrLn $ Prover.showMRFailure err) >> @@ -1755,7 +1756,7 @@ mrSolverProve addToEnv sc t1 t2 = mrSolverQuery :: SharedContext -> TypedTerm -> TypedTerm -> TopLevel Bool mrSolverQuery sc t1 t2 = do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get - (diff, res) <- askMRSolver (Just "Querying") sc t1 t2 + (diff, res) <- mrSolver Prover.askMRSolver (Just "Querying") sc t1 t2 case res of Left _ | dlvl == 0 -> printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> @@ -1770,6 +1771,33 @@ mrSolverQuery sc t1 t2 = printOutLnTop Info (printf "[MRSolver] Success in %s" (show diff)) >> return True +-- | Generate the 'Prover.FunAssump' which corresponds to the given refinement +-- and add it to the 'Prover.MREnv' +mrSolverAssume :: SharedContext -> TypedTerm -> TypedTerm -> TopLevel () +mrSolverAssume sc t1 t2 = + do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get + (_, res) <- mrSolver Prover.assumeMRSolver (Just "Assuming") sc t1 t2 + case res of + Left err | dlvl == 0 -> + io (putStrLn $ Prover.showMRFailure err) >> + printOutLnTop Info (printf "[MRSolver] Failure") >> + io (Exit.exitWith $ Exit.ExitFailure 1) + Left err -> + -- we ignore the MRFailure context here since it will have already + -- been printed by the debug trace + io (putStrLn $ Prover.showMRFailureNoCtx err) >> + printOutLnTop Info (printf "[MRSolver] Failure") >> + io (Exit.exitWith $ Exit.ExitFailure 1) + Right (Just (fnm, fassump)) -> + printOutLnTop Info ( + printf "[MRSolver] Success, added as an opaque assumption") >> + modify (\rw -> rw { rwMRSolverEnv = + Prover.mrEnvAddFunAssump fnm fassump (rwMRSolverEnv rw) }) + _ -> + printOutLnTop Info $ printf $ + "[MRSolver] Failure, given refinement cannot be interpreted as" ++ + " an assumption" + -- | Set the debug level of the 'Prover.MREnv' mrSolverSetDebug :: Int -> TopLevel () mrSolverSetDebug dlvl = diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index f5e6c289d4..4b76061075 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -3350,6 +3350,12 @@ primitives = Map.fromList , " be considered in future calls to Mr. Solver, and unlike both," , " this command will never fail." ] + , prim "mr_solver_assume" "Term -> Term -> TopLevel Bool" + (scVal mrSolverAssume) + Experimental + [ "Add the refinement of the two given expressions as an assumption" + , " which will be used in future calls to Mr. Solver." ] + , prim "mr_solver_set_debug_level" "Int -> TopLevel ()" (pureVal mrSolverSetDebug) Experimental diff --git a/src/SAWScript/Prover/MRSolver.hs b/src/SAWScript/Prover/MRSolver.hs index b422cfd996..32760eb2db 100644 --- a/src/SAWScript/Prover/MRSolver.hs +++ b/src/SAWScript/Prover/MRSolver.hs @@ -9,7 +9,7 @@ Portability : non-portable (language extensions) -} module SAWScript.Prover.MRSolver - (askMRSolver, MRSolverResult, + (askMRSolver, assumeMRSolver, MRSolverResult, MRFailure(..), showMRFailure, showMRFailureNoCtx, FunAssump(..), FunAssumpRHS(..), MREnv(..), emptyMREnv, mrEnvAddFunAssump, mrEnvSetDebugLevel, diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index dbb20fd7e7..c7755839ea 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -26,7 +26,6 @@ module SAWScript.Prover.MRSolver.Monad where import Data.List (find, findIndex, foldl') import qualified Data.Text as T import Numeric.Natural (Natural) -import Data.Bits (testBit) import System.IO (hPutStrLn, stderr) import Control.Monad.Reader import Control.Monad.State @@ -60,13 +59,14 @@ import SAWScript.Prover.MRSolver.Term -- | The context in which a failure occurred data FailCtx = FailCtxRefines NormComp NormComp + | FailCtxCoIndHyp CoIndHyp | FailCtxMNF Term deriving Show -- | That's MR. Failure to you data MRFailure = TermsNotRel Bool Term Term - | TypesNotEq Type Type + | TypesNotRel Bool Type Type | CompsDoNotRefine NormComp NormComp | ReturnNotError Term | FunsNotEq FunName FunName @@ -90,10 +90,14 @@ data MRFailure pattern TermsNotEq :: Term -> Term -> MRFailure pattern TermsNotEq t1 t2 = TermsNotRel False t1 t2 +pattern TypesNotEq :: Type -> Type -> MRFailure +pattern TypesNotEq t1 t2 = TypesNotRel False t1 t2 + -- | Remove the context from a 'MRFailure', i.e. remove all applications of the -- 'MRFailureLocalVar' and 'MRFailureCtx' constructors mrFailureWithoutCtx :: MRFailure -> MRFailure -mrFailureWithoutCtx (MRFailureLocalVar _ err) = mrFailureWithoutCtx err +mrFailureWithoutCtx (MRFailureLocalVar x err) = + MRFailureLocalVar x (mrFailureWithoutCtx err) mrFailureWithoutCtx (MRFailureCtx _ err) = mrFailureWithoutCtx err mrFailureWithoutCtx (MRFailureDisj err1 err2) = MRFailureDisj (mrFailureWithoutCtx err1) (mrFailureWithoutCtx err2) @@ -119,6 +123,9 @@ instance PrettyInCtx FailCtx where prettyInCtx (FailCtxRefines m1 m2) = group <$> nest 2 <$> ppWithPrefixSep "When proving refinement:" m1 "|=" m2 + prettyInCtx (FailCtxCoIndHyp hyp) = + group <$> nest 2 <$> + ppWithPrefix "When doing co-induction with hypothesis:" hyp prettyInCtx (FailCtxMNF t) = group <$> nest 2 <$> vsepM [return "When normalizing computation:", prettyInCtx t] @@ -128,8 +135,10 @@ instance PrettyInCtx MRFailure where ppWithPrefixSep "Could not prove terms equal:" t1 "and" t2 prettyInCtx (TermsNotRel True t1 t2) = ppWithPrefixSep "Could not prove terms heterogeneously related:" t1 "and" t2 - prettyInCtx (TypesNotEq tp1 tp2) = + prettyInCtx (TypesNotRel False tp1 tp2) = ppWithPrefixSep "Types not equal:" tp1 "and" tp2 + prettyInCtx (TypesNotRel True tp1 tp2) = + ppWithPrefixSep "Types not heterogeneously related:" tp1 "and" tp2 prettyInCtx (CompsDoNotRefine m1 m2) = ppWithPrefixSep "Could not prove refinement: " m1 "|=" m2 prettyInCtx (ReturnNotError t) = @@ -237,6 +246,28 @@ coIndHypArg :: CoIndHyp -> Either Int Int -> Term coIndHypArg hyp (Left i) = (coIndHypLHS hyp) !! i coIndHypArg hyp (Right i) = (coIndHypRHS hyp) !! i +-- | Set the @i@th argument on either the left- or right-hand side of a +-- coinductive hypothesis to the given value +coIndHypSetArg :: CoIndHyp -> Either Int Int -> Term -> CoIndHyp +coIndHypSetArg hyp@(CoIndHyp {..}) (Left i) x = + hyp { coIndHypLHS = take i coIndHypLHS ++ x : drop (i+1) coIndHypLHS } +coIndHypSetArg hyp@(CoIndHyp {..}) (Right i) x = + hyp { coIndHypRHS = take i coIndHypRHS ++ x : drop (i+1) coIndHypRHS } + +-- | Set all of the arguments in the given list to the given value in a +-- coinductive hypothesis, using 'coIndHypSetArg' +coIndHypSetArgs :: CoIndHyp -> [Either Int Int] -> Term -> CoIndHyp +coIndHypSetArgs hyp specs x = + foldl' (\hyp' spec -> coIndHypSetArg hyp' spec x) hyp specs + +-- | Add a variable to the context of a coinductive hypothesis, returning the +-- updated coinductive hypothesis and a 'Term' which is the new variable +coIndHypWithVar :: CoIndHyp -> LocalName -> Type -> MRM (CoIndHyp, Term) +coIndHypWithVar (CoIndHyp ctx f1 f2 args1 args2 invar1 invar2) nm (Type tp) = + do var <- liftSC1 scLocalVar 0 + (args1', args2') <- liftTermLike 0 1 (args1, args2) + return (CoIndHyp (ctx ++ [(nm,tp)]) f1 f2 args1' args2' invar1 invar2, var) + -- | A map from pairs of function names to co-inductive hypotheses over those -- names type CoIndHyps = Map (FunName, FunName) CoIndHyp @@ -251,9 +282,9 @@ instance PrettyInCtx CoIndHyp where (case invar2 of Just f -> prettyTermApp f args2 Nothing -> return "True"), return "=>", - prettyInCtx (FunBind f1 args1 CompFunReturn), + prettyTermApp (funNameTerm f1) args1, return "|=", - prettyInCtx (FunBind f2 args2 CompFunReturn)] + prettyTermApp (funNameTerm f2) args2] -- | An assumption that something is equal to one of the constructors of a -- datatype, e.g. equal to @Left@ of some 'Term' or @Right@ of some 'Term' @@ -267,10 +298,6 @@ instance PrettyInCtx DataTypeAssump where prettyInCtx (IsNum x) = prettyInCtx x >>= ppWithPrefix "TCNum" prettyInCtx IsInf = return "TCInf" --- | Create a term representing the type @IsFinite n@ -mrIsFinite :: Term -> MRM Term -mrIsFinite n = liftSC2 scGlobalApply "CryptolM.isFinite" [n] - -- | A map from 'Term's to 'DataTypeAssump's over that term type DataTypeAssumps = HashMap Term DataTypeAssump @@ -445,6 +472,84 @@ liftSC5 :: (SharedContext -> a -> b -> c -> d -> e -> IO f) -> liftSC5 f a b c d e = mrSC >>= \sc -> liftIO (f sc a b c d e) +---------------------------------------------------------------------- +-- * Relating Types Heterogeneously +---------------------------------------------------------------------- + +-- | A datatype encapsulating all the way in which we consider two types to +-- be heterogeneously related: either one is a @Num@ and the other is a @Nat@, +-- one is a @BVVec@ and the other is a non-@BVVec@ vector (of the same length, +-- which must be checked where 'matchHet' is used), or both sides are pairs +-- (whose components are respectively heterogeneously related, which must be +-- checked where 'matchHet' is used). See 'typesHetRelated' for an example. +data HetRelated = HetBVNum Natural + | HetNumBV Natural + | HetBVVecVec (Term, Term, Term) (Term, Term) + | HetVecBVVec (Term, Term) (Term, Term, Term) + | HetPair (Term, Term) (Term, Term) + +-- | Check to see if the given types match one of the cases of 'HetRelated' +matchHet :: Term -> Term -> Maybe HetRelated +matchHet (asBitvectorType -> Just n) + (asDataType -> Just (primName -> "Cryptol.Num", _)) = + Just $ HetBVNum n +matchHet (asDataType -> Just (primName -> "Cryptol.Num", _)) + (asBitvectorType -> Just n) = + Just $ HetNumBV n +matchHet (asBVVecType -> Just (n, len, a)) + (asNonBVVecVectorType -> Just (m, a')) = + Just $ HetBVVecVec (n, len, a) (m, a') +matchHet (asNonBVVecVectorType -> Just (m, a')) + (asBVVecType -> Just (n, len, a)) = + Just $ HetVecBVVec (m, a') (n, len, a) +matchHet (asPairType -> Just (tpL1, tpR1)) + (asPairType -> Just (tpL2, tpR2)) = + Just $ HetPair (tpL1, tpR1) (tpL2, tpR2) +matchHet _ _ = Nothing + +-- | Return true iff the given types are heterogeneously related +typesHetRelated :: Term -> Term -> MRM Bool +typesHetRelated tp1 tp2 = case matchHet tp1 tp2 of + Just (HetBVNum _) -> return True + Just (HetNumBV _) -> return True + Just (HetBVVecVec (n, len, a) (m, a')) -> mrBvToNat n len >>= \m' -> + (&&) <$> mrConvertible m m' <*> typesHetRelated a a' + Just (HetVecBVVec (m, a') (n, len, a)) -> mrBvToNat n len >>= \m' -> + (&&) <$> mrConvertible m m' <*> typesHetRelated a a' + Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) -> + (&&) <$> typesHetRelated tpL1 tpL2 <*> typesHetRelated tpR1 tpR2 + Nothing -> mrConvertible tp1 tp2 + + +---------------------------------------------------------------------- +-- * Functions for Building Terms +---------------------------------------------------------------------- + +-- | Create a term representing the type @IsFinite n@ +mrIsFinite :: Term -> MRM Term +mrIsFinite n = liftSC2 scGlobalApply "CryptolM.isFinite" [n] + +-- | Create a term representing an application of @Prelude.error@ +mrErrorTerm :: Term -> T.Text -> MRM Term +mrErrorTerm a str = + do err_str <- liftSC1 scString str + liftSC2 scGlobalApply "Prelude.error" [a, err_str] + +-- | Create a term representing an application of @Prelude.genBVVecFromVec@, +-- where the default value argument is @Prelude.error@ of the given 'T.Text' +mrGenBVVecFromVec :: Term -> Term -> Term -> T.Text -> Term -> Term -> MRM Term +mrGenBVVecFromVec m a v def_err_str n len = + do err_tm <- mrErrorTerm a def_err_str + liftSC2 scGlobalApply "Prelude.genBVVecFromVec" [m, a, v, err_tm, n, len] + +-- | Create a term representing an application of @Prelude.genFromBVVec@, +-- where the default value argument is @Prelude.error@ of the given 'T.Text' +mrGenFromBVVec :: Term -> Term -> Term -> Term -> T.Text -> Term -> MRM Term +mrGenFromBVVec n len a v def_err_str m = + do err_tm <- mrErrorTerm a def_err_str + liftSC2 scGlobalApply "Prelude.genFromBVVec" [n, len, a, v, err_tm, m] + + ---------------------------------------------------------------------- -- * Monadic Operations on Terms ---------------------------------------------------------------------- @@ -480,6 +585,10 @@ funNameType (GlobalName gd projs) = mrApplyAll :: Term -> [Term] -> MRM Term mrApplyAll f args = liftSC2 scApplyAllBeta f args +-- | Apply a 'Term' to a single argument and beta-reduce in Mr. Monad +mrApply :: Term -> Term -> MRM Term +mrApply f arg = mrApplyAll f [arg] + -- | Like 'scBvNat', but if given a bitvector literal it is converted to a -- natural number literal mrBvToNat :: Term -> Term -> MRM Term @@ -488,14 +597,6 @@ mrBvToNat _ (asArrayValue -> Just (asBoolType -> Just _, liftSC1 scNat $ foldl' (\n bit -> if bit then 2*n+1 else 2*n) 0 bits mrBvToNat n len = liftSC2 scBvNat n len --- | Like 'scBvConst', but returns a bitvector literal -mrBvConst :: Natural -> Integer -> MRM Term -mrBvConst n x = - do bool_tp <- liftSC0 scBoolType - bits <- mapM (liftSC1 scBool . testBit x) - [(fromIntegral n - 1), (fromIntegral n - 2) .. 0] - liftSC2 scVector bool_tp bits - -- | Get the current context of uvars as a list of variable names and their -- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in -- the order as seen "from the outside" @@ -550,6 +651,18 @@ uniquifyNames (nm:nms) nms_other = let nm' = uniquifyName nm nms_other in nm' : uniquifyNames nms (nm' : nms_other) +-- | Build a lambda term with the lifting (in the sense of 'incVars') of an +-- MR Solver term +mrLambdaLift :: TermLike tm => [(LocalName,Term)] -> tm -> + ([Term] -> tm -> MRM Term) -> MRM Term +mrLambdaLift [] t f = f [] t +mrLambdaLift ctx t f = + do nms <- uniquifyNames (map fst ctx) <$> map fst <$> mrUVars + let ctx' = zipWith (\nm (_,tp) -> (nm,tp)) nms ctx + vars <- reverse <$> mapM (liftSC1 scLocalVar) [0 .. length ctx - 1] + t' <- liftTermLike 0 (length ctx) t + f vars t' >>= liftSC2 scLambdaList ctx' + -- | Run a MR Solver computation in a context extended with a universal -- variable, which is passed as a 'Term' to the sub-computation. Note that any -- assumptions made in the sub-computation will be lost when it completes. @@ -712,6 +825,11 @@ mrCallsFun f = memoFixTermFun $ \recurse t -> case t of (unwrapTermF -> tf) -> foldM (\b t' -> if b then return b else recurse t') False tf + +---------------------------------------------------------------------- +-- * Monadic Operations on Mr. Solver State +---------------------------------------------------------------------- + -- | Make a fresh 'MRVar' of a given type, which must be closed, i.e., have no -- free uvars mrFreshVar :: LocalName -> Term -> MRM MRVar @@ -903,8 +1021,8 @@ mrGetFunAssump nm = Map.lookup nm <$> mrFunAssumps -- are 'Term's that can have the current uvars free withFunAssump :: FunName -> [Term] -> NormComp -> MRM a -> MRM a withFunAssump fname args rhs m = - do mrDebugPPPrefixSep 1 "withFunAssump" (FunBind - fname args CompFunReturn) "|=" rhs + do k <- CompFunReturn <$> Type <$> mrFunOutType fname args + mrDebugPPPrefixSep 1 "withFunAssump" (FunBind fname args k) "|=" rhs ctx <- mrUVarCtx assumps <- mrFunAssumps let assump = FunAssump ctx args (RewriteFunAssump rhs) @@ -980,6 +1098,17 @@ withDataTypeAssump x assump m = mrGetDataTypeAssump :: Term -> MRM (Maybe DataTypeAssump) mrGetDataTypeAssump x = HashMap.lookup x <$> mrDataTypeAssumps +-- | Convert a 'FunAssumpRHS' to a 'NormComp' +mrFunAssumpRHSAsNormComp :: FunAssumpRHS -> MRM NormComp +mrFunAssumpRHSAsNormComp (OpaqueFunAssump f args) = + FunBind f args <$> CompFunReturn <$> Type <$> mrFunOutType f args +mrFunAssumpRHSAsNormComp (RewriteFunAssump rhs) = return rhs + + +---------------------------------------------------------------------- +-- * Functions for Debug Output +---------------------------------------------------------------------- + -- | Print a 'String' if the debug level is at least the supplied 'Int' debugPrint :: Int -> String -> MRM () debugPrint i str = diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 70bf769020..29142dc90c 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -118,7 +118,7 @@ primBVTermFun sc = VExtra (VExtraTerm _ w_tm) -> return w_tm VWord (Left (_,w_tm)) -> return w_tm VWord (Right bv) -> - lift $ scBvConst sc (fromIntegral (Prim.width bv)) (Prim.unsigned bv) + lift $ scBvLit sc (fromIntegral (Prim.width bv)) (Prim.unsigned bv) VVector vs -> lift $ do tms <- traverse (boolValToTerm sc <=< force) (V.toList vs) @@ -350,20 +350,30 @@ mrEq t1 t2 = mrTypeOf t1 >>= \tp -> mrEq' tp t1 t2 -- are equal, where the first 'Term' gives their type (which we assume is the -- same for both). This is like 'scEq' except that it works on open terms. mrEq' :: Term -> Term -> Term -> MRM Term +-- FIXME: For this Nat case, the definition of 'equalNat' in @Prims.hs@ means +-- that if both sides do not have immediately clear bit-widths (e.g. either +-- side is is an application of @mulNat@) this will 'error'... mrEq' (asNatType -> Just _) t1 t2 = liftSC2 scEqualNat t1 t2 mrEq' (asBoolType -> Just _) t1 t2 = liftSC2 scBoolEq t1 t2 mrEq' (asIntegerType -> Just _) t1 t2 = liftSC2 scIntEq t1 t2 mrEq' (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = liftSC3 scBvEq n t1 t2 +mrEq' (asDataType -> Just (primName -> "Cryptol.Num", _)) t1 t2 = + liftSC1 scWhnf t1 >>= \t1' -> liftSC1 scWhnf t2 >>= \t2' -> case (t1', t2') of + (asCtor -> Just (primName -> "Cryptol.TCNum", [t1'']), + asCtor -> Just (primName -> "Cryptol.TCNum", [t2''])) -> + liftSC0 scNatType >>= \nat_tp -> mrEq' nat_tp t1'' t2'' + _ -> error "mrEq': Num terms do not normalize to TCNum constructors" mrEq' _ _ _ = error "mrEq': unsupported type" -- | A 'Term' in an extended context of universal variables, which are listed -- "outside in", meaning the highest deBruijn index comes first data TermInCtx = TermInCtx [(LocalName,Term)] Term --- | Conjoin two 'TermInCtx's, assuming they both have Boolean type -andTermInCtx :: TermInCtx -> TermInCtx -> MRM TermInCtx -andTermInCtx (TermInCtx ctx1 t1) (TermInCtx ctx2 t2) = +-- | Lift a binary operation on 'Term's to one on 'TermInCtx's +liftTermInCtx2 :: (SharedContext -> Term -> Term -> IO Term) -> + TermInCtx -> TermInCtx -> MRM TermInCtx +liftTermInCtx2 op (TermInCtx ctx1 t1) (TermInCtx ctx2 t2) = do -- Insert the variables in ctx2 into the context of t1 starting at index 0, -- by lifting its variables starting at 0 by length ctx2 @@ -372,7 +382,7 @@ andTermInCtx (TermInCtx ctx1 t1) (TermInCtx ctx2 t2) = -- length ctx2, by lifting its variables starting at length ctx2 by length -- ctx1 t2' <- liftTermLike (length ctx2) (length ctx1) t2 - TermInCtx (ctx1++ctx2) <$> liftSC2 scAnd t1' t2' + TermInCtx (ctx1++ctx2) <$> liftSC2 op t1' t2' -- | Extend the context of a 'TermInCtx' with additional universal variables -- bound "outside" the 'TermInCtx' @@ -415,12 +425,12 @@ mrAssertProveEq = mrAssertProveRel False mrProveRel :: Bool -> Term -> Term -> MRM Bool mrProveRel het t1 t2 = do let nm = if het then "mrProveRel" else "mrProveEq" - mrDebugPPPrefixSep 1 nm t1 (if het then "~=" else "==") t2 + mrDebugPPPrefixSep 2 nm t1 (if het then "~=" else "==") t2 tp1 <- mrTypeOf t1 >>= mrSubstEVars tp2 <- mrTypeOf t2 >>= mrSubstEVars cond_in_ctx <- mrProveRelH het tp1 tp2 t1 t2 res <- withTermInCtx cond_in_ctx mrProvable - debugPrint 1 $ nm ++ ": " ++ if res then "Success" else "Failure" + debugPrint 2 $ nm ++ ": " ++ if res then "Success" else "Failure" return res -- | Prove that two terms are related, heterogeneously iff the first argument, @@ -461,7 +471,7 @@ mrProveRelH' var_map _ tp1 tp2 (asEVarApp var_map -> Just (evar, args, Nothing)) t2' <- mrSubstEVars t2 success <- mrTrySetAppliedEVar evar args t2' when success $ - mrDebugPPPrefixSep 2 "mrProveRelH setting evar" evar "to" t2 + mrDebugPPPrefixSep 1 "setting evar" evar "to" t2 TermInCtx [] <$> liftSC1 scBool success -- If t2 is an instantiated evar, substitute and recurse @@ -477,7 +487,7 @@ mrProveRelH' var_map _ tp1 tp2 t1 (asEVarApp var_map -> Just (evar, args, Nothin t1' <- mrSubstEVars t1 success <- mrTrySetAppliedEVar evar args t1' when success $ - mrDebugPPPrefixSep 2 "mrProveRelH setting evar" evar "to" t1 + mrDebugPPPrefixSep 1 "setting evar" evar "to" t1 TermInCtx [] <$> liftSC1 scBool success -- For unit types, always return true @@ -498,17 +508,6 @@ mrProveRelH' _ _ (asBoolType -> Just _) (asBoolType -> Just _) t1 t2 = mrProveRelH' _ _ (asIntegerType -> Just _) (asIntegerType -> Just _) t1 t2 = mrProveEqSimple (liftSC2 scIntEq) t1 t2 --- For pair types, prove both the left and right projections are related -mrProveRelH' _ het (asPairType -> Just (tpL1, tpR1)) - (asPairType -> Just (tpL2, tpR2)) t1 t2 = - do t1L <- liftSC1 scPairLeft t1 - t2L <- liftSC1 scPairLeft t2 - t1R <- liftSC1 scPairRight t1 - t2R <- liftSC1 scPairRight t2 - condL <- mrProveRelH het tpL1 tpL2 t1L t2L - condR <- mrProveRelH het tpR1 tpR2 t1R t2R - andTermInCtx condL condR - -- For BVVec types, prove all projections are related by quantifying over an -- index variable and proving the projections at that index are related mrProveRelH' _ het tp1@(asBVVecType -> Just (n1, len1, tpA1)) @@ -520,53 +519,116 @@ mrProveRelH' _ het tp1@(asBVVecType -> Just (n1, len1, tpA1)) liftSC0 scBoolType >>= \bool_tp -> liftSC2 scVecType n1 bool_tp >>= \ix_tp -> withUVarLift "eq_ix" (Type ix_tp) (n1,(len1,(tpA1,(tpA2,(t1,t2))))) $ - \ix' (n1',(len1',(tpA1',(tpA2',(t1',t2'))))) -> - liftSC2 scGlobalApply "Prelude.is_bvult" [n1', ix', len1'] >>= \pf_tp -> - withUVarLift "eq_pf" (Type pf_tp) (n1',(len1',(tpA1',(tpA2',(ix',(t1',t2')))))) $ - \pf'' (n1'',(len1'',(tpA1'',(tpA2'',(ix'',(t1'',t2'')))))) -> - do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1'', len1'', tpA1'', - t1'', ix'', pf''] - t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1'', len1'', tpA2'', - t2'', ix'', pf''] - extTermInCtx [("eq_ix",ix_tp),("eq_pf",pf_tp)] <$> - mrProveRelH het tpA1'' tpA2'' t1_prj t2_prj - --- If our relation is heterogeneous and we have a BVVec on one side and a --- non-BVVec vector on the other, wrap the non-BVVec vector term in --- genBVVecFromVec and recurse -mrProveRelH' _ True tp1@(asBVVecType -> Just (n, len, _)) - tp2@(asNonBVVecVectorType -> Just (m, tpA2)) t1 t2 = - do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m' m + \ix (n1',(len1',(tpA1',(tpA2',(t1',t2'))))) -> + do ix_bound <- liftSC2 scGlobalApply "Prelude.bvult" [n1', ix, len1'] + pf <- liftSC2 scGlobalApply "Prelude.unsafeAssertBVULt" [n1', ix, len1'] + t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1', len1', tpA1', + t1', ix, pf] + t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1', len1', tpA2', + t2', ix, pf] + cond <- mrProveRelH het tpA1' tpA2' t1_prj t2_prj + extTermInCtx [("eq_ix",ix_tp)] <$> + liftTermInCtx2 scImplies (TermInCtx [] ix_bound) cond + +-- For non-BVVec vector types where at least one side is an application of +-- genFromBVVec, wrap both sides in genBVVecFromVec and recurse +mrProveRelH' _ het tp1@(asNonBVVecVectorType -> Just (m1, tpA1)) + tp2@(asNonBVVecVectorType -> Just (m2, tpA2)) + t1@(asGenFromBVVecTerm -> Just (n, len, _, _, _, _)) t2 = + do ms_are_eq <- mrConvertible m1 m2 if ms_are_eq then return () else throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp1' <- liftSC2 scVecType len' tpA1 tp2' <- liftSC2 scVecType len' tpA2 - err_str_tm <- liftSC1 scString "FIXME: mrProveRelH error" - err_tm <- liftSC2 scGlobalApply "Prelude.error" [tpA2, err_str_tm] - t2' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" - [m, tpA2, t2, err_tm, n, len] - -- mrDebugPPPrefixSep 2 "mrProveRelH on BVVec/Vec: " t1 "and" t2' - mrProveRelH True tp1 tp2' t1 t2' -mrProveRelH' _ True tp1@(asNonBVVecVectorType -> Just (m, tpA1)) - tp2@(asBVVecType -> Just (n, len, _)) t1 t2 = - do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m' m + t1' <- mrGenBVVecFromVec m1 tpA1 t1 "mrProveRelH (BVVec/BVVec)" n len + t2' <- mrGenBVVecFromVec m2 tpA2 t2 "mrProveRelH (BVVec/BVVec)" n len + mrProveRelH het tp1' tp2' t1' t2' +mrProveRelH' _ het tp1@(asNonBVVecVectorType -> Just (m1, tpA1)) + tp2@(asNonBVVecVectorType -> Just (m2, tpA2)) + t1 t2@(asGenFromBVVecTerm -> Just (n, len, _, _, _, _)) = + do ms_are_eq <- mrConvertible m1 m2 if ms_are_eq then return () else throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] tp1' <- liftSC2 scVecType len' tpA1 - err_str_tm <- liftSC1 scString "FIXME: mrProveRelH error" - err_tm <- liftSC2 scGlobalApply "Prelude.error" [tpA1, err_str_tm] - t1' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" - [m, tpA1, t1, err_tm, n, len] - -- mrDebugPPPrefixSep 2 "mrProveRelH on Vec/BVVec: " t1' "and" t2 - mrProveRelH True tp1' tp2 t1' t2 + tp2' <- liftSC2 scVecType len' tpA2 + t1' <- mrGenBVVecFromVec m1 tpA1 t1 "mrProveRelH (BVVec/BVVec)" n len + t2' <- mrGenBVVecFromVec m2 tpA2 t2 "mrProveRelH (BVVec/BVVec)" n len + mrProveRelH het tp1' tp2' t1' t2' + +mrProveRelH' _ True tp1 tp2 t1 t2 | Just mh <- matchHet tp1 tp2 = case mh of + + -- If our relation is heterogeneous and we have a bitvector on one side and + -- a Num on the other, ensure that the Num term is TCNum of some Nat, wrap + -- the Nat with bvNat, and recurse + HetBVNum n + | Just (primName -> "Cryptol.TCNum", [t2']) <- asCtor t2 -> + do n_tm <- liftSC1 scNat n + t2'' <- liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t2'] + mrProveRelH True tp1 tp1 t1 t2'' + | otherwise -> throwMRFailure (TermsNotEq t1 t2) + HetNumBV n + | Just (primName -> "Cryptol.TCNum", [t1']) <- asCtor t1 -> + do n_tm <- liftSC1 scNat n + t1'' <- liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t1'] + mrProveRelH True tp1 tp1 t1'' t2 + | otherwise -> throwMRFailure (TermsNotEq t1 t2) + + -- If our relation is heterogeneous and we have a BVVec on one side and a + -- non-BVVec vector on the other, wrap the non-BVVec vector term in + -- genBVVecFromVec and recurse + HetBVVecVec (n, len, _) (m, tpA2) -> + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m' m + if ms_are_eq then return () else + throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp2' <- liftSC2 scVecType len' tpA2 + t2' <- mrGenBVVecFromVec m tpA2 t2 "mrProveRelH (BVVec/Vec)" n len + -- mrDebugPPPrefixSep 2 "mrProveRelH on BVVec/Vec: " t1 "an`d" t2' + mrProveRelH True tp1 tp2' t1 t2' + HetVecBVVec (m, tpA1) (n, len, _) -> + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m' m + if ms_are_eq then return () else + throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp1' <- liftSC2 scVecType len' tpA1 + t1' <- mrGenBVVecFromVec m tpA1 t1 "mrProveRelH (Vec/BVVec)" n len + -- mrDebugPPPrefixSep 2 "mrProveRelH on Vec/BVVec: " t1' "and" t2 + mrProveRelH True tp1' tp2 t1' t2 + + -- For pair types, prove both the left and right projections are related + -- (this should be the same as the pair case below - we have to split them + -- up because otherwise GHC 9.0's pattern match checker complains...) + HetPair (tpL1, tpR1) (tpL2, tpR2) -> + do t1L <- liftSC1 scPairLeft t1 + t2L <- liftSC1 scPairLeft t2 + t1R <- liftSC1 scPairRight t1 + t2R <- liftSC1 scPairRight t2 + condL <- mrProveRelH True tpL1 tpL2 t1L t2L + condR <- mrProveRelH True tpR1 tpR2 t1R t2R + liftTermInCtx2 scAnd condL condR + +-- For pair types, prove both the left and right projections are related +-- (this should be the same as the pair case below - we have to split them +-- up because otherwise GHC 9.0's pattern match checker complains...) +mrProveRelH' _ False tp1 tp2 t1 t2 + | Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) <- matchHet tp1 tp2 = + do t1L <- liftSC1 scPairLeft t1 + t2L <- liftSC1 scPairLeft t2 + t1R <- liftSC1 scPairRight t1 + t2R <- liftSC1 scPairRight t2 + condL <- mrProveRelH False tpL1 tpL2 t1L t2L + condR <- mrProveRelH False tpR1 tpR2 t1R t2R + liftTermInCtx2 scAnd condL condR -- As a fallback, for types we can't handle, just check convertibility -mrProveRelH' _ _ tp1 tp2 t1 t2 = +mrProveRelH' _ het tp1 tp2 t1 t2 = do success <- mrConvertible t1 t2 if success then return () else - mrDebugPPPrefixSep 2 "mrProveRelH could not match types: " tp1 "and" tp2 >> - mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2 + if het then mrDebugPPPrefixSep 2 "mrProveRelH' could not match types: " tp1 "and" tp2 >> + mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2 + else mrDebugPPPrefixSep 2 "mrProveEq could not prove convertible: " t1 "and" t2 TermInCtx [] <$> liftSC1 scBool success diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 002317237c..15915e1271 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -125,7 +125,8 @@ module SAWScript.Prover.MRSolver.Solver where import Data.Maybe import Data.Either -import Data.List (findIndices, intercalate) +import Data.List (find, findIndices) +import Data.Foldable (foldlM) import Data.Bits (shiftL) import Control.Monad.Except import qualified Data.Map as Map @@ -134,7 +135,6 @@ import qualified Data.Text as Text import Prettyprinter import Verifier.SAW.Term.Functor -import Verifier.SAW.Term.CtxTerm (substTerm) import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer import Verifier.SAW.Cryptol.Monadify @@ -212,7 +212,7 @@ normComp (CompTerm t) = (>>) (mrDebugPPPrefix 3 "normCompTerm:" t) $ withFailureCtx (FailCtxMNF t) $ case asApplyAll t of - (f@(asLambda -> Just _), args) -> + (f@(asLambda -> Just _), args@(_:_)) -> mrApplyAll f args >>= normCompTerm (isGlobalDef "Prelude.returnM" -> Just (), [_, x]) -> return $ ReturnM x @@ -296,11 +296,8 @@ normComp (CompTerm t) = i)]) -> do body <- mrGlobalDefBody "CryptolM.bvVecAtM" if n < 1 `shiftL` fromIntegral w then do - n' <- mrBvConst w (toInteger n) - err_str <- liftSC1 scString "FIXME: normComp (atM) error" - err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str] - xs' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" - [n_tm, a, xs, err_tm, w_tm, n'] + n' <- liftSC2 scBvLit w (toInteger n) + xs' <- mrGenBVVecFromVec n_tm a xs "normComp (atM)" w_tm n' mrApplyAll body [w_tm, n', a, xs', i] >>= normCompTerm else throwMRFailure (MalformedComp t) @@ -323,11 +320,9 @@ normComp (CompTerm t) = i), x]) -> do body <- mrGlobalDefBody "CryptolM.fromBVVecUpdateM" if n < 1 `shiftL` fromIntegral w then do - n' <- mrBvConst w (toInteger n) - err_str <- liftSC1 scString "FIXME: normComp (updateM) error" - err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str] - xs' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" - [n_tm, a, xs, err_tm, w_tm, n'] + n' <- liftSC2 scBvLit w (toInteger n) + xs' <- mrGenBVVecFromVec n_tm a xs "normComp (updateM)" w_tm n' + err_tm <- mrErrorTerm a "normComp (updateM)" mrApplyAll body [w_tm, n', a, xs', i, x, err_tm, n_tm] >>= normCompTerm else throwMRFailure (MalformedComp t) @@ -358,10 +353,12 @@ normComp (CompTerm t) = -- FIXME: substitute for evars if they have been instantiated ((asExtCns -> Just ec), args) -> do fun_name <- extCnsToFunName ec - return $ FunBind fun_name args CompFunReturn + fun_tp <- Type <$> mrFunOutType fun_name args + return $ FunBind fun_name args (CompFunReturn fun_tp) ((asGlobalFunName -> Just f), args) -> - return $ FunBind f args CompFunReturn + do fun_tp <- Type <$> mrFunOutType f args + return $ FunBind f args (CompFunReturn fun_tp) _ -> throwMRFailure (MalformedComp t) @@ -385,34 +382,42 @@ normBind (ForallM tp f) k = return $ ForallM tp (compFunComp f k) normBind (FunBind f args k1) k2 -- Turn `bvVecMapInvarM ... >>= k` into `bvVecMapInvarBindM ... k` | GlobalName (globalDefString -> "CryptolM.bvVecMapInvarM") [] <- f - , not (isCompFunReturn (compFunComp k1 k2)) = + , (a:b:args_rest) <- args = do f' <- mrGlobalDef "CryptolM.bvVecMapInvarBindM" cont <- compFunToTerm (compFunComp k1 k2) - return $ FunBind f' (args ++ [cont]) CompFunReturn + c <- compFunReturnType k2 + return $ FunBind f' ((a:b:c:args_rest) ++ [cont]) + (CompFunReturn (Type c)) -- Turn `bvVecMapInvarBindM ... k1 >>= k2` into -- `bvVecMapInvarBindM ... (composeM ... k1 k2)` | GlobalName (globalDefString -> "CryptolM.bvVecMapInvarBindM") [] <- f - , (args_pre, [cont]) <- splitAt 8 args - , not (isCompFunReturn (compFunComp k1 k2)) = + , (args_pre, [cont]) <- splitAt 8 args = do cont' <- compFunToTerm (compFunComp (compFunComp (CompFunTerm cont) k1) k2) - return $ FunBind f (args_pre ++ [cont']) CompFunReturn + c <- compFunReturnType k2 + return $ FunBind f (args_pre ++ [cont']) (CompFunReturn (Type c)) | otherwise = return $ FunBind f args (compFunComp k1 k2) -- | Bind a 'Term' for a computation with a function and normalize normBindTerm :: Term -> CompFun -> MRM NormComp normBindTerm t f = normCompTerm t >>= \m -> normBind m f +-- | Get the return type of a 'CompFun' +compFunReturnType :: CompFun -> MRM Term +compFunReturnType (CompFunTerm t) = mrTypeOf t +compFunReturnType (CompFunComp _ g) = compFunReturnType g +compFunReturnType (CompFunReturn (Type t)) = return t + -- | Apply a computation function to a term argument to get a computation applyCompFun :: CompFun -> Term -> MRM Comp applyCompFun (CompFunComp f g) t = -- (f >=> g) t == f t >>= g do comp <- applyCompFun f t return $ CompBind comp g -applyCompFun CompFunReturn t = +applyCompFun (CompFunReturn _) t = return $ CompReturn t applyCompFun (CompFunTerm f) t = CompTerm <$> mrApplyAll f [t] --- | Convert a 'CompFun' which is not a 'CompFunReturn' into a 'Term' +-- | Convert a 'CompFun' into a 'Term' compFunToTerm :: CompFun -> MRM Term compFunToTerm (CompFunTerm t) = return t compFunToTerm (CompFunComp f g) = @@ -423,9 +428,16 @@ compFunToTerm (CompFunComp f g) = case (f_tp, g_tp) of (asPi -> Just (_, a, asCompM -> Just b), asPi -> Just (_, _, asCompM -> Just c)) -> - liftSC2 scGlobalApply "Prelude.composeM" [a, b, c, f', g'] + -- we explicitly unfold @Prelude.composeM@ here so @mrApplyAll@ will + -- beta-reduce + let nm = maybe "ret_val" id (compFunVarName f) in + mrLambdaLift [(nm, a)] (b, c, f', g') $ \[arg] (b', c', f'', g'') -> + do app <- mrApplyAll f'' [arg] + liftSC2 scGlobalApply "Prelude.bindM" [b', c', app, g''] _ -> error "compFunToTerm: type(s) not of the form: a -> CompM b" -compFunToTerm CompFunReturn = error "compFunToTerm: got a CompFunReturn" +compFunToTerm (CompFunReturn (Type a)) = + mrLambdaLift [("ret_val", a)] a $ \[ret_val] (a') -> + liftSC2 scGlobalApply "Prelude.returnM" [a', ret_val] -- | Convert a 'Comp' into a 'Term' compToTerm :: Comp -> MRM Term @@ -433,7 +445,7 @@ compToTerm (CompTerm t) = return t compToTerm (CompReturn t) = do tp <- mrTypeOf t liftSC2 scGlobalApply "Prelude.returnM" [tp, t] -compToTerm (CompBind m CompFunReturn) = compToTerm m +compToTerm (CompBind m (CompFunReturn _)) = compToTerm m compToTerm (CompBind m f) = do m' <- compToTerm m f' <- compFunToTerm f @@ -519,7 +531,7 @@ mrRefinesCoInd f1 args1 f2 args2 = -- | Prove the refinement represented by a 'CoIndHyp' coinductively. This is the -- main loop implementing 'mrRefinesCoInd'. See that function for documentation. proveCoIndHyp :: CoIndHyp -> MRM () -proveCoIndHyp hyp = +proveCoIndHyp hyp = withFailureCtx (FailCtxCoIndHyp hyp) $ do let f1 = coIndHypLHSFun hyp f2 = coIndHypRHSFun hyp args1 = coIndHypLHS hyp @@ -536,8 +548,6 @@ proveCoIndHyp hyp = -- NOTE: the state automatically gets reset here because we defined -- MRM with ExceptT at a lower level than StateT do mrDebugPPPrefixSep 1 "Widening recursive assumption for" nm1' "|=" nm2' - debugPrint 2 ("Widening indices: " ++ - intercalate ", " (map show new_vars)) hyp' <- generalizeCoIndHyp hyp new_vars proveCoIndHyp hyp' e -> throwError e @@ -547,7 +557,8 @@ proveCoIndHyp hyp = -- given arguments, otherwise throw an exception saying that widening is needed matchCoIndHyp :: CoIndHyp -> [Term] -> [Term] -> MRM () matchCoIndHyp hyp args1 args2 = - do (args1', args2') <- instantiateCoIndHyp hyp + do mrDebugPPPrefix 1 "matchCoIndHyp" hyp + (args1', args2') <- instantiateCoIndHyp hyp eqs1 <- zipWithM mrProveEq args1' args1 eqs2 <- zipWithM mrProveEq args2' args2 if and (eqs1 ++ eqs2) then return () else @@ -561,39 +572,106 @@ generalizeCoIndHyp :: CoIndHyp -> [Either Int Int] -> MRM CoIndHyp generalizeCoIndHyp hyp [] = return hyp generalizeCoIndHyp hyp all_specs@(arg_spec:arg_specs) = withOnlyUVars (coIndHypCtx hyp) $ do - mrDebugPPPrefixSep 2 "generalizeCoIndHyp" hyp "with arg specs" (show all_specs) + withNoUVars $ mrDebugPPPrefixSep 2 "generalizeCoIndHyp with indices" + all_specs "on" hyp -- Get the arg and type associated with arg_spec let arg = coIndHypArg hyp arg_spec arg_tp <- mrTypeOf arg - ctx <- mrUVarCtx - debugPretty 2 ("Current context: " <> ppCtx ctx) - -- Sort out the other args that equal arg + -- Sort out the other args that are heterogeneously related to arg eq_uneq_specs <- forM arg_specs $ \spec' -> do let arg' = coIndHypArg hyp spec' tp' <- mrTypeOf arg' - mrDebugPPPrefixSep 2 "generalizeCoIndHyp: the type of" arg' "is" tp' - tps_eq <- mrConvertible arg_tp tp' - args_eq <- if tps_eq then mrProveEq arg arg' else return False - return $ if args_eq then Left spec' else Right spec' + tps_rel <- typesHetRelated arg_tp tp' + args_rel <- if tps_rel then mrProveRel True arg arg' else return False + return $ if args_rel then Left (spec', tp') else Right spec' let (eq_specs, uneq_specs) = partitionEithers eq_uneq_specs - -- Add a new variable of type arg_tp, set all eq_specs plus our original - -- arg_spec to it, and recurse - hyp' <- generalizeCoIndHypArgs hyp arg_tp (arg_spec:eq_specs) + -- Group the eq_specs by their type, i.e. turn a list @[(Idx, Type)]@ into a + -- list @[([Idx], Type)]@, where all the indices in each pair share the same + -- type (as in 'mrConvertible') + let addArgByTp :: [([a], Term)] -> (a, Term) -> MRM [([a], Term)] + addArgByTp [] (x, tp) = return [([x], tp)] + addArgByTp ((xs, tp):xstps) (x, tp') = + do tps_eq <- mrConvertible tp' tp + if tps_eq then return ((x:xs, tp):xstps) + else ((xs, tp):) <$> addArgByTp xstps (x, tp') + eq_specs_gpd <- foldlM addArgByTp [] ((arg_spec,arg_tp):eq_specs) + -- Add a new variable, set all the indices in @eq_specs_gpd@ to it as in + -- 'generalizeCoIndHypArgs', and recurse + hyp' <- generalizeCoIndHypArgs hyp eq_specs_gpd generalizeCoIndHyp hyp' uneq_specs --- | Add a new variable of the given type to the context of a coinductive --- hypothesis and set the specified arguments to that new variable -generalizeCoIndHypArgs :: CoIndHyp -> Term -> [Either Int Int] -> MRM CoIndHyp -generalizeCoIndHypArgs (CoIndHyp ctx f1 f2 args1 args2 invar1 invar2) tp specs = - do let set_arg i args = - take i args ++ (Unshared $ LocalVar 0) : drop (i+1) args - let (specs1, specs2) = partitionEithers specs - -- NOTE: need to lift the arguments because we are adding a variable - args1' <- liftTermLike 0 1 args1 - args2' <- liftTermLike 0 1 args2 - let args1'' = foldr set_arg args1' specs1 - args2'' = foldr set_arg args2' specs2 - return $ CoIndHyp (ctx ++ [("z",tp)]) f1 f2 args1'' args2'' invar1 invar2 +-- | Assuming all the types in the given list are related by 'typesHetRelated' +-- and no two of them are convertible, add a new variable and set all of +-- indices in the given list to it, modulo possibly some wrapper functions +-- determined by how the types are heterogeneously related +generalizeCoIndHypArgs :: CoIndHyp -> [([Either Int Int], Term)] -> MRM CoIndHyp + +-- If all the arguments we need to generalize have the same type, introduce a +-- new variable and set all of the given arguments to it +generalizeCoIndHypArgs hyp [(specs, tp)] = + do (hyp', var) <- coIndHypWithVar hyp "z" (Type tp) + return $ coIndHypSetArgs hyp' specs var + +generalizeCoIndHypArgs hyp [(specs1, tp1), (specs2, tp2)] = case matchHet tp1 tp2 of + + -- If we need to generalize bitvector arguments with Num arguments, introduce + -- a bitvector variable and set all of the bitvector arguments to it and + -- all of the Num arguments to `TCNum` of `bvToNat` of it + Just (HetBVNum n) -> + do (hyp', var) <- coIndHypWithVar hyp "z" (Type tp1) + nat_tm <- liftSC2 scBvToNat n var + num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] + let hyp'' = coIndHypSetArgs hyp' specs1 var + return $ coIndHypSetArgs hyp'' specs2 num_tm + Just (HetNumBV n) -> + do (hyp', var) <- coIndHypWithVar hyp "z" (Type tp2) + nat_tm <- liftSC2 scBvToNat n var + num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] + let hyp'' = coIndHypSetArgs hyp' specs1 num_tm + return $ coIndHypSetArgs hyp'' specs2 var + + -- If we need to generalize BVVec arguments with Vec arguments, introduce a + -- BVVec variable and set all of the BVVec arguments to it and all of the + -- Vec arguments to `genBVVecFromVec` of it + -- FIXME: Could we handle the a /= a' case here and in mrRefinesFunH? + Just (HetBVVecVec (n, len, a) (m, a')) -> + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m m' + as_are_eq <- mrConvertible a a' + if ms_are_eq && as_are_eq then return () else + throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + (hyp', var) <- coIndHypWithVar hyp "z" (Type tp1) + bvv_tm <- mrGenFromBVVec n len a var "generalizeCoIndHypArgs (BVVec/Vec)" m + let hyp'' = coIndHypSetArgs hyp' specs1 var + return $ coIndHypSetArgs hyp'' specs2 bvv_tm + Just (HetVecBVVec (m, a') (n, len, a)) -> + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m m' + as_are_eq <- mrConvertible a a' + if ms_are_eq && as_are_eq then return () else + throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + (hyp', var) <- coIndHypWithVar hyp "z" (Type tp2) + bvv_tm <- mrGenFromBVVec n len a var "generalizeCoIndHypArgs (Vec/BVVec)" m + let hyp'' = coIndHypSetArgs hyp' specs1 bvv_tm + return $ coIndHypSetArgs hyp'' specs2 var + + -- This case should be unreachable because in 'mrRefinesFunH' we always + -- expand all tuples - though in principle we could handle it + Just (HetPair _ _) -> + debugPrint 0 "generalizeCoIndHypArgs: trying to widen distinct tuple types:" >> + throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + + Nothing -> throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + +generalizeCoIndHypArgs _ specs = map fst <$> mrUVars >>= \uvar_ctx -> + -- Being in this case implies we have types @tp1, tp2, tp3@ which are related + -- by 'typesHetRelated' but no two of them are convertible. As of the time of + -- writing, the only way this could be possible is if the types are pair + -- types related in different components (e.g. @(a,b), (a',b), (a,b')@). In + -- 'mrRefinesFunH' we always expand all tuples, so when we hit this function + -- no such types should remain. + error $ "generalizeCoIndHypArgs: too many distinct types to widen: " + ++ showInCtx uvar_ctx specs ---------------------------------------------------------------------- @@ -618,6 +696,8 @@ mrRefines t1 t2 = do m1 <- toNormComp t1 m2 <- toNormComp t2 mrDebugPPPrefixSep 1 "mrRefines" m1 "|=" m2 + -- ctx <- reverse . map (\(a,Type b) -> (a,b)) <$> mrUVars + -- mrDebugPPPrefix 2 "in context:" $ ppCtx ctx withFailureCtx (FailCtxRefines m1 m2) $ mrRefines' m1 m2 -- | The main implementation of 'mrRefines' @@ -766,7 +846,7 @@ mrRefines' m1@(FunBind (EVarFunName _) _ _) m2 = mrRefines' m1 m2@(FunBind (EVarFunName _) _ _) = throwMRFailure (CompsDoNotRefine m1 m2) {- -mrRefines' (FunBind (EVarFunName evar) args CompFunReturn) m2 = +mrRefines' (FunBind (EVarFunName evar) args (CompFunReturn _)) m2 = mrGetEVar evar >>= \case Just f -> (mrApplyAll f args >>= normCompTerm) >>= \m1' -> @@ -777,12 +857,13 @@ mrRefines' (FunBind (EVarFunName evar) args CompFunReturn) m2 = mrRefines' (FunBind (LetRecName f) args1 k1) (FunBind (LetRecName f') args2 k2) | f == f' && length args1 == length args2 = zipWithM_ mrAssertProveEq args1 args2 >> - mrRefinesFun k1 k2 + mrFunOutType (LetRecName f) args1 >>= \tp -> + mrRefinesFun tp k1 tp k2 mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = mrFunOutType f1 args1 >>= \tp1 -> mrFunOutType f2 args2 >>= \tp2 -> - mrConvertible tp1 tp2 >>= \tps_eq -> + typesHetRelated tp1 tp2 >>= \tps_rel -> mrFunBodyRecInfo f1 args1 >>= \maybe_f1_body -> mrFunBodyRecInfo f2 args2 >>= \maybe_f2_body -> mrGetCoIndHyp f1 f2 >>= \maybe_coIndHyp -> @@ -797,7 +878,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- * Otherwise, throw a 'CoIndHypMismatchFailure' error. (Just hyp, _) -> matchCoIndHyp hyp args1 args2 >> - mrRefinesFun k1 k2 + mrRefinesFun tp1 k1 tp2 k2 -- If we have an opaque FunAssump that f1 args1' refines f2 args2', then -- prove that args1 = args1', args2 = args2', and then that k1 refines k2 @@ -806,7 +887,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = (args1'', args2'') <- substTermLike 0 evars (args1', args2') zipWithM_ mrAssertProveEq args1'' args1 zipWithM_ mrAssertProveEq args2'' args2 - mrRefinesFun k1 k2 + mrRefinesFun tp1 k1 tp2 k2 -- If we have an opaque FunAssump that f1 refines some f /= f2, and f2 -- unfolds and is not recursive in itself, unfold f2 and recurse @@ -818,11 +899,12 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- f1 args1' refines some f args where f /= f2 and f2 does not match the -- case above, treat either case like we have a rewrite FunAssump and prove -- that args1 = args1' and then that f args refines m2 - (_, Just (FunAssump ctx args1' (funAssumpRHSAsNormComp -> rhs))) -> - do evars <- mrFreshEVars ctx - (args1'', rhs') <- substTermLike 0 evars (args1', rhs) + (_, Just (FunAssump ctx args1' rhs)) -> + do rhs' <- mrFunAssumpRHSAsNormComp rhs + evars <- mrFreshEVars ctx + (args1'', rhs'') <- substTermLike 0 evars (args1', rhs') zipWithM_ mrAssertProveEq args1'' args1 - m1' <- normBind rhs' k1 + m1' <- normBind rhs'' k1 mrRefines m1' m2 -- If f1 unfolds and is not recursive in itself, unfold it and recurse @@ -835,13 +917,13 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- If we don't have a co-inducitve hypothesis for f1 and f2, don't have an -- assumption that f1 refines some specification, and both f1 and f2 are - -- recursive and have the same return type, then try to coinductively prove - -- that f1 args1 |= f2 args2 under the assumption that f1 args1 |= f2 args2, - -- and then try to prove that k1 |= k2 - _ | tps_eq + -- recursive and have return types which are heterogeneously related, then + -- try to coinductively prove that f1 args1 |= f2 args2 under the assumption + -- that f1 args1 |= f2 args2, and then try to prove that k1 |= k2 + _ | tps_rel , Just _ <- maybe_f1_body , Just _ <- maybe_f2_body -> - mrRefinesCoInd f1 args1 f2 args2 >> mrRefinesFun k1 k2 + mrRefinesCoInd f1 args1 f2 args2 >> mrRefinesFun tp1 k1 tp2 k2 -- If we cannot line up f1 and f2, then making progress here would require us -- to somehow split either m1 or m2 into some bind m' >>= k' such that m' is @@ -857,11 +939,12 @@ mrRefines' m1@(FunBind f1 args1 k1) m2 = -- If we have an assumption that f1 args' refines some rhs, then prove that -- args1 = args' and then that rhs refines m2 - Just (FunAssump ctx args1' (funAssumpRHSAsNormComp -> rhs)) -> - do evars <- mrFreshEVars ctx - (args1'', rhs') <- substTermLike 0 evars (args1', rhs) + Just (FunAssump ctx args1' rhs) -> + do rhs' <- mrFunAssumpRHSAsNormComp rhs + evars <- mrFreshEVars ctx + (args1'', rhs'') <- substTermLike 0 evars (args1', rhs') zipWithM_ mrAssertProveEq args1'' args1 - m1' <- normBind rhs' k1 + m1' <- normBind rhs'' k1 mrRefines m1' m2 -- Otherwise, see if we can unfold f1 @@ -889,7 +972,7 @@ mrRefines' m1 m2@(FunBind f2 args2 k2) = -- proving m1 |= f2_body under the assumption that m1 |= f2 args2 {- FIXME: implement something like this Just (f2_body, True) - | CompFunReturn <- k2 -> + | CompFunReturn _ <- k2 -> withFunAssumpR m1 f2 args2 $ -} @@ -928,18 +1011,139 @@ mrRefines'' (ForallM tp f1) m2 = -- If none of the above cases match, then fail mrRefines'' m1 m2 = throwMRFailure (CompsDoNotRefine m1 m2) - -- | Prove that one function refines another for all inputs -mrRefinesFun :: CompFun -> CompFun -> MRM () -mrRefinesFun CompFunReturn CompFunReturn = return () -mrRefinesFun f1 f2 - | Just nm <- compFunVarName f1 `mplus` compFunVarName f2 - , Just tp <- compFunInputType f1 `mplus` compFunInputType f2 = - withUVarLift nm tp (f1,f2) $ \x (f1', f2') -> - do m1' <- applyNormCompFun f1' x - m2' <- applyNormCompFun f2' x - mrRefines m1' m2' -mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" +mrRefinesFun :: Term -> CompFun -> Term -> CompFun -> MRM () +mrRefinesFun tp1 f1 tp2 f2 = + do mrDebugPPPrefixSep 1 "mrRefinesFun on types:" tp1 "," tp2 + mrDebugPPPrefixSep 1 "mrRefinesFun" f1 "|=" f2 + f1' <- compFunToTerm f1 >>= liftSC1 scWhnf + f2' <- compFunToTerm f2 >>= liftSC1 scWhnf + let lnm = maybe "call_ret_val" id (compFunVarName f1) + rnm = maybe "call_ret_val" id (compFunVarName f2) + mrRefinesFunH mrRefines [] [(lnm, tp1)] f1' [(rnm, tp2)] f2' + +-- | The main loop of 'mrRefinesFun' and 'askMRSolver': given a continuation, +-- two terms of function type, and two equal-length lists representing the +-- argument types of the two terms, add a uvar for each corresponding pair of +-- types (assuming the types are either equal or are heterogeneously related, +-- as in 'HetRelated'), apply the terms to these uvars (modulo possibly some +-- wrapper functions determined by how the types are heterogeneously related), +-- and call the continuation on the resulting terms. The second argument is +-- an accumulator of variables to introduce, innermost first. +mrRefinesFunH :: (Term -> Term -> MRM a) -> [Term] -> + [(LocalName,Term)] -> Term -> [(LocalName,Term)] -> Term -> + MRM a + +mrRefinesFunH k vars ((nm1, tp1):tps1) t1 ((nm2, tp2):tps2) t2 = case matchHet tp1 tp2 of + + -- If we need to introduce a bitvector on one side and a Num on the other, + -- introduce a bitvector variable and substitute `TCNum` of `bvToNat` of that + -- variable on the Num side + Just (HetBVNum n) -> + let nm = maybe "_" id $ find ((/=) '_' . Text.head) + $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 + , asLambdaName t2 ] in + withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> + do nat_tm <- liftSC2 scBvToNat n var + num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] + tps2' <- zipWithM (\i tp -> liftTermLike 0 i num_tm >>= \num_tm' -> + substTermLike i (num_tm' : vars') tp >>= + mapM (liftSC1 scWhnf)) + [0..] tps2 + t1'' <- mrApplyAll t1' [var] + t2'' <- mrApplyAll t2' [num_tm] + mrRefinesFunH k (var : vars') tps1 t1'' tps2' t2'' + Just (HetNumBV n) -> + let nm = maybe "_" id $ find ((/=) '_' . Text.head) + $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 + , asLambdaName t2 ] in + withUVarLift nm (Type tp2) (vars, t1, t2) $ \var (vars', t1', t2') -> + do nat_tm <- liftSC2 scBvToNat n var + num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] + tps1' <- zipWithM (\i tp -> liftTermLike 0 i num_tm >>= \num_tm' -> + substTermLike i (num_tm' : vars') tp >>= + mapM (liftSC1 scWhnf)) + [0..] tps1 + t1'' <- mrApplyAll t1' [num_tm] + t2'' <- mrApplyAll t2' [var] + mrRefinesFunH k (var : vars') tps1' t1'' tps2 t2'' + + -- If we need to introduce a BVVec on one side and a non-BVVec vector on the + -- other, introduce a BVVec variable and substitute `genBVVecFromVec` of that + -- variable on the non-BVVec side + -- FIXME: Could we handle the a /= a' case here and in generalizeCoIndHypArgs? + Just (HetBVVecVec (n, len, a) (m, a')) -> + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m m' + as_are_eq <- mrConvertible a a' + if ms_are_eq && as_are_eq then return () else + throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + let nm = maybe "_" id $ find ((/=) '_' . Text.head) + $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 + , asLambdaName t2 ] + withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> + do bvv_tm <- mrGenFromBVVec n len a var "mrRefinesFunH (BVVec/Vec)" m + tps2' <- zipWithM (\i tp -> liftTermLike 0 i bvv_tm >>= \bvv_tm' -> + substTermLike i (bvv_tm' : vars') tp >>= + mapM (liftSC1 scWhnf)) + [0..] tps2 + t1'' <- mrApplyAll t1' [var] + t2'' <- mrApplyAll t2' [bvv_tm] + mrRefinesFunH k (var : vars') tps1 t1'' tps2' t2'' + Just (HetVecBVVec (m, a') (n, len, a)) -> + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m m' + as_are_eq <- mrConvertible a a' + if ms_are_eq && as_are_eq then return () else + throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + let nm = maybe "_" id $ find ((/=) '_' . Text.head) + $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 + , asLambdaName t2 ] + withUVarLift nm (Type tp2) (vars, t1, t2) $ \var (vars', t1', t2') -> + do bvv_tm <- mrGenFromBVVec n len a var "mrRefinesFunH (BVVec/Vec)" m + tps1' <- zipWithM (\i tp -> liftTermLike 0 i bvv_tm >>= \bvv_tm' -> + substTermLike i (bvv_tm' : vars') tp >>= + mapM (liftSC1 scWhnf)) + [0..] tps1 + t1'' <- mrApplyAll t1' [var] + t2'' <- mrApplyAll t2' [bvv_tm] + mrRefinesFunH k (var : vars') tps1' t1'' tps2 t2'' + + -- We always curry pair values before introducing them (NOTE: we do this even + -- when the have the same types to ensure we never have to unify a projection + -- of an evar with a non-projected value, i.e. evar.1 == val ) + Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) -> + do let tps1' = (nm1 <> "_1", tpL1):(nm1 <> "_2", tpR1):tps1 + tps2' = (nm2 <> "_1", tpL2):(nm2 <> "_2", tpR2):tps2 + t1'' <- mrLambdaLift [(nm1, tpL1), (nm1, tpR1)] t1 $ \[prj1, prj2] t1' -> + liftSC2 scPairValue prj1 prj2 >>= mrApply t1' + t2'' <- mrLambdaLift [(nm2, tpL2), (nm2, tpR2)] t2 $ \[prj1, prj2] t2' -> + liftSC2 scPairValue prj1 prj2 >>= mrApply t2' + mrRefinesFunH k vars tps1' t1'' tps2' t2'' + + -- Introduce variables of the same type together + Nothing -> + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) + let nm = maybe "_" id $ find ((/=) '_' . Text.head) + $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 + , asLambdaName t2 ] + withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> + do t1'' <- mrApplyAll t1' [var] + t2'' <- mrApplyAll t2' [var] + mrRefinesFunH k (var : vars') tps1 t1'' tps2 t2'' + +-- Error if we don't have the same number of arguments on both sides +-- FIXME: Add a specific error for this case +mrRefinesFunH _ _ ((_,tp1):_) _ [] _ = + liftSC0 scUnitType >>= \utp -> + throwMRFailure (TypesNotEq (Type tp1) (Type utp)) +mrRefinesFunH _ _ [] _ ((_,tp2):_) _ = + liftSC0 scUnitType >>= \utp -> + throwMRFailure (TypesNotEq (Type utp) (Type tp2)) + +mrRefinesFunH k _ [] t1 [] t2 = k t1 t2 ---------------------------------------------------------------------- @@ -951,131 +1155,68 @@ mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" -- a function name type MRSolverResult = Maybe (FunName, FunAssump) --- | The main loop of 'askMRSolver'. The first argument is an accumulator of --- variables to introduce, innermost first. -askMRSolverH :: [Term] -> Term -> Term -> Term -> Term -> MRM MRSolverResult - --- If we need to introduce a bitvector on one side and a Num on the other, --- introduce a bitvector variable and substitute `TCNum` of `bvToNat` of that --- variable on the Num side -askMRSolverH vars (asPi -> Just (nm1, tp@(asBitvectorType -> Just n), body1)) t1 - (asPi -> Just (nm2, asDataType -> Just (primName -> "Cryptol.Num", _), body2)) t2 = - let nm = if Text.head nm2 == '_' then nm1 else nm2 in - withUVarLift nm (Type tp) (vars, t1, t2) $ \var (vars', t1', t2') -> - do nat_tm <- liftSC2 scBvToNat n var - num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] - body2' <- substTerm 0 (num_tm : vars') body2 - t1'' <- mrApplyAll t1' [var] - t2'' <- mrApplyAll t2' [num_tm] - askMRSolverH (var : vars') body1 t1'' body2' t2'' -askMRSolverH vars (asPi -> Just (nm1, asDataType -> Just (primName -> "Cryptol.Num", _), body1)) t1 - (asPi -> Just (nm2, tp@(asBitvectorType -> Just n), body2)) t2 = - let nm = if Text.head nm2 == '_' then nm1 else nm2 in - withUVarLift nm (Type tp) (vars, t1, t2) $ \var (vars', t1', t2') -> - do nat_tm <- liftSC2 scBvToNat n var - num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] - body1' <- substTerm 0 (num_tm : vars') body1 - t1'' <- mrApplyAll t1' [num_tm] - t2'' <- mrApplyAll t2' [var] - askMRSolverH (var : vars') body1' t1'' body2 t2'' - --- If we need to introduce a BVVec on one side and a non-BVVec vector on the --- other, introduce a BVVec variable and substitute `genBVVecFromVec` of that --- variable on the non-BVVec side -askMRSolverH vars tp1@(asPi -> Just (nm1, tp@(asBVVecType -> Just (n, len, a)), body1)) t1 - tp2@(asPi -> Just (nm2, asNonBVVecVectorType -> Just (m, a'), body2)) t2 = - do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m' m - as_are_eq <- mrConvertible a a' - if ms_are_eq && as_are_eq then return () else - throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) - let nm = if Text.head nm2 == '_' then nm1 else nm2 - withUVarLift nm (Type tp) (vars, t1, t2) $ \var (vars', t1', t2') -> - do err_str_tm <- liftSC1 scString "FIXME: askMRSolverH error" - err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str_tm] - bvvec_tm <- liftSC2 scGlobalApply "Prelude.genFromBVVec" - [n, len, a, var, err_tm, m] - body2' <- substTerm 0 (bvvec_tm : vars') body2 - t1'' <- mrApplyAll t1' [var] - t2'' <- mrApplyAll t2' [bvvec_tm] - askMRSolverH (var : vars') body1 t1'' body2' t2'' -askMRSolverH vars tp1@(asPi -> Just (nm1, asNonBVVecVectorType -> Just (m, a'), body2)) t1 - tp2@(asPi -> Just (nm2, tp@(asBVVecType -> Just (n, len, a)), body1)) t2 = - do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m' m - as_are_eq <- mrConvertible a a' - if ms_are_eq && as_are_eq then return () else - throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) - let nm = if Text.head nm2 == '_' then nm1 else nm2 - withUVarLift nm (Type tp) (vars, t1, t2) $ \var (vars', t1', t2') -> - do err_str_tm <- liftSC1 scString "FIXME: askMRSolverH error" - err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str_tm] - bvvec_tm <- liftSC2 scGlobalApply "Prelude.genFromBVVec" - [n, len, a, var, err_tm, m] - body1' <- substTerm 0 (bvvec_tm : vars') body1 - t1'' <- mrApplyAll t1' [var] - t2'' <- mrApplyAll t2' [bvvec_tm] - askMRSolverH (var : vars') body1' t1'' body2 t2'' - --- Introduce variables of the same type together -askMRSolverH vars tp11@(asPi -> Just (nm1, tp1, body1)) t1 - tp22@(asPi -> Just (nm2, tp2, body2)) t2 = - do tps_are_eq <- mrConvertible tp1 tp2 - if tps_are_eq then return () else - throwMRFailure (TypesNotEq (Type tp11) (Type tp22)) - let nm = if Text.head nm2 == '_' then nm1 else nm2 - withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> - do t1'' <- mrApplyAll t1' [var] - t2'' <- mrApplyAll t2' [var] - askMRSolverH (var : vars') body1 t1'' body2 t2'' - --- Error if we don't have the same number of arguments on both sides -askMRSolverH _ tp1@(asPi -> Just _) _ tp2 _ = - throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) -askMRSolverH _ tp1 _ tp2@(asPi -> Just _) _ = - throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) - --- The base case: both sides are CompM of the same type -askMRSolverH _ (asCompM -> Just _) t1 (asCompM -> Just _) t2 = - do m1 <- normCompTerm t1 +-- | The continuation passed to 'mrRefinesFunH' in 'askMRSolver' and +-- 'assumeMRSolver': normalizes both resulting terms using 'normCompTerm', +-- calls the given monadic function, then returns a 'FunAssump', if possible +askMRSolverH :: (NormComp -> NormComp -> MRM ()) -> + Term -> Term -> MRM MRSolverResult +askMRSolverH f t1 t2 = + do m1 <- normCompTerm t1 m2 <- normCompTerm t2 - mrRefines m1 m2 + f m1 m2 case (m1, m2) of -- If t1 and t2 are both named functions, our result is the opaque -- FunAssump that forall xs. f1 xs |= f2 xs' - (FunBind f1 args1 CompFunReturn, FunBind f2 args2 CompFunReturn) -> + (FunBind f1 args1 (CompFunReturn _), FunBind f2 args2 (CompFunReturn _)) -> mrUVarCtx >>= \uvar_ctx -> return $ Just (f1, FunAssump { fassumpCtx = uvar_ctx, fassumpArgs = args1, fassumpRHS = OpaqueFunAssump f2 args2 }) -- If just t1 is a named function, our result is the rewrite FunAssump -- that forall xs. f1 xs |= m2 - (FunBind f1 args1 CompFunReturn, _) -> + (FunBind f1 args1 (CompFunReturn _), _) -> mrUVarCtx >>= \uvar_ctx -> return $ Just (f1, FunAssump { fassumpCtx = uvar_ctx, fassumpArgs = args1, fassumpRHS = RewriteFunAssump m2 }) _ -> return Nothing --- Error if we don't have CompM at the end -askMRSolverH _ (asCompM -> Just _) _ tp2 _ = - throwMRFailure (NotCompFunType tp2) -askMRSolverH _ tp1 _ _ _ = - throwMRFailure (NotCompFunType tp1) - - -- | Test two monadic, recursive terms for refinement. On success, if the --- left-hand term is a named function, add the refinement to the 'MREnv' --- environment. +-- left-hand term is a named function, returning a 'FunAssump' to add to the +-- 'MREnv'. askMRSolver :: SharedContext -> MREnv {- ^ The Mr Solver environment -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> Term -> Term -> IO (Either MRFailure MRSolverResult) - askMRSolver sc env timeout t1 t2 = do tp1 <- scTypeOf sc t1 >>= scWhnf sc tp2 <- scTypeOf sc t2 >>= scWhnf sc runMRM sc timeout env $ mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 >> - askMRSolverH [] tp1 t1 tp2 t2 + case (asPiList tp1, asPiList tp2) of + ((tps1, asCompM -> Just _), (tps2, asCompM -> Just _)) -> + mrRefinesFunH (askMRSolverH mrRefines) [] tps1 t1 tps2 t2 + ((_, asCompM -> Just _), (_, tp2')) -> + throwMRFailure (NotCompFunType tp2') + ((_, tp1'), _) -> + throwMRFailure (NotCompFunType tp1') + +-- | Return the 'FunAssump' to add to the 'MREnv' that would be generated if +-- 'askMRSolver' succeeded on the given terms. +assumeMRSolver :: + SharedContext -> + MREnv {- ^ The Mr Solver environment -} -> + Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + Term -> Term -> IO (Either MRFailure MRSolverResult) +assumeMRSolver sc env timeout t1 t2 = + do tp1 <- scTypeOf sc t1 >>= scWhnf sc + tp2 <- scTypeOf sc t2 >>= scWhnf sc + runMRM sc timeout env $ + case (asPiList tp1, asPiList tp2) of + ((tps1, asCompM -> Just _), (tps2, asCompM -> Just _)) -> + mrRefinesFunH (askMRSolverH (\_ _ -> return ())) [] tps1 t1 tps2 t2 + ((_, asCompM -> Just _), (_, tp2')) -> + throwMRFailure (NotCompFunType tp2') + ((_, tp1'), _) -> + throwMRFailure (NotCompFunType tp1') diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index ef093df317..5ce92c0cb6 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -131,21 +131,21 @@ data CompFun -- | An arbitrary term = CompFunTerm Term -- | A special case for the term @\ (x:a) -> returnM a x@ - | CompFunReturn + | CompFunReturn Type -- | The monadic composition @f >=> g@ | CompFunComp CompFun CompFun deriving (Generic, Show) -- | Compose two 'CompFun's, simplifying if one is a 'CompFunReturn' compFunComp :: CompFun -> CompFun -> CompFun -compFunComp CompFunReturn f = f -compFunComp f CompFunReturn = f +compFunComp (CompFunReturn _) f = f +compFunComp f (CompFunReturn _) = f compFunComp f g = CompFunComp f g -- | If a 'CompFun' contains an explicit lambda-abstraction, then return the -- textual name bound by that lambda compFunVarName :: CompFun -> Maybe LocalName -compFunVarName (CompFunTerm (asLambda -> Just (nm, _, _))) = Just nm +compFunVarName (CompFunTerm t) = asLambdaName t compFunVarName (CompFunComp f _) = compFunVarName f compFunVarName _ = Nothing @@ -154,13 +154,9 @@ compFunVarName _ = Nothing compFunInputType :: CompFun -> Maybe Type compFunInputType (CompFunTerm (asLambda -> Just (_, tp, _))) = Just $ Type tp compFunInputType (CompFunComp f _) = compFunInputType f +compFunInputType (CompFunReturn t) = Just t compFunInputType _ = Nothing --- | Returns true iff the given 'CompFun' is 'CompFunReturn' -isCompFunReturn :: CompFun -> Bool -isCompFunReturn CompFunReturn = True -isCompFunReturn _ = False - -- | A computation of type @CompM a@ for some @a@ data Comp = CompTerm Term | CompBind Comp CompFun | CompReturn Term deriving (Generic, Show) @@ -223,6 +219,11 @@ asNonBVVecVectorType :: Recognizer Term (Term, Term) asNonBVVecVectorType (asBVVecType -> Just _) = Nothing asNonBVVecVectorType t = asVectorType t +-- | Like 'asLambda', but only return's the lambda-bound variable's 'LocalName' +asLambdaName :: Recognizer Term LocalName +asLambdaName (asLambda -> Just (nm, _, _)) = Just nm +asLambdaName _ = Nothing + ---------------------------------------------------------------------- -- * Mr Solver Environments @@ -233,11 +234,6 @@ asNonBVVecVectorType t = asVectorType t data FunAssumpRHS = OpaqueFunAssump FunName [Term] | RewriteFunAssump NormComp --- | Convert a 'FunAssumpRHS' to a 'NormComp' -funAssumpRHSAsNormComp :: FunAssumpRHS -> NormComp -funAssumpRHSAsNormComp (OpaqueFunAssump f args) = FunBind f args CompFunReturn -funAssumpRHSAsNormComp (RewriteFunAssump rhs) = rhs - -- | An assumption that a named function refines some specification. This has -- the form -- @@ -384,6 +380,9 @@ instance TermLike Term where instance TermLike FunName where liftTermLike _ _ = return substTermLike _ _ = return +instance TermLike LocalName where + liftTermLike _ _ = return + substTermLike _ _ = return deriving instance TermLike Type deriving instance TermLike NormComp @@ -426,16 +425,15 @@ prettyTermApp f_top args = -- | FIXME: move this helper function somewhere better... ppCtx :: [(LocalName,Term)] -> SawDoc -ppCtx = helper [] where - helper :: [LocalName] -> [(LocalName,Term)] -> SawDoc - helper _ [] = "" +ppCtx = align . sep . helper [] where + helper :: [LocalName] -> [(LocalName,Term)] -> [SawDoc] + helper _ [] = [] + helper ns [(n,tp)] = + [ppTermInCtx defaultPPOpts (n:ns) (Unshared $ LocalVar 0) <> ":" <> + ppTermInCtx defaultPPOpts ns tp] helper ns ((n,tp):ctx) = - let ns' = n:ns in - ppTermInCtx defaultPPOpts ns' (Unshared $ LocalVar 0) <> ":" <> - ppTermInCtx defaultPPOpts ns tp <> ", " <> helper ns' ctx - -instance PrettyInCtx String where - prettyInCtx str = return $ fromString str + (ppTermInCtx defaultPPOpts (n:ns) (Unshared $ LocalVar 0) <> ":" <> + ppTermInCtx defaultPPOpts ns tp <> ",") : (helper (n:ns) ctx) instance PrettyInCtx SawDoc where prettyInCtx pp = return pp @@ -446,13 +444,23 @@ instance PrettyInCtx Type where instance PrettyInCtx MRVar where prettyInCtx (MRVar ec) = return $ ppName $ ecName ec -instance PrettyInCtx [Term] where +instance PrettyInCtx a => PrettyInCtx [a] where prettyInCtx xs = list <$> mapM prettyInCtx xs +instance {-# OVERLAPPING #-} PrettyInCtx String where + prettyInCtx str = return $ fromString str + +instance PrettyInCtx Int where + prettyInCtx i = return $ viaShow i + instance PrettyInCtx a => PrettyInCtx (Maybe a) where prettyInCtx (Just x) = (<+>) "Just" <$> prettyInCtx x prettyInCtx Nothing = return "Nothing" +instance (PrettyInCtx a, PrettyInCtx b) => PrettyInCtx (Either a b) where + prettyInCtx (Left a) = (<+>) "Left" <$> prettyInCtx a + prettyInCtx (Right b) = (<+>) "Right" <$> prettyInCtx b + instance (PrettyInCtx a, PrettyInCtx b) => PrettyInCtx (a,b) where prettyInCtx (x, y) = (\x' y' -> parens (x' <> "," <> y')) <$> prettyInCtx x <*> prettyInCtx y @@ -478,7 +486,8 @@ instance PrettyInCtx Comp where instance PrettyInCtx CompFun where prettyInCtx (CompFunTerm t) = prettyInCtx t - prettyInCtx CompFunReturn = return "returnM" + prettyInCtx (CompFunReturn t) = + prettyAppList [return "returnM", parens <$> prettyInCtx t] prettyInCtx (CompFunComp f g) = prettyAppList [prettyInCtx f, return ">=>", prettyInCtx g] @@ -515,7 +524,7 @@ instance PrettyInCtx NormComp where prettyInCtx (ForallM tp f) = prettyAppList [return "forallM", prettyInCtx tp, return "_", parens <$> prettyInCtx f] - prettyInCtx (FunBind f args CompFunReturn) = + prettyInCtx (FunBind f args (CompFunReturn _)) = prettyTermApp (funNameTerm f) args prettyInCtx (FunBind f [] k) = prettyAppList [prettyInCtx f, return ">>=", prettyInCtx k]