Skip to content

Commit

Permalink
Remove index to access descr (#528)
Browse files Browse the repository at this point in the history
* Remove the index_access_descr from Einsum object.

* Ruff corrections.

* Make the index_to_access_descriptor a cached property that will automatically build a dictionary for strings to access descriptor. Note that this may not be the exact same string as what the user passed in pt.einsum, but will be equivalent.

* Add a comment on the new cached property.

* Fixing comments made by ruff.

* Remove the cached property. Add in a warning to users requesting a subscript string.

* Remove duplicated code in generating a einsum string. The numpy code generator returned a string which contained spaces after the punctuation. It now no longer does.

* Numpy code generator seems to be dependent on the spaces after the punctuation. So, I am adding those back in.

* Deprecate the old function get_einsum_subscript_str, in favor of get_einsum_specification. Update the code to use the new function.

* Visualization needed to be updated as well.

* Update pytato/utils.py

Co-authored-by: Andreas Klöckner <[email protected]>

* Update pytato/utils.py

Co-authored-by: Andreas Klöckner <[email protected]>

* Updates target.

* Update the documentation to be more clear.

* Fixed a typo.

* Remove excess assert statement.

* Ensure there is a unique error message for arguments of string type since we are removing that functionality.

---------

Co-authored-by: Andreas Klöckner <[email protected]>
  • Loading branch information
nkoskelo and inducer authored Jul 30, 2024
1 parent a8793a9 commit 0030ec2
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 88 deletions.
28 changes: 7 additions & 21 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,8 +1149,6 @@ class Einsum(_SuppliedAxesAndTagsMixin, Array):
redn_axis_to_redn_descr: Mapping[EinsumReductionAxis,
ReductionDescriptor] = \
attrs.field(validator=attrs.validators.instance_of(immutabledict))
index_to_access_descr: Mapping[str, EinsumAxisDescriptor] = \
attrs.field(validator=attrs.validators.instance_of(immutabledict))
_mapper_method: ClassVar[str] = "map_einsum"

@memoize_method
Expand Down Expand Up @@ -1200,30 +1198,20 @@ def dtype(self) -> np.dtype[Any]:
return np.result_type(*[arg.dtype for arg in self.args])

def with_tagged_reduction(self,
redn_axis: EinsumReductionAxis | str,
redn_axis: EinsumReductionAxis,
tag: Tag) -> Einsum:
"""
Returns a copy of *self* with the :class:`ReductionDescriptor`
associated with *redn_axis* tagged with *tag*.
"""
from pytato.diagnostic import InvalidEinsumIndex, NotAReductionAxis
# {{{ sanity checks

# {{{ sanity checks
if isinstance(redn_axis, str):
try:
redn_axis_ = self.index_to_access_descr[redn_axis]
except KeyError as err:
raise InvalidEinsumIndex(
f"'{redn_axis}': not a valid axis index.") from err
if isinstance(redn_axis_, EinsumReductionAxis):
redn_axis = redn_axis_
else:
raise NotAReductionAxis(f"'{redn_axis}' is not"
" a reduction axis.")
elif isinstance(redn_axis, EinsumReductionAxis):
pass
else:
raise TypeError("Argument 'redn_axis' expected to be"
raise TypeError("Argument `redn_axis' as a string is no longer"
" accepted as a valid index type."
" Use the actual EinsumReductionAxis object instead.")
elif not isinstance(redn_axis, EinsumReductionAxis):
raise TypeError(f"Argument `redn_axis' expected to be"
f" EinsumReductionAxis, got {type(redn_axis)}")

if redn_axis in self.redn_axis_to_redn_descr:
Expand All @@ -1246,7 +1234,6 @@ def with_tagged_reduction(self,
redn_axis_to_redn_descr=immutabledict
(new_redn_axis_to_redn_descr),
tags=self.tags,
index_to_access_descr=self.index_to_access_descr,
non_equality_tags=self.non_equality_tags,
)

Expand Down Expand Up @@ -1453,7 +1440,6 @@ def einsum(subscripts: str, *operands: Array,
EinsumElementwiseAxis)})
),
redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr),
index_to_access_descr=index_to_descr,
non_equality_tags=_get_created_at_tag(),
)

Expand Down
28 changes: 2 additions & 26 deletions pytato/target/python/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from pytato.scalar_expr import SCALAR_CLASSES
from pytato.target.python import BoundPythonProgram, NumpyLikePythonTarget
from pytato.transform import CachedMapper
from pytato.utils import are_shape_components_equal
from pytato.utils import are_shape_components_equal, get_einsum_specification


T = TypeVar("T")
Expand Down Expand Up @@ -124,30 +124,6 @@ def first_true(iterable: Iterable[T], default: T,
return next(filter(pred, iterable), default)


def _get_einsum_subscripts(einsum: Einsum) -> str:
from pytato.array import EinsumAxisDescriptor, EinsumElementwiseAxis

idx_stream = (chr(i) for i in range(ord("i"), ord("z")))
idx_gen: Callable[[], str] = lambda: next(idx_stream) # noqa: E731
axis_descr_to_idx: dict[EinsumAxisDescriptor, str] = {}
input_specs = []
for access_descr in einsum.access_descriptors:
spec = ""
for axis_descr in access_descr:
try:
spec += axis_descr_to_idx[axis_descr]
except KeyError:
axis_descr_to_idx[axis_descr] = idx_gen()
spec += axis_descr_to_idx[axis_descr]

input_specs.append(spec)

output_spec = "".join(axis_descr_to_idx[EinsumElementwiseAxis(i)]
for i in range(einsum.ndim))

return f"{', '.join(input_specs)} -> {output_spec}"


def _is_slice_trivial(slice_: NormalizedSlice,
dim: ShapeComponent) -> bool:
"""
Expand Down Expand Up @@ -520,7 +496,7 @@ def map_einsum(self, expr: Einsum) -> str:
lhs = self.vng("_pt_tmp")
args = [ast.Name(self.rec(arg)) for arg in expr.args]
rhs = ast.Call(ast.Attribute(ast.Name(self.numpy_backend), "einsum"),
args=[ast.Constant(_get_einsum_subscripts(expr)),
args=[ast.Constant(get_einsum_specification(expr)),
*args],
keywords=[],
)
Expand Down
3 changes: 0 additions & 3 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def map_einsum(self, expr: Einsum) -> Array:
tuple(self.rec(arg) for arg in expr.args),
axes=expr.axes,
redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr,
index_to_access_descr=expr.index_to_access_descr,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

Expand Down Expand Up @@ -615,7 +614,6 @@ def map_einsum(self, expr: Einsum, *args: Any, **kwargs: Any) -> Array:
tuple(self.rec(arg, *args, **kwargs) for arg in expr.args),
axes=expr.axes,
redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr,
index_to_access_descr=expr.index_to_access_descr,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

Expand Down Expand Up @@ -1478,7 +1476,6 @@ def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator:
new_expr = Einsum(expr.access_descriptors,
tuple(ary.expr for ary in rec_arrays),
expr.redn_axis_to_redn_descr,
expr.index_to_access_descr,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
Expand Down
4 changes: 0 additions & 4 deletions pytato/transform/einsum_distributive_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class _EinsumDistributiveLawMapperContext:
surrounding_args: Mapping[int, Array]
redn_axis_to_redn_descr: Mapping[EinsumReductionAxis,
ReductionDescriptor]
index_to_access_descr: Mapping[str, EinsumAxisDescriptor]
axes: AxesT = attrs.field(kw_only=True)
tags: frozenset[Tag] = attrs.field(kw_only=True)

Expand Down Expand Up @@ -122,7 +121,6 @@ def _wrap_einsum_from_ctx(expr: Array,
ctx.access_descriptors,
new_args,
ctx.redn_axis_to_redn_descr,
ctx.index_to_access_descr,
tags=ctx.tags,
axes=ctx.axes
)
Expand Down Expand Up @@ -266,7 +264,6 @@ def map_einsum(self,
for iarg, arg in enumerate(expr.args)
if iarg != distributive_law_descr.ioperand}),
immutabledict(expr.redn_axis_to_redn_descr),
immutabledict(expr.index_to_access_descr),
tags=expr.tags,
axes=expr.axes,
)
Expand All @@ -277,7 +274,6 @@ def map_einsum(self,
expr.access_descriptors,
tuple(self.rec(arg, None) for arg in expr.args),
expr.redn_axis_to_redn_descr,
index_to_access_descr=expr.index_to_access_descr,
tags=expr.tags,
axes=expr.axes
)
Expand Down
1 change: 0 additions & 1 deletion pytato/transform/remove_broadcasts_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def map_einsum(self, expr: Einsum) -> Array:
return Einsum(tuple(new_access_descriptors),
tuple(new_args),
expr.redn_axis_to_redn_descr,
expr.index_to_access_descr,
tags=expr.tags,
axes=expr.axes,)

Expand Down
72 changes: 55 additions & 17 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,9 +658,10 @@ def get_common_dtype_of_ary_or_scalars(ary_or_scalars: Sequence[ArrayOrScalar]

def get_einsum_subscript_str(expr: Einsum) -> str:
"""
Returns the index subscript expression that was used in constructing *expr*
using the :func:`pytato.einsum` routine.
Returns the index subscript expression that can be
used in constructing *expr* using the :func:`pytato.einsum` routine.
Deprecated: use get_einsum_specification_str instead.
.. testsetup::
Expand All @@ -672,28 +673,65 @@ def get_einsum_subscript_str(expr: Einsum) -> str:
>>> A = pt.make_placeholder("A", (10, 6), np.float64)
>>> B = pt.make_placeholder("B", (6, 5), np.float64)
>>> C = pt.make_placeholder("B", (5, 4), np.float64)
>>> C = pt.make_placeholder("C", (5, 4), np.float64)
>>> ABC = pt.einsum("ij,jk,kl->il", A, B, C)
>>> get_einsum_subscript_str(ABC)
'ij,jk,kl->il'
"""
from pytato.array import EinsumElementwiseAxis
from warnings import warn

acc_descr_to_index = {
acc_descr: idx
for idx, acc_descr in expr.index_to_access_descr.items()
}
warn("get_einsum_subscript_str has been deprecated and will be removed in "
" Oct 2024. Use get_einsum_specification instead.",
DeprecationWarning, stacklevel=2)

output_subscripts = "".join(
[acc_descr_to_index[EinsumElementwiseAxis(idim)]
for idim in range(expr.ndim)]
)
arg_subscripts: list[str] = []
return get_einsum_specification(expr)

for acc_descrs in expr.access_descriptors:
arg_subscripts.append("".join(acc_descr_to_index[acc_descr]
for acc_descr in acc_descrs))

return f"{','.join(arg_subscripts)}->{output_subscripts}"
def get_einsum_specification(expr: Einsum) -> str:
"""
Returns the index subscript expression that can be
used in constructing *expr* using the :func:`pytato.einsum` routine.
Note this function may not return the exact same string as the
string you input as part of a call to :func:`pytato.einsum'.
Instead you will get a canonical version of the specification
starting the indices with the letter 'i'.
.. testsetup::
>>> import pytato as pt
>>> import numpy as np
>>> from pytato.utils import get_einsum_subscript_str
.. doctest::
>>> A = pt.make_placeholder("A", (10, 6), np.float64)
>>> B = pt.make_placeholder("B", (6, 5), np.float64)
>>> C = pt.make_placeholder("C", (5, 4), np.float64)
>>> ABC = pt.einsum("ab,bc,cd->ad", A, B, C)
>>> get_einsum_subscript_str(ABC)
'ij,jk,kl->il'
"""

from pytato.array import EinsumAxisDescriptor, EinsumElementwiseAxis

index_letters = (chr(i) for i in range(ord("i"), ord("z")))
axis_descr_to_idx: dict[EinsumAxisDescriptor, str] = {}
input_specs = []
for access_descr in expr.access_descriptors:
spec = ""
for axis_descr in access_descr:
try:
spec += axis_descr_to_idx[axis_descr]
except KeyError:
axis_descr_to_idx[axis_descr] = next(index_letters)
spec += axis_descr_to_idx[axis_descr]

input_specs.append(spec)

output_spec = "".join(axis_descr_to_idx[EinsumElementwiseAxis(i)]
for i in range(expr.ndim))

return f"{','.join(input_specs)}->{output_spec}"
# vim: fdm=marker
4 changes: 2 additions & 2 deletions pytato/visualization/fancy_placeholder_data_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def map_index_lambda(self, expr: IndexLambda) -> _FancyDotWriterNode:
return ret_node

def map_einsum(self, expr: Einsum) -> _FancyDotWriterNode:
from pytato.utils import get_einsum_subscript_str
from pytato.utils import get_einsum_specification

ensm_spec = get_einsum_subscript_str(expr)
ensm_spec = get_einsum_specification(expr)
node_id = self.vng("_pt_ensm")
spec = ensm_spec.replace("->", "→")
node_decl = (f'{node_id} [label="{spec}",'
Expand Down
1 change: 1 addition & 0 deletions test/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_einsum(spec, argshapes, jit):
np.testing.assert_allclose(np_out, pt_out)


# Ignore deprecation warnings starting with get_einsum_subscript_str
@pytest.mark.parametrize("jit", ([False, True]))
def test_random_dag_against_numpy(jit):
from testlib import RandomDAGContext, make_random_dag
Expand Down
21 changes: 7 additions & 14 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def test_expand_dims_input_validate():
def test_with_tagged_reduction():
from testlib import FooRednTag

from pytato.diagnostic import InvalidEinsumIndex, NotAReductionAxis
from pytato.diagnostic import NotAReductionAxis
from pytato.raising import index_lambda_to_high_level_op
x = pt.make_placeholder("x", shape=(10, 10))
x_sum = pt.sum(x)
Expand All @@ -1022,25 +1022,18 @@ def test_with_tagged_reduction():
assert x_sum.var_to_reduction_descr[hlo.axes[1]].tags_of_type(FooRednTag)
assert not x_sum.var_to_reduction_descr[hlo.axes[0]].tags_of_type(FooRednTag)

x_trace = pt.einsum("ii->i", x)
x_colsum = pt.einsum("ij->j", x)

with pytest.raises(NotAReductionAxis):
# 'j': not being reduced over.
with pytest.raises(TypeError):
# no longer support indexing by string.
x_colsum.with_tagged_reduction("j", FooRednTag())

with pytest.raises(InvalidEinsumIndex):
# 'k': unknown axis
x_colsum.with_tagged_reduction("k", FooRednTag())

with pytest.raises(NotAReductionAxis):
# 'i': not being reduced over.
x_trace.with_tagged_reduction("i", FooRednTag())

x_colsum = x_colsum.with_tagged_reduction("i", FooRednTag())
my_descr = x_colsum.access_descriptors[0][0]
x_colsum = x_colsum.with_tagged_reduction(my_descr,
FooRednTag())

assert (x_colsum
.redn_axis_to_redn_descr[x_colsum.index_to_access_descr["i"]]
.redn_axis_to_redn_descr[my_descr]
.tags_of_type(FooRednTag))


Expand Down

0 comments on commit 0030ec2

Please sign in to comment.