Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] TorchScript compat #1141

Open
wants to merge 1 commit into
base: gh/vmoens/35/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Update
[ghstack-poisoned]
vmoens committed Dec 17, 2024
commit bbfe08242df0cff81b3d78213352a9ae8ef2cdac
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
@@ -1334,7 +1334,7 @@ def _apply_nest(
filter_empty: bool | None = None,
is_leaf: Callable | None = None,
out: TensorDictBase | None = None,
**constructor_kwargs,
**constructor_kwargs: Any,
) -> T | None:
if inplace:
result = self
6 changes: 3 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
@@ -8393,7 +8393,7 @@ def func(*args, **kwargs):

def map(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
num_workers: int | None = None,
*,
@@ -8573,7 +8573,7 @@ def map(

def map_iter(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
num_workers: int | None = None,
*,
@@ -8771,7 +8771,7 @@ def map_iter(

def _map(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
*,
shuffle: bool = False,
64 changes: 51 additions & 13 deletions test/test_fx.py
Original file line number Diff line number Diff line change
@@ -4,16 +4,58 @@
# LICENSE file in the root directory of this source tree.

import argparse
import inspect

import pytest
import torch
import torch.nn as nn

from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from tensordict.prototype.fx import symbolic_trace


def test_fx():
seq = Seq(
Mod(lambda x: x + 1, in_keys=["x"], out_keys=["y"]),
Mod(lambda x, y: (x * y).sqrt(), in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z - z, in_keys=["z", "x"], out_keys=["a"]),
)
symbolic_trace(seq)


class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, td: TensorDict) -> torch.Tensor:
vals = td.values() # pyre-ignore[6]
return torch.cat([val._values for val in vals], dim=0)


def test_td_scripting() -> None:
for cls in (TensorDict,):
for name in dir(cls):
method = inspect.getattr_static(cls, name)
if isinstance(method, classmethod):
continue
elif isinstance(method, staticmethod):
continue
elif not callable(method):
continue
elif not name.startswith("__") or name in ("__init__", "__setitem__"):
setattr(cls, name, torch.jit.unused(method))

m = TestModule()
td = TensorDict(
a=torch.nested.nested_tensor([torch.ones((1,))], layout=torch.jagged)
)
m(td)
m = torch.jit.script(m, example_inputs=(td,))
m.code


def test_tensordictmodule_trace_consistency():
class Net(nn.Module):
def __init__(self):
@@ -24,7 +66,7 @@ def forward(self, x):
logits = self.linear(x)
return logits, torch.sigmoid(logits)

module = TensorDictModule(
module = Mod(
Net(),
in_keys=["input"],
out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
@@ -63,15 +105,13 @@ class Masker(nn.Module):
def forward(self, x, mask):
return torch.softmax(x * mask, dim=1)

net = TensorDictModule(
Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
)
masker = TensorDictModule(
net = Mod(Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")])
masker = Mod(
Masker(),
in_keys=[("intermediate", "x"), ("input", "mask")],
out_keys=[("output", "probabilities")],
)
module = TensorDictSequential(net, masker)
module = Seq(net, masker)
graph_module = symbolic_trace(module)

tensordict = TensorDict(
@@ -120,13 +160,11 @@ def forward(self, x):
module2 = Net(50, 40)
module3 = Output(40, 10)

tdmodule1 = TensorDictModule(module1, ["input"], ["x"])
tdmodule2 = TensorDictModule(module2, ["x"], ["x"])
tdmodule3 = TensorDictModule(module3, ["x"], ["probabilities"])
tdmodule1 = Mod(module1, ["input"], ["x"])
tdmodule2 = Mod(module2, ["x"], ["x"])
tdmodule3 = Mod(module3, ["x"], ["probabilities"])

tdmodule = TensorDictSequential(
TensorDictSequential(tdmodule1, tdmodule2), tdmodule3
)
tdmodule = Seq(Seq(tdmodule1, tdmodule2), tdmodule3)
graph_module = symbolic_trace(tdmodule)

tensordict = TensorDict({"input": torch.rand(32, 100)}, [32])