Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add edge counter functionality and tests #535

Open
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

kajalpatelinfo
Copy link
Contributor

This PR adds functionality to count the number of edges in a DAG, with or without duplicates. It also adds tests for these functionalities.

@kajalpatelinfo kajalpatelinfo marked this pull request as ready for review August 15, 2024 16:17
Copy link
Collaborator

@majosm majosm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I added a few suggestions.

Also, there's an edge (ha) case that we need to think about here: what happens if the same array appears twice as a dependency? For example, y = x + x; should there be one edge between x and y, or two? The behavior should ideally match whatever the visualization does (I don't remember offhand if it emits a single edge or multiple). If it needs to be multiple edges, I think one way to do that would be to tweak DirectPredecessorsGetter to return either a list or a frozenset based on an argument to __init__.

Concatenate,
DataWrapper as DataWrapper,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DataWrapper as DataWrapper,
DataWrapper,

from pytato.transform import (
ArrayOrNames,
CachedWalkMapper,
DependencyMapper as DependencyMapper,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DependencyMapper as DependencyMapper,
DependencyMapper,

Comment on lines +547 to +558
# Each dependency is connected by an edge
self.edge_count += len(self.get_dependencies(expr))

def get_dependencies(self, expr: Any) -> frozenset[Any]:
# Retrieve dependencies based on the type of the expression
if hasattr(expr, "bindings") or isinstance(expr, IndexLambda):
return frozenset(expr.bindings.values())
elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)):
return frozenset([expr.array])
elif isinstance(expr, Einsum):
return frozenset(expr.args)
return frozenset()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tempted to say that this and the DirectPredecessorsGetter implementation in the tests should be swapped. DirectPredecessorsGetter seems like the more "proper" way to do this, and get_dependencies makes sense as an alternate implementation to check that it's working.



def get_num_edges(outputs: Array | DictOfNamedArrays,
count_duplicates: bool | None = None) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
count_duplicates: bool | None = None) -> int:
count_duplicates: bool = False) -> int:

(Since get_num_edges is a new function, we don't have to keep the deprecation stuff that get_num_nodes has.)

Comment on lines +568 to +575
if count_duplicates is None:
from warnings import warn
warn(
"The default value of 'count_duplicates' will change "
"from True to False in 2025. "
"For now, pass the desired value explicitly.",
DeprecationWarning, stacklevel=2)
count_duplicates = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if count_duplicates is None:
from warnings import warn
warn(
"The default value of 'count_duplicates' will change "
"from True to False in 2025. "
"For now, pass the desired value explicitly.",
DeprecationWarning, stacklevel=2)
count_duplicates = True


def post_visit(self, expr: Any) -> None:
# Each dependency is connected by an edge
self.edge_count += len(self.get_dependencies(expr))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably will also want an if not isinstance(expr, DictOfNamedArrays): check here if switching to DirectPredecessorsGetter.

Comment on lines +609 to +620
for dep in dependencies:
self.edge_multiplicity_counts[dep, expr] += 1

def get_dependencies(self, expr: Any) -> frozenset[Any]:
# Retrieve dependencies based on the type of the expression
if hasattr(expr, "bindings") or isinstance(expr, IndexLambda):
return frozenset(expr.bindings.values())
elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)):
return frozenset([expr.array])
elif isinstance(expr, Einsum):
return frozenset(expr.args)
return frozenset()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Same deal here with DirectPredecessorsGetter.)

empty_dag = pt.make_dict_of_named_arrays({})

# Verify that get_num_edges returns 0 for an empty DAG
assert get_num_edges(empty_dag, count_duplicates=False) == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert get_num_edges(empty_dag, count_duplicates=False) == 0
assert get_num_edges(empty_dag) == 0

(And same for all the rest.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants