-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] support OrderedDict
as input to TensorDictSequential
#1126
Comments
Happy to make this an option! m = TensorDictSequential(OrderedDict(fc0=W))
m["fc0"] # returns w Also, there should be no mapping between module name and entry names right? Or would that be an interesting avenue to investigate (eg, all entries written by the "fc0" module are written in |
AFAICT >>> block = nn.Sequential(OrderedDict(linear=nn.Linear(10, 10)))
>>> block
Sequential(
(linear): Linear(in_features=10, out_features=10, bias=True)
)
>>> block[0]
Linear(in_features=10, out_features=10, bias=True)
>>> block['linear']
[...]
TypeError: 'str' object cannot be interpreted as an integer So I think your example should throw a >>> m = TensorDictSequential(OrderedDict(fc0=W))
>>> m["fc0"]
[...]
TypeError: ...
FWIW, this is my intended use case. I currently do something like this: # mapping from module "name" to its corresponding module, input keys, and output keys
model_config: dict[str, tuple[nn.Module, list | dict, list]
modules = [
# nest the output keys in a sub-td under the module's name
TensorDictModule(module, in_keys, [(name, key) for key in out_keys])
for name, (module, in_keys, out_keys) in model_config.items()
]
model = TensorDictSequential(modules) I think it would be a nice feature because it would mean I don't have to do this workaround and could just do this instead: modules = {
name: TensorDictModule(module, in_keys, out_keys)
for name, (module, in_keys, out_keys) in model_config.items()
}
model = TensorDictSequential(OrderedDict(modules)) But I think this would be a breaking change, no? Because then you either (1) implement the same behavior for the |
We could also display the graph using this |
We have this in pytorch for many functions, eg Open to discussion though! I can also see the point that different signatures may not be an excellent strategy but one more kwarg may also be a bit clunky. (that being said I think we should not have used an expanded sequence in Sequential but what is done is done) |
Yeah fair enough. I guess it's just a personal preference of mine when writing python code that all input signatures share the same output behavior. But FWIW, your proposal does make my life easier and seems like reasonable behavior, so I'll defer to you on what you think is best.
I was actually thinking of implementing something like this in my own code for users to visualize the compute graph of the model's This also brings up a related feature request that I was thinking about in terms of allowing users to supply just a bag of |
I guess the "problem" here is that I can overwrite entries in the input tensordict. If you flip the order in these cases, we can't backtrack what is the original input and what is the output, eg seq = TensorDictSequential(
Mod(lambda x: x+1, in_keys=["x"], out_keys["x"]),
Mod(lambda x: x.sqrt(), in_keys=["x"], out_keys["x"]),
) The order is important: if we swap the two modules, the graph will be different. In other words, your idea can work but only if there is no out_keys that match the same module in_keys and if there is no out_keys that appear twice in the graph (I think these conditions are sufficient but I might be wrong) |
Yeah exactly, and I view that as an implicit "feature" of from collections.abc import Collection
class TensorDictSequential:
@classmethod
def from_edges(modules: Collection[TensorDictModule]) -> Self:
"""Build a :class:`TensorDictSequential` from the collection of edges that define its forward computation graph.""" But yeah thinking about this more, the edge cases seem hard to define the behavior. Like, it would be nice if I had a simple DAG with no overlapping |
Motivation
The native
nn.Sequential
supports input in the form of anOrderedDict[str, Module]
to set the names of the underlying modules rather than just using ordinal numbers (as is the default when anIterable[Module]
is supplied). This isn't currently possible with the analogousTensorDictSequential
class, but it would be nice if we could do the same.Example
native
torch
behaviorCurrent
tensordict
behaviorDesired behavior
The text was updated successfully, but these errors were encountered: