diff --git a/cryptol-remote-api/python/cryptol/__init__.py b/cryptol-remote-api/python/cryptol/__init__.py index 3a2ceecb6..b07c2d045 100644 --- a/cryptol-remote-api/python/cryptol/__init__.py +++ b/cryptol-remote-api/python/cryptol/__init__.py @@ -4,6 +4,7 @@ import base64 import os +from enum import Enum from dataclasses import dataclass from distutils.spawn import find_executable from typing import Any, List, NoReturn, Optional, Union @@ -184,8 +185,13 @@ def __init__(self, connection : HasProtocolState, expr : Any) -> None: def process_result(self, res : Any) -> Any: return res['type schema'] +class SmtQueryType(str, Enum): + PROVE = 'prove' + SAFE = 'safe' + SAT = 'sat' + class CryptolProveSat(argo.Command): - def __init__(self, connection : HasProtocolState, qtype : str, expr : Any, solver : solver.Solver, count : Optional[int]) -> None: + def __init__(self, connection : HasProtocolState, qtype : SmtQueryType, expr : Any, solver : solver.Solver, count : Optional[int]) -> None: super(CryptolProveSat, self).__init__( 'prove or satisfy', {'query type': qtype, @@ -198,12 +204,12 @@ def __init__(self, connection : HasProtocolState, qtype : str, expr : Any, solve def process_result(self, res : Any) -> Any: if res['result'] == 'unsatisfiable': - if self.qtype == 'sat': + if self.qtype == SmtQueryType.SAT: return False - elif self.qtype == 'prove': + elif self.qtype == SmtQueryType.PROVE or self.qtype == SmtQueryType.SAFE: return True else: - raise ValueError("Unknown prove/sat query type: " + self.qtype) + raise ValueError("Unknown SMT query type: " + self.qtype) elif res['result'] == 'invalid': return [from_cryptol_arg(arg['expr']) for arg in res['counterexample']] @@ -212,15 +218,19 @@ def process_result(self, res : Any) -> Any: for m in res['models'] for arg in m] else: - raise ValueError("Unknown sat result " + str(res)) + raise ValueError("Unknown SMT result: " + str(res)) class CryptolProve(CryptolProveSat): def __init__(self, connection : HasProtocolState, expr : Any, solver : solver.Solver) -> None: - super(CryptolProve, self).__init__(connection, 'prove', expr, solver, 1) + super(CryptolProve, self).__init__(connection, SmtQueryType.PROVE, expr, solver, 1) class CryptolSat(CryptolProveSat): def __init__(self, connection : HasProtocolState, expr : Any, solver : solver.Solver, count : int) -> None: - super(CryptolSat, self).__init__(connection, 'sat', expr, solver, count) + super(CryptolSat, self).__init__(connection, SmtQueryType.SAT, expr, solver, count) + +class CryptolSafe(CryptolProveSat): + def __init__(self, connection : HasProtocolState, expr : Any, solver : solver.Solver) -> None: + super(CryptolSafe, self).__init__(connection, SmtQueryType.SAFE, expr, solver, 1) class CryptolNames(argo.Command): def __init__(self, connection : HasProtocolState) -> None: @@ -269,7 +279,7 @@ def connect(command : Optional[str]=None, :param cryptol_path: A replacement for the contents of the ``CRYPTOLPATH`` environment variable (if provided). - :param url: A URL at which to connect to an already running Cryptol + :param url: A URL at which to connect to an already running Cryptol HTTP server. :param reset_server: If ``True``, the server that is connected to will be @@ -456,6 +466,13 @@ def prove(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command: self.most_recent_result = CryptolProve(self, expr, solver) return self.most_recent_result + def safe(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command: + """Check via an external SMT solver that the given term is safe for all inputs, + which means it cannot encounter a run-time error. + """ + self.most_recent_result = CryptolSafe(self, expr, solver) + return self.most_recent_result + def names(self) -> argo.Command: """Discover the list of names currently in scope in the current context.""" self.most_recent_result = CryptolNames(self) diff --git a/cryptol-remote-api/python/tests/cryptol/test_AES.py b/cryptol-remote-api/python/tests/cryptol/test_AES.py index 67e9cd9f3..9db6c6971 100644 --- a/cryptol-remote-api/python/tests/cryptol/test_AES.py +++ b/cryptol-remote-api/python/tests/cryptol/test_AES.py @@ -28,8 +28,8 @@ def test_AES(self): decrypted_ct = c.call("aesDecrypt", (ct, key)).result() self.assertEqual(pt, decrypted_ct) - # c.safe("aesEncrypt") - # c.safe("aesDecrypt") + self.assertTrue(c.safe("aesEncrypt")) + self.assertTrue(c.safe("aesDecrypt")) self.assertTrue(c.check("AESCorrect").result().success) # c.prove("AESCorrect") # probably takes too long for this script...? diff --git a/cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py b/cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py index c51500325..208fb7cde 100644 --- a/cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py +++ b/cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py @@ -138,6 +138,15 @@ def test_check(self): self.assertEqual(len(res.args), 1) self.assertIsInstance(res.error_msg, str) + def test_safe(self): + c = self.c + res = c.safe("\\x -> x==(x:[8])").result() + self.assertTrue(res) + + res = c.safe("\\x -> x / (x:[8])").result() + self.assertEqual(res, [BV(size=8, value=0)]) + + def test_many_usages_one_connection(self): c = self.c for i in range(0,100): diff --git a/cryptol-remote-api/src/CryptolServer/Sat.hs b/cryptol-remote-api/src/CryptolServer/Sat.hs index 029e1171d..a4224762d 100644 --- a/cryptol-remote-api/src/CryptolServer/Sat.hs +++ b/cryptol-remote-api/src/CryptolServer/Sat.hs @@ -157,6 +157,7 @@ instance FromJSON ProveSatParams where \case "sat" -> pure (SatQuery numResults) "prove" -> pure ProveQuery + "safe" -> pure SafetyQuery _ -> empty) num v = ((JSON.withText "all" $ \t -> if t == "all" then pure AllSat else empty) v) <|> @@ -174,17 +175,23 @@ instance Doc.DescribedParams ProveSatParams where ++ (concat (map (\p -> [Doc.Literal (T.pack p), Doc.Text ", "]) proverNames)) ++ [Doc.Text "."])) , ("expression", - Doc.Paragraph [Doc.Text "The predicate (i.e., function) to check for satisfiability; " - , Doc.Text "must be a monomorphic function with return type Bit." ]) + Doc.Paragraph [ Doc.Text "The function to check for validity, satisfiability, or safety" + , Doc.Text " depending on the specified value for " + , Doc.Literal "query type" + , Doc.Text ". For validity and satisfiability checks, the function must be a predicate" + , Doc.Text " (i.e., monomorphic function with return type Bit)." + ]) , ("result count", Doc.Paragraph [Doc.Text "How many satisfying results to search for; either a positive integer or " - , Doc.Literal "all", Doc.Text"."]) + , Doc.Literal "all", Doc.Text". Only affects satisfiability checks."]) , ("query type", - Doc.Paragraph [ Doc.Text "Whether to attempt to prove (" + Doc.Paragraph [ Doc.Text "Whether to attempt to prove the predicate is true for all possible inputs (" , Doc.Literal "prove" - , Doc.Text ") or satisfy (" + , Doc.Text "), find some inputs which make the predicate true (" , Doc.Literal "sat" - , Doc.Text ") the predicate." + , Doc.Text "), or prove a function is safe (" + , Doc.Literal "safe" + , Doc.Text ")." ] ) ]