Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/merging functions with type capturing #328

Merged
merged 10 commits into from
Feb 2, 2024
44 changes: 14 additions & 30 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .rewrite.rewrite_subscript38 import RewriteSubscript38
from .rewrite.rewrite_tuple_assign import RewriteTupleAssign
from .optimize.optimize_remove_pass import OptimizeRemovePass
from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars
from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars, NameLoadCollector
from .type_inference import *
from .util import (
CompilingNodeTransformer,
Expand Down Expand Up @@ -136,27 +136,6 @@ def wrap_validator_double_function(x: plt.AST, pass_through: int = 0):
CallAST = typing.Callable[[plt.AST], plt.AST]


class FunctionBoundVarsCollector(NodeVisitor):
def __init__(self):
self.functions_bound_vars: typing.Dict[
FunctionType, typing.List[str]
] = defaultdict(list)

def visit_FunctionDef(self, node: FunctionDef) -> None:
self.functions_bound_vars[node.typ.typ] = sorted(
set(self.functions_bound_vars[node.typ.typ] + externally_bound_vars(node))
)
self.generic_visit(node)


def extract_function_bound_vars(
node: AST,
) -> typing.Dict[FunctionType, typing.List[str]]:
e = FunctionBoundVarsCollector()
e.visit(node)
return e.functions_bound_vars


class PlutoCompiler(CompilingNodeTransformer):
"""
Expects a TypedAST and returns UPLC/Pluto like code
Expand All @@ -170,9 +149,6 @@ def __init__(self, force_three_params=False, validator_function_name="validator"
self.validator_function_name = validator_function_name
# marked knowledge during compilation
self.current_function_typ: typing.List[FunctionType] = []
self.function_bound_vars: typing.Dict[
FunctionType, typing.List[str]
] = defaultdict(list)

def visit_sequence(self, node_seq: typing.List[typedstmt]) -> CallAST:
def g(s: plt.AST):
Expand Down Expand Up @@ -223,7 +199,6 @@ def visit_Compare(self, node: TypedCompare) -> plt.AST:

def visit_Module(self, node: TypedModule) -> plt.AST:
# extract actually read variables by each function
self.function_bound_vars = extract_function_bound_vars(node)
if self.validator_function_name is not None:
# for validators find main function
# TODO can use more sophisiticated procedure here i.e. functions marked by comment
Expand Down Expand Up @@ -292,7 +267,9 @@ def visit_Module(self, node: TypedModule) -> plt.AST:
]
)
self.current_function_typ.append(FunctionType([], InstanceType(AnyType())))
all_vs = sorted(set(all_vars(node)))
name_load_visitor = NameLoadCollector()
name_load_visitor.visit(node)
all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

# write all variables that are ever read
# once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
Expand Down Expand Up @@ -324,7 +301,9 @@ def visit_Module(self, node: TypedModule) -> plt.AST:
"The contract can not always detect if it was passed three or two parameters on-chain."
)
else:
all_vs = sorted(set(all_vars(node)))
name_load_visitor = NameLoadCollector()
name_load_visitor.visit(node)
all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

body = node.body
# write all variables that are ever read
Expand Down Expand Up @@ -434,12 +413,14 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
node.func.typ.typ.argtyps
)
)
bind_self = None
else:
assert isinstance(node.func.typ, InstanceType) and isinstance(
node.func.typ.typ, FunctionType
)
func_plt = self.visit(node.func)
bound_vs = self.function_bound_vars[node.func.typ.typ]
bind_self = node.func.typ.typ.bind_self
bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys()))
args = []
for a, t in zip(node.args, node.func.typ.typ.argtyps):
assert isinstance(t, InstanceType)
Expand All @@ -457,6 +438,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
[(f"p{i}", a) for i, a in enumerate(args)],
SafeApply(
func_plt,
*([plt.Var(bind_self)] if bind_self is not None else []),
*[plt.Var(n) for n in bound_vs],
*[plt.Delay(OVar(f"p{i}")) for i in range(len(args))],
),
Expand All @@ -469,7 +451,9 @@ def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST:
ret_val = plt.ConstrData(plt.Integer(0), plt.EmptyDataList())
else:
ret_val = plt.Unit()
read_vs = self.function_bound_vars[node.typ.typ]
read_vs = sorted(list(node.typ.typ.bound_vars.keys()))
if node.typ.typ.bind_self is not None:
read_vs.insert(0, node.typ.typ.bind_self)
self.current_function_typ.append(node.typ.typ)
compiled_body = self.visit_sequence(body)(ret_val)
self.current_function_typ.pop()
Expand Down
12 changes: 8 additions & 4 deletions opshin/optimize/optimize_remove_deadvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..util import CompilingNodeVisitor, CompilingNodeTransformer
from ..type_inference import INITIAL_SCOPE
from ..typed_ast import TypedAnnAssign
from ..typed_ast import TypedAnnAssign, TypedFunctionDef, TypedClassDef, TypedName

"""
Removes assignments to variables that are never read
Expand All @@ -19,18 +19,22 @@ class NameLoadCollector(CompilingNodeVisitor):
def __init__(self):
self.loaded = defaultdict(int)

def visit_Name(self, node: Name) -> None:
def visit_Name(self, node: TypedName) -> None:
if isinstance(node.ctx, Load):
self.loaded[node.id] += 1

def visit_ClassDef(self, node: ClassDef):
def visit_ClassDef(self, node: TypedClassDef):
# ignore the content (i.e. attribute names) of class definitions
pass

def visit_FunctionDef(self, node: FunctionDef):
def visit_FunctionDef(self, node: TypedFunctionDef):
# ignore the type hints of function arguments
for s in node.body:
self.visit(s)
for v in node.typ.typ.bound_vars.keys():
self.loaded[v] += 1
if node.typ.typ.bind_self is not None:
self.loaded[node.typ.typ.bind_self] += 1


class SafeOperationVisitor(CompilingNodeVisitor):
Expand Down
29 changes: 22 additions & 7 deletions opshin/rewrite/rewrite_remove_type_stuff.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
from typing import Optional

from ..typed_ast import TypedAssign, ClassType
from ..typed_ast import (
TypedAssign,
ClassType,
InstanceType,
PolymorphicFunctionType,
TypeInferenceError,
)
from ..util import CompilingNodeTransformer

"""
Remove class reassignments without constructors
Remove class reassignments without constructors and polymorphic function reassignments

Both of these are only present during the type inference and are discarded or generated in-place during compilation.
"""


class RewriteRemoveTypeStuff(CompilingNodeTransformer):
step = "Removing class re-assignments"
step = "Removing class and polymorphic function re-assignments"

def visit_Assign(self, node: TypedAssign) -> Optional[TypedAssign]:
assert (
len(node.targets) == 1
), "Assignments to more than one variable not supported yet"
try:
if isinstance(node.value.typ, ClassType):
node.value.typ.constr()
except NotImplementedError:
# The type does not have a constructor and the constructor can hence not be passed on
return None
try:
typ = node.value.typ.constr_type()
except TypeInferenceError:
# no constr_type is also fine
return None
else:
typ = node.value.typ
if isinstance(typ, InstanceType) and isinstance(
typ.typ, PolymorphicFunctionType
):
return None
except AttributeError:
# untyped attributes are fine too
pass
Expand Down
109 changes: 109 additions & 0 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2776,3 +2776,112 @@ def validator(x: Union[A, B, C]) -> int:
return fun(x)
"""
builder._compile(source_code)

@unittest.expectedFailure
def test_merge_function_same_capture_different_type(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass

@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int

@dataclass()
class B(PlutusData):
CONSTR_ID = 1
bar: int

def validator(x: bool) -> int:
if x:
y = A(0)
def foo() -> int:
return y.foo
else:
y = B(0)
def foo() -> int:
return y.bar
y = A(0)
return foo()
"""
builder._compile(source_code)

def test_merge_function_same_capture_same_type(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass

@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int

@dataclass()
class B(PlutusData):
CONSTR_ID = 1
bar: int

def validator(x: bool) -> int:
if x:
y = A(0)
def foo() -> int:
print(2)
return y.foo
else:
y = A(0) if x else B(0)
def foo() -> int:
print(y)
return 2
y = A(0)
return foo()
"""
res_true = eval_uplc_value(source_code, 1)
res_false = eval_uplc_value(source_code, 0)
self.assertEqual(res_true, 0)
self.assertEqual(res_false, 2)

def test_merge_print(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass

def validator(x: bool) -> None:
if x:
a = print
else:
b = print
a = b
return a(x)
"""
res_true = eval_uplc(source_code, 1)
res_false = eval_uplc(source_code, 0)

def test_print_reassign(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass

def validator(x: bool) -> None:
a = print
return a(x)
"""
res_true = eval_uplc(source_code, 1)
res_false = eval_uplc(source_code, 0)

def test_str_constr_reassign(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass

def validator(x: bool) -> str:
a = str
return a(x)
"""
res_true = eval_uplc_value(source_code, 1)
res_false = eval_uplc_value(source_code, 0)
33 changes: 33 additions & 0 deletions opshin/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,49 @@
from ..types import *


def test_record_type_order():
A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)]))
B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)]))
C = RecordType(Record("C", "C", 2, [("baz", IntegerInstanceType)]))
a = A
b = B
c = C

assert a >= a
assert not a >= b
assert not b >= a
assert not a >= c
assert not c >= a
assert not b >= c
assert not c >= b

A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)]))
B = RecordType(
Record(
"B", "B", 0, [("foo", IntegerInstanceType), ("bar", IntegerInstanceType)]
)
)
C = RecordType(Record("C", "C", 0, [("foo", InstanceType(AnyType()))]))
assert not A >= B
assert not C >= B
assert C >= A


def test_union_type_order():
A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)]))
B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)]))
C = RecordType(Record("C", "C", 2, [("baz", IntegerInstanceType)]))
abc = UnionType([A, B, C])
ab = UnionType([A, B])
a = A
c = C

assert a >= a
assert ab >= a
assert not a >= ab
assert abc >= ab
assert not ab >= abc
assert not c >= a
assert not a >= c
assert abc >= c
assert not ab >= c
3 changes: 2 additions & 1 deletion opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
functyp = FunctionType(
frozenlist([t.typ for t in tfd.args.args]),
InstanceType(self.type_from_annotation(tfd.returns)),
externally_bound_vars(node),
bound_vars={v: self.variable_type(v) for v in externally_bound_vars(node)},
bind_self=node.name if node.name in read_vars(node) else None,
)
tfd.typ = InstanceType(functyp)
if wraps_builtin:
Expand Down
Loading