Skip to content

Commit

Permalink
--amend
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Mar 15, 2024
1 parent dfaaeed commit 221de48
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper):
"""

def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self.counts = {}
self.counts = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if type(expr) not in counts:
self.counts[type(expr)] = 0
self.counts[type(expr)] += 1


def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int:
"""Returns the number of nodes of each given type in DAG *outputs*."""
def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]:
"""
Returns a dictionary mapping node types to node count for that type
in DAG *outputs*.
"""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)
Expand Down

0 comments on commit 221de48

Please sign in to comment.