diff --git a/tensordict/memmap_refact.py b/tensordict/memmap_refact.py index c034c4365..445a32a1b 100644 --- a/tensordict/memmap_refact.py +++ b/tensordict/memmap_refact.py @@ -162,6 +162,30 @@ def from_handler(cls, handler, dtype, shape, index): out.parent_shape = shape return out + def __setstate__(self, state): + if 'filename' in state: + self.__dict__.update(type(self).from_filename(**state).__dict__) + else: + self.__dict__.update(type(self).from_handler(**state).__dict__) + + def __getstate__(self): + if getattr(self, "_handler", None) is not None: + return { + 'handler': self._handler, + 'dtype': self.dtype, + 'shape': self.parent_shape, + 'index': self.index, + } + elif getattr(self, "_filename", None) is not None: + return { + 'filename': self._filename, + 'dtype': self.dtype, + 'shape': self.parent_shape, + 'index': self.index, + } + else: + raise RuntimeError("Could not find handler or filename.") + def __reduce__(self): if getattr(self, "_handler", None) is not None: return type(self).from_handler, (