From c9431c011ae3823666e9989d11608f5601d8fc4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Thu, 1 Feb 2024 21:39:06 +0100 Subject: [PATCH 01/10] More test cases --- opshin/tests/test_types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/opshin/tests/test_types.py b/opshin/tests/test_types.py index 7d1284ca..21d955b7 100644 --- a/opshin/tests/test_types.py +++ b/opshin/tests/test_types.py @@ -8,9 +8,14 @@ def test_union_type_order(): abc = UnionType([A, B, C]) ab = UnionType([A, B]) a = A + c = C assert a >= a assert ab >= a assert not a >= ab assert abc >= ab assert not ab >= abc + assert not c >= a + assert not a >= c + assert abc >= c + assert not ab >= c From 887417dc8ebba1f2672e64652b4bd49424942cd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 10:26:04 +0100 Subject: [PATCH 02/10] Discover a case where the program allows invalid types This happens because the name of captured variables is the same but the type is not. The type is merge after the branch but not checked to be compatible with the function type. A simple fix should disallow merging the functions because they capture variables of different types. --- opshin/tests/test_misc.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/opshin/tests/test_misc.py b/opshin/tests/test_misc.py index b8119761..ac8c2464 100644 --- a/opshin/tests/test_misc.py +++ b/opshin/tests/test_misc.py @@ -2776,3 +2776,34 @@ def validator(x: Union[A, B, C]) -> int: return fun(x) """ builder._compile(source_code) + + @unittest.expectedFailure + def test_merge_function_same_capture_different_type(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +@dataclass() +class B(PlutusData): + CONSTR_ID = 1 + bar: int + +def validator(x: bool) -> int: + if x: + y = A(0) + def foo() -> int: + return y.foo + else: + y = B(0) + def foo() -> int: + return y.bar + y = A(0) + return foo() + """ + builder._compile(source_code) From 44b1f70fddcb98118cb0cc6cbf60a40486aa19e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 10:35:32 +0100 Subject: [PATCH 03/10] Fix the issue in the type system --- opshin/type_inference.py | 2 +- opshin/types.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 0fb309c1..067058b1 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -574,7 +574,7 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: functyp = FunctionType( frozenlist([t.typ for t in tfd.args.args]), InstanceType(self.type_from_annotation(tfd.returns)), - externally_bound_vars(node), + {v: self.variable_type(v) for v in externally_bound_vars(node)}, ) tfd.typ = InstanceType(functyp) if wraps_builtin: diff --git a/opshin/types.py b/opshin/types.py index 4af4cd8d..5e02889b 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -1354,18 +1354,20 @@ def _unop_fun(self, unop: unaryop) -> Callable[[plt.AST], plt.AST]: class FunctionType(ClassType): argtyps: typing.List[Type] rettyp: Type - bound_vars: typing.List[str] = dataclasses.field(default_factory=frozenlist) + # A map from external variable names to their types when the function is defined + bound_vars: typing.Dict[str, Type] = dataclasses.field(default_factory=frozendict) def __post_init__(self): object.__setattr__(self, "argtyps", frozenlist(self.argtyps)) - object.__setattr__(self, "bound_vars", frozenlist(self.bound_vars)) + object.__setattr__(self, "bound_vars", frozendict(self.bound_vars)) def __ge__(self, other): return ( isinstance(other, FunctionType) and len(self.argtyps) == len(other.argtyps) and all(a >= oa for a, oa in zip(self.argtyps, other.argtyps)) - and self.bound_vars == other.bound_vars + and self.bound_vars.keys() == other.bound_vars.keys() + and all(sbv >= other.bound_vars[k] for k, sbv in self.bound_vars.items()) and other.rettyp >= self.rettyp ) From 39b81b375e2c95635283d83dffc8f68c3659e161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 10:56:37 +0100 Subject: [PATCH 04/10] Add more failing test cases --- opshin/tests/test_misc.py | 65 +++++++++++++++++++++++++++++++++++++++ opshin/util.py | 4 ++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/opshin/tests/test_misc.py b/opshin/tests/test_misc.py index ac8c2464..728732ab 100644 --- a/opshin/tests/test_misc.py +++ b/opshin/tests/test_misc.py @@ -2807,3 +2807,68 @@ def foo() -> int: return foo() """ builder._compile(source_code) + + def test_merge_function_same_capture_same_type(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + foo: int + +@dataclass() +class B(PlutusData): + CONSTR_ID = 1 + bar: int + +def validator(x: bool) -> int: + if x: + y = A(0) + def foo() -> int: + print(2) + return y.foo + else: + y = A(0) if x else B(0) + def foo() -> int: + print(y.foo if isinstance(y, A) else y.bar) + return 2 + y = A(0) + return foo() + """ + res_true = eval_uplc_value(source_code, 1) + res_false = eval_uplc_value(source_code, 0) + self.assertEqual(res_true, 0) + self.assertEqual(res_false, 2) + + def test_merge_print(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +def validator(x: bool) -> None: + if x: + a = print + else: + b = print + a = b + return a(x) + """ + res_true = eval_uplc_value(source_code, 1) + res_false = eval_uplc_value(source_code, 0) + + def test_print_reassign(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +def validator(x: bool) -> None: + a = print + return a(x) + """ + res_true = eval_uplc_value(source_code, 1) + res_false = eval_uplc_value(source_code, 0) diff --git a/opshin/util.py b/opshin/util.py index 37bcd82e..ad03c8cc 100644 --- a/opshin/util.py +++ b/opshin/util.py @@ -261,7 +261,9 @@ def all_vars(node): def externally_bound_vars(node: FunctionDef): """A superset of the variables bound from an outer scope""" - return sorted(set(read_vars(node)) - (set(written_vars(node)) - {node.name})) + return sorted( + set(read_vars(node)) - (set(written_vars(node)) - {node.name}) - {"isinstance"} + ) def opshin_name_scheme_compatible_varname(n: str): From 399e788902a3fb4e31703528ada4637c1de61618 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 11:28:01 +0100 Subject: [PATCH 05/10] Also fix the polymorphic type reassignment error This is a particularly nasty bug because it does not surface until running the code. It may cause strictly less validation than expected but never more (i.e. funds may get locked but not be stolen) --- opshin/rewrite/rewrite_remove_type_stuff.py | 29 ++++++++++++++++----- opshin/tests/test_misc.py | 19 +++++++++++--- opshin/types.py | 6 +++++ 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/opshin/rewrite/rewrite_remove_type_stuff.py b/opshin/rewrite/rewrite_remove_type_stuff.py index d3050b3e..508587e6 100644 --- a/opshin/rewrite/rewrite_remove_type_stuff.py +++ b/opshin/rewrite/rewrite_remove_type_stuff.py @@ -1,15 +1,23 @@ from typing import Optional -from ..typed_ast import TypedAssign, ClassType +from ..typed_ast import ( + TypedAssign, + ClassType, + InstanceType, + PolymorphicFunctionType, + TypeInferenceError, +) from ..util import CompilingNodeTransformer """ -Remove class reassignments without constructors +Remove class reassignments without constructors and polymorphic function reassignments + +Both of these are only present during the type inference and are discarded or generated in-place during compilation. """ class RewriteRemoveTypeStuff(CompilingNodeTransformer): - step = "Removing class re-assignments" + step = "Removing class and polymorphic function re-assignments" def visit_Assign(self, node: TypedAssign) -> Optional[TypedAssign]: assert ( @@ -17,10 +25,17 @@ def visit_Assign(self, node: TypedAssign) -> Optional[TypedAssign]: ), "Assignments to more than one variable not supported yet" try: if isinstance(node.value.typ, ClassType): - node.value.typ.constr() - except NotImplementedError: - # The type does not have a constructor and the constructor can hence not be passed on - return None + try: + typ = node.value.typ.constr_type() + except TypeInferenceError: + # no constr_type is also fine + return None + else: + typ = node.value.typ + if isinstance(typ, InstanceType) and isinstance( + typ.typ, PolymorphicFunctionType + ): + return None except AttributeError: # untyped attributes are fine too pass diff --git a/opshin/tests/test_misc.py b/opshin/tests/test_misc.py index 728732ab..3fa37b4b 100644 --- a/opshin/tests/test_misc.py +++ b/opshin/tests/test_misc.py @@ -2833,7 +2833,7 @@ def foo() -> int: else: y = A(0) if x else B(0) def foo() -> int: - print(y.foo if isinstance(y, A) else y.bar) + print(y) return 2 y = A(0) return foo() @@ -2857,8 +2857,8 @@ def validator(x: bool) -> None: a = b return a(x) """ - res_true = eval_uplc_value(source_code, 1) - res_false = eval_uplc_value(source_code, 0) + res_true = eval_uplc(source_code, 1) + res_false = eval_uplc(source_code, 0) def test_print_reassign(self): source_code = """ @@ -2868,6 +2868,19 @@ def test_print_reassign(self): def validator(x: bool) -> None: a = print + return a(x) + """ + res_true = eval_uplc(source_code, 1) + res_false = eval_uplc(source_code, 0) + + def test_str_constr_reassign(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +def validator(x: bool) -> str: + a = str return a(x) """ res_true = eval_uplc_value(source_code, 1) diff --git a/opshin/types.py b/opshin/types.py index 5e02889b..ebc3d84e 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -2487,6 +2487,12 @@ class PolymorphicFunctionType(ClassType): polymorphic_function: PolymorphicFunction + def __ge__(self, other): + return ( + isinstance(other, PolymorphicFunctionType) + and self.polymorphic_function == other.polymorphic_function + ) + @dataclass(frozen=True, unsafe_hash=True) class PolymorphicFunctionInstanceType(InstanceType): From d337eff08074f64dbf6ad41c2c1bdcb3a68c5196 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 11:51:07 +0100 Subject: [PATCH 06/10] Attempted fix for recursive binding --- opshin/compiler.py | 5 +++++ opshin/type_inference.py | 3 ++- opshin/types.py | 4 ++++ opshin/util.py | 4 +--- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index 3fd174b9..0d4a9be8 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -434,11 +434,13 @@ def visit_Call(self, node: TypedCall) -> plt.AST: node.func.typ.typ.argtyps ) ) + bind_self = None else: assert isinstance(node.func.typ, InstanceType) and isinstance( node.func.typ.typ, FunctionType ) func_plt = self.visit(node.func) + bind_self = node.func.typ.typ.bind_self bound_vs = self.function_bound_vars[node.func.typ.typ] args = [] for a, t in zip(node.args, node.func.typ.typ.argtyps): @@ -457,6 +459,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST: [(f"p{i}", a) for i, a in enumerate(args)], SafeApply( func_plt, + *([plt.Var(bind_self)] if bind_self is not None else []), *[plt.Var(n) for n in bound_vs], *[plt.Delay(OVar(f"p{i}")) for i in range(len(args))], ), @@ -470,6 +473,8 @@ def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST: else: ret_val = plt.Unit() read_vs = self.function_bound_vars[node.typ.typ] + if node.typ.typ.bind_self is not None: + read_vs.insert(0, node.typ.typ.bind_self) self.current_function_typ.append(node.typ.typ) compiled_body = self.visit_sequence(body)(ret_val) self.current_function_typ.pop() diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 067058b1..58a999c2 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -574,7 +574,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: functyp = FunctionType( frozenlist([t.typ for t in tfd.args.args]), InstanceType(self.type_from_annotation(tfd.returns)), - {v: self.variable_type(v) for v in externally_bound_vars(node)}, + bound_vars={v: self.variable_type(v) for v in externally_bound_vars(node)}, + bind_self=node.name if node.name in read_vars(node.body) else None, ) tfd.typ = InstanceType(functyp) if wraps_builtin: diff --git a/opshin/types.py b/opshin/types.py index ebc3d84e..71b38afd 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -1356,6 +1356,9 @@ class FunctionType(ClassType): rettyp: Type # A map from external variable names to their types when the function is defined bound_vars: typing.Dict[str, Type] = dataclasses.field(default_factory=frozendict) + # Whether and under which name the function binds itself + # The type of this variable is "self" + bind_self: typing.Optional[str] = None def __post_init__(self): object.__setattr__(self, "argtyps", frozenlist(self.argtyps)) @@ -1368,6 +1371,7 @@ def __ge__(self, other): and all(a >= oa for a, oa in zip(self.argtyps, other.argtyps)) and self.bound_vars.keys() == other.bound_vars.keys() and all(sbv >= other.bound_vars[k] for k, sbv in self.bound_vars.items()) + and self.bind_self == other.bind_self and other.rettyp >= self.rettyp ) diff --git a/opshin/util.py b/opshin/util.py index ad03c8cc..c22b00e4 100644 --- a/opshin/util.py +++ b/opshin/util.py @@ -261,9 +261,7 @@ def all_vars(node): def externally_bound_vars(node: FunctionDef): """A superset of the variables bound from an outer scope""" - return sorted( - set(read_vars(node)) - (set(written_vars(node)) - {node.name}) - {"isinstance"} - ) + return sorted(set(read_vars(node)) - set(written_vars(node)) - {"isinstance"}) def opshin_name_scheme_compatible_varname(n: str): From 88cef996259437a5a40a6015544d5c99dbe91031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 12:07:23 +0100 Subject: [PATCH 07/10] Fixing around things without reaching no failed tests - need a seperate deep session --- opshin/compiler.py | 4 ++-- opshin/type_inference.py | 2 +- opshin/util.py | 6 ------ 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index 0d4a9be8..9bcdce07 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -441,7 +441,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST: ) func_plt = self.visit(node.func) bind_self = node.func.typ.typ.bind_self - bound_vs = self.function_bound_vars[node.func.typ.typ] + bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys())) args = [] for a, t in zip(node.args, node.func.typ.typ.argtyps): assert isinstance(t, InstanceType) @@ -472,7 +472,7 @@ def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST: ret_val = plt.ConstrData(plt.Integer(0), plt.EmptyDataList()) else: ret_val = plt.Unit() - read_vs = self.function_bound_vars[node.typ.typ] + read_vs = sorted(list(node.typ.typ.bound_vars.keys())) if node.typ.typ.bind_self is not None: read_vs.insert(0, node.typ.typ.bind_self) self.current_function_typ.append(node.typ.typ) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 58a999c2..72cda84c 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -575,7 +575,7 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: frozenlist([t.typ for t in tfd.args.args]), InstanceType(self.type_from_annotation(tfd.returns)), bound_vars={v: self.variable_type(v) for v in externally_bound_vars(node)}, - bind_self=node.name if node.name in read_vars(node.body) else None, + bind_self=node.name if node.name in read_vars(node) else None, ) tfd.typ = InstanceType(functyp) if wraps_builtin: diff --git a/opshin/util.py b/opshin/util.py index c22b00e4..218dc45c 100644 --- a/opshin/util.py +++ b/opshin/util.py @@ -228,7 +228,6 @@ def visit_AnnAssign(self, node) -> None: def visit_FunctionDef(self, node) -> None: # ignore annotations of paramters and return - self.visit(node.args) for b in node.body: self.visit(b) @@ -240,11 +239,6 @@ def visit_ClassDef(self, node: ClassDef): # ignore the content (i.e. attribute names) of class definitions pass - def visit_FunctionDef(self, node: FunctionDef): - # ignore the type hints of function arguments - for s in node.body: - self.visit(s) - def read_vars(node): """ From 37208717a03bce2f403d56ef3822e14ed9c24fbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 22:34:24 +0100 Subject: [PATCH 08/10] Fix classes being optimized away as variables --- opshin/optimize/optimize_remove_deadvars.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/opshin/optimize/optimize_remove_deadvars.py b/opshin/optimize/optimize_remove_deadvars.py index e52f3873..52b555d1 100644 --- a/opshin/optimize/optimize_remove_deadvars.py +++ b/opshin/optimize/optimize_remove_deadvars.py @@ -31,6 +31,10 @@ def visit_FunctionDef(self, node: FunctionDef): # ignore the type hints of function arguments for s in node.body: self.visit(s) + for v in node.typ.typ.bound_vars.keys(): + self.loaded[v] += 1 + if node.typ.typ.bind_self is not None: + self.loaded[node.typ.typ.bind_self] += 1 class SafeOperationVisitor(CompilingNodeVisitor): From 56f049119feb6e1822e968b12c938e88a48c2031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 23:22:23 +0100 Subject: [PATCH 09/10] Fixes for unassigned variables --- opshin/compiler.py | 35 +++++---------------- opshin/optimize/optimize_remove_deadvars.py | 8 ++--- 2 files changed, 11 insertions(+), 32 deletions(-) diff --git a/opshin/compiler.py b/opshin/compiler.py index 9bcdce07..928ba6ab 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -25,7 +25,7 @@ from .rewrite.rewrite_subscript38 import RewriteSubscript38 from .rewrite.rewrite_tuple_assign import RewriteTupleAssign from .optimize.optimize_remove_pass import OptimizeRemovePass -from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars +from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars, NameLoadCollector from .type_inference import * from .util import ( CompilingNodeTransformer, @@ -136,27 +136,6 @@ def wrap_validator_double_function(x: plt.AST, pass_through: int = 0): CallAST = typing.Callable[[plt.AST], plt.AST] -class FunctionBoundVarsCollector(NodeVisitor): - def __init__(self): - self.functions_bound_vars: typing.Dict[ - FunctionType, typing.List[str] - ] = defaultdict(list) - - def visit_FunctionDef(self, node: FunctionDef) -> None: - self.functions_bound_vars[node.typ.typ] = sorted( - set(self.functions_bound_vars[node.typ.typ] + externally_bound_vars(node)) - ) - self.generic_visit(node) - - -def extract_function_bound_vars( - node: AST, -) -> typing.Dict[FunctionType, typing.List[str]]: - e = FunctionBoundVarsCollector() - e.visit(node) - return e.functions_bound_vars - - class PlutoCompiler(CompilingNodeTransformer): """ Expects a TypedAST and returns UPLC/Pluto like code @@ -170,9 +149,6 @@ def __init__(self, force_three_params=False, validator_function_name="validator" self.validator_function_name = validator_function_name # marked knowledge during compilation self.current_function_typ: typing.List[FunctionType] = [] - self.function_bound_vars: typing.Dict[ - FunctionType, typing.List[str] - ] = defaultdict(list) def visit_sequence(self, node_seq: typing.List[typedstmt]) -> CallAST: def g(s: plt.AST): @@ -223,7 +199,6 @@ def visit_Compare(self, node: TypedCompare) -> plt.AST: def visit_Module(self, node: TypedModule) -> plt.AST: # extract actually read variables by each function - self.function_bound_vars = extract_function_bound_vars(node) if self.validator_function_name is not None: # for validators find main function # TODO can use more sophisiticated procedure here i.e. functions marked by comment @@ -292,7 +267,9 @@ def visit_Module(self, node: TypedModule) -> plt.AST: ] ) self.current_function_typ.append(FunctionType([], InstanceType(AnyType()))) - all_vs = sorted(set(all_vars(node))) + name_load_visitor = NameLoadCollector() + name_load_visitor.visit(node) + all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys())) # write all variables that are ever read # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime) @@ -324,7 +301,9 @@ def visit_Module(self, node: TypedModule) -> plt.AST: "The contract can not always detect if it was passed three or two parameters on-chain." ) else: - all_vs = sorted(set(all_vars(node))) + name_load_visitor = NameLoadCollector() + name_load_visitor.visit(node) + all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys())) body = node.body # write all variables that are ever read diff --git a/opshin/optimize/optimize_remove_deadvars.py b/opshin/optimize/optimize_remove_deadvars.py index 52b555d1..b1dfbc36 100644 --- a/opshin/optimize/optimize_remove_deadvars.py +++ b/opshin/optimize/optimize_remove_deadvars.py @@ -6,7 +6,7 @@ from ..util import CompilingNodeVisitor, CompilingNodeTransformer from ..type_inference import INITIAL_SCOPE -from ..typed_ast import TypedAnnAssign +from ..typed_ast import TypedAnnAssign, TypedFunctionDef, TypedClassDef, TypedName """ Removes assignments to variables that are never read @@ -19,15 +19,15 @@ class NameLoadCollector(CompilingNodeVisitor): def __init__(self): self.loaded = defaultdict(int) - def visit_Name(self, node: Name) -> None: + def visit_Name(self, node: TypedName) -> None: if isinstance(node.ctx, Load): self.loaded[node.id] += 1 - def visit_ClassDef(self, node: ClassDef): + def visit_ClassDef(self, node: TypedClassDef): # ignore the content (i.e. attribute names) of class definitions pass - def visit_FunctionDef(self, node: FunctionDef): + def visit_FunctionDef(self, node: TypedFunctionDef): # ignore the type hints of function arguments for s in node.body: self.visit(s) From e0a4be55fa6bcecba034142f7f73d60cf13c7ffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 2 Feb 2024 23:35:12 +0100 Subject: [PATCH 10/10] Add type test --- opshin/tests/test_types.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/opshin/tests/test_types.py b/opshin/tests/test_types.py index 21d955b7..692a2f7c 100644 --- a/opshin/tests/test_types.py +++ b/opshin/tests/test_types.py @@ -1,6 +1,34 @@ from ..types import * +def test_record_type_order(): + A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)])) + B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)])) + C = RecordType(Record("C", "C", 2, [("baz", IntegerInstanceType)])) + a = A + b = B + c = C + + assert a >= a + assert not a >= b + assert not b >= a + assert not a >= c + assert not c >= a + assert not b >= c + assert not c >= b + + A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)])) + B = RecordType( + Record( + "B", "B", 0, [("foo", IntegerInstanceType), ("bar", IntegerInstanceType)] + ) + ) + C = RecordType(Record("C", "C", 0, [("foo", InstanceType(AnyType()))])) + assert not A >= B + assert not C >= B + assert C >= A + + def test_union_type_order(): A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)])) B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)]))