diff --git a/.github/scripts/m1_script.sh b/.github/scripts/m1_script.sh new file mode 100644 index 000000000..4226b2beb --- /dev/null +++ b/.github/scripts/m1_script.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +export BUILD_VERSION=0.2.1 diff --git a/.github/unittest/linux/scripts/install.sh b/.github/unittest/linux/scripts/install.sh index 65904ffe1..07557287a 100755 --- a/.github/unittest/linux/scripts/install.sh +++ b/.github/unittest/linux/scripts/install.sh @@ -47,7 +47,10 @@ printf "* Installing tensordict\n" python setup.py develop # install torchsnapshot nightly -python -m pip install git+https://github.com/pytorch/torchsnapshot --no-build-isolation - +if [[ "$TORCH_VERSION" == "nightly" ]]; then + python -m pip install git+https://github.com/pytorch/torchsnapshot --no-build-isolation +elif [[ "$TORCH_VERSION" == "stable" ]]; then + python -m pip install torchsnapshot +fi # smoke test python -c "import functorch;import torchsnapshot" diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml new file mode 100644 index 000000000..3ab789999 --- /dev/null +++ b/.github/workflows/build-wheels-m1.yml @@ -0,0 +1,43 @@ +name: Build M1 Wheels + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + workflow_dispatch: + +jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: macos-arm64 + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build: + needs: generate-matrix + name: pytorch/tensordict + uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main + with: + repository: pytorch/tensordict + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} +# pre-script: .github/scripts/pre_build_script_m1.sh + post-script: "" + package-name: tensordict + runner-type: macos-m1-12 + smoke-test-script: "" + trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/m1_script.sh + secrets: + AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} + AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 860ac2541..80069c09f 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -4,7 +4,7 @@ on: types: [opened, synchronize, reopened] push: branches: - - release/0.2.0 + - release/0.2.1 concurrency: # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. @@ -32,7 +32,7 @@ jobs: run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install wheel - BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.2.1 python3 setup.py bdist_wheel # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; # pytorch/pytorch binaries are also manylinux_2_17 compliant but they @@ -72,7 +72,7 @@ jobs: run: | export CC=clang CXX=clang++ python3 -mpip install wheel - BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.2.1 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: @@ -104,7 +104,7 @@ jobs: shell: bash run: | python3 -mpip install wheel - BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.2.1 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: diff --git a/tensordict/memmap.py b/tensordict/memmap.py index 73cbf1733..175aee471 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -308,7 +308,7 @@ def _init_shape( dtype: torch.dtype, transfer_ownership: bool, ): - self._device = device + self._device = torch.device(device) self._shape = shape self._shape_indexed = None self.np_shape = tuple(self._shape) diff --git a/tensordict/persistent.py b/tensordict/persistent.py index cf75d3ecf..c3103c524 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -124,9 +124,7 @@ class PersistentTensorDict(TensorDictBase): """ - def __new__(cls, *args, **kwargs): - cls._td_dim_names = None - return super().__new__(cls, *args, **kwargs) + _td_dim_names = None def __init__( self, diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index d0324c0f8..a3bc77a32 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -355,17 +355,14 @@ class TensorDictBase(MutableMapping): ) KEY_ERROR = 'key "{}" not found in {} with ' "keys {}" - def __new__(cls, *args: Any, **kwargs: Any) -> T: - self = super().__new__(cls) - self._safe = kwargs.get("_safe", False) - self._lazy = kwargs.get("_lazy", False) - self._inplace_set = kwargs.get("_inplace_set", False) - self.is_meta = kwargs.get("is_meta", False) - self._is_locked = kwargs.get("_is_locked", False) - self._cache = None - self._last_op = None - self.__last_op_queue = None - return self + _safe = False + _lazy = False + _inplace_set = False + is_meta = False + _is_locked = False + _cache = None + _last_op = None + __last_op_queue = None def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() @@ -1812,8 +1809,14 @@ def as_tensor(self): and will raise an exception in all other cases. """ - warnings.warn("as_tensor will soon be deprecated.", category=DeprecationWarning) - return self + + def as_tensor(x): + try: + return x.as_tensor() + except AttributeError: + return x + + return self._fast_apply(as_tensor) def update( self, @@ -3128,73 +3131,49 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas A list of TensorDict with specified size in given dimension. """ - batch_sizes = [] - if self.batch_dims == 0: - raise RuntimeError("TensorDict with empty batch size is not splittable") - if not (-self.batch_dims <= dim < self.batch_dims): - raise IndexError( - f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})" - ) + # we must use slices to keep the storage of the tensors + batch_size = self.batch_size if dim < 0: - dim += self.batch_dims - if isinstance(split_size, int): - rep, remainder = divmod(self.batch_size[dim], split_size) - rep_shape = torch.Size( - [ - split_size if idx == dim else size - for (idx, size) in enumerate(self.batch_size) - ] + dim = len(batch_size) + dim + if dim < 0 or dim >= len(batch_size): + raise RuntimeError( + f"The number of dimensions is insufficient for the split_dim {dim}." ) - batch_sizes = [rep_shape for _ in range(rep)] - if remainder: - batch_sizes.append( - torch.Size( - [ - remainder if dim_idx == dim else dim_size - for (dim_idx, dim_size) in enumerate(self.batch_size) - ] - ) + if isinstance(split_size, int): + if split_size <= 0: + raise RuntimeError( + f"split_size must be strictly greater than 0, got {split_size}." ) + idx0 = 0 + idx1 = split_size + split_sizes = [slice(idx0, idx1)] + while idx1 < batch_size[dim]: + idx0 = idx1 + idx1 = idx1 + split_size + split_sizes.append(slice(idx0, idx1)) elif isinstance(split_size, list) and all( isinstance(element, int) for element in split_size ): - if sum(split_size) != self.batch_size[dim]: + if len(split_size) == 0: + raise RuntimeError("Insufficient number of elements in split_size.") + if sum(split_size) != batch_size[dim]: raise RuntimeError( f"Split method expects split_size to sum exactly to {self.batch_size[dim]} (tensor's size at dimension {dim}), but got split_size={split_size}" ) - for i in split_size: - batch_sizes.append( - torch.Size( - [ - i if dim_idx == dim else dim_size - for (dim_idx, dim_size) in enumerate(self.batch_size) - ] - ) - ) + idx0 = 0 + idx1 = split_size[0] + split_sizes = [slice(idx0, idx1)] + for idx in split_size[1:]: + idx0 = idx1 + idx1 += idx + split_sizes.append(slice(idx0, idx1)) else: raise TypeError( "split(): argument 'split_size' must be int or list of ints" ) - dictionaries = [{} for _ in range(len(batch_sizes))] - for key, item in self.items(): - split_tensors = torch.split(item, split_size, dim) - for idx, split_tensor in enumerate(split_tensors): - dictionaries[idx][key] = split_tensor - names = None - if self._has_names(): - names = copy(self.names) - return [ - TensorDict( - dictionaries[i], - batch_sizes[i], - device=self.device, - names=names, - _run_checks=False, - _is_shared=self.is_shared(), - _is_memmap=self.is_memmap(), - ) - for i in range(len(dictionaries)) - ] + + index = (slice(None),) * dim + return tuple(self[index + (ss,)] for ss in split_sizes) def gather(self, dim: int, index: Tensor, out: T | None = None) -> T: """Gathers values along an axis specified by `dim`. @@ -4072,26 +4051,11 @@ class TensorDict(TensorDictBase): """ - __slots__ = ( - "_tensordict", - "_batch_size", - "_is_shared", - "_is_memmap", - "_device", - "_is_locked", - "_td_dim_names", - "_lock_id", - "_locked_tensordicts", - "_cache", - "_last_op", - "__last_op_queue", - ) - - def __new__(cls, *args: Any, **kwargs: Any) -> TensorDict: - cls._is_shared = False - cls._is_memmap = False - cls._td_dim_names = None - return super().__new__(cls, *args, _safe=True, _lazy=False, **kwargs) + _is_shared = False + _is_memmap = False + _td_dim_names = None + _safe = True + _lazy = False def __init__( self, @@ -4995,11 +4959,12 @@ def _nested_keys( ) def __getstate__(self): - return { - slot: getattr(self, slot) - for slot in self.__slots__ - if slot not in ("_last_op", "_cache", "__last_op_queue") + result = { + key: val + for key, val in self.__dict__.items() + if key not in ("_last_op", "_cache", "__last_op_queue") } + return result def __setstate__(self, state): for slot, value in state.items(): @@ -5784,10 +5749,11 @@ class SubTensorDict(TensorDictBase): """ - def __new__(cls, *args: Any, **kwargs: Any) -> SubTensorDict: - cls._is_shared = False - cls._is_memmap = False - return super().__new__(cls, _safe=False, _lazy=True, _inplace_set=True) + _is_shared = False + _is_memmap = False + _safe = False + _lazy = True + _inplace_set = True def __init__( self, @@ -6408,9 +6374,9 @@ def __torch_function__( else: return super().__torch_function__(func, types, args, kwargs) - def __new__(cls, *args: Any, **kwargs: Any) -> LazyStackedTensorDict: - cls._td_dim_name = None - return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs) + _td_dim_name = None + _safe = False + _lazy = True def __init__( self, @@ -8156,8 +8122,8 @@ def _repr_exclusive_fields(self): class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" - def __new__(cls, *args: Any, **kwargs: Any) -> _CustomOpTensorDict: - return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs) + _safe = False + _lazy = True def __init__( self, diff --git a/tensordict/utils.py b/tensordict/utils.py index 78da3281f..7e79fc2ff 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -767,7 +767,6 @@ def _is_shared(tensor: torch.Tensor) -> bool: elif isinstance(tensor, KeyedJaggedTensor): return False else: - print(type(tensor)) return tensor.is_shared() diff --git a/test/test_tensordict.py b/test/test_tensordict.py index bc5d0d19b..894ad99a4 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2273,6 +2273,20 @@ def test_chunk(self, td_name, device, dim, chunks): assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim] assert (torch.cat(td_chunks, dim) == td).all() + def test_as_tensor(self, td_name, device): + td = getattr(self, td_name)(device) + if "memmap" in td_name and device == torch.device("cpu"): + tdt = td.as_tensor() + assert (tdt == td).all() + elif "memmap" in td_name: + with pytest.raises( + RuntimeError, match="can only be called with MemmapTensors stored" + ): + td.as_tensor() + else: + # checks that it runs + td.as_tensor() + def test_items_values_keys(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) @@ -2684,56 +2698,41 @@ def test_setdefault_nested(self, td_name, device): torch.testing.assert_close(td.get(("a", "b", "d")), tensor2) @pytest.mark.parametrize("performer", ["torch", "tensordict"]) - def test_split(self, td_name, device, performer): + @pytest.mark.parametrize("dim", range(4)) + def test_split(self, td_name, device, performer, dim): td = getattr(self, td_name)(device) + tensor = torch.zeros(()).expand(td.shape) + + rep, remainder = divmod(td.shape[dim], 2) + + # split_sizes to be [2, 2, ..., 2, 1] or [2, 2, ..., 2] + split_sizes = [2] * rep + [1] * remainder + for test_split_size in (2, split_sizes): + if performer == "torch": + tds = torch.split(td, test_split_size, dim) + elif performer == "tensordict": + tds = td.split(test_split_size, dim) + tensors = tensor.split(test_split_size, dim) + length = len(tensors) + assert len(tds) == length, ( + test_split_size, + dim, + [td.shape for td in tds], + td.shape, + length, + ) - for dim in range(td.batch_dims): - rep, remainder = divmod(td.shape[dim], 2) - length = rep + remainder - - # split_sizes to be [2, 2, ..., 2, 1] or [2, 2, ..., 2] - split_sizes = [2] * rep + [1] * remainder - for test_split_size in (2, split_sizes): - if performer == "torch": - tds = torch.split(td, test_split_size, dim) - elif performer == "tensordict": - tds = td.split(test_split_size, dim) - assert len(tds) == length - - for idx, split_td in enumerate(tds): - expected_split_dim_size = 1 if idx == rep else 2 - expected_batch_size = [ - expected_split_dim_size if dim_idx == dim else dim_size - for (dim_idx, dim_size) in enumerate(td.batch_size) - ] - - # Test each split_td has the expected batch_size - assert split_td.batch_size == torch.Size(expected_batch_size) - - if td_name == "nested_td": - assert isinstance(split_td["my_nested_td"], TensorDict) - assert isinstance( - split_td["my_nested_td"]["inner"], torch.Tensor - ) + for _tensor, split_td in zip(tensors, tds): + assert _tensor.shape == split_td.shape + + if td_name == "nested_td": + assert isinstance(split_td["my_nested_td"], TensorDict) + assert isinstance(split_td["my_nested_td", "inner"], torch.Tensor) - # Test each tensor (or nested_td) in split_td has the expected shape - for key, item in split_td.items(): - expected_shape = [ - expected_split_dim_size if dim_idx == dim else dim_size - for (dim_idx, dim_size) in enumerate(td[key].shape) - ] - assert item.shape == torch.Size(expected_shape) - - if key == "my_nested_td": - expected_inner_tensor_size = [ - expected_split_dim_size if dim_idx == dim else dim_size - for (dim_idx, dim_size) in enumerate( - td[key]["inner"].shape - ) - ] - assert item["inner"].shape == torch.Size( - expected_inner_tensor_size - ) + # Test each tensor (or nested_td) in split_td has the expected shape + for key, item in split_td.items(True, True): + expected_shape = _tensor.shape + td[key].shape[len(_tensor.shape) :] + assert item.shape == torch.Size(expected_shape) def test_pop(self, td_name, device): td = getattr(self, td_name)(device) @@ -4306,7 +4305,9 @@ def test_flatten_unflatten_key_collision(inplace, separator): def test_split_with_invalid_arguments(): td = TensorDict({"a": torch.zeros(2, 1)}, []) # Test empty batch size - with pytest.raises(RuntimeError, match="not splittable"): + with pytest.raises( + RuntimeError, match="The number of dimensions is insufficient for the split_dim" + ): td.split(1, 0) td = TensorDict({}, [3, 2]) @@ -4318,16 +4319,22 @@ def test_split_with_invalid_arguments(): td.split(["1", 2], 0) # Test invalid split_size sum - with pytest.raises(RuntimeError, match="expects split_size to sum exactly"): + with pytest.raises( + RuntimeError, match="Insufficient number of elements in split_size" + ): td.split([], 0) with pytest.raises(RuntimeError, match="expects split_size to sum exactly"): td.split([1, 1], 0) # Test invalid dimension input - with pytest.raises(IndexError, match="Dimension out of range"): + with pytest.raises( + RuntimeError, match="The number of dimensions is insufficient for the split_dim" + ): td.split(1, 2) - with pytest.raises(IndexError, match="Dimension out of range"): + with pytest.raises( + RuntimeError, match="The number of dimensions is insufficient for the split_dim" + ): td.split(1, -3) diff --git a/version.txt b/version.txt index 0ea3a944b..0c62199f1 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.0 +0.2.1