Skip to content

Commit

Permalink
[Refactor] Refactor split (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 10, 2023
1 parent 4f216a3 commit 924a46a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 105 deletions.
84 changes: 30 additions & 54 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3130,73 +3130,49 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas
A list of TensorDict with specified size in given dimension.
"""
batch_sizes = []
if self.batch_dims == 0:
raise RuntimeError("TensorDict with empty batch size is not splittable")
if not (-self.batch_dims <= dim < self.batch_dims):
raise IndexError(
f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})"
)
# we must use slices to keep the storage of the tensors
batch_size = self.batch_size
if dim < 0:
dim += self.batch_dims
if isinstance(split_size, int):
rep, remainder = divmod(self.batch_size[dim], split_size)
rep_shape = torch.Size(
[
split_size if idx == dim else size
for (idx, size) in enumerate(self.batch_size)
]
dim = len(batch_size) + dim
if dim < 0 or dim >= len(batch_size):
raise RuntimeError(
f"The number of dimensions is insufficient for the split_dim {dim}."
)
batch_sizes = [rep_shape for _ in range(rep)]
if remainder:
batch_sizes.append(
torch.Size(
[
remainder if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(self.batch_size)
]
)
if isinstance(split_size, int):
if split_size <= 0:
raise RuntimeError(
f"split_size must be strictly greater than 0, got {split_size}."
)
idx0 = 0
idx1 = split_size
split_sizes = [slice(idx0, idx1)]
while idx1 < batch_size[dim]:
idx0 = idx1
idx1 = idx1 + split_size
split_sizes.append(slice(idx0, idx1))
elif isinstance(split_size, list) and all(
isinstance(element, int) for element in split_size
):
if sum(split_size) != self.batch_size[dim]:
if len(split_size) == 0:
raise RuntimeError("Insufficient number of elements in split_size.")
if sum(split_size) != batch_size[dim]:
raise RuntimeError(
f"Split method expects split_size to sum exactly to {self.batch_size[dim]} (tensor's size at dimension {dim}), but got split_size={split_size}"
)
for i in split_size:
batch_sizes.append(
torch.Size(
[
i if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(self.batch_size)
]
)
)
idx0 = 0
idx1 = split_size[0]
split_sizes = [slice(idx0, idx1)]
for idx in split_size[1:]:
idx0 = idx1
idx1 += idx
split_sizes.append(slice(idx0, idx1))
else:
raise TypeError(
"split(): argument 'split_size' must be int or list of ints"
)
dictionaries = [{} for _ in range(len(batch_sizes))]
for key, item in self.items():
split_tensors = torch.split(item, split_size, dim)
for idx, split_tensor in enumerate(split_tensors):
dictionaries[idx][key] = split_tensor
names = None
if self._has_names():
names = copy(self.names)
return [
TensorDict(
dictionaries[i],
batch_sizes[i],
device=self.device,
names=names,
_run_checks=False,
_is_shared=self.is_shared(),
_is_memmap=self.is_memmap(),
)
for i in range(len(dictionaries))
]

index = (slice(None),) * dim
return tuple(self[index + (ss,)] for ss in split_sizes)

def gather(self, dim: int, index: Tensor, out: T | None = None) -> T:
"""Gathers values along an axis specified by `dim`.
Expand Down
95 changes: 44 additions & 51 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,56 +2697,41 @@ def test_setdefault_nested(self, td_name, device):
torch.testing.assert_close(td.get(("a", "b", "d")), tensor2)

@pytest.mark.parametrize("performer", ["torch", "tensordict"])
def test_split(self, td_name, device, performer):
@pytest.mark.parametrize("dim", range(4))
def test_split(self, td_name, device, performer, dim):
td = getattr(self, td_name)(device)
tensor = torch.zeros(()).expand(td.shape)

rep, remainder = divmod(td.shape[dim], 2)

# split_sizes to be [2, 2, ..., 2, 1] or [2, 2, ..., 2]
split_sizes = [2] * rep + [1] * remainder
for test_split_size in (2, split_sizes):
if performer == "torch":
tds = torch.split(td, test_split_size, dim)
elif performer == "tensordict":
tds = td.split(test_split_size, dim)
tensors = tensor.split(test_split_size, dim)
length = len(tensors)
assert len(tds) == length, (
test_split_size,
dim,
[td.shape for td in tds],
td.shape,
length,
)

for dim in range(td.batch_dims):
rep, remainder = divmod(td.shape[dim], 2)
length = rep + remainder

# split_sizes to be [2, 2, ..., 2, 1] or [2, 2, ..., 2]
split_sizes = [2] * rep + [1] * remainder
for test_split_size in (2, split_sizes):
if performer == "torch":
tds = torch.split(td, test_split_size, dim)
elif performer == "tensordict":
tds = td.split(test_split_size, dim)
assert len(tds) == length

for idx, split_td in enumerate(tds):
expected_split_dim_size = 1 if idx == rep else 2
expected_batch_size = [
expected_split_dim_size if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(td.batch_size)
]

# Test each split_td has the expected batch_size
assert split_td.batch_size == torch.Size(expected_batch_size)

if td_name == "nested_td":
assert isinstance(split_td["my_nested_td"], TensorDict)
assert isinstance(
split_td["my_nested_td"]["inner"], torch.Tensor
)
for _tensor, split_td in zip(tensors, tds):
assert _tensor.shape == split_td.shape

if td_name == "nested_td":
assert isinstance(split_td["my_nested_td"], TensorDict)
assert isinstance(split_td["my_nested_td", "inner"], torch.Tensor)

# Test each tensor (or nested_td) in split_td has the expected shape
for key, item in split_td.items():
expected_shape = [
expected_split_dim_size if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(td[key].shape)
]
assert item.shape == torch.Size(expected_shape)

if key == "my_nested_td":
expected_inner_tensor_size = [
expected_split_dim_size if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(
td[key]["inner"].shape
)
]
assert item["inner"].shape == torch.Size(
expected_inner_tensor_size
)
# Test each tensor (or nested_td) in split_td has the expected shape
for key, item in split_td.items(True, True):
expected_shape = _tensor.shape + td[key].shape[len(_tensor.shape) :]
assert item.shape == torch.Size(expected_shape)

def test_pop(self, td_name, device):
td = getattr(self, td_name)(device)
Expand Down Expand Up @@ -4317,7 +4302,9 @@ def test_flatten_unflatten_key_collision(inplace, separator):
def test_split_with_invalid_arguments():
td = TensorDict({"a": torch.zeros(2, 1)}, [])
# Test empty batch size
with pytest.raises(RuntimeError, match="not splittable"):
with pytest.raises(
RuntimeError, match="The number of dimensions is insufficient for the split_dim"
):
td.split(1, 0)

td = TensorDict({}, [3, 2])
Expand All @@ -4329,16 +4316,22 @@ def test_split_with_invalid_arguments():
td.split(["1", 2], 0)

# Test invalid split_size sum
with pytest.raises(RuntimeError, match="expects split_size to sum exactly"):
with pytest.raises(
RuntimeError, match="Insufficient number of elements in split_size"
):
td.split([], 0)

with pytest.raises(RuntimeError, match="expects split_size to sum exactly"):
td.split([1, 1], 0)

# Test invalid dimension input
with pytest.raises(IndexError, match="Dimension out of range"):
with pytest.raises(
RuntimeError, match="The number of dimensions is insufficient for the split_dim"
):
td.split(1, 2)
with pytest.raises(IndexError, match="Dimension out of range"):
with pytest.raises(
RuntimeError, match="The number of dimensions is insufficient for the split_dim"
):
td.split(1, -3)


Expand Down

0 comments on commit 924a46a

Please sign in to comment.