diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index aeb5a8f23..966f36872 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -32,7 +32,6 @@ import orjson as json import torch -import torch.distributed as dist from tensordict.memmap import MemoryMappedTensor @@ -2388,7 +2387,7 @@ def _send( dst: int, _tag: int = -1, pseudo_rand: bool = False, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, ) -> int: for td in self.tensordicts: _tag = td._send(dst, _tag=_tag, pseudo_rand=pseudo_rand, group=group) @@ -2400,7 +2399,7 @@ def _isend( _tag: int = -1, _futures: list[torch.Future] | None = None, pseudo_rand: bool = False, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, ) -> int: if _futures is None: is_root = True @@ -2421,7 +2420,7 @@ def _recv( src: int, _tag: int = -1, pseudo_rand: bool = False, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, ) -> int: for td in self.tensordicts: _tag = td._recv(src, _tag=_tag, pseudo_rand=pseudo_rand, group=group) @@ -2434,7 +2433,7 @@ def _irecv( _tag: int = -1, _future_list: list[torch.Future] = None, pseudo_rand: bool = False, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, ) -> tuple[int, list[torch.Future]] | list[torch.Future] | None: root = False if _future_list is None: diff --git a/tensordict/base.py b/tensordict/base.py index b8712abd1..621db37db 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -98,7 +98,7 @@ unravel_key, unravel_key_list, ) -from torch import distributed as dist, multiprocessing as mp, nn, Tensor +from torch import multiprocessing as mp, nn, Tensor from torch.nn.parameter import Parameter, UninitializedTensorMixin from torch.utils._pytree import tree_map @@ -7260,7 +7260,7 @@ def del_(self, key: NestedKey) -> T: # Distributed functionality def gather_and_stack( - self, dst: int, group: "dist.ProcessGroup" | None = None + self, dst: int, group: "torch.distributed.ProcessGroup" | None = None ) -> T | None: """Gathers tensordicts from various workers and stacks them onto self in the destination worker. @@ -7319,6 +7319,8 @@ def gather_and_stack( ... main_worker.join() ... secondary_worker.join() """ + from torch import distributed as dist + output = ( [None for _ in range(dist.get_world_size(group=group))] if dst == dist.get_rank(group=group) @@ -7336,7 +7338,7 @@ def send( self, dst: int, *, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, init_tag: int = 0, pseudo_rand: bool = False, ) -> None: # noqa: D417 @@ -7426,8 +7428,10 @@ def _send( dst: int, _tag: int = -1, pseudo_rand: bool = False, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, ) -> int: + from torch import distributed as dist + for key in self.sorted_keys: value = self._get_str(key, NO_DEFAULT) if isinstance(value, Tensor): @@ -7449,7 +7453,7 @@ def recv( self, src: int, *, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, init_tag: int = 0, pseudo_rand: bool = False, ) -> int: # noqa: D417 @@ -7481,9 +7485,11 @@ def _recv( src: int, _tag: int = -1, pseudo_rand: bool = False, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, non_blocking: bool = False, ) -> int: + from torch import distributed as dist + for key in self.sorted_keys: value = self._get_str(key, NO_DEFAULT) if isinstance(value, Tensor): @@ -7508,7 +7514,7 @@ def isend( self, dst: int, *, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, init_tag: int = 0, pseudo_rand: bool = False, ) -> int: # noqa: D417 @@ -7603,8 +7609,10 @@ def _isend( _tag: int = -1, _futures: list[torch.Future] | None = None, pseudo_rand: bool = False, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, ) -> int: + from torch import distributed as dist + root = False if _futures is None: root = True @@ -7639,7 +7647,7 @@ def irecv( self, src: int, *, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, return_premature: bool = False, init_tag: int = 0, pseudo_rand: bool = False, @@ -7687,8 +7695,10 @@ def _irecv( _tag: int = -1, _future_list: list[torch.Future] = None, pseudo_rand: bool = False, - group: "dist.ProcessGroup" | None = None, + group: "torch.distributed.ProcessGroup" | None = None, ) -> tuple[int, list[torch.Future]] | list[torch.Future] | None: + from torch import distributed as dist + root = False if _future_list is None: _future_list = [] @@ -7736,6 +7746,8 @@ def reduce( Only the process with ``rank`` dst is going to receive the final result. """ + from torch import distributed as dist + if op is None: op = dist.ReduceOp.SUM return self._reduce(dst, op, async_op, return_premature, group=group) @@ -7749,6 +7761,8 @@ def _reduce( _future_list=None, group=None, ): + from torch import distributed as dist + if op is None: op = dist.ReduceOp.SUM root = False diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index da3dc7e07..87a7d7753 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -50,7 +50,7 @@ from tensordict.utils import ( unravel_key as unravel_key, unravel_key_list as unravel_key_list, ) -from torch import distributed as dist, multiprocessing as mp, nn, Tensor +from torch import multiprocessing as mp, nn, Tensor class _NoDefault(enum.IntEnum): ZERO = 0 @@ -663,7 +663,7 @@ class TensorClass: ) -> T: ... def del_(self, key: NestedKey) -> T: ... def gather_and_stack( - self, dst: int, group: dist.ProcessGroup | None = None + self, dst: int, group: "dist.ProcessGroup" | None = None ) -> T | None: ... def send( self, diff --git a/test/smoke_test.py b/test/smoke_test.py index d3c6a8a06..cc45dd52c 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import sys import pytest @@ -11,6 +12,14 @@ def test_imports(): from tensordict import TensorDict # noqa: F401 from tensordict.nn import TensorDictModule # noqa: F401 + # # Check that distributed is not imported + # v = set(sys.modules.values()) + # try: + # from torch import distributed + # except ImportError: + # return + # assert distributed not in v + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()