Skip to content

Commit

Permalink
feat(rpc): safe for python api (#1168)
Browse files Browse the repository at this point in the history
* feat(rpc): safe for python api

* refactor: use enum for smt query type

* Update cryptol-remote-api/python/cryptol/__init__.py

Co-authored-by: Ryan Scott <[email protected]>

* Update cryptol-remote-api/python/cryptol/__init__.py

Co-authored-by: Ryan Scott <[email protected]>

Co-authored-by: Ryan Scott <[email protected]>
  • Loading branch information
Andrew Kent and RyanGlScott authored Apr 22, 2021
1 parent 867096c commit 966b343
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 16 deletions.
33 changes: 25 additions & 8 deletions cryptol-remote-api/python/cryptol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import base64
import os
from enum import Enum
from dataclasses import dataclass
from distutils.spawn import find_executable
from typing import Any, List, NoReturn, Optional, Union
Expand Down Expand Up @@ -172,8 +173,13 @@ def __init__(self, connection : HasProtocolState, expr : Any) -> None:
def process_result(self, res : Any) -> Any:
return res['type schema']

class SmtQueryType(str, Enum):
PROVE = 'prove'
SAFE = 'safe'
SAT = 'sat'

class CryptolProveSat(argo.Command):
def __init__(self, connection : HasProtocolState, qtype : str, expr : Any, solver : solver.Solver, count : Optional[int]) -> None:
def __init__(self, connection : HasProtocolState, qtype : SmtQueryType, expr : Any, solver : solver.Solver, count : Optional[int]) -> None:
super(CryptolProveSat, self).__init__(
'prove or satisfy',
{'query type': qtype,
Expand All @@ -186,12 +192,12 @@ def __init__(self, connection : HasProtocolState, qtype : str, expr : Any, solve

def process_result(self, res : Any) -> Any:
if res['result'] == 'unsatisfiable':
if self.qtype == 'sat':
if self.qtype == SmtQueryType.SAT:
return False
elif self.qtype == 'prove':
elif self.qtype == SmtQueryType.PROVE or self.qtype == SmtQueryType.SAFE:
return True
else:
raise ValueError("Unknown prove/sat query type: " + self.qtype)
raise ValueError("Unknown SMT query type: " + self.qtype)
elif res['result'] == 'invalid':
return [from_cryptol_arg(arg['expr'])
for arg in res['counterexample']]
Expand All @@ -200,15 +206,19 @@ def process_result(self, res : Any) -> Any:
for m in res['models']
for arg in m]
else:
raise ValueError("Unknown sat result " + str(res))
raise ValueError("Unknown SMT result: " + str(res))

class CryptolProve(CryptolProveSat):
def __init__(self, connection : HasProtocolState, expr : Any, solver : solver.Solver) -> None:
super(CryptolProve, self).__init__(connection, 'prove', expr, solver, 1)
super(CryptolProve, self).__init__(connection, SmtQueryType.PROVE, expr, solver, 1)

class CryptolSat(CryptolProveSat):
def __init__(self, connection : HasProtocolState, expr : Any, solver : solver.Solver, count : int) -> None:
super(CryptolSat, self).__init__(connection, 'sat', expr, solver, count)
super(CryptolSat, self).__init__(connection, SmtQueryType.SAT, expr, solver, count)

class CryptolSafe(CryptolProveSat):
def __init__(self, connection : HasProtocolState, expr : Any, solver : solver.Solver) -> None:
super(CryptolSafe, self).__init__(connection, SmtQueryType.SAFE, expr, solver, 1)

class CryptolNames(argo.Command):
def __init__(self, connection : HasProtocolState) -> None:
Expand Down Expand Up @@ -257,7 +267,7 @@ def connect(command : Optional[str]=None,
:param cryptol_path: A replacement for the contents of
the ``CRYPTOLPATH`` environment variable (if provided).
:param url: A URL at which to connect to an already running Cryptol
:param url: A URL at which to connect to an already running Cryptol
HTTP server.
:param reset_server: If ``True``, the server that is connected to will be
Expand Down Expand Up @@ -439,6 +449,13 @@ def prove(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command:
self.most_recent_result = CryptolProve(self, expr, solver)
return self.most_recent_result

def safe(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command:
"""Check via an external SMT solver that the given term is safe for all inputs,
which means it cannot encounter a run-time error.
"""
self.most_recent_result = CryptolSafe(self, expr, solver)
return self.most_recent_result

def names(self) -> argo.Command:
"""Discover the list of names currently in scope in the current context."""
self.most_recent_result = CryptolNames(self)
Expand Down
4 changes: 2 additions & 2 deletions cryptol-remote-api/python/tests/cryptol/test_AES.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def test_AES(self):
decrypted_ct = c.call("aesDecrypt", (ct, key)).result()
self.assertEqual(pt, decrypted_ct)

# c.safe("aesEncrypt")
# c.safe("aesDecrypt")
self.assertTrue(c.safe("aesEncrypt"))
self.assertTrue(c.safe("aesDecrypt"))
self.assertTrue(c.check("AESCorrect").result().success)
# c.prove("AESCorrect") # probably takes too long for this script...?

Expand Down
9 changes: 9 additions & 0 deletions cryptol-remote-api/python/tests/cryptol/test_cryptol_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ def test_check(self):
self.assertEqual(len(res.args), 1)
self.assertIsInstance(res.error_msg, str)

def test_safe(self):
c = self.c
res = c.safe("\\x -> x==(x:[8])").result()
self.assertTrue(res)

res = c.safe("\\x -> x / (x:[8])").result()
self.assertEqual(res, [BV(size=8, value=0)])


def test_many_usages_one_connection(self):
c = self.c
for i in range(0,100):
Expand Down
19 changes: 13 additions & 6 deletions cryptol-remote-api/src/CryptolServer/Sat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ instance FromJSON ProveSatParams where
\case
"sat" -> pure (SatQuery numResults)
"prove" -> pure ProveQuery
"safe" -> pure SafetyQuery
_ -> empty)
num v = ((JSON.withText "all" $
\t -> if t == "all" then pure AllSat else empty) v) <|>
Expand All @@ -174,17 +175,23 @@ instance Doc.DescribedParams ProveSatParams where
++ (concat (map (\p -> [Doc.Literal (T.pack p), Doc.Text ", "]) proverNames))
++ [Doc.Text "."]))
, ("expression",
Doc.Paragraph [Doc.Text "The predicate (i.e., function) to check for satisfiability; "
, Doc.Text "must be a monomorphic function with return type Bit." ])
Doc.Paragraph [ Doc.Text "The function to check for validity, satisfiability, or safety"
, Doc.Text " depending on the specified value for "
, Doc.Literal "query type"
, Doc.Text ". For validity and satisfiability checks, the function must be a predicate"
, Doc.Text " (i.e., monomorphic function with return type Bit)."
])
, ("result count",
Doc.Paragraph [Doc.Text "How many satisfying results to search for; either a positive integer or "
, Doc.Literal "all", Doc.Text"."])
, Doc.Literal "all", Doc.Text". Only affects satisfiability checks."])
, ("query type",
Doc.Paragraph [ Doc.Text "Whether to attempt to prove ("
Doc.Paragraph [ Doc.Text "Whether to attempt to prove the predicate is true for all possible inputs ("
, Doc.Literal "prove"
, Doc.Text ") or satisfy ("
, Doc.Text "), find some inputs which make the predicate true ("
, Doc.Literal "sat"
, Doc.Text ") the predicate."
, Doc.Text "), or prove a function is safe ("
, Doc.Literal "safe"
, Doc.Text ")."
]
)
]

0 comments on commit 966b343

Please sign in to comment.