Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] TorchScript compat #1141

Open
wants to merge 1 commit into
base: gh/vmoens/35/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,7 @@ def _apply_nest(
filter_empty: bool | None = None,
is_leaf: Callable | None = None,
out: TensorDictBase | None = None,
**constructor_kwargs,
**constructor_kwargs: Any,
) -> T | None:
if inplace:
result = self
Expand Down
6 changes: 3 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8393,7 +8393,7 @@ def func(*args, **kwargs):

def map(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
num_workers: int | None = None,
*,
Expand Down Expand Up @@ -8573,7 +8573,7 @@ def map(

def map_iter(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
num_workers: int | None = None,
*,
Expand Down Expand Up @@ -8771,7 +8771,7 @@ def map_iter(

def _map(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
*,
shuffle: bool = False,
Expand Down
64 changes: 51 additions & 13 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,58 @@
# LICENSE file in the root directory of this source tree.

import argparse
import inspect

import pytest
import torch
import torch.nn as nn

from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from tensordict.prototype.fx import symbolic_trace


def test_fx():
seq = Seq(
Mod(lambda x: x + 1, in_keys=["x"], out_keys=["y"]),
Mod(lambda x, y: (x * y).sqrt(), in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z - z, in_keys=["z", "x"], out_keys=["a"]),
)
symbolic_trace(seq)


class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, td: TensorDict) -> torch.Tensor:
vals = td.values() # pyre-ignore[6]
return torch.cat([val._values for val in vals], dim=0)


def test_td_scripting() -> None:
for cls in (TensorDict,):
for name in dir(cls):
method = inspect.getattr_static(cls, name)
if isinstance(method, classmethod):
continue
elif isinstance(method, staticmethod):
continue
elif not callable(method):
continue
elif not name.startswith("__") or name in ("__init__", "__setitem__"):
setattr(cls, name, torch.jit.unused(method))

m = TestModule()
td = TensorDict(
a=torch.nested.nested_tensor([torch.ones((1,))], layout=torch.jagged)
)
m(td)
m = torch.jit.script(m, example_inputs=(td,))
m.code


def test_tensordictmodule_trace_consistency():
class Net(nn.Module):
def __init__(self):
Expand All @@ -24,7 +66,7 @@ def forward(self, x):
logits = self.linear(x)
return logits, torch.sigmoid(logits)

module = TensorDictModule(
module = Mod(
Net(),
in_keys=["input"],
out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
Expand Down Expand Up @@ -63,15 +105,13 @@ class Masker(nn.Module):
def forward(self, x, mask):
return torch.softmax(x * mask, dim=1)

net = TensorDictModule(
Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
)
masker = TensorDictModule(
net = Mod(Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")])
masker = Mod(
Masker(),
in_keys=[("intermediate", "x"), ("input", "mask")],
out_keys=[("output", "probabilities")],
)
module = TensorDictSequential(net, masker)
module = Seq(net, masker)
graph_module = symbolic_trace(module)

tensordict = TensorDict(
Expand Down Expand Up @@ -120,13 +160,11 @@ def forward(self, x):
module2 = Net(50, 40)
module3 = Output(40, 10)

tdmodule1 = TensorDictModule(module1, ["input"], ["x"])
tdmodule2 = TensorDictModule(module2, ["x"], ["x"])
tdmodule3 = TensorDictModule(module3, ["x"], ["probabilities"])
tdmodule1 = Mod(module1, ["input"], ["x"])
tdmodule2 = Mod(module2, ["x"], ["x"])
tdmodule3 = Mod(module3, ["x"], ["probabilities"])

tdmodule = TensorDictSequential(
TensorDictSequential(tdmodule1, tdmodule2), tdmodule3
)
tdmodule = Seq(Seq(tdmodule1, tdmodule2), tdmodule3)
graph_module = symbolic_trace(tdmodule)

tensordict = TensorDict({"input": torch.rand(32, 100)}, [32])
Expand Down
Loading