Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 3, 2024
1 parent 929343c commit 837ea2a
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import concurrent.futures
import functools
from typing import Any, Callable, Sequence, TypeVar

Expand All @@ -16,8 +15,13 @@

from tensordict.base import NO_DEFAULT, TensorDictBase
from tensordict.persistent import PersistentTensorDict
from tensordict.utils import _check_keys, _ErrorInteceptor, DeviceType, \
lazy_legacy, set_lazy_legacy
from tensordict.utils import (
_check_keys,
_ErrorInteceptor,
DeviceType,
lazy_legacy,
set_lazy_legacy,
)
from torch import Tensor

T = TypeVar("T", bound="TensorDictBase")
Expand Down Expand Up @@ -377,20 +381,25 @@ def _stack(

out = {}
for key in keys:
out[key] = [_tensordict._get_str(key, default=None) for _tensordict in list_of_tensordicts]
tensor_shape = out[key][0].shape
for tensor in out[key][1:]:
if tensor.shape != tensor_shape:
out[key] = []
tensor_shape = None
for _tensordict in list_of_tensordicts:
tensor = _tensordict._get_str(key, default=NO_DEFAULT)
if tensor_shape is None:
tensor_shape = tensor.shape
elif tensor.shape != tensor_shape:
with set_lazy_legacy(True):
return _stack(list_of_tensordicts, dim=dim)
def stack_fn(key: str):
out[key].append(tensor)

def stack_fn(key_values):
key, values = key_values
with _ErrorInteceptor(
key, "Attempted to stack tensors on different devices at key"
):
out[key] = torch.stack(out[key], dim)
return torch.stack(values, dim)

with concurrent.futures.ThreadPoolExecutor(32) as pool:
pool.map(stack_fn, out.keys())
out = {key: stack_fn((key, value)) for key, value in out.items()}

return TensorDict(
out,
Expand Down

0 comments on commit 837ea2a

Please sign in to comment.