Skip to content

Commit

Permalink
Add equal_nan option to torchdynamo.testing.same() (pytorch#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
jansel authored May 12, 2022
1 parent 01b754c commit eb6d4e6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ build-deps: clone-deps
(cd ../detectron2 && python setup.py clean && python setup.py develop)
(cd ../functorch && python setup.py clean && python setup.py develop)
(cd ../torchbenchmark && python install.py)
make setup_lint

offline-autotune-cpu: develop
rm -rf subgraphs
Expand Down
10 changes: 5 additions & 5 deletions torchdynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,20 @@ def reduce_to_scalar_loss(out):
raise NotImplementedError("Don't know how to reduce")


def same(a, b, cos_similarity=False, tol=1e-4):
def same(a, b, cos_similarity=False, tol=1e-4, equal_nan=False):
"""Check correctness to see if a and b match"""
if isinstance(a, (list, tuple, torch.nn.ParameterList, torch.Size)):
assert isinstance(b, (list, tuple)), f"type mismatch {type(a)} {type(b)}"
return len(a) == len(b) and all(
same(ai, bi, cos_similarity, tol) for ai, bi in zip(a, b)
same(ai, bi, cos_similarity, tol, equal_nan) for ai, bi in zip(a, b)
)
elif isinstance(a, dict):
assert isinstance(b, dict)
assert set(a.keys()) == set(
b.keys()
), f"keys mismatch {set(a.keys())} == {set(b.keys())}"
for k in a.keys():
if not (same(a[k], b[k], cos_similarity, tol)):
if not (same(a[k], b[k], cos_similarity, tol, equal_nan=equal_nan)):
print("Accuracy failed for key name", k)
return False
return True
Expand All @@ -101,7 +101,7 @@ def same(a, b, cos_similarity=False, tol=1e-4):
print(f"Similarity score={res.cpu().numpy()}")
return res >= 0.99
else:
return torch.allclose(a, b, atol=tol, rtol=tol)
return torch.allclose(a, b, atol=tol, rtol=tol, equal_nan=equal_nan)
elif isinstance(a, (str, int, float, type(None), bool, torch.device)):
return a == b
elif type(a).__name__ in (
Expand All @@ -119,7 +119,7 @@ def same(a, b, cos_similarity=False, tol=1e-4):
):
assert type(a) is type(b)
return all(
same(getattr(a, key), getattr(b, key), cos_similarity, tol)
same(getattr(a, key), getattr(b, key), cos_similarity, tol, equal_nan)
for key in a.__dict__.keys()
)
else:
Expand Down

0 comments on commit eb6d4e6

Please sign in to comment.