Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the raising to high level operator within Unify Axis #565

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
afc3e9e
Move the AxesEquationCollector to not use the raising.py operations.
nkoskelo Nov 4, 2024
574ab23
Remove the usage check.
nkoskelo Nov 5, 2024
e02a800
Look at the variables directly and as long as we are using the reserv…
nkoskelo Nov 5, 2024
0948843
Correct ruff suggestions.
nkoskelo Nov 5, 2024
953a643
Only record usage if the array is indexed in some way.
nkoskelo Nov 6, 2024
88eac82
Add a unit test case which is unbroadcastable but is still a legal py…
nkoskelo Nov 6, 2024
52dc445
Add a unit test and split out a reserved pattern for the reductions a…
nkoskelo Nov 7, 2024
1031868
Fix ruff suggestions.
nkoskelo Nov 7, 2024
85e5395
More ruff suggestions.
nkoskelo Nov 7, 2024
a50d58e
Make sure that we return a value if we need to. :)
nkoskelo Nov 7, 2024
921b55f
Working on mypy errors.
nkoskelo Nov 25, 2024
e858110
Respond to comments.
nkoskelo Dec 11, 2024
a04374d
Merge branch 'main' into remove-raising-revived
nkoskelo Dec 11, 2024
17df871
Update for ruff.
nkoskelo Dec 11, 2024
5b01c24
Move typing information to only import if type checking.
nkoskelo Dec 11, 2024
882312f
More ruff CI.
nkoskelo Dec 12, 2024
0feea14
Add noqa: RUF052 for kernels in test_codegen.py.
nkoskelo Dec 12, 2024
8848fd5
Add assert statements for typing purposes.
nkoskelo Dec 12, 2024
a30b3bc
Reorganize. Ruff was out of date. :)
nkoskelo Dec 12, 2024
07bc5ab
Fix some of the mypy errors.
nkoskelo Dec 12, 2024
6e73ed6
Add a mapper for applying the updates in the case of a reduction oper…
nkoskelo Dec 20, 2024
408cdd5
Merge branch 'main' into remove-raising-revived
nkoskelo Dec 20, 2024
0336c00
Fix the ruff comments.
nkoskelo Dec 20, 2024
5242f3f
Update the test code to use the correct name of the argument of its i…
nkoskelo Dec 20, 2024
b14d238
Merge branch 'main' into remove-raising-revived
nkoskelo Jan 8, 2025
e7e750a
Add a test case for the pattern match on binding names. Use the reduc…
nkoskelo Jan 8, 2025
a7827b7
Merge branch 'main' into remove-raising-revived
nkoskelo Jan 9, 2025
23ffe34
Let's get the reduction descriptors at the start when we are recordin…
nkoskelo Jan 14, 2025
9ac0a7d
Update the record equations keys to be [Array, int | str] so that we …
nkoskelo Jan 14, 2025
5ccc11f
Remove unused code.
nkoskelo Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression:
for name, bound in expr.bounds.items()}))


IDX_LAMBDA_RE = re.compile(r"_r?(0|([1-9][0-9]*))")
IDX_LAMBDA_RE = re.compile(r"^(_r?(0|([1-9][0-9]*)))$")
IDX_LAMBDA_INAME = re.compile(r"^(_(0|([1-9][0-9]*)))$")
IDX_LAMBDA_JUST_REDUCTIONS = re.compile(r"^(_r(0|([1-9][0-9]*)))$")


class DependencyMapper(DependencyMapperBase[P]):
Expand Down
187 changes: 93 additions & 94 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,27 @@
THE SOFTWARE.
"""


import logging
import re
from typing import (
TYPE_CHECKING,
Any,
ParamSpec,
TypeAlias,
TypeVar,
cast,
)

from bidict import bidict

import pymbolic.primitives as prim
from pytools import UniqueNameGenerator
from pytools.tag import Tag

from pytato.array import (
AbstractResultWithNamedArrays,
AdvancedIndexInContiguousAxes,
Array,
ArrayOrScalar,
AxisPermutation,
BasicIndex,
Concatenate,
Expand All @@ -70,18 +72,12 @@
Reshape,
Stack,
)
from pytato.diagnostic import UnknownIndexLambdaExpr
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.raising import (
BinaryOp,
BroadcastOp,
C99CallOp,
FullOp,
ReduceOp,
WhereOp,
index_lambda_to_high_level_op,
from pytato.function import NamedCallResult
from pytato.scalar_expr import (
IDX_LAMBDA_INAME,
CombineMapper,
)
from pytato.scalar_expr import SCALAR_CLASSES
from pytato.transform import ArrayOrNames, CopyMapper, Mapper
from pytato.utils import are_shape_components_equal, are_shapes_equal

Expand All @@ -90,17 +86,63 @@


if TYPE_CHECKING:
from collections.abc import Collection, Mapping
from collections.abc import Collection, Iterable, Mapping

from pytato.function import NamedCallResult
from pytato.loopy import LoopyCall


GraphNodeT = TypeVar("GraphNodeT")

BindingName: TypeAlias = str
P = ParamSpec("P")

BINDING_NAME_RESERVED_PATTERN = re.compile(r"^(_in?(0|([1-9][0-9]*)))$")
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved


# {{{ BindingSubscriptsCollector


class BindingSubscriptsCollector(CombineMapper[dict[BindingName,
set[tuple[prim.Expression, ...]]],
[]]):
"""
Return all the subscript expressions used by a variable specified by BindingName.
Ex:
_in1[_0,_1] would result in an dictionary entry {"_in1": ("_0", "_1")}.
"""
def combine(self,
values: Iterable[dict[BindingName,
set[tuple[prim.Expression, ...]]]]) \
-> dict[BindingName, set[tuple[prim.Expression, ...]]]:
out: dict[BindingName, set[tuple[prim.Expression, ...]]] = {}
from functools import reduce
return reduce(lambda x, y: x | y, values, out)
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved

def map_subscript(self, expr: prim.Subscript) -> dict[BindingName,
set[tuple[prim.Expression, ...]]]:
"""
Record the indexing expression if the Subscript expression has a prim.Variable
as its aggregate.
"""

if isinstance(expr.aggregate, prim.Variable):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a recursion. Subscripts can contain subscripts.

return {expr.aggregate.name: {expr.index_tuple}}
return {}

def map_algebraic_leaf(self, expr: prim.Expression) -> dict[BindingName,
set[tuple[prim.Expression, ...]]]:

return {}

def map_constant(self, expr: object) -> dict[BindingName,
set[tuple[prim.Expression, ...]]]:
return {}
# }}}

# {{{ AxesTagsEquationCollector


class AxesTagsEquationCollector(Mapper[None, []]):
r"""
Records equations arising from operand/output axes equivalence for an array
Expand Down Expand Up @@ -156,7 +198,7 @@ def __init__(self, tag_t: type[Tag]) -> None:

# axis_to_var: mapping from (array, iaxis) to the variable to be
# used for unification.
self.axis_to_var: bidict[tuple[Array, int], str] = bidict()
self.axis_to_var: bidict[tuple[Array, int | str], str] = bidict()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you're changing what identifies an axis here. Why? How? Will this still be unique? (The answers to these questions should become documentation.)

self.known_tag_to_var: dict[Tag, str] = {}

self.equations: list[tuple[str, str]] = []
Expand All @@ -165,7 +207,7 @@ def __init__(self, tag_t: type[Tag]) -> None:

# {{{ unification helpers

def get_var_for_axis(self, ary: Array, iaxis: int) -> str:
def get_var_for_axis(self, ary: Array, iaxis: int | str) -> str:
key = (ary, iaxis)

try:
Expand Down Expand Up @@ -227,75 +269,34 @@ def _map_input_base(self, expr: InputArgumentBase) -> None:

def map_index_lambda(self, expr: IndexLambda) -> None:
"""
The propagation semantics for a :class:`~pytato.IndexLambda` are
implemented only for operations that can be raised to a
:class:`~pytato.raising.HighLevelOp`. In such cases, an equality
equation is recorded for every non-broadcasted axis of an operand and
its corresponding axis of *expr*.
Equality conditions are added between an axis of the operands which is indexed
by a :class:`~pymbolic.Variable` which has a name that follows the reserved
iname format, "_[0-9]+", and the axis of the output specified by the iname.
"""
for bnd in expr.bindings.values():
self.rec(bnd)

try:
hlo = index_lambda_to_high_level_op(expr)
except UnknownIndexLambdaExpr:
from warnings import warn
warn(f"'{expr}' is an unknown index lambda type"
" no tags were propagated across it.", stacklevel=1)
# no propagation semantics implemented for such cases
return

if isinstance(hlo, BinaryOp):
subexprs: tuple[ArrayOrScalar, ...] = (hlo.x1, hlo.x2)
elif isinstance(hlo, WhereOp):
subexprs = (hlo.condition, hlo.then, hlo.else_)
elif isinstance(hlo, FullOp):
# A full-op does not impose any equations
subexprs = ()
elif isinstance(hlo, BroadcastOp):
subexprs = (hlo.x,)
elif isinstance(hlo, C99CallOp):
subexprs = hlo.args
elif isinstance(hlo, ReduceOp):

# {{{ ReduceOp doesn't quite involve broadcasting

i_out_axis = 0
for i_in_axis in range(hlo.x.ndim):
if i_in_axis not in hlo.axes:
self.record_equation(
self.get_var_for_axis(hlo.x, i_in_axis),
self.get_var_for_axis(expr, i_out_axis)
)
i_out_axis += 1
index_expr_used = BindingSubscriptsCollector()(expr.expr)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the variable name to match the terminology of the collector.


assert i_out_axis == expr.ndim
assert len(expr.shape) == expr.ndim
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved

# }}}
for vname, set_of_ind_tuple in index_expr_used.items():
for ind_tuple in set_of_ind_tuple:
for axis_ind, var_ind_name in enumerate(ind_tuple):
if isinstance(var_ind_name, prim.Variable):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an else case here, to document that you've thought about it. E.g.

else:
    raise AssertionError()

lhs: str = self.get_var_for_axis(expr.bindings[vname],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the : str annotations necessary here? (I suspect not.)

axis_ind)
if IDX_LAMBDA_INAME.fullmatch(var_ind_name.name):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an else case here, to document that you've thought about it. E.g.

else:
    raise AssertionError()

nkoskelo marked this conversation as resolved.
Show resolved Hide resolved
# matched with an iname.
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved
inum = int(var_ind_name.name[1:])
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved
rhs: str = self.get_var_for_axis(expr, inum)
self.record_equation(lhs, rhs)
elif var_ind_name.name in expr.var_to_reduction_descr.keys():
# matched with a reduction iname.
rhs = self.get_var_for_axis(expr, var_ind_name.name)
self.record_equation(lhs, rhs)

return

else:
raise NotImplementedError(type(hlo))

for subexpr in subexprs:
if isinstance(subexpr, Array):
for i_in_axis, i_out_axis in zip(
range(subexpr.ndim),
range(expr.ndim-subexpr.ndim, expr.ndim),
strict=True):
in_dim = subexpr.shape[i_in_axis]
out_dim = expr.shape[i_out_axis]
if are_shape_components_equal(in_dim, out_dim):
self.record_equation(
self.get_var_for_axis(subexpr, i_in_axis),
self.get_var_for_axis(expr, i_out_axis)
)
else:
# i_in_axis is broadcasted => do not propagate
assert are_shape_components_equal(in_dim, 1)
else:
assert isinstance(subexpr, SCALAR_CLASSES)
return
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return


def map_stack(self, expr: Stack) -> None:
"""
Expand Down Expand Up @@ -594,10 +595,11 @@ class AxisTagAttacher(CopyMapper):
A mapper that tags the axes in a DAG as prescribed by *axis_to_tags*.
"""
def __init__(self,
axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]],
axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]],
tag_corresponding_redn_descr: bool):
super().__init__()
self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags
self.axis_to_tags: Mapping[tuple[Array, int | str],
Collection[Tag]] = axis_to_tags
self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr

def rec(self, expr: ArrayOrNames) -> Any:
Expand Down Expand Up @@ -634,18 +636,15 @@ def rec(self, expr: ArrayOrNames) -> Any:

if isinstance(expr, IndexLambda):
assert isinstance(expr_copy, IndexLambda)
try:
hlo = index_lambda_to_high_level_op(expr)
except UnknownIndexLambdaExpr:
pass
else:
if isinstance(hlo, ReduceOp):
for iaxis, redn_var in hlo.axes.items():
expr_copy = expr_copy.with_tagged_reduction(
if expr_copy.var_to_reduction_descr:
# This is a reduction operation.
# We need to find the axes that are reduced over
# and update the tag/tag them appropriately.
for redn_var in expr.var_to_reduction_descr.keys():
expr_copy = expr_copy.with_tagged_reduction(
redn_var,
self.axis_to_tags.get((hlo.x, iaxis), [])
self.axis_to_tags.get((expr, redn_var), [])
)

# }}}

self._cache[key] = expr_copy
Expand Down Expand Up @@ -710,7 +709,7 @@ def unify_axes_tags(
)

known_tag_vars = frozenset(equations_collector.known_tag_to_var.values())
axis_to_solved_tags: dict[tuple[Array, int], set[Tag]] = {}
axis_to_solved_tags: dict[tuple[Array, int | str], set[Tag]] = {}

propagation_graph = undirected_graph_from_edges(
equations_collector.equations
Expand All @@ -721,10 +720,10 @@ def unify_axes_tags(
if isinstance(tag, AxisIgnoredForPropagationTag)
})

ignored_vars.update({
ax_var for (ary, ax), ax_var in equations_collector.axis_to_var.items()
if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag)
})
for (ary, ax), ax_var in equations_collector.axis_to_var.items():
if isinstance(ax, int):
if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag):
ignored_vars.update({ax_var})

for tag, var in equations_collector.known_tag_to_var.items():
reachable_nodes = get_reachable_nodes(propagation_graph, var,
Expand Down
14 changes: 7 additions & 7 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,11 +916,11 @@ def test_einsum_with_parameterized_shapes(ctx_factory):
m_in = np.random.randint(2, 20)
n_in = np.random.randint(2, 20)

def _get_a_shape(_m, _n):
return (2*_m+1, 3*_n+7)
def _get_a_shape(m_, n_):
return (2*m_+1, 3*n_+7)

def _get_x_shape(_m, _n):
return (3*_n+7, )
def _get_x_shape(_m, n_):
return (3*n_+7, )

A_in = np.random.rand(*_get_a_shape(m_in, n_in)) # noqa: N806
x_in = np.random.rand(*_get_x_shape(m_in, n_in))
Expand Down Expand Up @@ -1570,16 +1570,16 @@ def test_regression_reduction_in_conditional(ctx_factory):
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

def kernel(usr_np, _pt_data_9):
pt_tmp_53 = _pt_data_9 @ _pt_data_9
def kernel(usr_np, pt_data_9):
pt_tmp_53 = pt_data_9 @ pt_data_9
pt_tmp_42 = usr_np.maximum(pt_tmp_53, pt_tmp_53)
pt_tmp_27 = usr_np.sum(pt_tmp_42)
pt_tmp_0 = usr_np.maximum(pt_tmp_27, pt_tmp_53)
return pt_tmp_0

def get_np_input_args():
return {
"_pt_data_9": np.ones((2, 2)),
"pt_data_9": np.ones((2, 2)),
}

np_inputs = get_np_input_args()
Expand Down
Loading
Loading