From 951f680eb6c8acfbed5c7dbe167bd00d6e844934 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 14:03:01 +0000 Subject: [PATCH] [Refactor] Add missing functions in tensorclass register ghstack-source-id: 9f959be7b04c915596bdd67e149dc241cb8b4635 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1153 --- tensordict/tensorclass.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2486643ce..72fe055a8 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -166,25 +166,45 @@ def __subclasscheck__(self, subclass): _FALLBACK_METHOD_FROM_TD = [ "__abs__", "__add__", + "__bool__", + "__eq__", + "__ge__", + "__gt__", "__iadd__", "__imul__", "__ipow__", "__isub__", "__itruediv__", "__mul__", + "__ne__", + "__or__", "__pow__", "__sub__", "__truediv__", + "__xor__", "_add_batch_dim", "_apply_nest", + "_clone", + "_clone_recurse", + "_data", "_erase_names", # TODO: must be specialized "_exclude", # TODO: must be specialized "_fast_apply", + "_flatten_keys_inplace", + "_flatten_keys_outplace", "_get_sub_tensordict", + "_grad", + "_map", "_maybe_remove_batch_dim", + "_memmap_", "_multithread_apply_flat", + "_multithread_apply_flat", + "_multithread_apply_nest", + "_multithread_rebuild", + "_permute", "_remove_batch_dim", "_repeat", + "_repeat", "_select", # TODO: must be specialized "_set_at_tuple", "_set_tuple", @@ -217,6 +237,8 @@ def __subclasscheck__(self, subclass): "clamp_max_", "clamp_min", "clamp_min_", + "clear", + "clear_device_", "consolidate", "contiguous", "copy_", @@ -251,10 +273,8 @@ def __subclasscheck__(self, subclass): "frac_", "from_any", "from_dataclass", - "to_namedtuple", "from_namedtuple", "from_pytree", - "to_pytree", "gather", "isfinite", "isnan", @@ -275,6 +295,8 @@ def __subclasscheck__(self, subclass): "log_", "map", "map_iter", + "to_namedtuple", + "to_pytree", "masked_fill", "masked_fill_", "max",