Skip to content

Commit

Permalink
Merge pull request #328 from OpShin/fix/merging_functions_with_type_c…
Browse files Browse the repository at this point in the history
…apturing

Fix/merging functions with type capturing
  • Loading branch information
nielstron authored Feb 2, 2024
2 parents cede27a + e0a4be5 commit 7de70fb
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 52 deletions.
44 changes: 14 additions & 30 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .rewrite.rewrite_subscript38 import RewriteSubscript38
from .rewrite.rewrite_tuple_assign import RewriteTupleAssign
from .optimize.optimize_remove_pass import OptimizeRemovePass
from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars
from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars, NameLoadCollector
from .type_inference import *
from .util import (
CompilingNodeTransformer,
Expand Down Expand Up @@ -136,27 +136,6 @@ def wrap_validator_double_function(x: plt.AST, pass_through: int = 0):
CallAST = typing.Callable[[plt.AST], plt.AST]


class FunctionBoundVarsCollector(NodeVisitor):
def __init__(self):
self.functions_bound_vars: typing.Dict[
FunctionType, typing.List[str]
] = defaultdict(list)

def visit_FunctionDef(self, node: FunctionDef) -> None:
self.functions_bound_vars[node.typ.typ] = sorted(
set(self.functions_bound_vars[node.typ.typ] + externally_bound_vars(node))
)
self.generic_visit(node)


def extract_function_bound_vars(
node: AST,
) -> typing.Dict[FunctionType, typing.List[str]]:
e = FunctionBoundVarsCollector()
e.visit(node)
return e.functions_bound_vars


class PlutoCompiler(CompilingNodeTransformer):
"""
Expects a TypedAST and returns UPLC/Pluto like code
Expand All @@ -170,9 +149,6 @@ def __init__(self, force_three_params=False, validator_function_name="validator"
self.validator_function_name = validator_function_name
# marked knowledge during compilation
self.current_function_typ: typing.List[FunctionType] = []
self.function_bound_vars: typing.Dict[
FunctionType, typing.List[str]
] = defaultdict(list)

def visit_sequence(self, node_seq: typing.List[typedstmt]) -> CallAST:
def g(s: plt.AST):
Expand Down Expand Up @@ -223,7 +199,6 @@ def visit_Compare(self, node: TypedCompare) -> plt.AST:

def visit_Module(self, node: TypedModule) -> plt.AST:
# extract actually read variables by each function
self.function_bound_vars = extract_function_bound_vars(node)
if self.validator_function_name is not None:
# for validators find main function
# TODO can use more sophisiticated procedure here i.e. functions marked by comment
Expand Down Expand Up @@ -292,7 +267,9 @@ def visit_Module(self, node: TypedModule) -> plt.AST:
]
)
self.current_function_typ.append(FunctionType([], InstanceType(AnyType())))
all_vs = sorted(set(all_vars(node)))
name_load_visitor = NameLoadCollector()
name_load_visitor.visit(node)
all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

# write all variables that are ever read
# once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
Expand Down Expand Up @@ -324,7 +301,9 @@ def visit_Module(self, node: TypedModule) -> plt.AST:
"The contract can not always detect if it was passed three or two parameters on-chain."
)
else:
all_vs = sorted(set(all_vars(node)))
name_load_visitor = NameLoadCollector()
name_load_visitor.visit(node)
all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

body = node.body
# write all variables that are ever read
Expand Down Expand Up @@ -434,12 +413,14 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
node.func.typ.typ.argtyps
)
)
bind_self = None
else:
assert isinstance(node.func.typ, InstanceType) and isinstance(
node.func.typ.typ, FunctionType
)
func_plt = self.visit(node.func)
bound_vs = self.function_bound_vars[node.func.typ.typ]
bind_self = node.func.typ.typ.bind_self
bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys()))
args = []
for a, t in zip(node.args, node.func.typ.typ.argtyps):
assert isinstance(t, InstanceType)
Expand All @@ -457,6 +438,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
[(f"p{i}", a) for i, a in enumerate(args)],
SafeApply(
func_plt,
*([plt.Var(bind_self)] if bind_self is not None else []),
*[plt.Var(n) for n in bound_vs],
*[plt.Delay(OVar(f"p{i}")) for i in range(len(args))],
),
Expand All @@ -469,7 +451,9 @@ def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST:
ret_val = plt.ConstrData(plt.Integer(0), plt.EmptyDataList())
else:
ret_val = plt.Unit()
read_vs = self.function_bound_vars[node.typ.typ]
read_vs = sorted(list(node.typ.typ.bound_vars.keys()))
if node.typ.typ.bind_self is not None:
read_vs.insert(0, node.typ.typ.bind_self)
self.current_function_typ.append(node.typ.typ)
compiled_body = self.visit_sequence(body)(ret_val)
self.current_function_typ.pop()
Expand Down
12 changes: 8 additions & 4 deletions opshin/optimize/optimize_remove_deadvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..util import CompilingNodeVisitor, CompilingNodeTransformer
from ..type_inference import INITIAL_SCOPE
from ..typed_ast import TypedAnnAssign
from ..typed_ast import TypedAnnAssign, TypedFunctionDef, TypedClassDef, TypedName

"""
Removes assignments to variables that are never read
Expand All @@ -19,18 +19,22 @@ class NameLoadCollector(CompilingNodeVisitor):
def __init__(self):
self.loaded = defaultdict(int)

def visit_Name(self, node: Name) -> None:
def visit_Name(self, node: TypedName) -> None:
if isinstance(node.ctx, Load):
self.loaded[node.id] += 1

def visit_ClassDef(self, node: ClassDef):
def visit_ClassDef(self, node: TypedClassDef):
# ignore the content (i.e. attribute names) of class definitions
pass

def visit_FunctionDef(self, node: FunctionDef):
def visit_FunctionDef(self, node: TypedFunctionDef):
# ignore the type hints of function arguments
for s in node.body:
self.visit(s)
for v in node.typ.typ.bound_vars.keys():
self.loaded[v] += 1
if node.typ.typ.bind_self is not None:
self.loaded[node.typ.typ.bind_self] += 1


class SafeOperationVisitor(CompilingNodeVisitor):
Expand Down
29 changes: 22 additions & 7 deletions opshin/rewrite/rewrite_remove_type_stuff.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
from typing import Optional

from ..typed_ast import TypedAssign, ClassType
from ..typed_ast import (
TypedAssign,
ClassType,
InstanceType,
PolymorphicFunctionType,
TypeInferenceError,
)
from ..util import CompilingNodeTransformer

"""
Remove class reassignments without constructors
Remove class reassignments without constructors and polymorphic function reassignments
Both of these are only present during the type inference and are discarded or generated in-place during compilation.
"""


class RewriteRemoveTypeStuff(CompilingNodeTransformer):
step = "Removing class re-assignments"
step = "Removing class and polymorphic function re-assignments"

def visit_Assign(self, node: TypedAssign) -> Optional[TypedAssign]:
assert (
len(node.targets) == 1
), "Assignments to more than one variable not supported yet"
try:
if isinstance(node.value.typ, ClassType):
node.value.typ.constr()
except NotImplementedError:
# The type does not have a constructor and the constructor can hence not be passed on
return None
try:
typ = node.value.typ.constr_type()
except TypeInferenceError:
# no constr_type is also fine
return None
else:
typ = node.value.typ
if isinstance(typ, InstanceType) and isinstance(
typ.typ, PolymorphicFunctionType
):
return None
except AttributeError:
# untyped attributes are fine too
pass
Expand Down
109 changes: 109 additions & 0 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2776,3 +2776,112 @@ def validator(x: Union[A, B, C]) -> int:
return fun(x)
"""
builder._compile(source_code)

@unittest.expectedFailure
def test_merge_function_same_capture_different_type(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
@dataclass()
class B(PlutusData):
CONSTR_ID = 1
bar: int
def validator(x: bool) -> int:
if x:
y = A(0)
def foo() -> int:
return y.foo
else:
y = B(0)
def foo() -> int:
return y.bar
y = A(0)
return foo()
"""
builder._compile(source_code)

def test_merge_function_same_capture_same_type(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
@dataclass()
class B(PlutusData):
CONSTR_ID = 1
bar: int
def validator(x: bool) -> int:
if x:
y = A(0)
def foo() -> int:
print(2)
return y.foo
else:
y = A(0) if x else B(0)
def foo() -> int:
print(y)
return 2
y = A(0)
return foo()
"""
res_true = eval_uplc_value(source_code, 1)
res_false = eval_uplc_value(source_code, 0)
self.assertEqual(res_true, 0)
self.assertEqual(res_false, 2)

def test_merge_print(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass
def validator(x: bool) -> None:
if x:
a = print
else:
b = print
a = b
return a(x)
"""
res_true = eval_uplc(source_code, 1)
res_false = eval_uplc(source_code, 0)

def test_print_reassign(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass
def validator(x: bool) -> None:
a = print
return a(x)
"""
res_true = eval_uplc(source_code, 1)
res_false = eval_uplc(source_code, 0)

def test_str_constr_reassign(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass
def validator(x: bool) -> str:
a = str
return a(x)
"""
res_true = eval_uplc_value(source_code, 1)
res_false = eval_uplc_value(source_code, 0)
33 changes: 33 additions & 0 deletions opshin/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,49 @@
from ..types import *


def test_record_type_order():
A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)]))
B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)]))
C = RecordType(Record("C", "C", 2, [("baz", IntegerInstanceType)]))
a = A
b = B
c = C

assert a >= a
assert not a >= b
assert not b >= a
assert not a >= c
assert not c >= a
assert not b >= c
assert not c >= b

A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)]))
B = RecordType(
Record(
"B", "B", 0, [("foo", IntegerInstanceType), ("bar", IntegerInstanceType)]
)
)
C = RecordType(Record("C", "C", 0, [("foo", InstanceType(AnyType()))]))
assert not A >= B
assert not C >= B
assert C >= A


def test_union_type_order():
A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)]))
B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)]))
C = RecordType(Record("C", "C", 2, [("baz", IntegerInstanceType)]))
abc = UnionType([A, B, C])
ab = UnionType([A, B])
a = A
c = C

assert a >= a
assert ab >= a
assert not a >= ab
assert abc >= ab
assert not ab >= abc
assert not c >= a
assert not a >= c
assert abc >= c
assert not ab >= c
3 changes: 2 additions & 1 deletion opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
functyp = FunctionType(
frozenlist([t.typ for t in tfd.args.args]),
InstanceType(self.type_from_annotation(tfd.returns)),
externally_bound_vars(node),
bound_vars={v: self.variable_type(v) for v in externally_bound_vars(node)},
bind_self=node.name if node.name in read_vars(node) else None,
)
tfd.typ = InstanceType(functyp)
if wraps_builtin:
Expand Down
Loading

0 comments on commit 7de70fb

Please sign in to comment.