Skip to content

Commit

Permalink
Show requires_grad info for torch tensors and parameters. (#59)
Browse files Browse the repository at this point in the history
For `Tensor`s, we add requires_grad=True to the summary when
that flag is set on the tensor.

For `nn.Parameter`s, we assume gradients are required by default
and add requires_grad=False when the parameter is frozen.

Summary shows regardless of whether or not autovizualization is on.

Fixes #51.
  • Loading branch information
danieldjohnson authored Feb 17, 2025
1 parent 022f8ee commit d08cd70
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
2 changes: 1 addition & 1 deletion run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
import subprocess

if __name__ == "__main__":
subprocess.check_call(["python", "-m", "pytest"])
subprocess.check_call(["python", "-m", "pytest", "--tb=short"])
70 changes: 70 additions & 0 deletions tests/renderer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def _repr_html_(self):
return self.repr_html


def _fill_grad_as_ones(tensor: torch.Tensor) -> torch.Tensor:
tensor.requires_grad_()
tensor.sum().backward()
return tensor


class TreescopeRendererTest(parameterized.TestCase):

def setUp(self):
Expand Down Expand Up @@ -399,6 +405,70 @@ def hook_that_crashes(node, path, node_renderer):
[ 7, 8, 9, 10, 11, 12, 13],
[14, 15, 16, 17, 18, 19, 20]])"""),
),
dict(
testcase_name="pytorch_tensor_large_requires_grad",
target_builder=lambda: torch.tensor(
np.arange(3 * 7, dtype=np.float32).reshape((3, 7))
).requires_grad_(),
expected_collapsed=(
"""<torch.Tensor float32(3, 7) ≈1e+01 ±3.7e+01 [≥0.0, ≤2e+01] zero:1 nonzero:20 requires_grad=True>"""
),
expected_expanded=textwrap.dedent(
"""\
# torch.Tensor float32(3, 7) ≈1e+01 ±3.7e+01 [≥0.0, ≤2e+01] zero:1 nonzero:20 requires_grad=True
tensor([[ 0., 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12., 13.],
[14., 15., 16., 17., 18., 19., 20.]], requires_grad=True)"""
),
),
dict(
testcase_name="pytorch_tensor_large_has_grad",
target_builder=lambda: _fill_grad_as_ones(
torch.tensor(np.arange(3 * 7, dtype=np.float32).reshape((3, 7)))
),
expected_collapsed=(
"""<torch.Tensor float32(3, 7) ≈1e+01 ±3.7e+01 [≥0.0, ≤2e+01] zero:1 nonzero:20 requires_grad=True grad=<Tensor>>"""
),
expected_expanded=textwrap.dedent(
"""\
# torch.Tensor float32(3, 7) ≈1e+01 ±3.7e+01 [≥0.0, ≤2e+01] zero:1 nonzero:20 requires_grad=True grad=<Tensor>
tensor([[ 0., 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12., 13.],
[14., 15., 16., 17., 18., 19., 20.]], requires_grad=True)"""
),
),
dict(
testcase_name="pytorch_parameter",
target_builder=lambda: torch.nn.Parameter(
torch.tensor(np.arange(3 * 7, dtype=np.float32).reshape((3, 7)))
),
expected_collapsed=(
"""<torch.nn.Parameter float32(3, 7) ≈1e+01 ±3.7e+01 [≥0.0, ≤2e+01] zero:1 nonzero:20>"""
),
expected_expanded=textwrap.dedent(
"""\
# torch.nn.Parameter float32(3, 7) ≈1e+01 ±3.7e+01 [≥0.0, ≤2e+01] zero:1 nonzero:20
Parameter containing:
tensor([[ 0., 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12., 13.],
[14., 15., 16., 17., 18., 19., 20.]], requires_grad=True)"""
),
),
dict(
testcase_name="pytorch_parameter_frozen",
target_builder=lambda: torch.nn.Parameter(
torch.tensor(np.arange(3 * 7, dtype=np.float32).reshape((3, 7)))
).requires_grad_(False),
expected_collapsed=(
"""<torch.nn.Parameter float32(3, 7) ≈1e+01 ±3.7e+01 [≥0.0, ≤2e+01] zero:1 nonzero:20 requires_grad=False>"""
),
expected_expanded=textwrap.dedent("""\
# torch.nn.Parameter float32(3, 7) ≈1e+01 ±3.7e+01 [≥0.0, ≤2e+01] zero:1 nonzero:20 requires_grad=False
Parameter containing:
tensor([[ 0., 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12., 13.],
[14., 15., 16., 17., 18., 19., 20.]])"""),
),
dict(
testcase_name="well_known_function",
target=treescope.render_to_text,
Expand Down
14 changes: 14 additions & 0 deletions treescope/external/torch_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def get_array_summary(
) -> rendering_parts.RenderableTreePart:
assert torch is not None, "PyTorch is not available."
ty = type(array)
array_grad = array.grad
array_requires_grad = array.requires_grad
array = array.detach()
typename = f"{ty.__module__}.{ty.__name__}"
abbrv = f"{ty.__name__}"
Expand Down Expand Up @@ -268,6 +270,18 @@ def get_array_summary(
if ct_false:
summary_parts.append(f" false:{ct_false:_d}")

if issubclass(ty, torch.nn.Parameter):
# Assume parameters require grad by default.
if not array_requires_grad:
summary_parts.append(" requires_grad=False")
else:
# Assume non-parameters don't require grad by default.
if array_requires_grad:
summary_parts.append(" requires_grad=True")

if array_grad is not None:
summary_parts.append(f" grad=<{type(array_grad).__name__}>")

return rendering_parts.siblings(
*always_show_parts,
rendering_parts.abbreviatable(
Expand Down

0 comments on commit d08cd70

Please sign in to comment.