Skip to content

Commit

Permalink
chore: add docstrings to egraph classes
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Apr 19, 2023
1 parent b4439fd commit 54ae015
Showing 1 changed file with 142 additions and 0 deletions.
142 changes: 142 additions & 0 deletions ibis/common/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,77 @@ def __setattr__(self, name, value):


class Variable(Slotted):
"""A named capture in a pattern.
Parameters
----------
name : str
The name of the variable.
"""

__slots__ = ("name",)

def __repr__(self):
return f"${self.name}"

def substitute(self, egraph, enode, subst):
"""Substitute the variable with the corresponding value in the substitution.
Parameters
----------
egraph : EGraph
The egraph instance.
enode : ENode
The matched enode.
subst : dict
The substitution dictionary.
Returns
-------
value : Any
The substituted value.
"""
return subst[self.name]


# Pattern corresponsds to a selection which is flattened to a join of selections
class Pattern(Slotted):
"""A non-ground term, tree of enodes possibly containing variables.
This class is used to represent a pattern in a query. The pattern is almost
identical to an ENode, except that it can contain variables.
Parameters
----------
head : type
The head or python type of the ENode to match against.
args : tuple
The arguments of the pattern. The arguments can be enodes, patterns,
variables or leaf values.
name : str, optional
The name of the pattern which is used to refer to it in a rewrite rule.
"""

__slots__ = ("head", "args", "name")

# TODO(kszucs): consider to raise if the pattern matches none
def __init__(self, head, args, name=None, conditions=None):
super().__init__(head, tuple(args), name)

def matches_none(self):
"""Evaluate whether the pattern is guaranteed to match nothing.
This can be evaluated before the matching loop starts, so eventually can
be eliminated from the flattened query.
"""
return len(self.head.__argnames__) != len(self.args)

def matches_all(self):
"""Evaluate whether the pattern is guaranteed to match everything.
This can be evaluated before the matching loop starts, so eventually can
be eliminated from the flattened query.
"""
return not self.matches_none() and all(
isinstance(arg, Variable) for arg in self.args
)
Expand All @@ -72,16 +122,50 @@ def __repr__(self):
return f"P{self.head.__name__}({argstring})"

def __rshift__(self, rhs):
"""Syntax sugar to create a rewrite rule."""
return Rewrite(self, rhs)

def __rmatmul__(self, name):
"""Syntax sugar to create a named pattern."""
return self.__class__(self.head, self.args, name)

def to_enode(self):
"""Convert the pattern to an ENode.
None of the arguments can be a pattern or a variable.
Returns
-------
enode : ENode
The pattern converted to an ENode.
"""
# TODO(kszucs): ensure that self is a ground term
return ENode(self.head, self.args)

def flatten(self, var=None, counter=None):
"""Recursively flatten the pattern to a join of selections.
`Pattern(Add, (Pattern(Mul, ($x, 1)), $y))` is turned into a join of
selections by introducing auxilary variables where each selection gets
executed as a dictionary lookup.
In SQL terms this is equivalent to the following query:
SELECT m.0 AS $x, a.1 AS $y FROM Add a JOIN Mul m ON a.0 = m.id WHERE m.1 = 1
Parameters
----------
var : Variable
The variable to assign to the flattened pattern.
counter : Iterator[int]
The counter to generate unique variable names for auxilary variables
connecting the selections.
Yields
------
(var, pattern) : tuple[Variable, Pattern]
The variable and the flattened pattern where the flattened pattern
cannot contain any patterns just variables.
"""
# TODO(kszucs): convert a pattern to a query object instead by flattening it
counter = counter or itertools.count()

Expand All @@ -106,6 +190,22 @@ def flatten(self, var=None, counter=None):
yield (var, Pattern(self.head, args))

def substitute(self, egraph, enode, subst):
"""Substitute the variables in the pattern with the corresponding values.
Parameters
----------
egraph : EGraph
The egraph instance.
enode : ENode
The matched enode.
subst : dict
The substitution dictionary.
Returns
-------
enode : ENode
The substituted pattern which is a ground term aka. an ENode.
"""
args = []
for arg in self.args:
if isinstance(arg, (Variable, Pattern)):
Expand All @@ -115,6 +215,8 @@ def substitute(self, egraph, enode, subst):


class DynamicApplier(Slotted):
"""A dynamic applier which calls a function to compute the result."""

__slots__ = ("func",)

def substitute(self, egraph, enode, subst):
Expand All @@ -124,6 +226,8 @@ def substitute(self, egraph, enode, subst):


class Rewrite(Slotted):
"""A rewrite rule which matches a pattern and applies a pattern or a function."""

__slots__ = ("matcher", "applier")

def __init__(self, matcher, applier):
Expand All @@ -140,17 +244,29 @@ def __repr__(self):


class ENode(Slotted, Node):
"""A ground term which is a node in the EGraph, called ENode.
Parameters
----------
head : type
The type of the Node the ENode represents.
args : tuple
The arguments of the ENode which are either ENodes or leaf values.
"""

__slots__ = ("head", "args")

def __init__(self, head, args):
super().__init__(head, tuple(args))

@property
def __argnames__(self):
"""Implementation for the `ibis.common.graph.Node` protocol."""
return self.head.__argnames__

@property
def __args__(self):
"""Implementation for the `ibis.common.graph.Node` protocol."""
return self.args

def __repr__(self):
Expand All @@ -162,12 +278,16 @@ def __lt__(self, other):

@classmethod
def from_node(cls, node: Any):
"""Convert an `ibis.common.graph.Node` to an `ENode`."""

def mapper(node, _, **kwargs):
return cls(node.__class__, kwargs.values())

return node.map(mapper)[node]

def to_node(self):
"""Convert the ENode back to an `ibis.common.graph.Node`."""

def mapper(node, _, **kwargs):
return node.head(**kwargs)

Expand Down Expand Up @@ -261,6 +381,28 @@ def union(self, node1: Node, node2: Node) -> ENode:
return self._eclasses.union(enode1, enode2)

def _match_args(self, args, patargs):
"""Match the arguments of an enode against a pattern's arguments.
An enode matches a pattern if each of the arguments are:
- both leaf values and equal
- both enodes and in the same eclass
- an enode and a variable, in which case the variable gets bound to the enode
Parameters
----------
args : tuple
The arguments of the enode. Since an enode is a ground term, the arguments
are either enodes or leaf values.
patargs : tuple
The arguments of the pattern. Since a pattern is a flat term (flattened
using auxilliary variables), the arguments are either variables or leaf
values.
Returns
-------
dict[str, Any] :
The mapping of variable names to enodes or leaf values.
"""
subst = {}
for arg, patarg in zip(args, patargs):
if isinstance(patarg, Variable):
Expand Down

0 comments on commit 54ae015

Please sign in to comment.