-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add edge counter functionality and tests #535
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Matt Smith <[email protected]>
Co-authored-by: Andreas Klöckner <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I added a few suggestions.
Also, there's an edge (ha) case that we need to think about here: what happens if the same array appears twice as a dependency? For example, y = x + x
; should there be one edge between x
and y
, or two? The behavior should ideally match whatever the visualization does (I don't remember offhand if it emits a single edge or multiple). If it needs to be multiple edges, I think one way to do that would be to tweak DirectPredecessorsGetter
to return either a list
or a frozenset
based on an argument to __init__
.
Concatenate, | ||
DataWrapper as DataWrapper, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DataWrapper as DataWrapper, | |
DataWrapper, |
from pytato.transform import ( | ||
ArrayOrNames, | ||
CachedWalkMapper, | ||
DependencyMapper as DependencyMapper, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DependencyMapper as DependencyMapper, | |
DependencyMapper, |
# Each dependency is connected by an edge | ||
self.edge_count += len(self.get_dependencies(expr)) | ||
|
||
def get_dependencies(self, expr: Any) -> frozenset[Any]: | ||
# Retrieve dependencies based on the type of the expression | ||
if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): | ||
return frozenset(expr.bindings.values()) | ||
elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): | ||
return frozenset([expr.array]) | ||
elif isinstance(expr, Einsum): | ||
return frozenset(expr.args) | ||
return frozenset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tempted to say that this and the DirectPredecessorsGetter
implementation in the tests should be swapped. DirectPredecessorsGetter
seems like the more "proper" way to do this, and get_dependencies
makes sense as an alternate implementation to check that it's working.
|
||
|
||
def get_num_edges(outputs: Array | DictOfNamedArrays, | ||
count_duplicates: bool | None = None) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
count_duplicates: bool | None = None) -> int: | |
count_duplicates: bool = False) -> int: |
(Since get_num_edges
is a new function, we don't have to keep the deprecation stuff that get_num_nodes
has.)
if count_duplicates is None: | ||
from warnings import warn | ||
warn( | ||
"The default value of 'count_duplicates' will change " | ||
"from True to False in 2025. " | ||
"For now, pass the desired value explicitly.", | ||
DeprecationWarning, stacklevel=2) | ||
count_duplicates = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if count_duplicates is None: | |
from warnings import warn | |
warn( | |
"The default value of 'count_duplicates' will change " | |
"from True to False in 2025. " | |
"For now, pass the desired value explicitly.", | |
DeprecationWarning, stacklevel=2) | |
count_duplicates = True |
|
||
def post_visit(self, expr: Any) -> None: | ||
# Each dependency is connected by an edge | ||
self.edge_count += len(self.get_dependencies(expr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably will also want an if not isinstance(expr, DictOfNamedArrays):
check here if switching to DirectPredecessorsGetter
.
for dep in dependencies: | ||
self.edge_multiplicity_counts[dep, expr] += 1 | ||
|
||
def get_dependencies(self, expr: Any) -> frozenset[Any]: | ||
# Retrieve dependencies based on the type of the expression | ||
if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): | ||
return frozenset(expr.bindings.values()) | ||
elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): | ||
return frozenset([expr.array]) | ||
elif isinstance(expr, Einsum): | ||
return frozenset(expr.args) | ||
return frozenset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Same deal here with DirectPredecessorsGetter
.)
empty_dag = pt.make_dict_of_named_arrays({}) | ||
|
||
# Verify that get_num_edges returns 0 for an empty DAG | ||
assert get_num_edges(empty_dag, count_duplicates=False) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert get_num_edges(empty_dag, count_duplicates=False) == 0 | |
assert get_num_edges(empty_dag) == 0 |
(And same for all the rest.)
This PR adds functionality to count the number of edges in a DAG, with or without duplicates. It also adds tests for these functionalities.