Skip to content

Commit

Permalink
[Refactor] Add missing functions in tensorclass register
Browse files Browse the repository at this point in the history
ghstack-source-id: 9f959be7b04c915596bdd67e149dc241cb8b4635
Pull Request resolved: #1153
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent efb89a6 commit 951f680
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,25 +166,45 @@ def __subclasscheck__(self, subclass):
_FALLBACK_METHOD_FROM_TD = [
"__abs__",
"__add__",
"__bool__",
"__eq__",
"__ge__",
"__gt__",
"__iadd__",
"__imul__",
"__ipow__",
"__isub__",
"__itruediv__",
"__mul__",
"__ne__",
"__or__",
"__pow__",
"__sub__",
"__truediv__",
"__xor__",
"_add_batch_dim",
"_apply_nest",
"_clone",
"_clone_recurse",
"_data",
"_erase_names", # TODO: must be specialized
"_exclude", # TODO: must be specialized
"_fast_apply",
"_flatten_keys_inplace",
"_flatten_keys_outplace",
"_get_sub_tensordict",
"_grad",
"_map",
"_maybe_remove_batch_dim",
"_memmap_",
"_multithread_apply_flat",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild",
"_permute",
"_remove_batch_dim",
"_repeat",
"_repeat",
"_select", # TODO: must be specialized
"_set_at_tuple",
"_set_tuple",
Expand Down Expand Up @@ -217,6 +237,8 @@ def __subclasscheck__(self, subclass):
"clamp_max_",
"clamp_min",
"clamp_min_",
"clear",
"clear_device_",
"consolidate",
"contiguous",
"copy_",
Expand Down Expand Up @@ -251,10 +273,8 @@ def __subclasscheck__(self, subclass):
"frac_",
"from_any",
"from_dataclass",
"to_namedtuple",
"from_namedtuple",
"from_pytree",
"to_pytree",
"gather",
"isfinite",
"isnan",
Expand All @@ -275,6 +295,8 @@ def __subclasscheck__(self, subclass):
"log_",
"map",
"map_iter",
"to_namedtuple",
"to_pytree",
"masked_fill",
"masked_fill_",
"max",
Expand Down

0 comments on commit 951f680

Please sign in to comment.