diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2486643ce..62eeb83bb 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -166,23 +166,41 @@ def __subclasscheck__(self, subclass): _FALLBACK_METHOD_FROM_TD = [ "__abs__", "__add__", + "__bool__", + "__eq__", + "__ge__", + "__gt__", "__iadd__", "__imul__", "__ipow__", "__isub__", "__itruediv__", "__mul__", + "__ne__", + "__or__", "__pow__", "__sub__", "__truediv__", + "__xor__", "_add_batch_dim", "_apply_nest", + "_clone", + "_clone_recurse", + "_data", "_erase_names", # TODO: must be specialized "_exclude", # TODO: must be specialized "_fast_apply", + "_flatten_keys_inplace", + "_flatten_keys_outplace", "_get_sub_tensordict", + "_grad", + "_map", "_maybe_remove_batch_dim", + "_memmap_", "_multithread_apply_flat", + "_multithread_apply_nest", + "_multithread_rebuild", + "_permute", "_remove_batch_dim", "_repeat", "_select", # TODO: must be specialized @@ -217,6 +235,8 @@ def __subclasscheck__(self, subclass): "clamp_max_", "clamp_min", "clamp_min_", + "clear", + "clear_device_", "consolidate", "contiguous", "copy_", @@ -251,10 +271,8 @@ def __subclasscheck__(self, subclass): "frac_", "from_any", "from_dataclass", - "to_namedtuple", "from_namedtuple", "from_pytree", - "to_pytree", "gather", "isfinite", "isnan", @@ -275,6 +293,8 @@ def __subclasscheck__(self, subclass): "log_", "map", "map_iter", + "to_namedtuple", + "to_pytree", "masked_fill", "masked_fill_", "max",