Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into memmap_tensor_refact
Browse files Browse the repository at this point in the history
# Conflicts:
#	tensordict/tensordict.py
#	test/test_tensordict.py
  • Loading branch information
vmoens committed Nov 14, 2023
2 parents 19ae8e0 + 924a46a commit ba931ef
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 163 deletions.
3 changes: 3 additions & 0 deletions .github/scripts/m1_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

export BUILD_VERSION=0.2.1
7 changes: 5 additions & 2 deletions .github/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ printf "* Installing tensordict\n"
python setup.py develop

# install torchsnapshot nightly
python -m pip install git+https://github.com/pytorch/torchsnapshot --no-build-isolation

if [[ "$TORCH_VERSION" == "nightly" ]]; then
python -m pip install git+https://github.com/pytorch/torchsnapshot --no-build-isolation
elif [[ "$TORCH_VERSION" == "stable" ]]; then
python -m pip install torchsnapshot
fi
# smoke test
python -c "import functorch;import torchsnapshot"
43 changes: 43 additions & 0 deletions .github/workflows/build-wheels-m1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Build M1 Wheels

on:
pull_request:
push:
branches:
- nightly
- main
- release/*
tags:
# NOTE: Binary build pipelines should only get triggered on release candidate builds
# Release candidate tags look like: v1.11.0-rc1
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
workflow_dispatch:

jobs:
generate-matrix:
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
with:
package-type: wheel
os: macos-arm64
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build:
needs: generate-matrix
name: pytorch/tensordict
uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main
with:
repository: pytorch/tensordict
ref: ""
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
# pre-script: .github/scripts/pre_build_script_m1.sh
post-script: ""
package-name: tensordict
runner-type: macos-m1-12
smoke-test-script: ""
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/m1_script.sh
secrets:
AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
8 changes: 4 additions & 4 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
types: [opened, synchronize, reopened]
push:
branches:
- release/0.2.0
- release/0.2.1

concurrency:
# Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
Expand Down Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
python3 -mpip install wheel
BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel
BUILD_VERSION=0.2.1 python3 setup.py bdist_wheel
# NB: wheels have the linux_x86_64 tag so we rename to manylinux1
# find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \;
# pytorch/pytorch binaries are also manylinux_2_17 compliant but they
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
run: |
export CC=clang CXX=clang++
python3 -mpip install wheel
BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel
BUILD_VERSION=0.2.1 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
with:
Expand Down Expand Up @@ -104,7 +104,7 @@ jobs:
shell: bash
run: |
python3 -mpip install wheel
BUILD_VERSION=0.2.0 python3 setup.py bdist_wheel
BUILD_VERSION=0.2.1 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
with:
Expand Down
2 changes: 1 addition & 1 deletion tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _init_shape(
dtype: torch.dtype,
transfer_ownership: bool,
):
self._device = device
self._device = torch.device(device)
self._shape = shape
self._shape_indexed = None
self.np_shape = tuple(self._shape)
Expand Down
4 changes: 1 addition & 3 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ class PersistentTensorDict(TensorDictBase):
"""

def __new__(cls, *args, **kwargs):
cls._td_dim_names = None
return super().__new__(cls, *args, **kwargs)
_td_dim_names = None

def __init__(
self,
Expand Down
166 changes: 66 additions & 100 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,17 +355,14 @@ class TensorDictBase(MutableMapping):
)
KEY_ERROR = 'key "{}" not found in {} with ' "keys {}"

def __new__(cls, *args: Any, **kwargs: Any) -> T:
self = super().__new__(cls)
self._safe = kwargs.get("_safe", False)
self._lazy = kwargs.get("_lazy", False)
self._inplace_set = kwargs.get("_inplace_set", False)
self.is_meta = kwargs.get("is_meta", False)
self._is_locked = kwargs.get("_is_locked", False)
self._cache = None
self._last_op = None
self.__last_op_queue = None
return self
_safe = False
_lazy = False
_inplace_set = False
is_meta = False
_is_locked = False
_cache = None
_last_op = None
__last_op_queue = None

def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
Expand Down Expand Up @@ -1812,8 +1809,14 @@ def as_tensor(self):
and will raise an exception in all other cases.
"""
warnings.warn("as_tensor will soon be deprecated.", category=DeprecationWarning)
return self

def as_tensor(x):
try:
return x.as_tensor()
except AttributeError:
return x

return self._fast_apply(as_tensor)

def update(
self,
Expand Down Expand Up @@ -3128,73 +3131,49 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas
A list of TensorDict with specified size in given dimension.
"""
batch_sizes = []
if self.batch_dims == 0:
raise RuntimeError("TensorDict with empty batch size is not splittable")
if not (-self.batch_dims <= dim < self.batch_dims):
raise IndexError(
f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})"
)
# we must use slices to keep the storage of the tensors
batch_size = self.batch_size
if dim < 0:
dim += self.batch_dims
if isinstance(split_size, int):
rep, remainder = divmod(self.batch_size[dim], split_size)
rep_shape = torch.Size(
[
split_size if idx == dim else size
for (idx, size) in enumerate(self.batch_size)
]
dim = len(batch_size) + dim
if dim < 0 or dim >= len(batch_size):
raise RuntimeError(
f"The number of dimensions is insufficient for the split_dim {dim}."
)
batch_sizes = [rep_shape for _ in range(rep)]
if remainder:
batch_sizes.append(
torch.Size(
[
remainder if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(self.batch_size)
]
)
if isinstance(split_size, int):
if split_size <= 0:
raise RuntimeError(
f"split_size must be strictly greater than 0, got {split_size}."
)
idx0 = 0
idx1 = split_size
split_sizes = [slice(idx0, idx1)]
while idx1 < batch_size[dim]:
idx0 = idx1
idx1 = idx1 + split_size
split_sizes.append(slice(idx0, idx1))
elif isinstance(split_size, list) and all(
isinstance(element, int) for element in split_size
):
if sum(split_size) != self.batch_size[dim]:
if len(split_size) == 0:
raise RuntimeError("Insufficient number of elements in split_size.")
if sum(split_size) != batch_size[dim]:
raise RuntimeError(
f"Split method expects split_size to sum exactly to {self.batch_size[dim]} (tensor's size at dimension {dim}), but got split_size={split_size}"
)
for i in split_size:
batch_sizes.append(
torch.Size(
[
i if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(self.batch_size)
]
)
)
idx0 = 0
idx1 = split_size[0]
split_sizes = [slice(idx0, idx1)]
for idx in split_size[1:]:
idx0 = idx1
idx1 += idx
split_sizes.append(slice(idx0, idx1))
else:
raise TypeError(
"split(): argument 'split_size' must be int or list of ints"
)
dictionaries = [{} for _ in range(len(batch_sizes))]
for key, item in self.items():
split_tensors = torch.split(item, split_size, dim)
for idx, split_tensor in enumerate(split_tensors):
dictionaries[idx][key] = split_tensor
names = None
if self._has_names():
names = copy(self.names)
return [
TensorDict(
dictionaries[i],
batch_sizes[i],
device=self.device,
names=names,
_run_checks=False,
_is_shared=self.is_shared(),
_is_memmap=self.is_memmap(),
)
for i in range(len(dictionaries))
]

index = (slice(None),) * dim
return tuple(self[index + (ss,)] for ss in split_sizes)

def gather(self, dim: int, index: Tensor, out: T | None = None) -> T:
"""Gathers values along an axis specified by `dim`.
Expand Down Expand Up @@ -4072,26 +4051,11 @@ class TensorDict(TensorDictBase):
"""

__slots__ = (
"_tensordict",
"_batch_size",
"_is_shared",
"_is_memmap",
"_device",
"_is_locked",
"_td_dim_names",
"_lock_id",
"_locked_tensordicts",
"_cache",
"_last_op",
"__last_op_queue",
)

def __new__(cls, *args: Any, **kwargs: Any) -> TensorDict:
cls._is_shared = False
cls._is_memmap = False
cls._td_dim_names = None
return super().__new__(cls, *args, _safe=True, _lazy=False, **kwargs)
_is_shared = False
_is_memmap = False
_td_dim_names = None
_safe = True
_lazy = False

def __init__(
self,
Expand Down Expand Up @@ -4995,11 +4959,12 @@ def _nested_keys(
)

def __getstate__(self):
return {
slot: getattr(self, slot)
for slot in self.__slots__
if slot not in ("_last_op", "_cache", "__last_op_queue")
result = {
key: val
for key, val in self.__dict__.items()
if key not in ("_last_op", "_cache", "__last_op_queue")
}
return result

def __setstate__(self, state):
for slot, value in state.items():
Expand Down Expand Up @@ -5784,10 +5749,11 @@ class SubTensorDict(TensorDictBase):
"""

def __new__(cls, *args: Any, **kwargs: Any) -> SubTensorDict:
cls._is_shared = False
cls._is_memmap = False
return super().__new__(cls, _safe=False, _lazy=True, _inplace_set=True)
_is_shared = False
_is_memmap = False
_safe = False
_lazy = True
_inplace_set = True

def __init__(
self,
Expand Down Expand Up @@ -6408,9 +6374,9 @@ def __torch_function__(
else:
return super().__torch_function__(func, types, args, kwargs)

def __new__(cls, *args: Any, **kwargs: Any) -> LazyStackedTensorDict:
cls._td_dim_name = None
return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs)
_td_dim_name = None
_safe = False
_lazy = True

def __init__(
self,
Expand Down Expand Up @@ -8156,8 +8122,8 @@ def _repr_exclusive_fields(self):
class _CustomOpTensorDict(TensorDictBase):
"""Encodes lazy operations on tensors contained in a TensorDict."""

def __new__(cls, *args: Any, **kwargs: Any) -> _CustomOpTensorDict:
return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs)
_safe = False
_lazy = True

def __init__(
self,
Expand Down
1 change: 0 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,6 @@ def _is_shared(tensor: torch.Tensor) -> bool:
elif isinstance(tensor, KeyedJaggedTensor):
return False
else:
print(type(tensor))
return tensor.is_shared()


Expand Down
Loading

0 comments on commit ba931ef

Please sign in to comment.