From 221de48e26e1f89bb137d9d544dba788524f9236 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Fri, 15 Mar 2024 17:47:43 -0500 Subject: [PATCH] --amend --- pytato/analysis/__init__.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index dce70e62f..79e7cd59f 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper): """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.counts = {} + self.counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) def post_visit(self, expr: Any) -> None: + if type(expr) not in counts: + self.counts[type(expr)] = 0 self.counts[type(expr)] += 1 -def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int: - """Returns the number of nodes of each given type in DAG *outputs*.""" +def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs)