diff --git a/cryptol-remote-api/cryptol-eval-server/Main.hs b/cryptol-remote-api/cryptol-eval-server/Main.hs index d11a732a9..037e0667e 100644 --- a/cryptol-remote-api/cryptol-eval-server/Main.hs +++ b/cryptol-remote-api/cryptol-eval-server/Main.hs @@ -18,6 +18,7 @@ import Options.Applicative strOption, value ) +import CryptolServer.Check ( check, checkDescr ) import CryptolServer.ClearState ( clearState, clearStateDescr, clearAllStates, clearAllStatesDescr) import Cryptol.Eval (EvalOpts(..), defaultPPOpts) @@ -120,6 +121,10 @@ initMod = StartingFile <$> (Left <$> filename <|> Right . textToModName <$> modu cryptolEvalMethods :: [AppMethod ServerState] cryptolEvalMethods = [ command + "check" + checkDescr + check + , command "focused module" focusedModuleDescr focusedModule diff --git a/cryptol-remote-api/cryptol-remote-api.cabal b/cryptol-remote-api/cryptol-remote-api.cabal index b6cb9ce9c..7ee65548b 100644 --- a/cryptol-remote-api/cryptol-remote-api.cabal +++ b/cryptol-remote-api/cryptol-remote-api.cabal @@ -47,6 +47,7 @@ common deps mtl ^>= 2.2, scientific ^>= 0.3, text ^>= 1.2.3, + tf-random, unordered-containers ^>= 0.2, vector ^>= 0.12, @@ -59,6 +60,7 @@ library exposed-modules: CryptolServer CryptolServer.Call + CryptolServer.Check CryptolServer.ClearState CryptolServer.Data.Expression CryptolServer.Data.Type diff --git a/cryptol-remote-api/cryptol-remote-api/Main.hs b/cryptol-remote-api/cryptol-remote-api/Main.hs index 9f9dd7bb8..c681f67bf 100644 --- a/cryptol-remote-api/cryptol-remote-api/Main.hs +++ b/cryptol-remote-api/cryptol-remote-api/Main.hs @@ -15,6 +15,7 @@ import qualified Argo.Doc as Doc import CryptolServer ( command, notification, initialState, extendSearchPath, ServerState ) import CryptolServer.Call ( call, callDescr ) +import CryptolServer.Check ( check, checkDescr ) import CryptolServer.ClearState ( clearState, clearStateDescr, clearAllStates, clearAllStatesDescr ) import CryptolServer.EvalExpr @@ -58,7 +59,11 @@ getSearchPaths = cryptolMethods :: [AppMethod ServerState] cryptolMethods = - [ notification + [ command + "check" + checkDescr + check + , notification "clear state" clearStateDescr clearState diff --git a/cryptol-remote-api/python/cryptol/__init__.py b/cryptol-remote-api/python/cryptol/__init__.py index c993eb4d2..3a2ceecb6 100644 --- a/cryptol-remote-api/python/cryptol/__init__.py +++ b/cryptol-remote-api/python/cryptol/__init__.py @@ -4,11 +4,10 @@ import base64 import os -import types -import sys +from dataclasses import dataclass from distutils.spawn import find_executable -from typing import Any, Dict, Iterable, List, NoReturn, Optional, Union, Callable -from mypy_extensions import TypedDict +from typing import Any, List, NoReturn, Optional, Union +from typing_extensions import Literal import argo_client.interaction as argo from argo_client.interaction import HasProtocolState @@ -126,6 +125,54 @@ def __init__(self, connection : HasProtocolState, fun : str, args : List[Any]) - def process_result(self, res : Any) -> Any: return from_cryptol_arg(res['value']) + +@dataclass +class CheckReport: + """Class for describing ``check`` test results.""" + success: bool + args: List[Any] + error_msg: Optional[str] + tests_run: int + tests_possible: Optional[int] + +class CryptolCheck(argo.Command): + def __init__(self, connection : HasProtocolState, expr : Any, num_tests : Union[Literal['all'],int, None]) -> None: + if num_tests: + args = {'expression': expr, 'number of tests':num_tests} + else: + args = {'expression': expr} + super(CryptolCheck, self).__init__( + 'check', + args, + connection + ) + + def process_result(self, res : Any) -> 'CheckReport': + if res['result'] == 'pass': + return CheckReport( + success=True, + args=[], + error_msg = None, + tests_run=res['tests run'], + tests_possible=res['tests possible']) + elif res['result'] == 'fail': + return CheckReport( + success=False, + args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']], + error_msg = None, + tests_run=res['tests run'], + tests_possible=res['tests possible']) + elif res['result'] == 'error': + return CheckReport( + success=False, + args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']], + error_msg = res['error message'], + tests_run=res['tests run'], + tests_possible=res['tests possible']) + else: + raise ValueError("Unknown check result " + str(res)) + + class CryptolCheckType(argo.Command): def __init__(self, connection : HasProtocolState, expr : Any) -> None: super(CryptolCheckType, self).__init__( @@ -369,6 +416,21 @@ def call(self, fun : str, *args : List[Any]) -> argo.Command: self.most_recent_result = CryptolCall(self, fun, encoded_args) return self.most_recent_result + def check(self, expr : Any, *, num_tests : Union[Literal['all'], int, None] = None) -> argo.Command: + """Tests the validity of a Cryptol expression with random inputs. The expression must be a function with + return type ``Bit``. + + If ``num_tests`` is ``"all"`` then the expression is tested exhaustively (i.e., against all possible inputs). + + If ``num_tests`` is omitted, Cryptol defaults to running 100 tests. + """ + if num_tests == "all" or isinstance(num_tests, int) or num_tests is None: + self.most_recent_result = CryptolCheck(self, expr, num_tests) + return self.most_recent_result + else: + raise ValueError('``num_tests`` must be an integer, ``None``, or the string literall ``"all"``') + + def check_type(self, code : Any) -> argo.Command: """Check the type of a Cryptol expression, represented according to :ref:`cryptol-json-expression`, with Python datatypes standing for @@ -406,7 +468,7 @@ def focused_module(self) -> argo.Command: def reset(self) -> None: """Resets the connection, causing its unique state on the server to be freed (if applicable). - + After a reset a connection may be treated as if it were a fresh connection with the server if desired.""" CryptolReset(self) self.most_recent_result = None diff --git a/cryptol-remote-api/python/tests/cryptol/test-files/examples b/cryptol-remote-api/python/tests/cryptol/test-files/examples new file mode 120000 index 000000000..288bcb059 --- /dev/null +++ b/cryptol-remote-api/python/tests/cryptol/test-files/examples @@ -0,0 +1 @@ +../../../../../examples \ No newline at end of file diff --git a/cryptol-remote-api/python/tests/cryptol/test_AES.py b/cryptol-remote-api/python/tests/cryptol/test_AES.py new file mode 100644 index 000000000..67e9cd9f3 --- /dev/null +++ b/cryptol-remote-api/python/tests/cryptol/test_AES.py @@ -0,0 +1,39 @@ +import unittest +from pathlib import Path +import unittest +import cryptol +from cryptol.bitvector import BV + + +class TestAES(unittest.TestCase): + def test_AES(self): + c = cryptol.connect() + c.load_file(str(Path('tests','cryptol','test-files', 'examples','AES.cry'))) + + pt = BV(size=128, value=0x3243f6a8885a308d313198a2e0370734) + key = BV(size=128, value=0x2b7e151628aed2a6abf7158809cf4f3c) + ct = c.call("aesEncrypt", (pt, key)).result() + expected_ct = BV(size=128, value=0x3925841d02dc09fbdc118597196a0b32) + self.assertEqual(ct, expected_ct) + + decrypted_ct = c.call("aesDecrypt", (ct, key)).result() + self.assertEqual(pt, decrypted_ct) + + pt = BV(size=128, value=0x00112233445566778899aabbccddeeff) + key = BV(size=128, value=0x000102030405060708090a0b0c0d0e0f) + ct = c.call("aesEncrypt", (pt, key)).result() + expected_ct = BV(size=128, value=0x69c4e0d86a7b0430d8cdb78070b4c55a) + self.assertEqual(ct, expected_ct) + + decrypted_ct = c.call("aesDecrypt", (ct, key)).result() + self.assertEqual(pt, decrypted_ct) + + # c.safe("aesEncrypt") + # c.safe("aesDecrypt") + self.assertTrue(c.check("AESCorrect").result().success) + # c.prove("AESCorrect") # probably takes too long for this script...? + + + +if __name__ == "__main__": + unittest.main() diff --git a/cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py b/cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py index 3e6a95b43..c51500325 100644 --- a/cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py +++ b/cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py @@ -82,6 +82,62 @@ def test_sat(self): # check for a valid condition self.assertTrue(c.prove('\\x -> isSqrtOf9 x ==> elem x [3,131,125,253]').result()) + def test_check(self): + c = self.c + res = c.check("\\x -> x==(x:[8])").result() + self.assertTrue(res.success) + self.assertEqual(res.tests_run, 100) + self.assertEqual(res.tests_possible, 256) + self.assertFalse(len(res.args), 0) + self.assertEqual(res.error_msg, None) + + res = c.check("\\x -> x==(x:[8])", num_tests=1).result() + self.assertTrue(res.success) + self.assertEqual(res.tests_run, 1) + self.assertEqual(res.tests_possible, 256) + self.assertEqual(len(res.args), 0) + self.assertEqual(res.error_msg, None) + + res = c.check("\\x -> x==(x:[8])", num_tests=42).result() + self.assertTrue(res.success) + self.assertEqual(res.tests_run, 42) + self.assertEqual(res.tests_possible, 256) + self.assertEqual(len(res.args), 0) + self.assertEqual(res.error_msg, None) + + res = c.check("\\x -> x==(x:[8])", num_tests=1000).result() + self.assertTrue(res.success) + self.assertEqual(res.tests_run, 256) + self.assertEqual(res.tests_possible, 256) + self.assertEqual(len(res.args), 0) + self.assertEqual(res.error_msg, None) + + res = c.check("\\x -> x==(x:[8])", num_tests='all').result() + self.assertTrue(res.success) + self.assertEqual(res.tests_run, 256) + self.assertEqual(res.tests_possible, 256) + self.assertEqual(len(res.args), 0) + self.assertEqual(res.error_msg, None) + + res = c.check("\\x -> x==(x:Integer)", num_tests=1024).result() + self.assertTrue(res.success) + self.assertEqual(res.tests_run, 1024) + self.assertEqual(res.tests_possible, None) + self.assertEqual(len(res.args), 0) + self.assertEqual(res.error_msg, None) + + res = c.check("\\x -> (x + 1)==(x:[8])").result() + self.assertFalse(res.success) + self.assertEqual(res.tests_possible, 256) + self.assertEqual(len(res.args), 1) + self.assertEqual(res.error_msg, None) + + res = c.check("\\x -> (x / 0)==(x:[8])").result() + self.assertFalse(res.success) + self.assertEqual(res.tests_possible, 256) + self.assertEqual(len(res.args), 1) + self.assertIsInstance(res.error_msg, str) + def test_many_usages_one_connection(self): c = self.c for i in range(0,100): diff --git a/cryptol-remote-api/python/tests/cryptol_eval/test-files/M.cry b/cryptol-remote-api/python/tests/cryptol_eval/test-files/M.cry index fa72c30c9..81e394b39 100644 --- a/cryptol-remote-api/python/tests/cryptol_eval/test-files/M.cry +++ b/cryptol-remote-api/python/tests/cryptol_eval/test-files/M.cry @@ -26,3 +26,6 @@ id x = x type Word8 = [8] type Twenty a = [20]a + +isSqrtOf9 : [8] -> Bit +isSqrtOf9 x = x*x == 9 diff --git a/cryptol-remote-api/python/tests/cryptol_eval/test_basics.py b/cryptol-remote-api/python/tests/cryptol_eval/test_basics.py index f7af7f16b..a32932443 100644 --- a/cryptol-remote-api/python/tests/cryptol_eval/test_basics.py +++ b/cryptol-remote-api/python/tests/cryptol_eval/test_basics.py @@ -28,6 +28,21 @@ def test_evaluation(self): res = c.call('f', BV(size=8,value=0xff)).result() self.assertEqual(res, [BV(size=8,value=0xff), BV(size=8,value=0xff)]) + + # more thorough testing of backend functionality found in standard server's tests + def test_sat(self): + c = self.c + # test a single sat model can be returned + rootsOf9 = c.sat('isSqrtOf9').result() + self.assertEqual(len(rootsOf9), 1) + self.assertTrue(int(rootsOf9[0]) ** 2 % 256, 9) + + # more thorough testing of backend functionality found in standard server's tests + def test_check(self): + c = self.c + res = c.check("\\x -> x==(x:[8])").result() + self.assertTrue(res.success) + # def test_disallowed_ops(self): # pass # TODO/FIXME diff --git a/cryptol-remote-api/src/CryptolServer/Check.hs b/cryptol-remote-api/src/CryptolServer/Check.hs new file mode 100644 index 000000000..794956f14 --- /dev/null +++ b/cryptol-remote-api/src/CryptolServer/Check.hs @@ -0,0 +1,184 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ViewPatterns #-} +module CryptolServer.Check + ( check + , checkDescr + , CheckParams(..) + , CheckResult(..) + ) + where + +import qualified Argo.Doc as Doc +import Control.Applicative +import Control.Monad.IO.Class +import Data.Aeson ((.=), (.:), (.:?), (.!=), FromJSON, ToJSON) +import qualified Data.Aeson as JSON +import Data.Scientific (floatingOrInteger) +import Data.Text (Text) +import qualified Data.Text as T +import System.Random.TF(newTFGen) + +import qualified Cryptol.Eval.Concrete as CEC +import qualified Cryptol.Eval.Env as CEE +import qualified Cryptol.Eval.Type as CET +import qualified Cryptol.ModuleSystem as CM +import Cryptol.ModuleSystem.Env (DynamicEnv(..), meDynEnv) +import qualified Cryptol.Testing.Random as R +import qualified Cryptol.TypeCheck.AST as AST +import Cryptol.TypeCheck.Subst (apSubst, listParamSubst) +import Cryptol.TypeCheck.Solve (defaultReplExpr) + + +import CryptolServer + ( getTCSolver, + getModuleEnv, + runModuleCmd, + CryptolMethod(raise), + CryptolCommand ) +import CryptolServer.Exceptions (evalPolyErr) +import CryptolServer.Data.Expression + ( readBack, observe, getExpr, Expression ) +import CryptolServer.Data.Type +import Cryptol.Utils.PP (pretty) + +checkDescr :: Doc.Block +checkDescr = + Doc.Paragraph + [ Doc.Text "Tests a property against random values to give quick feedback."] + +-- | Check a property a la quickcheck (see `:check` at the Cryptol REPL) +check :: CheckParams -> CryptolCommand CheckResult +check (CheckParams jsonExpr cMode) = + do e <- getExpr jsonExpr + (_expr, ty, schema) <- runModuleCmd (CM.checkExpr e) + -- TODO? validEvalContext expr, ty, schema + s <- getTCSolver + perhapsDef <- liftIO (defaultReplExpr s ty schema) + case perhapsDef of + Nothing -> raise (evalPolyErr schema) + Just (tys, checked) -> do + (val,tyv) <- do -- TODO: warnDefaults here + let su = listParamSubst tys + let theType = apSubst su (AST.sType schema) + tenv <- CEE.envTypes . deEnv . meDynEnv <$> getModuleEnv + let tval = CET.evalValType tenv theType + val <- runModuleCmd (CM.evalExpr checked) + pure (val,tval) + let (isExaustive, randomTestNum) = case cMode of + ExhaustiveCheckMode -> (True, 0) + RandomCheckMode n -> (False, n) + case R.testableType tyv of + Just (Just sz, argTys, vss, _gens) | isExaustive || sz <= randomTestNum -> do + -- TODO? catch interruptions in testing like `qcExpr` in REPL + (res,num) <- R.exhaustiveTests (const $ pure ()) val vss + args <- convertTestResult argTys res + pure $ CheckResult args num (Just sz) + Just (sz,argTys,_,gens) | isExaustive ==False -> do + g <- liftIO $ newTFGen + (res,num) <- R.randomTests (const $ pure ()) randomTestNum val gens g + args <- convertTestResult argTys res + return $ CheckResult args num sz + _ -> error $ "type is not testable: " ++ (pretty ty) + +convertTestArg :: (CET.TValue, CEC.Value) -> CryptolCommand (JSONType, Expression) +convertTestArg (t, v) = do + e <- observe $ readBack t v + return (JSONType mempty (CET.tValTy t), e) + +convertTestResult :: + [CET.TValue] {- ^ Argument types -} -> + R.TestResult {- ^ Result to convert -} -> + CryptolCommand ServerTestResult +convertTestResult _ R.Pass = pure Pass +convertTestResult ts (R.FailFalse vals) = do + args <- mapM convertTestArg $ zip ts vals + pure $ FailFalse args +convertTestResult ts (R.FailError exn vals) = do + args <- mapM convertTestArg $ zip ts vals + pure $ FailError (T.pack (pretty exn)) args + + + + +data ServerTestResult + = Pass + | FailFalse [(JSONType, Expression)] + | FailError Text [(JSONType, Expression)] + +data CheckResult = + CheckResult + { checkTestResult :: ServerTestResult + , checkTestsRun :: Integer + , checkTestsPossible :: Maybe Integer + } + +convertServerTestResult :: ServerTestResult -> [(Text, JSON.Value)] +convertServerTestResult Pass = ["result" .= ("pass" :: Text)] +convertServerTestResult (FailFalse args) = + [ "result" .= ("fail" :: Text) + , "arguments" .= + [ JSON.object [ "type" .= t, "expr" .= e] | (t, e) <- args ] + ] +convertServerTestResult (FailError err args) = + [ "result" .= ("error" :: Text) + , "error message" .= (pretty err) + , "arguments" .= + [ JSON.object [ "type" .= t, "expr" .= e] | (t, e) <- args ] + ] + + +instance ToJSON CheckResult where + toJSON res = JSON.object $ [ "tests run" .= (checkTestsRun res) + , "tests possible" .= (checkTestsPossible res) + ] ++ (convertServerTestResult (checkTestResult res)) + +data CheckMode + = ExhaustiveCheckMode + | RandomCheckMode Integer + deriving (Eq, Show) + +data CheckParams = + CheckParams + { checkExpression :: Expression + , checkMode :: CheckMode + } + + +instance FromJSON CheckParams where + parseJSON = + JSON.withObject "check parameters" $ + \o -> + do e <- o .: "expression" + m <- (o .:? "number of tests" >>= num) .!= (RandomCheckMode 100) + pure CheckParams {checkExpression = e, checkMode = m} + where + num (Just v) = + ((JSON.withText "all" $ + \t -> if t == "all" then pure $ Just ExhaustiveCheckMode else empty) v) + <|> + ((JSON.withScientific "number of tests" $ + \s -> + case floatingOrInteger s of + Left (_float :: Double) -> empty + Right n -> pure $ Just $ RandomCheckMode $ (toInteger :: Int -> Integer) n) v) + num Nothing = pure Nothing + +instance Doc.DescribedParams CheckParams where + parameterFieldDescription = + [ ("expression", + Doc.Paragraph [Doc.Text "The predicate (i.e., function) to check; " + , Doc.Text "must be a monomorphic function with return type Bit." ]) + , ("number of tests", + Doc.Paragraph [Doc.Text "The number of random inputs to test the property with, or " + , Doc.Literal "all" + , Doc.Text " to exhaustively check the property (defaults to " + , Doc.Literal "100" + , Doc.Text " if not provided). If " + , Doc.Literal "all" + , Doc.Text " is specified and the property's argument types are not" + , Doc.Text " sufficiently small, checking may take longer than you are willing to wait!" + ]) + ] diff --git a/cryptol-remote-api/src/CryptolServer/Sat.hs b/cryptol-remote-api/src/CryptolServer/Sat.hs index a8174f1a2..029e1171d 100644 --- a/cryptol-remote-api/src/CryptolServer/Sat.hs +++ b/cryptol-remote-api/src/CryptolServer/Sat.hs @@ -175,7 +175,7 @@ instance Doc.DescribedParams ProveSatParams where ++ [Doc.Text "."])) , ("expression", Doc.Paragraph [Doc.Text "The predicate (i.e., function) to check for satisfiability; " - , Doc.Text "must be a monomorphic function type with return type Bit." ]) + , Doc.Text "must be a monomorphic function with return type Bit." ]) , ("result count", Doc.Paragraph [Doc.Text "How many satisfying results to search for; either a positive integer or " , Doc.Literal "all", Doc.Text"."]) diff --git a/cryptol-remote-api/test_docker.sh b/cryptol-remote-api/test_docker.sh index 770631e56..5106bd628 100755 --- a/cryptol-remote-api/test_docker.sh +++ b/cryptol-remote-api/test_docker.sh @@ -6,6 +6,9 @@ TAG=${1:-cryptol-remote-api} pushd $DIR +rm $PWD/python/tests/cryptol/test-files/examples +mv $PWD/../examples $PWD/python/tests/cryptol/test-files/ + docker run --name=cryptol-remote-api -d \ -v $PWD/python/tests/cryptol/test-files:/home/cryptol/tests/cryptol/test-files \ -p 8080:8080 \