Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent 0caad41 commit bbfe082
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 17 deletions.
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,7 @@ def _apply_nest(
filter_empty: bool | None = None,
is_leaf: Callable | None = None,
out: TensorDictBase | None = None,
**constructor_kwargs,
**constructor_kwargs: Any,
) -> T | None:
if inplace:
result = self
Expand Down
6 changes: 3 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8393,7 +8393,7 @@ def func(*args, **kwargs):

def map(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
num_workers: int | None = None,
*,
Expand Down Expand Up @@ -8573,7 +8573,7 @@ def map(

def map_iter(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
num_workers: int | None = None,
*,
Expand Down Expand Up @@ -8771,7 +8771,7 @@ def map_iter(

def _map(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
fn: Callable, # [[TensorDictBase], TensorDictBase | None],
dim: int = 0,
*,
shuffle: bool = False,
Expand Down
64 changes: 51 additions & 13 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,58 @@
# LICENSE file in the root directory of this source tree.

import argparse
import inspect

import pytest
import torch
import torch.nn as nn

from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from tensordict.prototype.fx import symbolic_trace


def test_fx():
seq = Seq(
Mod(lambda x: x + 1, in_keys=["x"], out_keys=["y"]),
Mod(lambda x, y: (x * y).sqrt(), in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z - z, in_keys=["z", "x"], out_keys=["a"]),
)
symbolic_trace(seq)


class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, td: TensorDict) -> torch.Tensor:
vals = td.values() # pyre-ignore[6]
return torch.cat([val._values for val in vals], dim=0)


def test_td_scripting() -> None:
for cls in (TensorDict,):
for name in dir(cls):
method = inspect.getattr_static(cls, name)
if isinstance(method, classmethod):
continue
elif isinstance(method, staticmethod):
continue
elif not callable(method):
continue
elif not name.startswith("__") or name in ("__init__", "__setitem__"):
setattr(cls, name, torch.jit.unused(method))

m = TestModule()
td = TensorDict(
a=torch.nested.nested_tensor([torch.ones((1,))], layout=torch.jagged)
)
m(td)
m = torch.jit.script(m, example_inputs=(td,))
m.code


def test_tensordictmodule_trace_consistency():
class Net(nn.Module):
def __init__(self):
Expand All @@ -24,7 +66,7 @@ def forward(self, x):
logits = self.linear(x)
return logits, torch.sigmoid(logits)

module = TensorDictModule(
module = Mod(
Net(),
in_keys=["input"],
out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
Expand Down Expand Up @@ -63,15 +105,13 @@ class Masker(nn.Module):
def forward(self, x, mask):
return torch.softmax(x * mask, dim=1)

net = TensorDictModule(
Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
)
masker = TensorDictModule(
net = Mod(Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")])
masker = Mod(
Masker(),
in_keys=[("intermediate", "x"), ("input", "mask")],
out_keys=[("output", "probabilities")],
)
module = TensorDictSequential(net, masker)
module = Seq(net, masker)
graph_module = symbolic_trace(module)

tensordict = TensorDict(
Expand Down Expand Up @@ -120,13 +160,11 @@ def forward(self, x):
module2 = Net(50, 40)
module3 = Output(40, 10)

tdmodule1 = TensorDictModule(module1, ["input"], ["x"])
tdmodule2 = TensorDictModule(module2, ["x"], ["x"])
tdmodule3 = TensorDictModule(module3, ["x"], ["probabilities"])
tdmodule1 = Mod(module1, ["input"], ["x"])
tdmodule2 = Mod(module2, ["x"], ["x"])
tdmodule3 = Mod(module3, ["x"], ["probabilities"])

tdmodule = TensorDictSequential(
TensorDictSequential(tdmodule1, tdmodule2), tdmodule3
)
tdmodule = Seq(Seq(tdmodule1, tdmodule2), tdmodule3)
graph_module = symbolic_trace(tdmodule)

tensordict = TensorDict({"input": torch.rand(32, 100)}, [32])
Expand Down

0 comments on commit bbfe082

Please sign in to comment.