Skip to content

Commit

Permalink
avoid traversing functions multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm authored and inducer committed Aug 25, 2024
1 parent a92a0d1 commit 58c8dbf
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 105 deletions.
42 changes: 33 additions & 9 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from typing import TYPE_CHECKING, Any, Mapping

from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

from pytato.array import (
Array,
Expand All @@ -46,7 +45,7 @@
)
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper, _SelfMapper


if TYPE_CHECKING:
Expand Down Expand Up @@ -410,20 +409,38 @@ class NodeCountMapper(CachedWalkMapper):
Dictionary mapping node types to number of nodes of that type.
"""

def __init__(self, count_duplicates: bool = False) -> None:
def __init__(
self,
count_duplicates: bool = False,
_visited_functions: set[Any] | None = None,
) -> None:
super().__init__(_visited_functions=_visited_functions)

from collections import defaultdict
super().__init__()
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
self.count_duplicates = count_duplicates

def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
# Returns unique nodes only if count_duplicates is False
return id(expr) if self.count_duplicates else expr

def get_function_definition_cache_key(
self, expr: FunctionDefinition) -> int | FunctionDefinition:
# Returns unique nodes only if count_duplicates is False
return id(expr) if self.count_duplicates else expr

def post_visit(self, expr: Any) -> None:
if not isinstance(expr, DictOfNamedArrays):
self.expr_type_counts[type(expr)] += 1

def clone_for_callee(
self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper:
# type-ignore-reason: self.__init__ has a different function signature
# than Mapper.__init__
return type(self)(
count_duplicates=self.count_duplicates, # type: ignore[attr-defined]
_visited_functions=self._visited_functions) # type: ignore[call-arg,attr-defined]


def get_node_type_counts(
outputs: Array | DictOfNamedArrays,
Expand Down Expand Up @@ -485,15 +502,20 @@ class NodeMultiplicityMapper(CachedWalkMapper):
.. autoattribute:: expr_multiplicity_counts
"""
def __init__(self) -> None:
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)

from collections import defaultdict
super().__init__()
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
# Returns each node, including nodes that are duplicates
return id(expr)

def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
# Returns each node, including nodes that are duplicates
return id(expr)

def post_visit(self, expr: Any) -> None:
if not isinstance(expr, DictOfNamedArrays):
self.expr_multiplicity_counts[expr] += 1
Expand Down Expand Up @@ -527,14 +549,16 @@ class CallSiteCountMapper(CachedWalkMapper):
The number of nodes.
"""

def __init__(self) -> None:
super().__init__()
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)
self.count = 0

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

@memoize_method
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def map_function_definition(self, expr: FunctionDefinition) -> None:
if not self.visit(expr):
return
Expand Down
22 changes: 14 additions & 8 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""

import dataclasses
from typing import Any, Mapping, Tuple
from typing import Any, Hashable, Mapping, Tuple

from immutabledict import immutabledict

Expand All @@ -42,7 +42,7 @@
SizeParam,
make_dict_of_named_arrays,
)
from pytato.function import NamedCallResult
from pytato.function import FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall
from pytato.scalar_expr import IntegralScalarExpression
from pytato.target import Target
Expand Down Expand Up @@ -118,10 +118,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
====================================== =====================================
"""

def __init__(self, target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None
) -> None:
super().__init__()
def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_function_cache: dict[Hashable, FunctionDefinition] | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
Expand Down Expand Up @@ -247,13 +250,16 @@ def normalize_outputs(

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class NamesValidityChecker(CachedWalkMapper):
def __init__(self) -> None:
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
self.name_to_input: dict[str, InputArgumentBase] = {}
super().__init__()
super().__init__(_visited_functions=_visited_functions)

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if isinstance(expr, (Placeholder, SizeParam, DataWrapper)):
if expr.name is not None:
Expand Down
5 changes: 3 additions & 2 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
) -> None:
super().__init__()
super().__init__(_function_cache=_function_cache)

self.recvd_ary_to_name = recvd_ary_to_name
self.sptpo_ary_to_name = sptpo_ary_to_name
Expand All @@ -307,7 +308,7 @@ def clone_for_callee(
self, function: FunctionDefinition) -> _DistributedInputReplacer:
# Function definitions aren't allowed to contain receives,
# stored arrays promoted to part outputs, or part outputs
return type(self)({}, {}, {})
return type(self)({}, {}, {}, _function_cache=self._function_cache)

def map_placeholder(self, expr: Placeholder) -> Placeholder:
self.user_input_names.add(expr.name)
Expand Down
4 changes: 2 additions & 2 deletions pytato/distributed/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ class MissingRecvError(DistributedPartitionVerificationError):

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class _SeenNodesWalkMapper(CachedWalkMapper):
def __init__(self) -> None:
super().__init__()
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)
self.seen_nodes: set[ArrayOrNames] = set()

def get_cache_key(self, expr: ArrayOrNames) -> int:
Expand Down
19 changes: 10 additions & 9 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

from typing import TYPE_CHECKING, Any, Callable, Union

from pytools import memoize_method

from pytato.array import (
AbstractResultWithNamedArrays,
AdvancedIndexInContiguousAxes,
Expand Down Expand Up @@ -83,19 +81,23 @@ class EqualityComparer:
more on this.
"""
def __init__(self) -> None:
# Uses the same cache for both arrays and functions
self._cache: dict[tuple[int, int], bool] = {}

def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool:
def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool:
cache_key = id(expr1), id(expr2)
try:
return self._cache[cache_key]
except KeyError:

method: Callable[[Array | AbstractResultWithNamedArrays, Any],
bool]
method: Callable[
[Array | AbstractResultWithNamedArrays | FunctionDefinition, Any],
bool]

try:
method = getattr(self, expr1._mapper_method)
method = (
getattr(self, expr1._mapper_method)
if isinstance(expr1, (Array, AbstractResultWithNamedArrays))
else self.map_function_definition)
except AttributeError:
if isinstance(expr1, Array):
result = self.handle_unsupported_array(expr1, expr2)
Expand Down Expand Up @@ -293,7 +295,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
and expr1.tags == expr2.tags
)

@memoize_method
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
) -> bool:
return (expr1.__class__ is expr2.__class__
Expand All @@ -307,7 +308,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any

def map_call(self, expr1: Call, expr2: Any) -> bool:
return (expr1.__class__ is expr2.__class__
and self.map_function_definition(expr1.function, expr2.function)
and self.rec(expr1.function, expr2.function)
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
and all(self.rec(bnd,
expr2.bindings[name])
Expand Down
15 changes: 11 additions & 4 deletions pytato/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import numpy as np
from immutabledict import immutabledict

from pytools import memoize_method

from pytato.array import (
Array,
Axis,
Expand Down Expand Up @@ -68,6 +66,7 @@ def __init__(self,
self.truncation_depth = truncation_depth
self.truncation_string = truncation_string

# Uses the same cache for both arrays and functions
self._cache: dict[tuple[int, int], str] = {}

def rec(self, expr: Any, depth: int) -> str:
Expand All @@ -79,6 +78,15 @@ def rec(self, expr: Any, depth: int) -> str:
self._cache[cache_key] = result
return result # type: ignore[no-any-return]

def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
cache_key = (id(expr), depth)
try:
return self._cache[cache_key]
except KeyError:
result = super().rec_function_definition(expr, depth)
self._cache[cache_key] = result
return result # type: ignore[no-any-return]

def __call__(self, expr: Any, depth: int = 0) -> str:
return self.rec(expr, depth)

Expand Down Expand Up @@ -168,7 +176,6 @@ def _get_field_val(field: str) -> str:
for field in attrs.fields(type(expr)))
+ ")")

@memoize_method
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string
Expand All @@ -191,7 +198,7 @@ def map_call(self, expr: Call, depth: int) -> str:

def _get_field_val(field: str) -> str:
if field == "function":
return self.map_function_definition(expr.function, depth+1)
return self.rec_function_definition(expr.function, depth+1)
else:
return self.rec(getattr(expr, field), depth+1)

Expand Down
3 changes: 2 additions & 1 deletion pytato/target/python/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
SizeParam,
Stack,
)
from pytato.function import FunctionDefinition
from pytato.raising import BinaryOpType, C99CallOp
from pytato.reductions import (
AllReductionOperation,
Expand Down Expand Up @@ -169,7 +170,7 @@ def _is_slice_trivial(slice_: NormalizedSlice,
}


class NumpyCodegenMapper(CachedMapper[str]):
class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition]):
"""
.. note::
Expand Down
Loading

0 comments on commit 58c8dbf

Please sign in to comment.