diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fadf1c92d..a037c491a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,7 +29,6 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( Array, @@ -46,7 +45,7 @@ ) from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall -from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper +from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper, _SelfMapper if TYPE_CHECKING: @@ -410,9 +409,14 @@ class NodeCountMapper(CachedWalkMapper): Dictionary mapping node types to number of nodes of that type. """ - def __init__(self, count_duplicates: bool = False) -> None: + def __init__( + self, + count_duplicates: bool = False, + _visited_functions: set[Any] | None = None, + ) -> None: + super().__init__(_visited_functions=_visited_functions) + from collections import defaultdict - super().__init__() self.expr_type_counts: dict[type[Any], int] = defaultdict(int) self.count_duplicates = count_duplicates @@ -420,10 +424,23 @@ def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: # Returns unique nodes only if count_duplicates is False return id(expr) if self.count_duplicates else expr + def get_function_definition_cache_key( + self, expr: FunctionDefinition) -> int | FunctionDefinition: + # Returns unique nodes only if count_duplicates is False + return id(expr) if self.count_duplicates else expr + def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_type_counts[type(expr)] += 1 + def clone_for_callee( + self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: + # type-ignore-reason: self.__init__ has a different function signature + # than Mapper.__init__ + return type(self)( + count_duplicates=self.count_duplicates, # type: ignore[attr-defined] + _visited_functions=self._visited_functions) # type: ignore[call-arg,attr-defined] + def get_node_type_counts( outputs: Array | DictOfNamedArrays, @@ -485,15 +502,20 @@ class NodeMultiplicityMapper(CachedWalkMapper): .. autoattribute:: expr_multiplicity_counts """ - def __init__(self) -> None: + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) + from collections import defaultdict - super().__init__() self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: # Returns each node, including nodes that are duplicates return id(expr) + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + # Returns each node, including nodes that are duplicates + return id(expr) + def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_multiplicity_counts[expr] += 1 @@ -527,14 +549,16 @@ class CallSiteCountMapper(CachedWalkMapper): The number of nodes. """ - def __init__(self) -> None: - super().__init__() + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.count = 0 def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - @memoize_method + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return diff --git a/pytato/codegen.py b/pytato/codegen.py index 0e1126289..78e02c107 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -24,7 +24,7 @@ """ import dataclasses -from typing import Any, Mapping, Tuple +from typing import Any, Hashable, Mapping, Tuple from immutabledict import immutabledict @@ -42,7 +42,7 @@ SizeParam, make_dict_of_named_arrays, ) -from pytato.function import NamedCallResult +from pytato.function import FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall from pytato.scalar_expr import IntegralScalarExpression from pytato.target import Target @@ -118,10 +118,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] ====================================== ===================================== """ - def __init__(self, target: Target, - kernels_seen: dict[str, lp.LoopKernel] | None = None - ) -> None: - super().__init__() + def __init__( + self, + target: Target, + kernels_seen: dict[str, lp.LoopKernel] | None = None, + _function_cache: dict[Hashable, FunctionDefinition] | None = None + ) -> None: + super().__init__(_function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target @@ -247,13 +250,16 @@ def normalize_outputs( @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NamesValidityChecker(CachedWalkMapper): - def __init__(self) -> None: + def __init__(self, _visited_functions: set[Any] | None = None) -> None: self.name_to_input: dict[str, InputArgumentBase] = {} - super().__init__() + super().__init__(_visited_functions=_visited_functions) def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def post_visit(self, expr: Any) -> None: if isinstance(expr, (Placeholder, SizeParam, DataWrapper)): if expr.name is not None: diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 5865ec491..5040557f1 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -292,8 +292,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], + _function_cache: dict[Hashable, FunctionDefinition] | None = None, ) -> None: - super().__init__() + super().__init__(_function_cache=_function_cache) self.recvd_ary_to_name = recvd_ary_to_name self.sptpo_ary_to_name = sptpo_ary_to_name @@ -307,7 +308,7 @@ def clone_for_callee( self, function: FunctionDefinition) -> _DistributedInputReplacer: # Function definitions aren't allowed to contain receives, # stored arrays promoted to part outputs, or part outputs - return type(self)({}, {}, {}) + return type(self)({}, {}, {}, _function_cache=self._function_cache) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index 730cf346c..0afb5add5 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -142,8 +142,8 @@ class MissingRecvError(DistributedPartitionVerificationError): @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class _SeenNodesWalkMapper(CachedWalkMapper): - def __init__(self) -> None: - super().__init__() + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.seen_nodes: set[ArrayOrNames] = set() def get_cache_key(self, expr: ArrayOrNames) -> int: diff --git a/pytato/equality.py b/pytato/equality.py index 79d038d72..87f613397 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -27,8 +27,6 @@ from typing import TYPE_CHECKING, Any, Callable, Union -from pytools import memoize_method - from pytato.array import ( AbstractResultWithNamedArrays, AdvancedIndexInContiguousAxes, @@ -83,19 +81,23 @@ class EqualityComparer: more on this. """ def __init__(self) -> None: + # Uses the same cache for both arrays and functions self._cache: dict[tuple[int, int], bool] = {} - def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool: + def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool: cache_key = id(expr1), id(expr2) try: return self._cache[cache_key] except KeyError: - - method: Callable[[Array | AbstractResultWithNamedArrays, Any], - bool] + method: Callable[ + [Array | AbstractResultWithNamedArrays | FunctionDefinition, Any], + bool] try: - method = getattr(self, expr1._mapper_method) + method = ( + getattr(self, expr1._mapper_method) + if isinstance(expr1, (Array, AbstractResultWithNamedArrays)) + else self.map_function_definition) except AttributeError: if isinstance(expr1, Array): result = self.handle_unsupported_array(expr1, expr2) @@ -293,7 +295,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.tags == expr2.tags ) - @memoize_method def map_function_definition(self, expr1: FunctionDefinition, expr2: Any ) -> bool: return (expr1.__class__ is expr2.__class__ @@ -307,7 +308,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any def map_call(self, expr1: Call, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ - and self.map_function_definition(expr1.function, expr2.function) + and self.rec(expr1.function, expr2.function) and frozenset(expr1.bindings) == frozenset(expr2.bindings) and all(self.rec(bnd, expr2.bindings[name]) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 4f39713cf..5b6d5cf0b 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -31,8 +31,6 @@ import numpy as np from immutabledict import immutabledict -from pytools import memoize_method - from pytato.array import ( Array, Axis, @@ -68,6 +66,7 @@ def __init__(self, self.truncation_depth = truncation_depth self.truncation_string = truncation_string + # Uses the same cache for both arrays and functions self._cache: dict[tuple[int, int], str] = {} def rec(self, expr: Any, depth: int) -> str: @@ -79,6 +78,15 @@ def rec(self, expr: Any, depth: int) -> str: self._cache[cache_key] = result return result # type: ignore[no-any-return] + def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str: + cache_key = (id(expr), depth) + try: + return self._cache[cache_key] + except KeyError: + result = super().rec_function_definition(expr, depth) + self._cache[cache_key] = result + return result # type: ignore[no-any-return] + def __call__(self, expr: Any, depth: int = 0) -> str: return self.rec(expr, depth) @@ -168,7 +176,6 @@ def _get_field_val(field: str) -> str: for field in attrs.fields(type(expr))) + ")") - @memoize_method def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string @@ -191,7 +198,7 @@ def map_call(self, expr: Call, depth: int) -> str: def _get_field_val(field: str) -> str: if field == "function": - return self.map_function_definition(expr.function, depth+1) + return self.rec_function_definition(expr.function, depth+1) else: return self.rec(getattr(expr, field), depth+1) diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 54b6328dc..af9595075 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -64,6 +64,7 @@ SizeParam, Stack, ) +from pytato.function import FunctionDefinition from pytato.raising import BinaryOpType, C99CallOp from pytato.reductions import ( AllReductionOperation, @@ -169,7 +170,7 @@ def _is_slice_trivial(slice_: NormalizedSlice, } -class NumpyCodegenMapper(CachedMapper[str]): +class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition]): """ .. note:: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index ea790e9b6..418e5a1bc 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1,7 +1,5 @@ from __future__ import annotations -from pytools import memoize_method - __copyright__ = """ Copyright (C) 2020 Matt Wala @@ -92,6 +90,7 @@ TransformMapperResultT = TypeVar("TransformMapperResultT", # used in TransformMapper Array, AbstractResultWithNamedArrays, ArrayOrNames) CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper +CachedMapperFunctionT = TypeVar("CachedMapperFunctionT") # used in CachedMapper IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") R = FrozenSet[Array] _SelfMapper = TypeVar("_SelfMapper", bound="Mapper") @@ -149,10 +148,15 @@ A type variable representing the type of a :class:`CombineMapper`. +.. class:: CachedMapperFunctionT + + A type variable used to represent the output type of a :class:`CachedMapper` + for :class:`pytato.function.FunctionDefinition`. + .. class:: _SelfMapper A type variable used to represent the type of a mapper in - :meth:`TransformMapper.clone_for_callee`. + :meth:`CachedMapper.clone_for_callee`. """ transform_logger = logging.getLogger(__file__) @@ -199,7 +203,7 @@ def map_foreign(self, expr: Any, *args: Any, **kwargs: Any) -> Any: def rec(self, expr: MappedT, *args: Any, **kwargs: Any) -> Any: """Call the mapper method of *expr* and return the result.""" - method: Callable[..., Array] | None + method: Callable[..., Any] | None try: method = getattr(self, expr._mapper_method) @@ -219,6 +223,20 @@ def rec(self, expr: MappedT, *args: Any, **kwargs: Any) -> Any: assert method is not None return method(expr, *args, **kwargs) + def rec_function_definition( + self, expr: FunctionDefinition, *args: Any, **kwargs: Any + ) -> Any: + """Call the mapper method of *expr* and return the result.""" + method: Callable[..., Any] | None + + try: + method = self.map_function_definition # type: ignore[attr-defined] + except AttributeError: + return self.map_foreign(expr, *args, **kwargs) + + assert method is not None + return method(expr, *args, **kwargs) + def __call__(self, expr: MappedT, *args: Any, **kwargs: Any) -> Any: """Handle the mapping of *expr*.""" return self.rec(expr, *args, **kwargs) @@ -228,21 +246,37 @@ def __call__(self, expr: MappedT, *args: Any, **kwargs: Any) -> Any: # {{{ CachedMapper -class CachedMapper(Mapper, Generic[CachedMapperT]): +class CachedMapper(Mapper, Generic[CachedMapperT, CachedMapperFunctionT]): """Mapper class that maps each node in the DAG exactly once. This loses some information compared to :class:`Mapper` as a node is visited only from one of its predecessors. .. automethod:: get_cache_key + .. automethod:: get_function_definition_cache_key + .. automethod:: clone_for_callee """ - def __init__(self) -> None: + def __init__( + self, + # Arrays are cached separately for each call stack frame, but + # functions are cached globally + _function_cache: dict[Hashable, CachedMapperFunctionT] | None = None + ) -> None: super().__init__() self._cache: dict[Hashable, CachedMapperT] = {} + if _function_cache is None: + _function_cache = {} + + self._function_cache: dict[Hashable, CachedMapperFunctionT] = _function_cache + def get_cache_key(self, expr: ArrayOrNames) -> Hashable: return expr + def get_function_definition_cache_key( + self, expr: FunctionDefinition) -> Hashable: + return expr + def rec(self, expr: ArrayOrNames) -> CachedMapperT: key = self.get_cache_key(expr) try: @@ -253,6 +287,27 @@ def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type-ignore-reason: Mapper.rec has imprecise func. signature return result # type: ignore[no-any-return] + def rec_function_definition( + self, expr: FunctionDefinition) -> CachedMapperFunctionT: + key = self.get_function_definition_cache_key(expr) + try: + return self._function_cache[key] + except KeyError: + result = super().rec_function_definition(expr) + self._function_cache[key] = result + return result # type: ignore[no-any-return] + + def clone_for_callee( + self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + # type-ignore-reason: self.__init__ has a different function signature + # than Mapper.__init__ + return type(self)( + _function_cache=self._function_cache) # type: ignore[call-arg,attr-defined] + if TYPE_CHECKING: def __call__(self, expr: ArrayOrNames) -> CachedMapperT: return self.rec(expr) @@ -262,15 +317,13 @@ def __call__(self, expr: ArrayOrNames) -> CachedMapperT: # {{{ TransformMapper -class TransformMapper(CachedMapper[ArrayOrNames]): +class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition]): """Base class for mappers that transform :class:`pytato.array.Array`\\ s into other :class:`pytato.array.Array`\\ s. Enables certain operations that can only be done if the mapping results are also arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. - - .. automethod:: clone_for_callee """ if TYPE_CHECKING: def rec(self, expr: TransformMapperResultT) -> TransformMapperResultT: @@ -279,45 +332,55 @@ def rec(self, expr: TransformMapperResultT) -> TransformMapperResultT: def __call__(self, expr: TransformMapperResultT) -> TransformMapperResultT: return self.rec(expr) - def clone_for_callee( - self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - return type(self)() - # }}} # {{{ TransformMapperWithExtraArgs -class TransformMapperWithExtraArgs(CachedMapper[ArrayOrNames]): +class TransformMapperWithExtraArgs(CachedMapper[ArrayOrNames, FunctionDefinition]): """ Similar to :class:`TransformMapper`, but each mapper method takes extra ``*args``, ``**kwargs`` that are propagated along a path by default. The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. - - .. automethod:: clone_for_callee """ - def __init__(self) -> None: - super().__init__() + def __init__( + self, + _function_cache: dict[Hashable, FunctionDefinition] | None = None + ) -> None: + super().__init__(_function_cache=_function_cache) # type-ignored as '._cache' attribute is not coherent with the base # class - self._cache: dict[tuple[ArrayOrNames, - tuple[Any, ...], - tuple[tuple[str, Any], ...] - ], - ArrayOrNames] = {} # type: ignore[assignment] - - def get_cache_key(self, - expr: ArrayOrNames, - *args: Any, **kwargs: Any) -> tuple[ArrayOrNames, - tuple[Any, ...], - tuple[tuple[str, Any], ...] - ]: + self._cache: dict[ + tuple[ + ArrayOrNames, + tuple[Any, ...], + tuple[tuple[str, Any], ...]], + ArrayOrNames] = {} # type: ignore[assignment] + # type-ignored as '._function_cache' attribute is not coherent with the base + # class + self._function_cache: dict[ + tuple[ + FunctionDefinition, + tuple[Any, ...], + tuple[tuple[str, Any], ...]], + FunctionDefinition] = self._function_cache # type: ignore[assignment] + + def get_cache_key( + self, expr: ArrayOrNames, *args: Any, **kwargs: Any + ) -> tuple[ + ArrayOrNames, + tuple[Any, ...], + tuple[tuple[str, Any], ...]]: + return (expr, args, tuple(sorted(kwargs.items()))) + + def get_function_definition_cache_key( + self, expr: FunctionDefinition, *args: Any, **kwargs: Any + ) -> tuple[ + FunctionDefinition, + tuple[Any, ...], + tuple[tuple[str, Any], ...]]: return (expr, args, tuple(sorted(kwargs.items()))) def rec(self, @@ -335,13 +398,17 @@ def rec(self, # type-ignore-reason: Mapper.rec is imprecise return result # type: ignore[no-any-return] - def clone_for_callee( - self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - return type(self)() + def rec_function_definition( + self, + expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> FunctionDefinition: + key = self.get_function_definition_cache_key(expr, *args, **kwargs) + try: + return self._function_cache[key] + except KeyError: + result = Mapper.rec_function_definition(self, expr, *args, **kwargs) + self._function_cache[key] = result + return result # type: ignore[no-any-return] # }}} @@ -517,18 +584,19 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array: dtype=expr.dtype, tags=expr.tags, axes=expr.axes, non_equality_tags=expr.non_equality_tags) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: # spawn a new mapper to avoid unsound cache hits, since the namespace of the # function's body is different from that of the caller. + # FIXME: Clone in rec_function_definition prior to calling + # map_function_definition? new_mapper = self.clone_for_callee(expr) new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} return attrs.evolve(expr, returns=immutabledict(new_returns)) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return Call(self.map_function_definition(expr.function), + return Call(self.rec_function_definition(expr.function), immutabledict({name: self.rec(bnd) for name, bnd in expr.bindings.items()}), tags=expr.tags, @@ -731,7 +799,7 @@ def map_function_definition(self, expr: FunctionDefinition, def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: - return Call(self.map_function_definition(expr.function, *args, **kwargs), + return Call(self.rec_function_definition(expr.function, *args, **kwargs), immutabledict({name: self.rec(bnd, *args, **kwargs) for name, bnd in expr.bindings.items()}), tags=expr.tags, @@ -758,6 +826,7 @@ class CombineMapper(Mapper, Generic[CombineT]): def __init__(self) -> None: super().__init__() self.cache: dict[ArrayOrNames, CombineT] = {} + self.function_cache: dict[FunctionDefinition, CombineT] = {} def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[CombineT, ...]: @@ -770,6 +839,14 @@ def rec(self, expr: ArrayOrNames) -> CombineT: # type: ignore self.cache[expr] = result return result + def rec_function_definition( + self, expr: FunctionDefinition) -> CombineT: + if expr in self.function_cache: + return self.function_cache[expr] + result: CombineT = super().rec_function_definition(expr) + self.function_cache[expr] = result + return result + # type-ignore reason: incompatible ret. type with super class def __call__(self, expr: ArrayOrNames) -> CombineT: # type: ignore return self.rec(expr) @@ -854,7 +931,6 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> CombineT: return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> CombineT: raise NotImplementedError("Combining results from a callee expression" " is context-dependent. Derived classes" @@ -935,14 +1011,13 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> R: # do not include arrays from the function's body as it would involve # putting arrays from different namespaces into the same collection. return frozenset() def map_call(self, expr: Call) -> R: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[self.rec(bnd) for bnd in expr.bindings.values()]) def map_named_call_result(self, expr: NamedCallResult) -> R: @@ -992,7 +1067,6 @@ def map_data_wrapper(self, expr: DataWrapper) -> frozenset[InputArgumentBase]: def map_size_param(self, expr: SizeParam) -> frozenset[SizeParam]: return frozenset([expr]) - @memoize_method def map_function_definition(self, expr: FunctionDefinition ) -> frozenset[InputArgumentBase]: # get rid of placeholders local to the function. @@ -1014,7 +1088,7 @@ def map_function_definition(self, expr: FunctionDefinition return frozenset(result) def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[ self.rec(bnd) for name, bnd in sorted(expr.bindings.items())]) @@ -1037,14 +1111,13 @@ def combine(self, *args: frozenset[SizeParam] def map_size_param(self, expr: SizeParam) -> frozenset[SizeParam]: return frozenset([expr]) - @memoize_method def map_function_definition(self, expr: FunctionDefinition ) -> frozenset[SizeParam]: return self.combine(*[self.rec(ret) for ret in expr.returns.values()]) def map_call(self, expr: Call) -> frozenset[SizeParam]: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[ self.rec(bnd) for name, bnd in sorted(expr.bindings.items())]) @@ -1236,7 +1309,7 @@ def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None: if not self.visit(expr, *args, **kwargs): return - self.map_function_definition(expr.function, *args, **kwargs) + self.rec_function_definition(expr.function, *args, **kwargs) for bnd in expr.bindings.values(): self.rec(bnd, *args, **kwargs) @@ -1263,25 +1336,60 @@ class CachedWalkMapper(WalkMapper): one of its predecessors. """ - def __init__(self) -> None: + def __init__( + self, + _visited_functions: set[Any] | None = None) -> None: super().__init__() - self._visited_nodes: set[Any] = set() + self._visited_arrays_or_names: set[Any] = set() + + if _visited_functions is None: + _visited_functions = set() + + self._visited_functions: set[Any] = _visited_functions def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError + def get_function_definition_cache_key( + self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any ) -> None: cache_key = self.get_cache_key(expr, *args, **kwargs) - if cache_key in self._visited_nodes: + if cache_key in self._visited_arrays_or_names: return super().rec(expr, *args, **kwargs) - self._visited_nodes.add(cache_key) + self._visited_arrays_or_names.add(cache_key) + + def rec_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + cache_key = self.get_function_definition_cache_key(expr, *args, **kwargs) + if cache_key in self._visited_functions: + return + + super().rec_function_definition(expr, *args, **kwargs) + self._visited_functions.add(cache_key) def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: - return type(self)() + # type-ignore-reason: self.__init__ has a different function signature + # than Mapper.__init__ + return type(self)( + _visited_functions=self._visited_functions) # type: ignore[call-arg,attr-defined] + + def map_function_definition( + self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> None: + if not self.visit(expr, *args, **kwargs): + return + + new_mapper = self.clone_for_callee(expr) + for subexpr in expr.returns.values(): + new_mapper(subexpr, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + # }}} @@ -1299,8 +1407,10 @@ class TopoSortMapper(CachedWalkMapper): :class:`~pytato.function.FunctionDefinition`. """ - def __init__(self) -> None: - super().__init__() + def __init__( + self, + _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.topological_order: list[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: @@ -1309,7 +1419,6 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: def post_visit(self, expr: Any) -> None: self.topological_order.append(expr) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return @@ -1325,15 +1434,21 @@ class CachedMapAndCopyMapper(CopyMapper): traversals are memoized i.e. each node is mapped via *map_fn* exactly once. """ - def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None: - super().__init__() + def __init__( + self, + map_fn: Callable[[ArrayOrNames], ArrayOrNames], + _function_cache: dict[Hashable, FunctionDefinition] | None = None + ) -> None: + super().__init__(_function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: # type-ignore-reason: self.__init__ has a different function signature # than Mapper.__init__ and does not have map_fn - return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined] + return type(self)( + self.map_fn, # type: ignore[call-arg,attr-defined] + _function_cache=self._function_cache) # type: ignore[attr-defined] def rec(self, expr: MappedT) -> MappedT: if expr in self._cache: @@ -1665,7 +1780,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: # {{{ UsersCollector -class UsersCollector(CachedMapper[ArrayOrNames]): +class UsersCollector(CachedMapper[ArrayOrNames, FunctionDefinition]): """ Maps a graph to a dictionary representation mapping a node to its users, i.e. all the nodes using its value. @@ -1790,7 +1905,6 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> None: self.rec_idx_or_size_tuple(expr, expr.shape) - @memoize_method def map_function_definition(self, expr: FunctionDefinition, *args: Any ) -> None: raise AssertionError("Control shouldn't reach at this point." diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 6677b13a8..04bf036de 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -53,6 +53,7 @@ class PlaceholderSubstitutor(CopyMapper): """ def __init__(self, substitutions: Mapping[str, Array]) -> None: + # Ignoring function cache, not needed super().__init__() self.substitutions = substitutions diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 4fc81208a..ac65d1da5 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -41,6 +41,7 @@ from typing import ( TYPE_CHECKING, Any, + Hashable, Iterable, List, Mapping, @@ -73,7 +74,7 @@ ) from pytato.diagnostic import UnknownIndexLambdaExpr from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder -from pytato.function import NamedCallResult +from pytato.function import FunctionDefinition, NamedCallResult from pytato.raising import ( BinaryOp, BroadcastOp, @@ -623,8 +624,9 @@ class AxisTagAttacher(CopyMapper): """ def __init__(self, axis_to_tags: Mapping[tuple[Array, int], Iterable[Tag]], - tag_corresponding_redn_descr: bool): - super().__init__() + tag_corresponding_redn_descr: bool, + _function_cache: dict[Hashable, FunctionDefinition] | None = None): + super().__init__(_function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int], Iterable[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index f241fe227..9331abbe1 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -178,7 +178,7 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): +class ArrayToDotNodeInfoMapper(CachedMapper[None, None]): def __init__(self) -> None: super().__init__() self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {} diff --git a/pytato/visualization/fancy_placeholder_data_flow.py b/pytato/visualization/fancy_placeholder_data_flow.py index b388227a9..e519f70b7 100644 --- a/pytato/visualization/fancy_placeholder_data_flow.py +++ b/pytato/visualization/fancy_placeholder_data_flow.py @@ -23,6 +23,7 @@ Placeholder, Stack, ) +from pytato.function import FunctionDefinition from pytato.transform import CachedMapper @@ -96,7 +97,7 @@ def _get_dot_node_from_predecessors(node_id: str, return NoShowNode(), frozenset() -class FancyDotWriter(CachedMapper[_FancyDotWriterNode]): +class FancyDotWriter(CachedMapper[_FancyDotWriterNode, FunctionDefinition]): def __init__(self) -> None: super().__init__() self.vng = UniqueNameGenerator()