diff --git a/docs/examples/use_cases/index.py b/docs/examples/use_cases/index.py index 71f1bab090d..526c37ecc12 100644 --- a/docs/examples/use_cases/index.py +++ b/docs/examples/use_cases/index.py @@ -4,6 +4,7 @@ "video_superres/README.rst", "pytorch/resnet50/pytorch-resnet50.rst", "pytorch/single_stage_detector/pytorch_ssd.rst", + "pytorch/efficientnet/readme.rst", "tensorflow/resnet-n/README.rst", "tensorflow/yolov4/readme.rst", "tensorflow/efficientdet/README.rst", diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py new file mode 100644 index 00000000000..2ca1f20f0b6 --- /dev/null +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvidia.dali import fn +from nvidia.dali import types + +from nvidia.dali.pipeline.experimental import pipeline_def + +from nvidia.dali.auto_aug import auto_augment, trivial_augment + + +@pipeline_def(enable_conditionals=True) +def training_pipe(data_dir, interpolation, image_size, automatic_augmentation, dali_device="gpu", + rank=0, world_size=1): + rng = fn.random.coin_flip(probability=0.5) + + jpegs, labels = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank, + num_shards=world_size, random_shuffle=True, pad_last_batch=True) + + if dali_device == "gpu": + decoder_device = "mixed" + rrc_device = "gpu" + else: + decoder_device = "cpu" + rrc_device = "cpu" + + images = fn.decoders.image(jpegs, device=decoder_device, output_type=types.RGB, + device_memory_padding=211025920, host_memory_padding=140544512) + + images = fn.random_resized_crop(images, device=rrc_device, size=[image_size, image_size], + interp_type=interpolation, + random_aspect_ratio=[0.75, 4.0 / 3.0], random_area=[0.08, 1.0], + num_attempts=100, antialias=False) + + # Make sure that from this point we are processing on GPU regardless of dali_device parameter + images = images.gpu() + + images = fn.flip(images, horizontal=rng) + + # Based on the specification, apply the automatic augmentation policy. Note, that from the point + # of Pipeline definition, this `if` statement relies on static scalar parameter, so it is + # evaluated exactly once during build - we either include automatic augmentations or not. + if automatic_augmentation == "autoaugment": + shapes = fn.peek_image_shape(jpegs) + output = auto_augment.auto_augment_image_net(images, shapes) + elif automatic_augmentation == "trivialaugment": + output = trivial_augment.trivial_augment_wide(images) + else: + output = images + + output = fn.crop_mirror_normalize(output, dtype=types.FLOAT, output_layout=types.NCHW, + crop=(image_size, image_size), + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + + return output, labels + + +@pipeline_def +def validation_pipe(data_dir, interpolation, image_size, image_crop, rank=0, world_size=1): + jpegs, label = fn.readers.file(file_root=data_dir, shard_id=rank, num_shards=world_size, + random_shuffle=False, pad_last_batch=True) + + images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB) + + images = fn.resize(images, resize_shorter=image_size, interp_type=interpolation, + antialias=False) + + output = fn.crop_mirror_normalize(images, dtype=types.FLOAT, output_layout=types.NCHW, + crop=(image_crop, image_crop), + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + return output, label diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py index 47a25862a21..ff4bf6c8832 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py @@ -30,8 +30,6 @@ import os import torch import numpy as np -import torchvision.datasets as datasets -import torchvision.transforms as transforms from PIL import Image from functools import partial @@ -44,13 +42,17 @@ import nvidia.dali.ops as ops import nvidia.dali.types as types - DATA_BACKEND_CHOICES.append("dali-gpu") - DATA_BACKEND_CHOICES.append("dali-cpu") -except ImportError: + from image_classification.dali import training_pipe, validation_pipe + + DATA_BACKEND_CHOICES.append("dali") +except ImportError as e: print( "Please install DALI from https://www.github.com/NVIDIA/DALI to run this example." ) +# TODO(klecki): Move it back again +import torchvision.datasets as datasets +import torchvision.transforms as transforms def load_jpeg_from_file(path, cuda=True): img_transforms = transforms.Compose( @@ -75,135 +77,6 @@ def load_jpeg_from_file(path, cuda=True): return input - -class HybridTrainPipe(Pipeline): - def __init__( - self, - batch_size, - num_threads, - device_id, - data_dir, - interpolation, - crop, - dali_cpu=False, - ): - super(HybridTrainPipe, self).__init__( - batch_size, num_threads, device_id, seed=12 + device_id - ) - interpolation = { - "bicubic": types.INTERP_CUBIC, - "bilinear": types.INTERP_LINEAR, - "triangular": types.INTERP_TRIANGULAR, - }[interpolation] - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - else: - rank = 0 - world_size = 1 - - self.input = ops.FileReader( - file_root=data_dir, - shard_id=rank, - num_shards=world_size, - random_shuffle=True, - pad_last_batch=True, - ) - - if dali_cpu: - dali_device = "cpu" - self.decode = ops.ImageDecoder(device=dali_device, output_type=types.RGB) - else: - dali_device = "gpu" - # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet - # without additional reallocations - self.decode = ops.ImageDecoder( - device="mixed", - output_type=types.RGB, - device_memory_padding=211025920, - host_memory_padding=140544512, - ) - - self.res = ops.RandomResizedCrop( - device=dali_device, - size=[crop, crop], - interp_type=interpolation, - random_aspect_ratio=[0.75, 4.0 / 3.0], - random_area=[0.08, 1.0], - num_attempts=100, - antialias=False, - ) - - self.cmnp = ops.CropMirrorNormalize( - device="gpu", - dtype=types.FLOAT, - output_layout=types.NCHW, - crop=(crop, crop), - mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], - std=[0.229 * 255, 0.224 * 255, 0.225 * 255], - ) - self.coin = ops.CoinFlip(probability=0.5) - - def define_graph(self): - rng = self.coin() - self.jpegs, self.labels = self.input(name="Reader") - images = self.decode(self.jpegs) - images = self.res(images) - output = self.cmnp(images.gpu(), mirror=rng) - return [output, self.labels] - - -class HybridValPipe(Pipeline): - def __init__( - self, batch_size, num_threads, device_id, data_dir, interpolation, crop, size - ): - super(HybridValPipe, self).__init__( - batch_size, num_threads, device_id, seed=12 + device_id - ) - interpolation = { - "bicubic": types.INTERP_CUBIC, - "bilinear": types.INTERP_LINEAR, - "triangular": types.INTERP_TRIANGULAR, - }[interpolation] - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - else: - rank = 0 - world_size = 1 - - self.input = ops.FileReader( - file_root=data_dir, - shard_id=rank, - num_shards=world_size, - random_shuffle=False, - pad_last_batch=True, - ) - - self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) - self.res = ops.Resize( - device="gpu", - resize_shorter=size, - interp_type=interpolation, - antialias=False, - ) - self.cmnp = ops.CropMirrorNormalize( - device="gpu", - dtype=types.FLOAT, - output_layout=types.NCHW, - crop=(crop, crop), - mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], - std=[0.229 * 255, 0.224 * 255, 0.225 * 255], - ) - - def define_graph(self): - self.jpegs, self.labels = self.input(name="Reader") - images = self.decode(self.jpegs) - images = self.res(images) - output = self.cmnp(images) - return [output, self.labels] - - class DALIWrapper(object): def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format): for data in dalipipeline: @@ -226,7 +99,7 @@ def __iter__(self): ) -def get_dali_train_loader(dali_cpu=False): +def get_dali_train_loader(dali_device="gpu"): def gdtl( data_path, image_size, @@ -234,7 +107,7 @@ def gdtl( num_classes, one_hot, interpolation="bilinear", - augmentation=None, + augmentation="disabled", start_epoch=0, workers=5, _worker_init_fn=None, @@ -248,21 +121,24 @@ def gdtl( rank = 0 world_size = 1 + interpolation = { + "bicubic": types.INTERP_CUBIC, + "bilinear": types.INTERP_LINEAR, + "triangular": types.INTERP_TRIANGULAR, + }[interpolation] + traindir = os.path.join(data_path, "train") - if augmentation is not None: - raise NotImplementedError( - f"Augmentation {augmentation} for dali loader is not supported" - ) - - pipe = HybridTrainPipe( - batch_size=batch_size, - num_threads=workers, - device_id=rank % torch.cuda.device_count(), - data_dir=traindir, - interpolation=interpolation, - crop=image_size, - dali_cpu=dali_cpu, - ) + + pipeline_kwargs = { + "batch_size" : batch_size, + "num_threads" : workers, + "device_id" : rank % torch.cuda.device_count(), + "seed": 12 + rank % torch.cuda.device_count(), + } + + pipe = training_pipe(data_dir=traindir, interpolation=interpolation, image_size=image_size, + dali_device=dali_device, rank=rank, world_size=world_size, + **pipeline_kwargs) pipe.build() train_loader = DALIClassificationIterator( @@ -298,17 +174,24 @@ def gdvl( rank = 0 world_size = 1 + interpolation = { + "bicubic": types.INTERP_CUBIC, + "bilinear": types.INTERP_LINEAR, + "triangular": types.INTERP_TRIANGULAR, + }[interpolation] + valdir = os.path.join(data_path, "val") - pipe = HybridValPipe( - batch_size=batch_size, - num_threads=workers, - device_id=rank % torch.cuda.device_count(), - data_dir=valdir, - interpolation=interpolation, - crop=image_size, - size=image_size + crop_padding, - ) + pipeline_kwargs = { + "batch_size" : batch_size, + "num_threads" : workers, + "device_id" : rank % torch.cuda.device_count(), + "seed": 12 + rank % torch.cuda.device_count(), + } + + pipe = validation_pipe(data_dir=valdir, interpolation=interpolation, + image_size=image_size + crop_padding, image_crop=image_size, + **pipeline_kwargs) pipe.build() val_loader = DALIClassificationIterator( @@ -430,8 +313,13 @@ def get_pytorch_train_loader( transforms.RandomResizedCrop(image_size, interpolation=interpolation), transforms.RandomHorizontalFlip(), ] - if augmentation == "autoaugment": + if augmentation == "disabled": + pass + elif augmentation == "autoaugment": transforms_list.append(AutoaugmentImageNetPolicy()) + else: + raise NotImplementedError(f"Automatic augmentation: '{augmentation}' is not supported" + " for PyTorch data loader.") train_dataset = datasets.ImageFolder(traindir, transforms.Compose(transforms_list)) if torch.distributed.is_initialized(): diff --git a/docs/examples/use_cases/pytorch/efficientnet/main.py b/docs/examples/use_cases/pytorch/efficientnet/main.py index 2c3bff565e1..64fdc87eb76 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/main.py +++ b/docs/examples/use_cases/pytorch/efficientnet/main.py @@ -72,11 +72,11 @@ def add_parser_arguments(parser, skip_arch=False): parser.add_argument( "--data-backend", metavar="BACKEND", - default="dali-cpu", + default="dali", choices=DATA_BACKEND_CHOICES, help="data backend: " + " | ".join(DATA_BACKEND_CHOICES) - + " (default: dali-cpu)", + + " (default: dali)", ) parser.add_argument( "--interpolation", @@ -111,7 +111,14 @@ def add_parser_arguments(parser, skip_arch=False): default=2, type=int, metavar="N", - help="number of samples prefetched by each loader", + help="number of samples prefetched by each loader (PyTorch only)", + ) + parser.add_argument( + "--dali-device", + default="gpu", + type=str, + choices=["cpu", "gpu"], + help=("The placement of DALI decode and random resized crop operations (default: gpu)"), ) parser.add_argument( "--epochs", @@ -309,17 +316,17 @@ def add_parser_arguments(parser, skip_arch=False): parser.add_argument( "--memory-format", type=str, - default="nchw", + default="nhwc", choices=["nchw", "nhwc"], help="memory layout, nchw or nhwc", ) parser.add_argument("--use-ema", default=None, type=float, help="use EMA") parser.add_argument( - "--augmentation", + "--automatic-augmentation", type=str, - default=None, - choices=[None, "autoaugment"], - help="augmentation method", + default="autoaugment", + choices=["disabled", "autoaugment", "trivialaugment"], + help="Automatic augmentation method, trivialaugment is supported only for DALI data backend", ) parser.add_argument( @@ -480,11 +487,8 @@ def _worker_init_fn(id): args.workers = args.workers * 2 get_train_loader = get_pytorch_train_loader get_val_loader = get_pytorch_val_loader - elif args.data_backend == "dali-gpu": - get_train_loader = get_dali_train_loader(dali_cpu=False) - get_val_loader = get_dali_val_loader() - elif args.data_backend == "dali-cpu": - get_train_loader = get_dali_train_loader(dali_cpu=True) + elif args.data_backend == "dali": + get_train_loader = get_dali_train_loader(dali_device=args.dali_device) get_val_loader = get_dali_val_loader() elif args.data_backend == "synthetic": get_val_loader = get_synthetic_loader @@ -500,7 +504,7 @@ def _worker_init_fn(id): model_args.num_classes, args.mixup > 0.0, interpolation=args.interpolation, - augmentation=args.augmentation, + augmentation=args.automatic_augmentation, start_epoch=start_epoch, workers=args.workers, _worker_init_fn=_worker_init_fn, diff --git a/docs/examples/use_cases/pytorch/efficientnet/readme.rst b/docs/examples/use_cases/pytorch/efficientnet/readme.rst new file mode 100644 index 00000000000..3059f298235 --- /dev/null +++ b/docs/examples/use_cases/pytorch/efficientnet/readme.rst @@ -0,0 +1,172 @@ +EfficientNet for PyTorch with DALI and AutoAugment +================================================== + +This example shows how DALI's implementation of automatic augmentations - most notably `AutoAugment `_ and `TrivialAugment `_ - can be used in training. It shows the training of EfficientNet, an image classification model first described in `EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks `_. + +The code is based on `NVIDIA Deep Learning Examples `_ - it has been extended with DALI pipeline supporting automatic augmentations, which can be found in :fileref:`here `. + + +Differences to the Deep Learning Examples configuration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* The default values of the parameters were adjusted to values used in EfficientNet training. +* ``--data-backend`` parameter was changed to accept ``dali``, ``pytorch``, or ``synthetic``. It is set to ``dali`` by default. +* ``--dali-device`` was added to control placement of some of DALI operators. +* ``--augmentation`` was replaced with ``--automatic-augmentation``, now supporting ``disabled``, ``autoaugment``, and ``trivialaugment`` values. +* ``--workers`` defaults were halved to accommodate DALI. The value is automatically doubled when ``pytorch`` data loader is used. +* the model is restricted to EfficientNet-B0 architecture. + + +Data backends +^^^^^^^^^^^^^ + +This model uses the following data augmentation: + +* For training: + + * Random resized crop to target images size (in this case 224) + + * Scale from 8% to 100% + * Aspect ratio from 3/4 to 4/3 + + * Random horizontal flip + * [Optional: AutoAugment or TrivialAugment] + * Normalization + +* For inference: + + * Scale to target image size + 32 + * Center crop to target image size + * Normalization + + + +Requirements +^^^^^^^^^^^^ + +The EfficientNet script operates on ImageNet 1k, a widely popular image classification dataset from the ILSVRC challenge. + +1. Download the dataset from http://image-net.org/download-images + +2. Extract the training data: + +.. code-block:: bash + + mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train + tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar + find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done + cd .. + +3. Extract the validation data and move the images to subfolders: + +.. code-block:: bash + + mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar + wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash + +The directory in which the ``train/`` and ``val/`` directories are placed, is referred to as ``$PATH_TO_IMAGENET`` in this document. + +4. Make sure you are either using the `NVIDIA PyTorch NGC container `_ or you have `DALI `_ and `PyTorch `_ installed. + +5. Install `NVIDIA DLLogger `_ and `pynvml `_. + + +Running the model +^^^^^^^^^^^^^^^^^ + +Training +-------- + +To run training on 1GPU, use the ``main.py`` entry point: + +* For FP32: ``python ./main.py $PATH_TO_IMAGENET`` +* For AMP: ``python ./main.py --amp --static-loss-scale 128 $PATH_TO_IMAGENET`` + +You may need to adjust ``--batch-size`` parameter for your machine. + +You can change the data loader and automatic augmentation scheme that are used by adding: + +* ``--data-backend``: dali | pytorch | synthetic, +* ``--automatic-augmentation``: disabled | autoaugment | trivialaugment (the last one only for DALI), +* ``--dali-device``: cpu | gpu (only for DALI). + +By default DALI GPU-variant with AutoAugment is used. + +For example to run the EfficientNet with AMP on a batch size of 128 with DALI using TrivialAugment you need to invoke: + +.. code-block:: bash + + python ./main.py --amp --static-loss-scale 128 --batch-size 128 --data-backend dali --automatic-augmentation trivialaugment $PATH_TO_IMAGENET + +To run on multiple GPUs, use the ``multiproc.py`` to launch the ``main.py`` entry point script, passing the number of GPUs as ``--nproc_per_node`` argument. For example, to run the model on 8 GPUs using AMP and DALI with AutoAugment you need to invoke: + +.. code-block:: bash + + python ./multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --data-backend dali --automatic-augmentation autoaugment $PATH_TO_IMAGENET + +To see the full list of available options and their descriptions, use the ``-h`` or ``--help`` command-line option, for example: + +.. code-block:: bash + + python main.py -h + + +Training with standard configuration +------------------------------------ + +To run the training in a standard configuration (DGX A100/DGX-1V, AMP, 400 Epochs, DALI with AutoAugment) invoke the following command: + +* for DGX1V-16G: ``python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 $PATH_TO_IMAGENET`` + +* for DGX-A100: ``python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 256 $PATH_TO_IMAGENET``` + +Benchmarking +------------ + +To run training benchmarks with different data loaders and automatic augmentations, you can use following commands, assuming that they are running on DGX1V-16G with 8 GPUs, 128 batch size and AMP: + +.. code-block:: bash + + # Adjust the following variable to control where to store the results of the benchmark runs + export RESULT_WORKSPACE=./ + + # synthetic benchmark + python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 1 --prof 100 --no-checkpoints --training-only --data-backend synthetic --workspace $RESULT_WORKSPACE --raport-file bench_report_synthetic.json $PATH_TO_IMAGENET + + # DALI without automatic augmentations + python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 4 --no-checkpoints --training-only --data-backend dali --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --raport-file bench_report_dali.json $PATH_TO_IMAGENET + + # DALI with AutoAugment + python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 4 --no-checkpoints --training-only --data-backend dali --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --raport-file bench_report_dali_aa.json $PATH_TO_IMAGENET + + # DALI with TrivialAugment + python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 4 --no-checkpoints --training-only --data-backend dali --automatic-augmentation trivialaugment --workspace $RESULT_WORKSPACE --raport-file bench_report_dali_ta.json $PATH_TO_IMAGENET + + # PyTorch without automatic augmentations + python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 4 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --raport-file bench_report_pytorch.json $PATH_TO_IMAGENET + + # PyTorch with AutoAugment: + python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 4 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --raport-file bench_report_pytorch_aa.json $PATH_TO_IMAGENET + + +Inference +--------- + +Validation is done every epoch, and can be also run separately on a checkpointed model. + +.. code-block:: bash + + python ./main.py --evaluate --epochs 1 --resume -b $PATH_TO_IMAGENET + +To run inference on JPEG image, you have to first extract the model weights from checkpoint: + +.. code-block:: bash + + python checkpoint2model.py --checkpoint-path --weight-path + +Then, run the classification script: + +.. code-block:: bash + + python classify.py --pretrained-from-file --precision AMP|FP32 --image +