Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/merging functions with type capturing #328

Merged
merged 10 commits into from
Feb 2, 2024
Prev Previous commit
Next Next commit
Attempted fix for recursive binding
nielstron committed Feb 2, 2024
commit d337eff08074f64dbf6ad41c2c1bdcb3a68c5196
5 changes: 5 additions & 0 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
@@ -434,11 +434,13 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
node.func.typ.typ.argtyps
)
)
bind_self = None
else:
assert isinstance(node.func.typ, InstanceType) and isinstance(
node.func.typ.typ, FunctionType
)
func_plt = self.visit(node.func)
bind_self = node.func.typ.typ.bind_self
bound_vs = self.function_bound_vars[node.func.typ.typ]
args = []
for a, t in zip(node.args, node.func.typ.typ.argtyps):
@@ -457,6 +459,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
[(f"p{i}", a) for i, a in enumerate(args)],
SafeApply(
func_plt,
*([plt.Var(bind_self)] if bind_self is not None else []),
*[plt.Var(n) for n in bound_vs],
*[plt.Delay(OVar(f"p{i}")) for i in range(len(args))],
),
@@ -470,6 +473,8 @@ def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST:
else:
ret_val = plt.Unit()
read_vs = self.function_bound_vars[node.typ.typ]
if node.typ.typ.bind_self is not None:
read_vs.insert(0, node.typ.typ.bind_self)
self.current_function_typ.append(node.typ.typ)
compiled_body = self.visit_sequence(body)(ret_val)
self.current_function_typ.pop()
3 changes: 2 additions & 1 deletion opshin/type_inference.py
Original file line number Diff line number Diff line change
@@ -574,7 +574,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
functyp = FunctionType(
frozenlist([t.typ for t in tfd.args.args]),
InstanceType(self.type_from_annotation(tfd.returns)),
{v: self.variable_type(v) for v in externally_bound_vars(node)},
bound_vars={v: self.variable_type(v) for v in externally_bound_vars(node)},
bind_self=node.name if node.name in read_vars(node.body) else None,
)
tfd.typ = InstanceType(functyp)
if wraps_builtin:
4 changes: 4 additions & 0 deletions opshin/types.py
Original file line number Diff line number Diff line change
@@ -1356,6 +1356,9 @@ class FunctionType(ClassType):
rettyp: Type
# A map from external variable names to their types when the function is defined
bound_vars: typing.Dict[str, Type] = dataclasses.field(default_factory=frozendict)
# Whether and under which name the function binds itself
# The type of this variable is "self"
bind_self: typing.Optional[str] = None

def __post_init__(self):
object.__setattr__(self, "argtyps", frozenlist(self.argtyps))
@@ -1368,6 +1371,7 @@ def __ge__(self, other):
and all(a >= oa for a, oa in zip(self.argtyps, other.argtyps))
and self.bound_vars.keys() == other.bound_vars.keys()
and all(sbv >= other.bound_vars[k] for k, sbv in self.bound_vars.items())
and self.bind_self == other.bind_self
and other.rettyp >= self.rettyp
)

4 changes: 1 addition & 3 deletions opshin/util.py
Original file line number Diff line number Diff line change
@@ -261,9 +261,7 @@ def all_vars(node):

def externally_bound_vars(node: FunctionDef):
"""A superset of the variables bound from an outer scope"""
return sorted(
set(read_vars(node)) - (set(written_vars(node)) - {node.name}) - {"isinstance"}
)
return sorted(set(read_vars(node)) - set(written_vars(node)) - {"isinstance"})


def opshin_name_scheme_compatible_varname(n: str):