Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] support OrderedDict as input to TensorDictSequential #1126

Closed
davidegraff opened this issue Dec 3, 2024 · 7 comments · Fixed by #1142
Closed

[Feature Request] support OrderedDict as input to TensorDictSequential #1126

davidegraff opened this issue Dec 3, 2024 · 7 comments · Fixed by #1142
Assignees
Labels
enhancement New feature or request

Comments

@davidegraff
Copy link

Motivation

The native nn.Sequential supports input in the form of an OrderedDict[str, Module] to set the names of the underlying modules rather than just using ordinal numbers (as is the default when an Iterable[Module] is supplied). This isn't currently possible with the analogous TensorDictSequential class, but it would be nice if we could do the same.

Example

native torch behavior

>>> W = nn.Linear(10, 1)
>>> nn.Sequential(W)
Sequential(
  (0): Linear(in_features=10, out_features=1, bias=True)
)
>>> nn.Sequential(OrderedDict(fc0=W)
Sequential(
  (fc0): Linear(in_features=10, out_features=1, bias=True)
)

Current tensordict behavior

>>> W_tdm = TensorDictModule(W, ['x'], ['y'])
>>> TensorDictSequential(W)
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=Linear(in_features=8, out_features=16, bias=True),
          device=cpu,
          in_keys=['x'],
          out_keys=['linear0'])
    ),
    device=cpu,
    in_keys=['x'],
    out_keys=['y'])
>>> TensorDictSequential(OrderedDict(fc0=W))
TensorDictSequential(
    module=ModuleList(
      (0): WrapModule()
    ),
    device=cpu,
    in_keys=[],
    out_keys=[])

Desired behavior

>>> TensorDictSequential(OrderedDict(fc0=W))
TensorDictSequential(
    module=ModuleList(
      (fc0): TensorDictModule(
          module=Linear(in_features=8, out_features=16, bias=True),
          device=cpu,
          in_keys=['x'],
          out_keys=['linear0'])
    ),
    device=cpu,
    in_keys=['x'],
    out_keys=['linear0'])
@davidegraff davidegraff added the enhancement New feature or request label Dec 3, 2024
@vmoens
Copy link
Contributor

vmoens commented Dec 4, 2024

Happy to make this an option!
I guess then we'd want that indexing the sequence works like indexing a dict

m = TensorDictSequential(OrderedDict(fc0=W))
m["fc0"] # returns w

Also, there should be no mapping between module name and entry names right? Or would that be an interesting avenue to investigate (eg, all entries written by the "fc0" module are written in tensordict["fco"] sub-td)

@davidegraff
Copy link
Author

davidegraff commented Dec 4, 2024

AFAICT nn.Sequential.__getitem__() doesn't support getting by string keys, only integer values:

>>> block = nn.Sequential(OrderedDict(linear=nn.Linear(10, 10)))
>>> block
Sequential(
  (linear): Linear(in_features=10, out_features=10, bias=True)
)
>>> block[0]
Linear(in_features=10, out_features=10, bias=True)
>>> block['linear']
[...]
TypeError: 'str' object cannot be interpreted as an integer

So I think your example should throw a TypeError as well:

>>> m = TensorDictSequential(OrderedDict(fc0=W))
>>> m["fc0"]
[...]
TypeError: ...

Also, there should be no mapping between module name and entry names right? Or would that be an interesting avenue to investigate (eg, all entries written by the "fc0" module are written in tensordict["fco"] sub-td)

FWIW, this is my intended use case. I currently do something like this:

 # mapping from module "name" to its corresponding module, input keys, and output keys
model_config: dict[str, tuple[nn.Module, list | dict, list]
modules = [
    # nest the output keys in a sub-td under the module's name
    TensorDictModule(module, in_keys, [(name, key) for key in out_keys])
    for name, (module, in_keys, out_keys) in model_config.items()
]
model = TensorDictSequential(modules)

I think it would be a nice feature because it would mean I don't have to do this workaround and could just do this instead:

modules = {
    name: TensorDictModule(module, in_keys, out_keys)
    for name, (module, in_keys, out_keys) in model_config.items()
}
model = TensorDictSequential(OrderedDict(modules))

But I think this would be a breaking change, no? Because then you either (1) implement the same behavior for the TensorDictSequential(*modules: TensorDictModule) case, putting all intermediate entries under their ordinal value. That is, placing intermediate values into sub-tensordicts under keys tuple[int, str | tuple[str, ...]] where the 0th item is the ordinal, e.g., 0, 1, ..., n. Or (2) produce different behaviors for these two cases, keeping the current behavior for TensorDictSequential(*modules: TensorDictModule) and adding the new behavior only for TensorDictSequential(module_dict: OrderedDict[str, TensorDictModule]). IMO having differing behavior depending on the input signature isn't a good strategy. A possible middle ground would be introducing a new keyword argument that adds this behavior, something like nest_intermediate_values: bool = False (open to other name ideas), which when True, would produce this behavior.

@vmoens
Copy link
Contributor

vmoens commented Dec 17, 2024

We could also display the graph using this
https://gist.github.com/vmoens/ab18938154997b94f6940b1fadde8999
It's not really as fine grained as fx plot but it's nice to see the macro structure of the model

@vmoens
Copy link
Contributor

vmoens commented Dec 17, 2024

IMO having differing behavior depending on the input signature isn't a good strategy.

We have this in pytorch for many functions, eg view(shape) or view(*shape) etc.
So I'd be in favor of TensorDictSequential(ordered_dict) and TensorDictSequential(*modules).

Open to discussion though! I can also see the point that different signatures may not be an excellent strategy but one more kwarg may also be a bit clunky.

(that being said I think we should not have used an expanded sequence in Sequential but what is done is done)

@vmoens vmoens linked a pull request Dec 17, 2024 that will close this issue
@davidegraff
Copy link
Author

We have this in pytorch for many functions, eg view(shape) or view(*shape) etc. So I'd be in favor of TensorDictSequential(ordered_dict) and TensorDictSequential(*modules).

Yeah fair enough. I guess it's just a personal preference of mine when writing python code that all input signatures share the same output behavior. But FWIW, your proposal does make my life easier and seems like reasonable behavior, so I'll defer to you on what you think is best.

We could also display the graph using this https://gist.github.com/vmoens/ab18938154997b94f6940b1fadde8999 It's not really as fine grained as fx plot but it's nice to see the macro structure of the model

I was actually thinking of implementing something like this in my own code for users to visualize the compute graph of the model's forward() method.

This also brings up a related feature request that I was thinking about in terms of allowing users to supply just a bag of TensorDictModules and then infer the compute graph from there. I don't know how useful this would be for most users, but it would remove the one requirement that the modules in the graph be "ordered" w/r/t to their dependencies when supplied to TensorDictSequential. Maybe it would be better as a subclass of TensorDictSequential, e.g. TensorDictDigraph extends TensorDictSequential. Curious to hear what you think--I didn't want to overload you with feature requests!

@vmoens
Copy link
Contributor

vmoens commented Dec 18, 2024

This also brings up a related feature request that I was thinking about in terms of allowing users to supply just a bag of TensorDictModules and then infer the compute graph from there.

I guess the "problem" here is that I can overwrite entries in the input tensordict. If you flip the order in these cases, we can't backtrack what is the original input and what is the output, eg

seq = TensorDictSequential(
   Mod(lambda x: x+1, in_keys=["x"], out_keys["x"]),
   Mod(lambda x: x.sqrt(), in_keys=["x"], out_keys["x"]),
)

The order is important: if we swap the two modules, the graph will be different.

In other words, your idea can work but only if there is no out_keys that match the same module in_keys and if there is no out_keys that appear twice in the graph (I think these conditions are sufficient but I might be wrong)

@vmoens vmoens closed this as completed Dec 18, 2024
@davidegraff
Copy link
Author

davidegraff commented Dec 18, 2024

Yeah exactly, and I view that as an implicit "feature" of TensorDictSequential. Of course, you don't "know" the precedence of these edges when they're supplied as a bag, but as long as the input collection has some known, canonical ordering then you can just establish that edges are added to the graph on a FIFO basis. Perhaps this could be made clear by making this just an alternate constructor of TensorDictSequential:

from collections.abc import Collection

class TensorDictSequential:
    @classmethod
    def from_edges(modules: Collection[TensorDictModule]) -> Self:
        """Build a :class:`TensorDictSequential` from the collection of edges that define its forward computation graph."""

But yeah thinking about this more, the edge cases seem hard to define the behavior. Like, it would be nice if I had a simple DAG with no overlapping out_keys that I could just toss in the bag of edges and build my TensorDictSequential, but I think the ultimate utility of this might not be that great. If you think it might be worthwhile in the long run, you could just cover the simple case at first and throw a ValueError for the edge cases and just sort that out later. Just wanted to provide some thoughts, but thanks as always!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants