Skip to content

Commit

Permalink
Fix casting to Builtin
Browse files Browse the repository at this point in the history
  • Loading branch information
SCMusson committed Sep 3, 2024
1 parent d1f7791 commit 8b69eec
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
2 changes: 2 additions & 0 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def visit_Name(self, node: TypedName) -> plt.AST:
if isinstance(node.typ, ClassType):
# if this is not an instance but a class, call the constructor
return node.typ.constr()
if hasattr(node, "is_wrapped"):

This comment has been minimized.

Copy link
@nielstron

nielstron Sep 3, 2024

Contributor

I know this is only set iff it's set to True but could you use getattr here please 😅 a bit unnerving, if someone sets this actively to false it will exactly not behave as expected

This comment has been minimized.

Copy link
@SCMusson

SCMusson Sep 4, 2024

Author Contributor

Done, see last commit

This comment has been minimized.

Copy link
@nielstron

nielstron Sep 4, 2024

Contributor

Thank you!

return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id)))
return plt.Force(plt.Var(node.id))

def visit_Expr(self, node: TypedExpr) -> CallAST:
Expand Down
3 changes: 2 additions & 1 deletion opshin/std/fractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __add__(self, other: Union[Self, int]) -> Self:
)
else:
return Fraction(
self.numerator + (other * self.denominator), self.denominator
(self.numerator) + (other * self.denominator),
self.denominator,
)

def __neg__(
Expand Down
19 changes: 19 additions & 0 deletions opshin/tests/test_Unions.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,22 @@ def validator(x: Union[int, bytes, bool]) -> int:
with self.assertRaises(CompilerError) as ce:
res = eval_uplc_value(source_code, True)
self.assertIsInstance(ce.exception.orig_err, AssertionError)

@hypothesis.given(st.sampled_from([14, b""]))
def test_Union_builtin_cast(self, x):
source_code = """
from dataclasses import dataclass
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
def validator(x: Union[int,bytes]) -> int:
k: int = 0
if isinstance(x, int):
k = x+5
elif isinstance(x, bytes):
k = 7
return k
"""
res = eval_uplc_value(source_code, x)
real = x + 5 if isinstance(x, int) else 7
self.assertEqual(res, real)
12 changes: 9 additions & 3 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class AggressiveTypeInferencer(CompilingNodeTransformer):
def __init__(self, allow_isinstance_anything=False):
self.allow_isinstance_anything = allow_isinstance_anything
self.FUNCTION_ARGUMENT_REGISTRY = {}
self.wrapped = []

# A stack of dictionaries for storing scoped knowledge of variable types
self.scopes = [INITIAL_SCOPE]
Expand Down Expand Up @@ -625,15 +626,19 @@ def visit_If(self, node: If) -> TypedIf:
).visit(typed_if.test)
# for the time of the branch, these types are cast
initial_scope = copy(self.scopes[-1])
self.implement_typechecks(typchecks)
wrapped = self.implement_typechecks(typchecks)
self.wrapped.extend(wrapped.keys())
typed_if.body = self.visit_sequence(node.body)
self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()]

# save resulting types
final_scope_body = copy(self.scopes[-1])
# reverse typechecks and remove typing of one branch
self.scopes[-1] = initial_scope
# for the time of the else branch, the inverse types hold
self.implement_typechecks(inv_typchecks)
wrapped = self.implement_typechecks(inv_typchecks)
typed_if.orelse = self.visit_sequence(node.orelse)
self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()]
final_scope_else = self.scopes[-1]
# unify the resulting branch scopes
self.scopes[-1] = merge_scope(final_scope_body, final_scope_else)
Expand Down Expand Up @@ -702,6 +707,8 @@ def visit_Name(self, node: Name) -> TypedName:
else:
# Make sure that the rhs of an assign is evaluated first
tn.typ = self.variable_type(node.id)
if node.id in self.wrapped:
tn.is_wrapped = True
return tn

def visit_keyword(self, node: keyword) -> Typedkeyword:
Expand Down Expand Up @@ -864,7 +871,6 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript:
"Dict",
"List",
]:

ts.value = ts.typ = self.type_from_annotation(ts)
return ts

Expand Down

0 comments on commit 8b69eec

Please sign in to comment.