From bbfe08242df0cff81b3d78213352a9ae8ef2cdac Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 17 Dec 2024 10:32:12 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/_td.py | 2 +- tensordict/base.py | 6 ++--- test/test_fx.py | 64 ++++++++++++++++++++++++++++++++++++---------- 3 files changed, 55 insertions(+), 17 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 7a7364d4b..a335921b7 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1334,7 +1334,7 @@ def _apply_nest( filter_empty: bool | None = None, is_leaf: Callable | None = None, out: TensorDictBase | None = None, - **constructor_kwargs, + **constructor_kwargs: Any, ) -> T | None: if inplace: result = self diff --git a/tensordict/base.py b/tensordict/base.py index c868d1447..d774325af 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8393,7 +8393,7 @@ def func(*args, **kwargs): def map( self, - fn: Callable[[TensorDictBase], TensorDictBase | None], + fn: Callable, # [[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, *, @@ -8573,7 +8573,7 @@ def map( def map_iter( self, - fn: Callable[[TensorDictBase], TensorDictBase | None], + fn: Callable, # [[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, *, @@ -8771,7 +8771,7 @@ def map_iter( def _map( self, - fn: Callable[[TensorDictBase], TensorDictBase | None], + fn: Callable, # [[TensorDictBase], TensorDictBase | None], dim: int = 0, *, shuffle: bool = False, diff --git a/test/test_fx.py b/test/test_fx.py index 50cd8e762..747dbdfde 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4,16 +4,58 @@ # LICENSE file in the root directory of this source tree. import argparse +import inspect import pytest import torch import torch.nn as nn from tensordict import TensorDict -from tensordict.nn import TensorDictModule, TensorDictSequential +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq from tensordict.prototype.fx import symbolic_trace +def test_fx(): + seq = Seq( + Mod(lambda x: x + 1, in_keys=["x"], out_keys=["y"]), + Mod(lambda x, y: (x * y).sqrt(), in_keys=["x", "y"], out_keys=["z"]), + Mod(lambda z, x: z - z, in_keys=["z", "x"], out_keys=["a"]), + ) + symbolic_trace(seq) + + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, td: TensorDict) -> torch.Tensor: + vals = td.values() # pyre-ignore[6] + return torch.cat([val._values for val in vals], dim=0) + + +def test_td_scripting() -> None: + for cls in (TensorDict,): + for name in dir(cls): + method = inspect.getattr_static(cls, name) + if isinstance(method, classmethod): + continue + elif isinstance(method, staticmethod): + continue + elif not callable(method): + continue + elif not name.startswith("__") or name in ("__init__", "__setitem__"): + setattr(cls, name, torch.jit.unused(method)) + + m = TestModule() + td = TensorDict( + a=torch.nested.nested_tensor([torch.ones((1,))], layout=torch.jagged) + ) + m(td) + m = torch.jit.script(m, example_inputs=(td,)) + m.code + + def test_tensordictmodule_trace_consistency(): class Net(nn.Module): def __init__(self): @@ -24,7 +66,7 @@ def forward(self, x): logits = self.linear(x) return logits, torch.sigmoid(logits) - module = TensorDictModule( + module = Mod( Net(), in_keys=["input"], out_keys=[("outputs", "logits"), ("outputs", "probabilities")], @@ -63,15 +105,13 @@ class Masker(nn.Module): def forward(self, x, mask): return torch.softmax(x * mask, dim=1) - net = TensorDictModule( - Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")] - ) - masker = TensorDictModule( + net = Mod(Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]) + masker = Mod( Masker(), in_keys=[("intermediate", "x"), ("input", "mask")], out_keys=[("output", "probabilities")], ) - module = TensorDictSequential(net, masker) + module = Seq(net, masker) graph_module = symbolic_trace(module) tensordict = TensorDict( @@ -120,13 +160,11 @@ def forward(self, x): module2 = Net(50, 40) module3 = Output(40, 10) - tdmodule1 = TensorDictModule(module1, ["input"], ["x"]) - tdmodule2 = TensorDictModule(module2, ["x"], ["x"]) - tdmodule3 = TensorDictModule(module3, ["x"], ["probabilities"]) + tdmodule1 = Mod(module1, ["input"], ["x"]) + tdmodule2 = Mod(module2, ["x"], ["x"]) + tdmodule3 = Mod(module3, ["x"], ["probabilities"]) - tdmodule = TensorDictSequential( - TensorDictSequential(tdmodule1, tdmodule2), tdmodule3 - ) + tdmodule = Seq(Seq(tdmodule1, tdmodule2), tdmodule3) graph_module = symbolic_trace(tdmodule) tensordict = TensorDict({"input": torch.rand(32, 100)}, [32])