From 92f7f6f208be1850f80f71ee506d9a94696379f6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 4 Jan 2024 21:57:05 +0000 Subject: [PATCH] [BugFix, Doc] Fix tutorial (#606) --- .github/workflows/docs.yml | 169 +++-- docs/source/reference/nn.rst | 2 +- docs/source/reference/tensordict.rst | 3 +- tensordict/__init__.py | 15 +- tensordict/_td.py | 2 +- tutorials/sphinx_tuto/data_fashion.py | 17 +- tutorials/sphinx_tuto/tensorclass_fashion.py | 8 +- tutorials/sphinx_tuto/tensorclass_imagenet.py | 715 +++++++++--------- tutorials/sphinx_tuto/tensordict_memory.py | 12 +- 9 files changed, 520 insertions(+), 423 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index f90df1678..541d3cf42 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -4,6 +4,9 @@ on: push: branches: - main + pull_request: + branches: + - "*" workflow_dispatch: concurrency: @@ -14,46 +17,126 @@ concurrency: jobs: build-docs: - runs-on: ubuntu-20.04 - defaults: - run: - shell: bash -l {0} - steps: - - name: Checkout tensordict - uses: actions/checkout@v3 - - name: Setup Miniconda - uses: conda-incubator/setup-miniconda@v2 - with: - activate-environment: build_docs - python-version: "3.9" - - name: Check Python version - run: | - python --version - - name: Install PyTorch - run: | - pip3 install pip --upgrade - pip3 install packaging - pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu - - name: Install TensorDict - run: | - python -m pip install -e . - - name: Test tensordict installation - run: | - mkdir _tmp - cd _tmp - python -c "import tensordict" - cd .. - - name: Build the docset - working-directory: ./docs - run: | - python -m pip install -r requirements.txt - make docs - - name: Get output time - run: echo "The time was ${{ steps.build.outputs.time }}" - - name: Deploy - uses: JamesIves/github-pages-deploy-action@releases/v3 - with: - ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} - BRANCH: gh-pages # The branch the action should deploy to. - FOLDER: docs/build/html # The folder the action should deploy. - CLEAN: false + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/tensordict + upload-artifact: docs + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -e + set -v + apt-get update && apt-get install -y git wget gcc g++ + root_dir="$(pwd)" + conda_dir="${root_dir}/conda" + env_dir="${root_dir}/env" + os=Linux + + # 1. Install conda at ./conda + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" + eval "$(${conda_dir}/bin/conda shell.bash hook)" + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python=3.8 + printf "* Activating\n" + conda activate "${env_dir}" + + # 2. upgrade pip, ninja and packaging + apt-get install python3.8 python3-pip unzip -y + python3 -m pip install --upgrade pip + python3 -m pip install setuptools ninja packaging -U + + # 3. check python version + python3 --version + + # 4. Check git version + git version + + # 5. Install PyTorch + python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu --quiet --root-user-action=ignore + + # 6. Install tensordict + python3 setup.py develop + + # 7. Install requirements + python3 -m pip install -r docs/requirements.txt --quiet --root-user-action=ignore + + # 8. Test tensordict installation + mkdir _tmp + cd _tmp + PYOPENGL_PLATFORM=egl MUJOCO_GL=egl python3 -c """from tensordict import *""" + cd .. + + # 10. Build doc + cd ./docs + make docs + # PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build + cd .. + + cp -r docs/build/* "${RUNNER_ARTIFACT_DIR}" + echo $(ls "${RUNNER_ARTIFACT_DIR}") + if [[ ${{ github.event_name == 'pull_request' }} ]]; then + cp -r docs/build/* "${RUNNER_DOCS_DIR}" + fi + + upload: + needs: build-docs + if: github.repository == 'pytorch/tensordict' && github.event_name == 'push' && + ((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag') + permissions: + contents: write + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/tensordict + download-artifact: docs + ref: gh-pages + test-infra-ref: main + script: | + set -euo pipefail + + REF_TYPE=${{ github.ref_type }} + REF_NAME=${{ github.ref_name }} + + # TODO: adopt this behaviour + # if [[ "${REF_TYPE}" == branch ]]; then + # TARGET_FOLDER="${REF_NAME}" + # elif [[ "${REF_TYPE}" == tag ]]; then + # case "${REF_NAME}" in + # *-rc*) + # echo "Aborting upload since this is an RC tag: ${REF_NAME}" + # exit 0 + # ;; + # *) + # # Strip the leading "v" as well as the trailing patch version. For example: + # # 'v0.15.2' -> '0.15' + # TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/') + # ;; + # esac + # fi + TARGET_FOLDER="./" + echo "Target Folder: ${TARGET_FOLDER}" + + # mkdir -p "${TARGET_FOLDER}" + # rm -rf "${TARGET_FOLDER}"/* + echo $(ls "${RUNNER_ARTIFACT_DIR}") + rsync -a "${RUNNER_ARTIFACT_DIR}"/ "${TARGET_FOLDER}" + git add "${TARGET_FOLDER}" || true + + # if [[ "${TARGET_FOLDER}" == main ]]; then + # mkdir -p _static + # rm -rf _static/* + # cp -r "${TARGET_FOLDER}"/_static/* _static + # git add _static || true + # fi + + git config user.name 'pytorchbot' + git config user.email 'soumith+bot@pytorch.org' + git config http.postBuffer 524288000 + git commit -m "auto-generating sphinx docs" || true + git push diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index 9e8b43ce3..5b526eb47 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -320,7 +320,7 @@ first traced using :func:`~.symbolic_trace`. Distributions ------------- -.. py:currentmodule::tensordict.nn.distributions +.. currentmodule:: tensordict.nn.distributions .. autosummary:: :toctree: generated/ diff --git a/docs/source/reference/tensordict.rst b/docs/source/reference/tensordict.rst index b4f5ced08..3ab8e31f8 100644 --- a/docs/source/reference/tensordict.rst +++ b/docs/source/reference/tensordict.rst @@ -50,7 +50,7 @@ Memory-mapped tensors `tensordict` offers the :class:`~tensordict.MemoryMappedTensor` primitive which allows you to work with tensors stored in physical memory in a handy way. The main advantages of :class:`~tensordict.MemoryMappedTensor` -are its easiness of construction (no need to handle the storage of a tensor), +are its ease of construction (no need to handle the storage of a tensor), the possibility to work with big contiguous data that would not fit in memory, an efficient (de)serialization across processes and efficient indexing of stored tensors. @@ -96,3 +96,4 @@ Utils pad_sequence dense_stack_tds set_lazy_legacy + lazy_legacy diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 2238fb0a3..ad9f55358 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -17,17 +17,22 @@ from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership from tensordict.persistent import PersistentTensorDict from tensordict.tensorclass import NonTensorData, tensorclass -from tensordict.utils import assert_allclose_td, is_batchedtensor, is_tensorclass +from tensordict.utils import ( + assert_allclose_td, + is_batchedtensor, + is_tensorclass, + lazy_legacy, + set_lazy_legacy, +) +from tensordict._pytree import * + +from tensordict._tensordict import unravel_key, unravel_key_list try: from tensordict.version import __version__ except ImportError: __version__ = None -from tensordict._pytree import * - -from tensordict._tensordict import unravel_key, unravel_key_list - __all__ = [ "LazyStackedTensorDict", "MemmapTensor", diff --git a/tensordict/_td.py b/tensordict/_td.py index d355c6be0..c603e77cd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -554,7 +554,7 @@ def __setitem__( if key in keys: self._set_at_str(key, item, index, validated=False) else: - subtd.set(key, item) + subtd.set(key, item, inplace=True) else: for key in self.keys(): self.set_at_(key, value, index) diff --git a/tutorials/sphinx_tuto/data_fashion.py b/tutorials/sphinx_tuto/data_fashion.py index 3f2fb8ac7..8a42ddff7 100644 --- a/tutorials/sphinx_tuto/data_fashion.py +++ b/tutorials/sphinx_tuto/data_fashion.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from tensordict import MemmapTensor, TensorDict +from tensordict import MemoryMappedTensor, TensorDict from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor @@ -50,27 +50,26 @@ # batches of transformed data from disk rather than repeatedly load and transform # individual images. # -# First we create the ``MemmapTensor`` containers. +# First we create the :class:`~tensordict.MemoryMappedTensor` containers. training_data_td = TensorDict( { - "images": MemmapTensor( - len(training_data), - *training_data[0][0].squeeze().shape, + "images": MemoryMappedTensor.empty( + (len(training_data), *training_data[0][0].squeeze().shape), dtype=torch.float32, ), - "targets": MemmapTensor(len(training_data), dtype=torch.int64), + "targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64), }, batch_size=[len(training_data)], device=device, ) test_data_td = TensorDict( { - "images": MemmapTensor( - len(test_data), *test_data[0][0].squeeze().shape, dtype=torch.float32 + "images": MemoryMappedTensor.empty( + (len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32 ), - "targets": MemmapTensor(len(test_data), dtype=torch.int64), + "targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64), }, batch_size=[len(test_data)], device=device, diff --git a/tutorials/sphinx_tuto/tensorclass_fashion.py b/tutorials/sphinx_tuto/tensorclass_fashion.py index fad38283d..98f3d8fb0 100644 --- a/tutorials/sphinx_tuto/tensorclass_fashion.py +++ b/tutorials/sphinx_tuto/tensorclass_fashion.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from tensordict import MemmapTensor +from tensordict import MemoryMappedTensor from tensordict.prototype import tensorclass from torch.utils.data import DataLoader from torchvision import datasets @@ -68,10 +68,10 @@ class FashionMNISTData: @classmethod def from_dataset(cls, dataset, device=None): data = cls( - images=MemmapTensor( - len(dataset), *dataset[0][0].squeeze().shape, dtype=torch.float32 + images=MemoryMappedTensor.empty( + (len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32 ), - targets=MemmapTensor(len(dataset), dtype=torch.int64), + targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64), batch_size=[len(dataset)], device=device, ) diff --git a/tutorials/sphinx_tuto/tensorclass_imagenet.py b/tutorials/sphinx_tuto/tensorclass_imagenet.py index 4e7a39ffa..3e711e3bf 100644 --- a/tutorials/sphinx_tuto/tensorclass_imagenet.py +++ b/tutorials/sphinx_tuto/tensorclass_imagenet.py @@ -39,374 +39,381 @@ import torch.nn as nn import tqdm -from tensordict import MemmapTensor +from tensordict import MemoryMappedTensor from tensordict.prototype import tensorclass from torch.utils.data import DataLoader from torchvision import datasets, transforms -NUM_WORKERS = int(os.environ.get("NUM_WORKERS", "4")) -# sphinx_gallery_start_ignore -# this example can be run locally or in the tensordict CI on the small hymenoptera -# subset of imagenet, but we also want to be able to compare to runs on larger subsets -# of imagenet which would be impractical to run in CI. If this script is run with the -# environment variable RUN_ON_CLUSTER set, then we set everything to run on a larger -# subset of imagenet. the fraction of images can be set with the FRACTION environment -# variable, we use the first `len(dataset) // FRACTION` images. Default is 10. -RUN_ON_CLUSTER = strtobool(os.environ.get("RUN_ON_CLUSTER", "False")) -FRACTION = int(os.environ.get("FRACTION", 10)) -# sphinx_gallery_end_ignore -device = "cuda:0" if torch.cuda.is_available() else "cpu" -print(f"Using device: {device}") - - -############################################################################## -# Transforms -# ---------- -# First we define train and val transforms that will be applied to train and -# val examples respectively. Note that there are random components in the -# train transform to prevent overfitting to training data over multiple -# epochs. - -train_transform = transforms.Compose( - [ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] -) - -val_transform = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] -) - -############################################################################## -# We use ``torchvision.datasets.ImageFolder`` to conveniently load and -# transform the data from disk. - -data_dir = Path("data") / "hymenoptera_data/" -# sphinx_gallery_start_ignore -if RUN_ON_CLUSTER: - data_dir = Path("/datasets01_ontap/imagenet_full_size/061417/") -# sphinx_gallery_end_ignore - -train_data = datasets.ImageFolder(root=data_dir / "train", transform=train_transform) -val_data = datasets.ImageFolder(root=data_dir / "val", transform=val_transform) -# sphinx_gallery_start_ignore -if RUN_ON_CLUSTER: - if FRACTION > 1: - train_data.samples = train_data.samples[: len(train_data) // FRACTION] - val_data.samples = val_data.samples[: len(val_data) // FRACTION] -# sphinx_gallery_end_ignore - -############################################################################## -# We’ll also create a dataset of the raw training data that simply resizes -# the image to a common size and converts to tensor. We’ll use this to -# load the data into memory-mapped tensors. The random transformations -# need to be different each time we iterate through the data, so they -# cannot be pre-computed. We also do not scale the data yet so that we can set the -# ``dtype`` of the memory-mapped array to ``uint8`` and save space. - -train_data_raw = datasets.ImageFolder( - root=data_dir / "train", - transform=transforms.Compose( - [transforms.Resize((256, 256)), transforms.PILToTensor()] - ), -) -# sphinx_gallery_start_ignore -if RUN_ON_CLUSTER: - train_data_raw.samples = train_data_raw.samples[: len(train_data_raw) // FRACTION] -# sphinx_gallery_end_ignore - - -############################################################################## -# Since we'll be loading our data in batches, we write a few custom transformations -# that take advantage of this, and apply the transformations in a vectorized way. -# -# First a transformation that can be used for normalization. -class InvAffine(nn.Module): - """A custom normalization layer.""" - - def __init__(self, loc, scale): - super().__init__() - self.loc = loc - self.scale = scale - - def forward(self, x): - return (x - self.loc) / self.scale - +if __name__ == "__main__": + NUM_WORKERS = int(os.environ.get("NUM_WORKERS", "4")) + # sphinx_gallery_start_ignore + # this example can be run locally or in the tensordict CI on the small hymenoptera + # subset of imagenet, but we also want to be able to compare to runs on larger subsets + # of imagenet which would be impractical to run in CI. If this script is run with the + # environment variable RUN_ON_CLUSTER set, then we set everything to run on a larger + # subset of imagenet. the fraction of images can be set with the FRACTION environment + # variable, we use the first `len(dataset) // FRACTION` images. Default is 10. + RUN_ON_CLUSTER = strtobool(os.environ.get("RUN_ON_CLUSTER", "False")) + FRACTION = int(os.environ.get("FRACTION", 10)) + # sphinx_gallery_end_ignore + device = "cuda:0" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + ############################################################################## + # Transforms + # ---------- + # First we define train and val transforms that will be applied to train and + # val examples respectively. Note that there are random components in the + # train transform to prevent overfitting to training data over multiple + # epochs. + + train_transform = transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) -############################################################################## -# Next two transformations that can be used to randomly crop and flip the images. - - -class RandomHFlip(nn.Module): - def forward(self, x: torch.Tensor): - idx = ( - torch.zeros(*x.shape[:-3], 1, 1, 1, device=x.device, dtype=torch.bool) - .bernoulli_() - .expand_as(x) - ) - return x.masked_fill(idx, 0.0) + x.masked_fill(~idx, 0.0).flip(-1) - - -class RandomCrop(nn.Module): - def __init__(self, w, h): - super(RandomCrop, self).__init__() - self.w = w - self.h = h - - def forward(self, x): - batch = x.shape[:-3] - index0 = torch.randint(x.shape[-2] - self.h, (*batch, 1), device=x.device) - index0 = index0 + torch.arange(self.h, device=x.device) - index0 = ( - index0.unsqueeze(1).unsqueeze(-1).expand((*batch, 3), self.h, x.shape[-1]) - ) - index1 = torch.randint(x.shape[-1] - self.w, (*batch, 1), device=x.device) - index1 = index1 + torch.arange(self.w, device=x.device) - index1 = index1.unsqueeze(1).unsqueeze(-2).expand((*batch, 3), self.h, self.w) - return x.gather(-2, index0).gather(-1, index1) + val_transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + ############################################################################## + # We use ``torchvision.datasets.ImageFolder`` to conveniently load and + # transform the data from disk. -############################################################################## -# When each batch is loaded, we will scale it, then randomly crop and flip. The random -# transformations cannot be pre-applied as they must differ each time we iterate over -# the data. The scaling could be pre-applied in principle, but by waiting until we load -# the data into RAM, we are able to set the dtype of the memory-mapped array to -# ``uint8``, a significant space saving over ``float32``. - -collate_transform = nn.Sequential( - InvAffine( - loc=torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1) * 255, - scale=torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1) * 255, - ), - RandomCrop(224, 224), - RandomHFlip(), -) + data_dir = Path("data") / "hymenoptera_data/" + # sphinx_gallery_start_ignore + if RUN_ON_CLUSTER: + data_dir = Path("/datasets01_ontap/imagenet_full_size/061417/") + # sphinx_gallery_end_ignore -############################################################################## -# Representing data with a TensorClass -# ------------------------------------ -# Tensorclasses are a good choice when the structure of your data is known -# apriori. They are dataclasses that expose dedicated tensor methods over -# their contents much like a ``TensorDict``. -# -# As well as specifying the contents (in this case ``images`` and -# ``targets``) we can also encapsulate related logic as custom methods -# when defining the class. Here we add a classmethod that takes a dataset -# and creates a tensorclass containing the data by iterating over the -# dataset. We create memory-mapped tensors to hold the data so that they -# can be efficiently loaded in batches later. - - -@tensorclass -class ImageNetData: - images: torch.Tensor - targets: torch.Tensor - - @classmethod - def from_dataset(cls, dataset): - data = cls( - images=MemmapTensor( - len(dataset), - *dataset[0][0].squeeze().shape, - dtype=torch.uint8, - ), - targets=MemmapTensor(len(dataset), dtype=torch.int64), - batch_size=[len(dataset)], - ) - # locks the tensorclass and ensures that is_memmap will return True. - data.memmap_() - - batch = 64 - dl = DataLoader(dataset, batch_size=batch, num_workers=NUM_WORKERS) - i = 0 - pbar = tqdm.tqdm(total=len(dataset)) - for image, target in dl: - _batch = image.shape[0] - pbar.update(_batch) - data[i : i + _batch] = cls( - images=image, targets=target, batch_size=[_batch] + train_data = datasets.ImageFolder( + root=data_dir / "train", transform=train_transform + ) + val_data = datasets.ImageFolder(root=data_dir / "val", transform=val_transform) + # sphinx_gallery_start_ignore + if RUN_ON_CLUSTER: + if FRACTION > 1: + train_data.samples = train_data.samples[: len(train_data) // FRACTION] + val_data.samples = val_data.samples[: len(val_data) // FRACTION] + # sphinx_gallery_end_ignore + + ############################################################################## + # We’ll also create a dataset of the raw training data that simply resizes + # the image to a common size and converts to tensor. We’ll use this to + # load the data into memory-mapped tensors. The random transformations + # need to be different each time we iterate through the data, so they + # cannot be pre-computed. We also do not scale the data yet so that we can set the + # ``dtype`` of the memory-mapped array to ``uint8`` and save space. + + train_data_raw = datasets.ImageFolder( + root=data_dir / "train", + transform=transforms.Compose( + [transforms.Resize((256, 256)), transforms.PILToTensor()] + ), + ) + # sphinx_gallery_start_ignore + if RUN_ON_CLUSTER: + train_data_raw.samples = train_data_raw.samples[ + : len(train_data_raw) // FRACTION + ] + # sphinx_gallery_end_ignore + + ############################################################################## + # Since we'll be loading our data in batches, we write a few custom transformations + # that take advantage of this, and apply the transformations in a vectorized way. + # + # First a transformation that can be used for normalization. + class InvAffine(nn.Module): + """A custom normalization layer.""" + + def __init__(self, loc, scale): + super().__init__() + self.loc = loc + self.scale = scale + + def forward(self, x): + return (x - self.loc) / self.scale + + ############################################################################## + # Next two transformations that can be used to randomly crop and flip the images. + + class RandomHFlip(nn.Module): + def forward(self, x: torch.Tensor): + idx = ( + torch.zeros(*x.shape[:-3], 1, 1, 1, device=x.device, dtype=torch.bool) + .bernoulli_() + .expand_as(x) ) - i += _batch - - return data - - -############################################################################## -# We create two tensorclasses, one for the training and on for the -# validation data. Note that while this step can be slightly expensive, it -# allows us to save repeated computation later during training. - -train_data_tc = ImageNetData.from_dataset(train_data_raw) -val_data_tc = ImageNetData.from_dataset(val_data) - -############################################################################## -# DataLoaders -# ----------- -# -# We can create dataloaders both from the ``torchvision``-provided -# Datasets, as well as from our memory-mapped tensorclasses. -# -# Since tensorclasses implement ``__len__`` and ``__getitem__`` (and also -# ``__getitems__``) we can use them like a map-style Dataset and create a -# ``DataLoader`` directly from them. -# -# Since the TensorClass data will be loaded in batches, we need to specify how these -# batches should be collated. For this we write the following helper class - - -class Collate(nn.Module): - def __init__(self, transform=None, device=None): - super().__init__() - self.transform = transform - self.device = torch.device(device) - - def __call__(self, x: ImageNetData): - # move data to RAM - if self.device.type == "cuda": - out = x.apply(lambda x: x.as_tensor()).pin_memory() - else: - out = x.apply(lambda x: x.as_tensor()) - if self.device: - # move data to gpu - out = out.to(self.device) - if self.transform: - # apply transforms on gpu - out.images = self.transform(out.images) - return out - + return x.masked_fill(idx, 0.0) + x.masked_fill(~idx, 0.0).flip(-1) + + class RandomCrop(nn.Module): + def __init__(self, w, h): + super(RandomCrop, self).__init__() + self.w = w + self.h = h + + def forward(self, x): + batch = x.shape[:-3] + index0 = torch.randint(x.shape[-2] - self.h, (*batch, 1), device=x.device) + index0 = index0 + torch.arange(self.h, device=x.device) + index0 = ( + index0.unsqueeze(1) + .unsqueeze(-1) + .expand((*batch, 3, self.h, x.shape[-1])) + ) + index1 = torch.randint(x.shape[-1] - self.w, (*batch, 1), device=x.device) + index1 = index1 + torch.arange(self.w, device=x.device) + index1 = ( + index1.unsqueeze(1).unsqueeze(-2).expand((*batch, 3, self.h, self.w)) + ) + return x.gather(-2, index0).gather(-1, index1) + + ############################################################################## + # When each batch is loaded, we will scale it, then randomly crop and flip. The random + # transformations cannot be pre-applied as they must differ each time we iterate over + # the data. The scaling could be pre-applied in principle, but by waiting until we load + # the data into RAM, we are able to set the dtype of the memory-mapped array to + # ``uint8``, a significant space saving over ``float32``. + + collate_transform = nn.Sequential( + InvAffine( + loc=torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1) * 255, + scale=torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1) + * 255, + ), + RandomCrop(224, 224), + RandomHFlip(), + ) -############################################################################## -# ``DataLoader`` has support for multiple workers loading data in parallel. The -# tensorclass dataloader will use just one worker, but load data in batches. -# -# Note that under this approach our ``collate_fn`` is essentially just an ``nn.Module``, -# making it transparent and easy to implement. But this approach also offers -# flexibility, for example, if needed we could move the collation step into the training -# loop by considering the ``Collate`` module as part of the model. - -batch_size = 8 -# sphinx_gallery_start_ignore -if RUN_ON_CLUSTER: - batch_size = 128 -# sphinx_gallery_end_ignore -train_dataloader = DataLoader( - train_data, - batch_size=batch_size, - num_workers=NUM_WORKERS, -) -val_dataloader = DataLoader( - val_data, - batch_size=batch_size, - num_workers=NUM_WORKERS, -) - -train_dataloader_tc = DataLoader( # noqa: TOR401 - train_data_tc, - batch_size=batch_size, - collate_fn=Collate(collate_transform, device), -) -val_dataloader_tc = DataLoader( # noqa: TOR401 - val_data_tc, - batch_size=batch_size, - collate_fn=Collate(device=device), -) + ############################################################################## + # Representing data with a TensorClass + # ------------------------------------ + # Tensorclasses are a good choice when the structure of your data is known + # apriori. They are dataclasses that expose dedicated tensor methods over + # their contents much like a ``TensorDict``. + # + # As well as specifying the contents (in this case ``images`` and + # ``targets``) we can also encapsulate related logic as custom methods + # when defining the class. Here we add a classmethod that takes a dataset + # and creates a tensorclass containing the data by iterating over the + # dataset. We create memory-mapped tensors to hold the data so that they + # can be efficiently loaded in batches later. + + @tensorclass + class ImageNetData: + images: torch.Tensor + targets: torch.Tensor + + @classmethod + def from_dataset(cls, dataset): + data = cls( + images=MemoryMappedTensor.empty( + ( + len(dataset), + *dataset[0][0].squeeze().shape, + ), + dtype=torch.uint8, + ), + targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64), + batch_size=[len(dataset)], + ) + # locks the tensorclass and ensures that is_memmap will return True. + data.memmap_() + + batch = 64 + dl = DataLoader(dataset, batch_size=batch, num_workers=NUM_WORKERS) + i = 0 + pbar = tqdm.tqdm(total=len(dataset)) + for image, target in dl: + _batch = image.shape[0] + pbar.update(_batch) + print(data) + print(cls(images=image, targets=target, batch_size=[_batch])) + data[i : i + _batch] = cls( + images=image, targets=target, batch_size=[_batch] + ) + i += _batch + + return data + + ############################################################################## + # We create two tensorclasses, one for the training and on for the + # validation data. Note that while this step can be slightly expensive, it + # allows us to save repeated computation later during training. + + train_data_tc = ImageNetData.from_dataset(train_data_raw) + val_data_tc = ImageNetData.from_dataset(val_data) + + ############################################################################## + # DataLoaders + # ----------- + # + # We can create dataloaders both from the ``torchvision``-provided + # Datasets, as well as from our memory-mapped tensorclasses. + # + # Since tensorclasses implement ``__len__`` and ``__getitem__`` (and also + # ``__getitems__``) we can use them like a map-style Dataset and create a + # ``DataLoader`` directly from them. + # + # Since the TensorClass data will be loaded in batches, we need to specify how these + # batches should be collated. For this we write the following helper class + + class Collate(nn.Module): + def __init__(self, transform=None, device=None): + super().__init__() + self.transform = transform + self.device = torch.device(device) + + def __call__(self, x: ImageNetData): + # move data to RAM + if self.device.type == "cuda": + out = x.pin_memory() + else: + out = x + if self.device: + # move data to gpu + out = out.to(self.device) + if self.transform: + # apply transforms on gpu + out.images = self.transform(out.images) + return out + + ############################################################################## + # ``DataLoader`` has support for multiple workers loading data in parallel. The + # tensorclass dataloader will use just one worker, but load data in batches. + # + # Note that under this approach our ``collate_fn`` is essentially just an ``nn.Module``, + # making it transparent and easy to implement. But this approach also offers + # flexibility, for example, if needed we could move the collation step into the training + # loop by considering the ``Collate`` module as part of the model. + + batch_size = 8 + # sphinx_gallery_start_ignore + if RUN_ON_CLUSTER: + batch_size = 128 + # sphinx_gallery_end_ignore + train_dataloader = DataLoader( + train_data, + batch_size=batch_size, + num_workers=NUM_WORKERS, + ) + val_dataloader = DataLoader( + val_data, + batch_size=batch_size, + num_workers=NUM_WORKERS, + ) -############################################################################## -# We can now compare how long it takes to iterate once over the data in -# each case. The regular dataloader loads images one by one from disk, -# applies the transform sequentially and then stacks the results -# (note: we start measuring time a little after the first iteration, as -# starting the dataloader can take some time). - -total = 0 -for i, (image, target) in enumerate(train_dataloader): - if i == 3: - t0 = time.time() - if i >= 3: - total += image.shape[0] - image, target = image.to(device), target.to(device) -t = time.time() - t0 -print(f"One iteration over dataloader done! Rate: {total/t:4.4f} fps, time: {t: 4.4f}s") + train_dataloader_tc = DataLoader( # noqa: TOR401 + train_data_tc, + batch_size=batch_size, + collate_fn=Collate(collate_transform, device), + ) + val_dataloader_tc = DataLoader( # noqa: TOR401 + val_data_tc, + batch_size=batch_size, + collate_fn=Collate(device=device), + ) -############################################################################## -# Our tensorclass-based dataloader instead loads data from the -# memory-mapped tensor in batches. We then apply the batched random -# transformations to the batched images. - -total = 0 -for i, batch in enumerate(train_dataloader_tc): - if i == 3: - t0 = time.time() - if i >= 3: - total += batch.numel() - image, target = batch.images, batch.targets -t = time.time() - t0 -print( - f"One iteration over tensorclass dataloader done! Rate: {total/t:4.4f} fps, time: {t: 4.4f}s" -) + ############################################################################## + # We can now compare how long it takes to iterate once over the data in + # each case. The regular dataloader loads images one by one from disk, + # applies the transform sequentially and then stacks the results + # (note: we start measuring time a little after the first iteration, as + # starting the dataloader can take some time). + + total = 0 + for i, (image, target) in enumerate(train_dataloader): + if i == 3: + t0 = time.time() + if i >= 3: + total += image.shape[0] + image, target = image.to(device), target.to(device) + t = time.time() - t0 + print( + f"One iteration over dataloader done! Rate: {total/t:4.4f} fps, time: {t: 4.4f}s" + ) -############################################################################## -# In the case of the validation set, we see an even bigger performance -# improvement, because there are no random transformations, so we can save -# the fully transformed data in the memory-mapped tensor, eliminating the -# need for additional transformations as we load from disk. - -total = 0 -for i, (image, target) in enumerate(val_dataloader): - if i == 3: - t0 = time.time() - if i >= 3: - total += image.shape[0] - image, target = image.to(device), target.to(device) -t = time.time() - t0 -print(f"One iteration over val data done! Rate: {total/t:4.4f} fps, time: {t: 4.4f}s") - -total = 0 -for i, batch in enumerate(val_dataloader_tc): - if i == 3: - t0 = time.time() - if i >= 3: - total += batch.shape[0] - image, target = batch.images.contiguous().to(device), batch.targets.contiguous().to( - device + ############################################################################## + # Our tensorclass-based dataloader instead loads data from the + # memory-mapped tensor in batches. We then apply the batched random + # transformations to the batched images. + + total = 0 + for i, batch in enumerate(train_dataloader_tc): + if i == 3: + t0 = time.time() + if i >= 3: + total += batch.numel() + image, target = batch.images, batch.targets + t = time.time() - t0 + print( + f"One iteration over tensorclass dataloader done! Rate: {total/t:4.4f} fps, time: {t: 4.4f}s" ) -t = time.time() - t0 -print( - f"One iteration over tensorclass val data done! Rate: {total/t:4.4f} fps, time: {t: 4.4f}s" -) -############################################################################## -# Results from ImageNet -# --------------------- -# -# We repeated the above on full-size ImageNet data, running on an AWS EC2 instance with -# 32 cores and 1 A100 GPU. We compare against the regular ``DataLoader`` with different -# numbers of workers. We found that our single-threaded TensorClass approach -# out-performed the ``DataLoader`` even when we used a large number of workers. -# -# .. image:: /reference/generated/tutorials/media/imagenet-benchmark-time.png -# :alt: Bar chart showing runtimes of dataloaders compared with TensorClass -# -# .. image:: /reference/generated/tutorials/media/imagenet-benchmark-speed.png -# :alt: Bar chart showing collection rate of dataloaders compared with TensorClass + ############################################################################## + # In the case of the validation set, we see an even bigger performance + # improvement, because there are no random transformations, so we can save + # the fully transformed data in the memory-mapped tensor, eliminating the + # need for additional transformations as we load from disk. + + total = 0 + for i, (image, target) in enumerate(val_dataloader): + if i == 3: + t0 = time.time() + if i >= 3: + total += image.shape[0] + image, target = image.to(device), target.to(device) + t = time.time() - t0 + print( + f"One iteration over val data done! Rate: {total/t:4.4f} fps, time: {t: 4.4f}s" + ) + total = 0 + for i, batch in enumerate(val_dataloader_tc): + if i == 3: + t0 = time.time() + if i >= 3: + total += batch.shape[0] + image, target = batch.images.contiguous().to( + device + ), batch.targets.contiguous().to(device) + t = time.time() - t0 + print( + f"One iteration over tensorclass val data done! Rate: {total/t:4.4f} fps, time: {t: 4.4f}s" + ) -############################################################################## -# This shows that much of the overhead is coming from i/o operations rather than the -# transforms, and hence explains how the memory-mapped array helps us load data more -# efficiently. Check out the `distributed example `__ -# for more context about the other results from these charts. -# -# We can get even better performance with the TensorClass approach by using multiple -# workers to load batches from the memory-mapped array, though this comes with some -# added complexity. See `this example in our benchmarks -# `__ -# for an example of how this could work. + ############################################################################## + # Results from ImageNet + # --------------------- + # + # We repeated the above on full-size ImageNet data, running on an AWS EC2 instance with + # 32 cores and 1 A100 GPU. We compare against the regular ``DataLoader`` with different + # numbers of workers. We found that our single-threaded TensorClass approach + # out-performed the ``DataLoader`` even when we used a large number of workers. + # + # .. image:: /reference/generated/tutorials/media/imagenet-benchmark-time.png + # :alt: Bar chart showing runtimes of dataloaders compared with TensorClass + # + # .. image:: /reference/generated/tutorials/media/imagenet-benchmark-speed.png + # :alt: Bar chart showing collection rate of dataloaders compared with TensorClass + + ############################################################################## + # This shows that much of the overhead is coming from i/o operations rather than the + # transforms, and hence explains how the memory-mapped array helps us load data more + # efficiently. Check out the `distributed example `__ + # for more context about the other results from these charts. + # + # We can get even better performance with the TensorClass approach by using multiple + # workers to load batches from the memory-mapped array, though this comes with some + # added complexity. See `this example in our benchmarks + # `__ + # for an example of how this could work. diff --git a/tutorials/sphinx_tuto/tensordict_memory.py b/tutorials/sphinx_tuto/tensordict_memory.py index c4e1b5ef9..bcf25c680 100644 --- a/tutorials/sphinx_tuto/tensordict_memory.py +++ b/tutorials/sphinx_tuto/tensordict_memory.py @@ -102,9 +102,10 @@ # # Memory-mapped Tensors # --------------------- -# ``tensordict`` provides a class :class:`~.MemmapTensor` which allows us to store the -# contents of a tensor on disk, while still supporting fast indexing and loading of the -# contents in batches. See the `ImageNet Tutorial <./tensorclass_imagenet.html>`_ for an +# ``tensordict`` provides a class :class:`~tensordict.MemoryMappedTensor` +# which allows us to store the contents of a tensor on disk, while still +# supporting fast indexing and loading of the contents in batches. +# See the `ImageNet Tutorial <./tensorclass_imagenet.html>`_ for an # example of this in action. # # To convert the :class:`TensorDict` to a collection of memory-mapped tensors, use the @@ -127,8 +128,9 @@ ############################################################################## # Alternatively one can use the # :meth:`TensorDict.memmap_like ` method. This will -# create a new :class:`~.TensorDict` of the same structure with :class:`~.MemmapTensor` -# values, however it will not copy the contents of the original tensors to the +# create a new :class:`~.TensorDict` of the same structure with +# :class:`~tensordict.MemoryMappedTensor` values, however it will not copy the +# contents of the original tensors to the # memory-mapped tensors. This allows you to create the memory-mapped # :class:`~.TensorDict` and then populate it slowly, and hence should generally be # preferred to ``memmap_``.