Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/bytes hex #163

Merged
merged 7 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,56 @@
USub: {IntegerInstanceType: lambda x: plt.SubtractInteger(plt.Integer(0), x)},
}

ConstantMap = {
str: plt.Text,
bytes: plt.ByteString,
int: plt.Integer,
bool: plt.Bool,
type(None): lambda _: plt.Unit(),
}

def rec_constant_map_data(c):
if isinstance(c, bool):
return uplc.PlutusInteger(int(c))
if isinstance(c, int):
return uplc.PlutusInteger(c)
if isinstance(c, type(None)):
return uplc.PlutusConstr(0, [])
if isinstance(c, bytes):
return uplc.PlutusByteString(c)
if isinstance(c, str):
return uplc.PlutusByteString(c.encode())
if isinstance(c, list):
return uplc.PlutusList([rec_constant_map_data(ce) for ce in c])
if isinstance(c, dict):
return uplc.PlutusMap(
dict(
zip(
(rec_constant_map_data(ce) for ce in c.keys()),
(rec_constant_map_data(ce) for ce in c.values()),
)
)
)
raise NotImplementedError(f"Unsupported constant type {type(c)}")


def rec_constant_map(c):
if isinstance(c, bool):
return uplc.BuiltinBool(c)
if isinstance(c, int):
return uplc.BuiltinInteger(c)
if isinstance(c, type(None)):
return uplc.BuiltinUnit()
if isinstance(c, bytes):
return uplc.BuiltinByteString(c)
if isinstance(c, str):
return uplc.BuiltinString(c)
if isinstance(c, list):
return uplc.BuiltinList([rec_constant_map(ce) for ce in c])
if isinstance(c, dict):
return uplc.BuiltinList(
[
uplc.BuiltinPair(*p)
for p in zip(
(rec_constant_map_data(ce) for ce in c.keys()),
(rec_constant_map_data(ce) for ce in c.values()),
)
]
)
raise NotImplementedError(f"Unsupported constant type {type(c)}")


def wrap_validator_double_function(x: plt.AST, pass_through: int = 0):
Expand Down Expand Up @@ -310,12 +353,8 @@ def visit_Module(self, node: TypedModule) -> plt.AST:
return cp

def visit_Constant(self, node: TypedConstant) -> plt.AST:
plt_type = ConstantMap.get(type(node.value))
if plt_type is None:
raise NotImplementedError(
f"Constants of type {type(node.value)} are not supported"
)
return plt.Lambda([STATEMONAD], plt_type(node.value))
plt_val = plt.UPLCConstant(rec_constant_map(node.value))
return plt.Lambda([STATEMONAD], plt_val)

def visit_NoneType(self, _: typing.Optional[typing.Any]) -> plt.AST:
return plt.Lambda([STATEMONAD], plt.Unit())
Expand Down
18 changes: 3 additions & 15 deletions opshin/optimize/optimize_const_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,8 @@ def err():
except Exception as e:
return node

def rec_dump(c):
if any(isinstance(c, a) for a in ACCEPTED_ATOMIC_TYPES):
new_node = Constant(c, None)
copy_location(new_node, node)
return new_node
# TODO dump these values directly as plutus constants in the code (they will be built on chain with this)
if isinstance(c, list):
return List([rec_dump(ce) for ce in c], Load())
if isinstance(c, dict):
return Dict(
[rec_dump(ce) for ce in c.keys()],
[rec_dump(ce) for ce in c.values()],
)

if any(isinstance(node_eval, t) for t in ACCEPTED_ATOMIC_TYPES + [list, dict]):
return rec_dump(node_eval)
new_node = Constant(node_eval, None)
copy_location(new_node, node)
return new_node
return node
52 changes: 52 additions & 0 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,11 +972,63 @@ def validator(_: None) -> List[int]:
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast).compile()
self.assertIn("(con list<integer> [0, 2, 4, 6, 8])", code.dumps())
res = uplc_eval(uplc.Apply(code, uplc.PlutusConstr(0, [])))
self.assertEqual(
res, uplc.PlutusList([uplc.PlutusInteger(i) for i in range(0, 10, 2)])
)

def test_constant_folding_dict(self):
source_code = """
from opshin.prelude import *

def validator(_: None) -> Dict[str, bool]:
return {"s": True, "m": False}
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast).compile()
self.assertIn(
"(con list<pair<data, data>> [[#4173, #01], [#416d, #00]]))", code.dumps()
)
res = uplc_eval(uplc.Apply(code, uplc.PlutusConstr(0, [])))
self.assertEqual(
res,
uplc.PlutusMap(
{
uplc.PlutusByteString("s".encode()): uplc.PlutusInteger(1),
uplc.PlutusByteString("m".encode()): uplc.PlutusInteger(0),
}
),
)

def test_constant_folding_complex(self):
source_code = """
from opshin.prelude import *

def validator(_: None) -> Dict[str, List[Dict[bytes, int]]]:
return {"s": [{b"": 0}, {b"0": 1}]}
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast).compile()
res = uplc_eval(uplc.Apply(code, uplc.PlutusConstr(0, [])))
self.assertEqual(
res,
uplc.PlutusMap(
{
uplc.PlutusByteString("s".encode()): uplc.PlutusList(
[
uplc.PlutusMap(
{uplc.PlutusByteString(b""): uplc.PlutusInteger(0)}
),
uplc.PlutusMap(
{uplc.PlutusByteString(b"0"): uplc.PlutusInteger(1)}
),
]
),
}
),
)

def test_constant_folding_math(self):
source_code = """
from opshin.prelude import *
Expand Down
26 changes: 25 additions & 1 deletion opshin/tests/test_stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def validator(x: str) -> bytes:
self.assertEqual(ret, xs.encode(), "str.encode returned wrong value")

@given(xs=st.binary())
def test_str_decode(self, xs):
def test_bytes_decode(self, xs):
# this tests that errors that are caused by assignments are actually triggered at the time of assigning
source_code = """
def validator(x: bytes) -> str:
Expand All @@ -200,6 +200,30 @@ def validator(x: bytes) -> str:
ret = None
self.assertEqual(ret, exp, "bytes.decode returned wrong value")

@given(xs=st.binary())
def test_bytes_hex(self, xs):
# this tests that errors that are caused by assignments are actually triggered at the time of assigning
source_code = """
def validator(x: bytes) -> str:
return x.hex()
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast)
code = code.compile()
f = code.term
try:
exp = xs.hex()
except UnicodeDecodeError:
exp = None
# UPLC lambdas may only take one argument at a time, so we evaluate by repeatedly applying
for d in [uplc.PlutusByteString(xs)]:
f = uplc.Apply(f, d)
try:
ret = uplc_eval(f).value.decode()
except UnicodeDecodeError:
ret = None
self.assertEqual(ret, exp, "bytes.hex returned wrong value")

@given(xs=st.binary())
@example(b"dc315c289fee4484eda07038393f21dc4e572aff292d7926018725c2")
def test_constant_bytestring(self, xs):
Expand Down
33 changes: 32 additions & 1 deletion opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,37 @@
)


def constant_type(c):
if isinstance(c, bool):
return BoolInstanceType
if isinstance(c, int):
return IntegerInstanceType
if isinstance(c, type(None)):
return UnitInstanceType
if isinstance(c, bytes):
return ByteStringInstanceType
if isinstance(c, str):
return StringInstanceType
if isinstance(c, list):
assert len(c) > 0, "Lists must be non-empty"
first_typ = constant_type(c[0])
assert all(
constant_type(ce) == first_typ for ce in c[1:]
), "Constant lists must contain elements of a single type only"
return InstanceType(ListType(first_typ))
if isinstance(c, dict):
assert len(c) > 0, "Lists must be non-empty"
first_key_typ = constant_type(next(iter(c.keys())))
first_value_typ = constant_type(next(iter(c.values())))
assert all(
constant_type(ce) == first_key_typ for ce in c.keys()
), "Constant dicts must contain keys of a single type only"
assert all(
constant_type(ce) == first_value_typ for ce in c.values()
), "Constant dicts must contain values of a single type only"
return InstanceType(DictType(first_key_typ, first_value_typ))


class AggressiveTypeInferencer(CompilingNodeTransformer):
step = "Static Type Inference"

Expand Down Expand Up @@ -154,7 +185,7 @@ def visit_Constant(self, node: Constant) -> TypedConstant:
complex,
type(...),
], "Float, complex numbers and ellipsis currently not supported"
tc.typ = InstanceType(ATOMIC_TYPES[type(node.value).__name__])
tc.typ = constant_type(node.value)
return tc

def visit_Tuple(self, node: Tuple) -> TypedTuple:
Expand Down
91 changes: 91 additions & 0 deletions opshin/typed_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,12 +748,103 @@ def constr(self) -> plt.AST:
def attribute_type(self, attr) -> Type:
if attr == "decode":
return InstanceType(FunctionType([], StringInstanceType))
if attr == "hex":
return InstanceType(FunctionType([], StringInstanceType))
return super().attribute_type(attr)

def attribute(self, attr) -> plt.AST:
if attr == "decode":
# No codec -> only the default (utf8) is allowed
return plt.Lambda(["x", "_"], plt.DecodeUtf8(plt.Var("x")))
if attr == "hex":
return plt.Lambda(
["x", "_"],
plt.DecodeUtf8(
plt.Let(
[
(
"hexlist",
plt.RecFun(
plt.Lambda(
["f", "i"],
plt.Ite(
plt.LessThanInteger(
plt.Var("i"), plt.Integer(0)
),
plt.EmptyIntegerList(),
plt.MkCons(
plt.IndexByteString(
plt.Var("x"), plt.Var("i")
),
plt.Apply(
plt.Var("f"),
plt.Var("f"),
plt.SubtractInteger(
plt.Var("i"), plt.Integer(1)
),
),
),
),
),
),
),
(
"map_str",
plt.Lambda(
["i"],
plt.AddInteger(
plt.Var("i"),
plt.IfThenElse(
plt.LessThanInteger(
plt.Var("i"), plt.Integer(10)
),
plt.Integer(ord("0")),
plt.Integer(ord("a") - 10),
),
),
),
),
(
"mkstr",
plt.Lambda(
["i"],
plt.FoldList(
plt.Apply(plt.Var("hexlist"), plt.Var("i")),
plt.Lambda(
["b", "i"],
plt.ConsByteString(
plt.Apply(
plt.Var("map_str"),
plt.DivideInteger(
plt.Var("i"), plt.Integer(16)
),
),
plt.ConsByteString(
plt.Apply(
plt.Var("map_str"),
plt.ModInteger(
plt.Var("i"),
plt.Integer(16),
),
),
plt.Var("b"),
),
),
),
plt.ByteString(b""),
),
),
),
],
plt.Apply(
plt.Var("mkstr"),
plt.SubtractInteger(
plt.LengthOfByteString(plt.Var("x")), plt.Integer(1)
),
),
),
),
)
return super().attribute(attr)

def cmp(self, op: cmpop, o: "Type") -> plt.AST:
Expand Down