diff --git a/opshin/compiler.py b/opshin/compiler.py index 3fd174b9..928ba6ab 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -25,7 +25,7 @@ from .rewrite.rewrite_subscript38 import RewriteSubscript38 from .rewrite.rewrite_tuple_assign import RewriteTupleAssign from .optimize.optimize_remove_pass import OptimizeRemovePass -from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars +from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars, NameLoadCollector from .type_inference import * from .util import ( CompilingNodeTransformer, @@ -136,27 +136,6 @@ def wrap_validator_double_function(x: plt.AST, pass_through: int = 0): CallAST = typing.Callable[[plt.AST], plt.AST] -class FunctionBoundVarsCollector(NodeVisitor): - def __init__(self): - self.functions_bound_vars: typing.Dict[ - FunctionType, typing.List[str] - ] = defaultdict(list) - - def visit_FunctionDef(self, node: FunctionDef) -> None: - self.functions_bound_vars[node.typ.typ] = sorted( - set(self.functions_bound_vars[node.typ.typ] + externally_bound_vars(node)) - ) - self.generic_visit(node) - - -def extract_function_bound_vars( - node: AST, -) -> typing.Dict[FunctionType, typing.List[str]]: - e = FunctionBoundVarsCollector() - e.visit(node) - return e.functions_bound_vars - - class PlutoCompiler(CompilingNodeTransformer): """ Expects a TypedAST and returns UPLC/Pluto like code @@ -170,9 +149,6 @@ def __init__(self, force_three_params=False, validator_function_name="validator" self.validator_function_name = validator_function_name # marked knowledge during compilation self.current_function_typ: typing.List[FunctionType] = [] - self.function_bound_vars: typing.Dict[ - FunctionType, typing.List[str] - ] = defaultdict(list) def visit_sequence(self, node_seq: typing.List[typedstmt]) -> CallAST: def g(s: plt.AST): @@ -223,7 +199,6 @@ def visit_Compare(self, node: TypedCompare) -> plt.AST: def visit_Module(self, node: TypedModule) -> plt.AST: # extract actually read variables by each function - self.function_bound_vars = extract_function_bound_vars(node) if self.validator_function_name is not None: # for validators find main function # TODO can use more sophisiticated procedure here i.e. functions marked by comment @@ -292,7 +267,9 @@ def visit_Module(self, node: TypedModule) -> plt.AST: ] ) self.current_function_typ.append(FunctionType([], InstanceType(AnyType()))) - all_vs = sorted(set(all_vars(node))) + name_load_visitor = NameLoadCollector() + name_load_visitor.visit(node) + all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys())) # write all variables that are ever read # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime) @@ -324,7 +301,9 @@ def visit_Module(self, node: TypedModule) -> plt.AST: "The contract can not always detect if it was passed three or two parameters on-chain." ) else: - all_vs = sorted(set(all_vars(node))) + name_load_visitor = NameLoadCollector() + name_load_visitor.visit(node) + all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys())) body = node.body # write all variables that are ever read @@ -434,12 +413,14 @@ def visit_Call(self, node: TypedCall) -> plt.AST: node.func.typ.typ.argtyps ) ) + bind_self = None else: assert isinstance(node.func.typ, InstanceType) and isinstance( node.func.typ.typ, FunctionType ) func_plt = self.visit(node.func) - bound_vs = self.function_bound_vars[node.func.typ.typ] + bind_self = node.func.typ.typ.bind_self + bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys())) args = [] for a, t in zip(node.args, node.func.typ.typ.argtyps): assert isinstance(t, InstanceType) @@ -457,6 +438,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST: [(f"p{i}", a) for i, a in enumerate(args)], SafeApply( func_plt, + *([plt.Var(bind_self)] if bind_self is not None else []), *[plt.Var(n) for n in bound_vs], *[plt.Delay(OVar(f"p{i}")) for i in range(len(args))], ), @@ -469,7 +451,9 @@ def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST: ret_val = plt.ConstrData(plt.Integer(0), plt.EmptyDataList()) else: ret_val = plt.Unit() - read_vs = self.function_bound_vars[node.typ.typ] + read_vs = sorted(list(node.typ.typ.bound_vars.keys())) + if node.typ.typ.bind_self is not None: + read_vs.insert(0, node.typ.typ.bind_self) self.current_function_typ.append(node.typ.typ) compiled_body = self.visit_sequence(body)(ret_val) self.current_function_typ.pop() diff --git a/opshin/optimize/optimize_remove_deadvars.py b/opshin/optimize/optimize_remove_deadvars.py index e52f3873..b1dfbc36 100644 --- a/opshin/optimize/optimize_remove_deadvars.py +++ b/opshin/optimize/optimize_remove_deadvars.py @@ -6,7 +6,7 @@ from ..util import CompilingNodeVisitor, CompilingNodeTransformer from ..type_inference import INITIAL_SCOPE -from ..typed_ast import TypedAnnAssign +from ..typed_ast import TypedAnnAssign, TypedFunctionDef, TypedClassDef, TypedName """ Removes assignments to variables that are never read @@ -19,18 +19,22 @@ class NameLoadCollector(CompilingNodeVisitor): def __init__(self): self.loaded = defaultdict(int) - def visit_Name(self, node: Name) -> None: + def visit_Name(self, node: TypedName) -> None: if isinstance(node.ctx, Load): self.loaded[node.id] += 1 - def visit_ClassDef(self, node: ClassDef): + def visit_ClassDef(self, node: TypedClassDef): # ignore the content (i.e. attribute names) of class definitions pass - def visit_FunctionDef(self, node: FunctionDef): + def visit_FunctionDef(self, node: TypedFunctionDef): # ignore the type hints of function arguments for s in node.body: self.visit(s) + for v in node.typ.typ.bound_vars.keys(): + self.loaded[v] += 1 + if node.typ.typ.bind_self is not None: + self.loaded[node.typ.typ.bind_self] += 1 class SafeOperationVisitor(CompilingNodeVisitor): diff --git a/opshin/rewrite/rewrite_remove_type_stuff.py b/opshin/rewrite/rewrite_remove_type_stuff.py index d3050b3e..508587e6 100644 --- a/opshin/rewrite/rewrite_remove_type_stuff.py +++ b/opshin/rewrite/rewrite_remove_type_stuff.py @@ -1,15 +1,23 @@ from typing import Optional -from ..typed_ast import TypedAssign, ClassType +from ..typed_ast import ( + TypedAssign, + ClassType, + InstanceType, + PolymorphicFunctionType, + TypeInferenceError, +) from ..util import CompilingNodeTransformer """ -Remove class reassignments without constructors +Remove class reassignments without constructors and polymorphic function reassignments + +Both of these are only present during the type inference and are discarded or generated in-place during compilation. """ class RewriteRemoveTypeStuff(CompilingNodeTransformer): - step = "Removing class re-assignments" + step = "Removing class and polymorphic function re-assignments" def visit_Assign(self, node: TypedAssign) -> Optional[TypedAssign]: assert ( @@ -17,10 +25,17 @@ def visit_Assign(self, node: TypedAssign) -> Optional[TypedAssign]: ), "Assignments to more than one variable not supported yet" try: if isinstance(node.value.typ, ClassType): - node.value.typ.constr() - except NotImplementedError: - # The type does not have a constructor and the constructor can hence not be passed on - return None + try: + typ = node.value.typ.constr_type() + except TypeInferenceError: + # no constr_type is also fine + return None + else: + typ = node.value.typ + if isinstance(typ, InstanceType) and isinstance( + typ.typ, PolymorphicFunctionType + ): + return None except AttributeError: # untyped attributes are fine too pass diff --git a/opshin/tests/test_misc.py b/opshin/tests/test_misc.py index b8119761..3fa37b4b 100644 --- a/opshin/tests/test_misc.py +++ b/opshin/tests/test_misc.py @@ -2776,3 +2776,112 @@ def validator(x: Union[A, B, C]) -> int: return fun(x) """ builder._compile(source_code) + + @unittest.expectedFailure + def test_merge_function_same_capture_different_type(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +@dataclass() +class B(PlutusData): + CONSTR_ID = 1 + bar: int + +def validator(x: bool) -> int: + if x: + y = A(0) + def foo() -> int: + return y.foo + else: + y = B(0) + def foo() -> int: + return y.bar + y = A(0) + return foo() + """ + builder._compile(source_code) + + def test_merge_function_same_capture_same_type(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +@dataclass() +class B(PlutusData): + CONSTR_ID = 1 + bar: int + +def validator(x: bool) -> int: + if x: + y = A(0) + def foo() -> int: + print(2) + return y.foo + else: + y = A(0) if x else B(0) + def foo() -> int: + print(y) + return 2 + y = A(0) + return foo() + """ + res_true = eval_uplc_value(source_code, 1) + res_false = eval_uplc_value(source_code, 0) + self.assertEqual(res_true, 0) + self.assertEqual(res_false, 2) + + def test_merge_print(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +def validator(x: bool) -> None: + if x: + a = print + else: + b = print + a = b + return a(x) + """ + res_true = eval_uplc(source_code, 1) + res_false = eval_uplc(source_code, 0) + + def test_print_reassign(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +def validator(x: bool) -> None: + a = print + return a(x) + """ + res_true = eval_uplc(source_code, 1) + res_false = eval_uplc(source_code, 0) + + def test_str_constr_reassign(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +def validator(x: bool) -> str: + a = str + return a(x) + """ + res_true = eval_uplc_value(source_code, 1) + res_false = eval_uplc_value(source_code, 0) diff --git a/opshin/tests/test_types.py b/opshin/tests/test_types.py index 7d1284ca..692a2f7c 100644 --- a/opshin/tests/test_types.py +++ b/opshin/tests/test_types.py @@ -1,6 +1,34 @@ from ..types import * +def test_record_type_order(): + A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)])) + B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)])) + C = RecordType(Record("C", "C", 2, [("baz", IntegerInstanceType)])) + a = A + b = B + c = C + + assert a >= a + assert not a >= b + assert not b >= a + assert not a >= c + assert not c >= a + assert not b >= c + assert not c >= b + + A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)])) + B = RecordType( + Record( + "B", "B", 0, [("foo", IntegerInstanceType), ("bar", IntegerInstanceType)] + ) + ) + C = RecordType(Record("C", "C", 0, [("foo", InstanceType(AnyType()))])) + assert not A >= B + assert not C >= B + assert C >= A + + def test_union_type_order(): A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)])) B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)])) @@ -8,9 +36,14 @@ def test_union_type_order(): abc = UnionType([A, B, C]) ab = UnionType([A, B]) a = A + c = C assert a >= a assert ab >= a assert not a >= ab assert abc >= ab assert not ab >= abc + assert not c >= a + assert not a >= c + assert abc >= c + assert not ab >= c diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 0fb309c1..72cda84c 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -574,7 +574,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: functyp = FunctionType( frozenlist([t.typ for t in tfd.args.args]), InstanceType(self.type_from_annotation(tfd.returns)), - externally_bound_vars(node), + bound_vars={v: self.variable_type(v) for v in externally_bound_vars(node)}, + bind_self=node.name if node.name in read_vars(node) else None, ) tfd.typ = InstanceType(functyp) if wraps_builtin: diff --git a/opshin/types.py b/opshin/types.py index 4af4cd8d..71b38afd 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -1354,18 +1354,24 @@ def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: class FunctionType(ClassType): argtyps: typing.List[Type] rettyp: Type - bound_vars: typing.List[str] = dataclasses.field(default_factory=frozenlist) + # A map from external variable names to their types when the function is defined + bound_vars: typing.Dict[str, Type] = dataclasses.field(default_factory=frozendict) + # Whether and under which name the function binds itself + # The type of this variable is "self" + bind_self: typing.Optional[str] = None def __post_init__(self): object.__setattr__(self, "argtyps", frozenlist(self.argtyps)) - object.__setattr__(self, "bound_vars", frozenlist(self.bound_vars)) + object.__setattr__(self, "bound_vars", frozendict(self.bound_vars)) def __ge__(self, other): return ( isinstance(other, FunctionType) and len(self.argtyps) == len(other.argtyps) and all(a >= oa for a, oa in zip(self.argtyps, other.argtyps)) - and self.bound_vars == other.bound_vars + and self.bound_vars.keys() == other.bound_vars.keys() + and all(sbv >= other.bound_vars[k] for k, sbv in self.bound_vars.items()) + and self.bind_self == other.bind_self and other.rettyp >= self.rettyp ) @@ -2485,6 +2491,12 @@ class PolymorphicFunctionType(ClassType): polymorphic_function: PolymorphicFunction + def __ge__(self, other): + return ( + isinstance(other, PolymorphicFunctionType) + and self.polymorphic_function == other.polymorphic_function + ) + @dataclass(frozen=True, unsafe_hash=True) class PolymorphicFunctionInstanceType(InstanceType): diff --git a/opshin/util.py b/opshin/util.py index 37bcd82e..218dc45c 100644 --- a/opshin/util.py +++ b/opshin/util.py @@ -228,7 +228,6 @@ def visit_AnnAssign(self, node) -> None: def visit_FunctionDef(self, node) -> None: # ignore annotations of paramters and return - self.visit(node.args) for b in node.body: self.visit(b) @@ -240,11 +239,6 @@ def visit_ClassDef(self, node: ClassDef): # ignore the content (i.e. attribute names) of class definitions pass - def visit_FunctionDef(self, node: FunctionDef): - # ignore the type hints of function arguments - for s in node.body: - self.visit(s) - def read_vars(node): """ @@ -261,7 +255,7 @@ def all_vars(node): def externally_bound_vars(node: FunctionDef): """A superset of the variables bound from an outer scope""" - return sorted(set(read_vars(node)) - (set(written_vars(node)) - {node.name})) + return sorted(set(read_vars(node)) - set(written_vars(node)) - {"isinstance"}) def opshin_name_scheme_compatible_varname(n: str):