Skip to content

Commit

Permalink
refactor(common): add sanity checks for creating ENodes and Patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Apr 20, 2023
1 parent 6099af9 commit 553d63e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
39 changes: 19 additions & 20 deletions ibis/common/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ class Variable(Slotted):

__slots__ = ("name",)

def __init__(self, name: str):
if name is None:
raise ValueError("Variable name cannot be None")
super().__init__(name)

def __repr__(self):
return f"${self.name}"

Expand Down Expand Up @@ -322,6 +327,8 @@ class Pattern(Slotted):

# TODO(kszucs): consider to raise if the pattern matches none
def __init__(self, head, args, name=None, conditions=None):
# TODO(kszucs): ensure that args are either patterns, variables or leaf values
assert all(not isinstance(arg, (ENode, Node)) for arg in args)
super().__init__(head, tuple(args), name)

def matches_none(self):
Expand Down Expand Up @@ -354,19 +361,6 @@ def __rmatmul__(self, name):
"""Syntax sugar to create a named pattern."""
return self.__class__(self.head, self.args, name)

def to_enode(self):
"""Convert the pattern to an ENode.
None of the arguments can be a pattern or a variable.
Returns
-------
enode : ENode
The pattern converted to an ENode.
"""
# TODO(kszucs): ensure that self is a ground term
return ENode(self.head, self.args)

def flatten(self, var=None, counter=None):
"""Recursively flatten the pattern to a join of selections.
Expand Down Expand Up @@ -447,7 +441,9 @@ class DynamicApplier(Slotted):
def substitute(self, egraph, enode, subst):
kwargs = {k: v for k, v in subst.items() if isinstance(k, str)}
result = self.func(egraph, enode, **kwargs)
return result.to_enode() if isinstance(result, Pattern) else result
if not isinstance(result, ENode):
raise TypeError(f"applier must return an ENode, got {type(result)}")
return result


class Rewrite(Slotted):
Expand Down Expand Up @@ -482,6 +478,8 @@ class ENode(Slotted, Node):
__slots__ = ("head", "args")

def __init__(self, head, args):
# TODO(kszucs): ensure that it is a ground term, this check should be removed
assert all(not isinstance(arg, (Pattern, Variable)) for arg in args)
super().__init__(head, tuple(args))

@property
Expand Down Expand Up @@ -631,15 +629,16 @@ def _match_args(self, args, patargs):
subst = {}
for arg, patarg in zip(args, patargs):
if isinstance(patarg, Variable):
if patarg.name is None:
pass
elif isinstance(arg, ENode):
if isinstance(arg, ENode):
subst[patarg.name] = self._eclasses.find(arg)
else:
subst[patarg.name] = arg
elif isinstance(arg, ENode):
if self._eclasses.find(arg) != self._eclasses.find(arg):
return None
# TODO(kszucs): this is not needed since patarg is either a variable or a
# leaf value due to the pattern flattening, though we may choose to
# support this in the future
# elif isinstance(arg, ENode):
# if self._eclasses.find(arg) != self._eclasses.find(arg):
# return None
elif patarg != arg:
return None
return subst
Expand Down
2 changes: 1 addition & 1 deletion ibis/common/tests/test_egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def test_egraph_rewrite_to_pattern():

def test_egraph_rewrite_dynamic():
def applier(egraph, match, a, mul, times):
return p.Add(a, a).to_enode()
return ENode(ops.Add, (a, a))

node = (one * 2).op()

Expand Down

0 comments on commit 553d63e

Please sign in to comment.