Skip to content

Commit

Permalink
[no ci] WIP: some progress on generalizing pushing indirections
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Apr 3, 2024
1 parent 92125af commit d71dc04
Showing 1 changed file with 35 additions and 137 deletions.
172 changes: 35 additions & 137 deletions pytato/transform/indirections.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import sys
from typing import (Any, Dict, Mapping, Tuple, TypeAlias, Iterable,
FrozenSet, Union, Set, List, Optional, Callable)
from pytato.array import (Array, InputArgumentBase, DictOfNamedArrays,
Expand All @@ -38,9 +40,13 @@
from immutables import Map
from pytato.utils import are_shape_components_equal

if sys.version >= (3, 11):
zip_equal = lambda *_args: zip(*_args, strict=True)
else:
from more_itertools import zip_equal

_ComposedIndirectionT: TypeAlias = Tuple[Array, ...]
IndexT: TypeAlias = Union[Array, NormalizedSlice]
IndexStackT: TypeAlias = Tuple[IndexT, ...]


def _is_materialized(expr: Array) -> bool:
Expand All @@ -53,15 +59,15 @@ def _is_materialized(expr: Array) -> bool:
or bool(expr.tags_of_type(ImplStored)))


def _is_trivial_slice(dim: ShapeComponent, slice_: NormalizedSlice) -> bool:
def _is_trivial_slice(dim: ShapeComponent, slice_: IndexT) -> bool:
"""
Returns *True* only if *slice_* indexes an entire axis of shape *dim* with
a step of 1.
"""
return (slice_.step == 1
return (isinstance(slice_, NormalizedSlice)
and slice_.step == 1
and are_shape_components_equal(slice_.start, 0)
and are_shape_components_equal(slice_.stop, dim)
)
and are_shape_components_equal(slice_.stop, dim))


def _take_along_axis(ary: Array, iaxis: int, idxs: IndexStackT) -> Array:
Expand Down Expand Up @@ -427,35 +433,35 @@ class _IndirectionPusher(Mapper):

def __init__(self) -> None:
self.get_reordarable_axes = _LegallyAxisReorderingFinder()
self._cache: Dict[Tuple[ArrayOrNames, Map[int, IndexStackT]],
self._cache: Dict[Tuple[ArrayOrNames, Map[int, IndexT]],
ArrayOrNames] = {}
super().__init__()

def rec(self, # type: ignore[override]
expr: MappedT,
index_stacks: Map[int, IndexStackT]) -> MappedT:
key = (expr, index_stacks)
indices: Tuple[IndexT, ...]) -> MappedT:
assert len(indices) == expr.ndim
key = (expr, indices)
try:
# type-ignore-reason: parametric mapping types aren't a thing in 'typing'
return self._cache[key] # type: ignore[return-value]
except KeyError:
result = Mapper.rec(self, expr, index_stacks)
result = Mapper.rec(self, expr, indices)
self._cache[key] = result
return result # type: ignore[no-any-return]

def __call__(self, # type: ignore[override]
expr: MappedT,
index_stacks: Map[int, IndexStackT]) -> MappedT:
return self.rec(expr, index_stacks)
indices: Map[int, IndexT]) -> MappedT:
return self.rec(expr, indices)

def _map_materialized(self,
expr: Array,
index_stacks: Map[int, IndexStackT]) -> Array:
result = expr
for iaxis, idxs in index_stacks.items():
result = _take_along_axis(result, iaxis, idxs)

return result
indices: Tuple[IndexT, ...]) -> Array:
if all(_is_trivial_slice(dim, idx)
for dim, idx in zip(expr.shape, indices)):
return expr
return expr[*indices]

def map_dict_of_named_arrays(self,
expr: DictOfNamedArrays,
Expand All @@ -467,9 +473,12 @@ def map_dict_of_named_arrays(self,

def map_index_lambda(self,
expr: IndexLambda,
index_stacks: Map[int, IndexStackT]
indices: Tuple[IndexT, ...],
) -> Array:
if _is_materialized(expr):
# FIXME: Move this logic to .rec (Why on earth do we need)
# to copy the damn node???

# do not propagate the indexings to the bindings.
expr = IndexLambda(expr.expr,
expr.shape,
Expand All @@ -478,9 +487,13 @@ def map_index_lambda(self,
for name, bnd in expr.bindings.items()}),
expr.var_to_reduction_descr,
tags=expr.tags,
axes=expr.axes,
)
return self._map_materialized(expr, index_stacks)
axes=expr.axes,)
return self._map_materialized(expr, indices)

# FIXME:
# This is the money shot. Over here we need to figure out the index
# propagation logic.


iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr)

Expand Down Expand Up @@ -886,128 +899,13 @@ def push_axis_indirections_towards_materialized_nodes(expr: MappedT
) -> MappedT:
"""
Returns a copy of *expr* with the indirections propagated closer to the
materialized nodes. We propagate an indirections only if the indirection in
an :class:`~pytato.array.AdvancedIndexInContiguousAxes` or
:class:`~pytato.array.AdvancedIndexInNoncontiguousAxes` is an indirection
over a single axis.
materialized nodes.
"""
mapper = _IndirectionPusher()

return mapper(expr, Map())


def _get_unbroadcasted_axis_in_indirections(
expr: AdvancedIndexInContiguousAxes) -> Optional[Mapping[int, int]]:
"""
Returns a mapping from the index of an indirection to its *only*
unbroadcasted axis as required by the logic. Returns *None* if no such
mapping exists.
"""
from pytato.utils import partition, get_shape_after_broadcasting
adv_indices, _ = partition(lambda i: isinstance(expr.indices[i],
NormalizedSlice),
range(expr.array.ndim))
i_ary_indices = [i_idx
for i_idx, idx in enumerate(expr.indices)
if isinstance(idx, Array)]

adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx]
for i_idx in adv_indices])

if len(adv_idx_shape) != len(i_ary_indices):
return None

i_adv_out_axis_to_candidate_i_arys: Dict[int, Set[int]] = {
idim: set()
for idim, _ in enumerate(adv_idx_shape)
}

for i_ary_idx in i_ary_indices:
ary = expr.indices[i_ary_idx]
assert isinstance(ary, Array)
for iadv_out_axis, i_ary_axis in zip(range(len(adv_idx_shape)-1, -1, -1),
range(ary.ndim-1, -1, -1)):
if are_shape_components_equal(adv_idx_shape[iadv_out_axis],
ary.shape[i_ary_axis]):
i_adv_out_axis_to_candidate_i_arys[iadv_out_axis].add(i_ary_idx)

from itertools import permutations
# FIXME: O(expr.ndim!) complexity, typically ndim <= 4 so this should be fine.
for guess_i_adv_out_axis_to_i_ary in permutations(range(len(i_ary_indices))):
if all(i_ary in i_adv_out_axis_to_candidate_i_arys[i_adv_out]
for i_adv_out, i_ary in enumerate(guess_i_adv_out_axis_to_i_ary)):
# TODO: Return the mapping here...
i_ary_to_unbroadcasted_axis: Dict[int, int] = {}
for guess_i_adv_out_axis, i_ary_idx in enumerate(
guess_i_adv_out_axis_to_i_ary):
ary = expr.indices[i_ary_idx]
assert isinstance(ary, Array)
iunbroadcasted_axis, = [
i_ary_axis
for i_adv_out_axis, i_ary_axis in zip(
range(len(adv_idx_shape)-1, -1, -1),
range(ary.ndim-1, -1, -1))
if i_adv_out_axis == guess_i_adv_out_axis
]
i_ary_to_unbroadcasted_axis[i_ary_idx] = iunbroadcasted_axis

return Map(i_ary_to_unbroadcasted_axis)

return None


class MultiAxisIndirectionsDecoupler(CopyMapper):
def map_contiguous_advanced_index(self,
expr: AdvancedIndexInContiguousAxes
) -> Array:
i_ary_idx_to_unbroadcasted_axis = _get_unbroadcasted_axis_in_indirections(
expr)

if i_ary_idx_to_unbroadcasted_axis is not None:
from pytato.utils import partition
i_adv_indices, _ = partition(lambda idx: isinstance(expr.indices[idx],
NormalizedSlice),
range(len(expr.indices)))

result = self.rec(expr.array)

for iaxis, idx in enumerate(expr.indices):
if isinstance(idx, Array):
from pytato.array import squeeze
axes_to_squeeze = [
idim
for idim in range(expr
.indices[iaxis] # type: ignore[union-attr]
.ndim)
if idim != i_ary_idx_to_unbroadcasted_axis[iaxis]]
if axes_to_squeeze:
idx = squeeze(idx, axis=axes_to_squeeze)
if not (isinstance(idx, NormalizedSlice)
and _is_trivial_slice(expr.array.shape[iaxis], idx)):
result = result[
(slice(None),) * iaxis + (idx, )] # type: ignore[operator]

return result
else:
return super().map_contiguous_advanced_index(expr)


def decouple_multi_axis_indirections_into_single_axis_indirections(
expr: MappedT) -> MappedT:
"""
Returns a copy of *expr* with multiple indirections in an
:class:`~pytato.array.AdvancedIndexInContiguousAxes` decoupled as a
composition of indexing nodes with single-axis indirections.
.. note::
This is a dependency preserving transformation. If a decoupling an
advanced indexing node is not legal, we leave the node unmodified.
"""
mapper = MultiAxisIndirectionsDecoupler()
return mapper(expr)


# {{{ fold indirection constants

class _ConstantIndirectionArrayCollector(CombineMapper[FrozenSet[Array]]):
Expand Down

0 comments on commit d71dc04

Please sign in to comment.