From 4d0845365a43f9df73c8bf272516051075f8751b Mon Sep 17 00:00:00 2001 From: Loki Date: Tue, 23 Aug 2022 22:49:14 +0000 Subject: [PATCH] Adding new CV notebook for distributed training with PT 1.11 --- ...guage-modeling-multi-gpu-single-node.ipynb | 0 .../scripts/launch_pt_dt_sm_native.py | 0 .../scripts/launch_sm_training_compiler.py | 0 .../scripts/run_clm.py | 0 .../scripts/run_mlm.py | 0 .../scripts/requirements.txt | 1 + .../vision_transformer/scripts/run_mae.py | 390 +++++++++++++ .../vision_transformer/scripts/run_mim.py | 472 +++++++++++++++ .../vision-transformer-p4-fp32.ipynb | 542 ++++++++++++++++++ 9 files changed, 1405 insertions(+) rename sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/{ => language-modeling}/language-modeling-multi-gpu-single-node.ipynb (100%) rename sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/{ => language-modeling}/scripts/launch_pt_dt_sm_native.py (100%) rename sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/{ => language-modeling}/scripts/launch_sm_training_compiler.py (100%) rename sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/{ => language-modeling}/scripts/run_clm.py (100%) rename sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/{ => language-modeling}/scripts/run_mlm.py (100%) create mode 100644 sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/requirements.txt create mode 100644 sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/run_mae.py create mode 100644 sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/run_mim.py create mode 100644 sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/vision-transformer-p4-fp32.ipynb diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling-multi-gpu-single-node.ipynb b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/language-modeling-multi-gpu-single-node.ipynb similarity index 100% rename from sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling-multi-gpu-single-node.ipynb rename to sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/language-modeling-multi-gpu-single-node.ipynb diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/scripts/launch_pt_dt_sm_native.py b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/scripts/launch_pt_dt_sm_native.py similarity index 100% rename from sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/scripts/launch_pt_dt_sm_native.py rename to sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/scripts/launch_pt_dt_sm_native.py diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/scripts/launch_sm_training_compiler.py b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/scripts/launch_sm_training_compiler.py similarity index 100% rename from sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/scripts/launch_sm_training_compiler.py rename to sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/scripts/launch_sm_training_compiler.py diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/scripts/run_clm.py b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/scripts/run_clm.py similarity index 100% rename from sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/scripts/run_clm.py rename to sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/scripts/run_clm.py diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/scripts/run_mlm.py b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/scripts/run_mlm.py similarity index 100% rename from sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/scripts/run_mlm.py rename to sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/language-modeling/scripts/run_mlm.py diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/requirements.txt b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/requirements.txt new file mode 100644 index 0000000000..2e6ab725a9 --- /dev/null +++ b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/requirements.txt @@ -0,0 +1 @@ +accelerate \ No newline at end of file diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/run_mae.py b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/run_mae.py new file mode 100644 index 0000000000..2ef182d6a2 --- /dev/null +++ b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/run_mae.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor +from torchvision.transforms.functional import InterpolationMode + +import transformers +from transformers import ( + HfArgumentParser, + Trainer, + TrainingArguments, + ViTFeatureExtractor, + ViTMAEConfig, + ViTMAEForPreTraining, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + + +""" Pre-training a 🤗 ViT model as an MAE (masked autoencoder), as proposed in https://arxiv.org/abs/2111.06377.""" + +logger = logging.getLogger(__name__) + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.21.0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + dataset_name: Optional[str] = field( + default="cifar10", metadata={"help": "Name of a dataset from the datasets package"} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + image_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of the images in the files."} + ) + train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."}) + validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."}) + train_val_split: Optional[float] = field( + default=0.15, metadata={"help": "Percent to split off of train for validation."} + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + + def __post_init__(self): + data_files = dict() + if self.train_dir is not None: + data_files["train"] = self.train_dir + if self.validation_dir is not None: + data_files["val"] = self.validation_dir + self.data_files = data_files if data_files else None + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/feature extractor we are going to pre-train. + """ + + model_name_or_path: str = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." + ) + }, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name_or_path"} + ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + ) + }, + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."}) + use_auth_token: bool = field( + default=False, + metadata={ + "help": ( + "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + ) + }, + ) + mask_ratio: float = field( + default=0.75, metadata={"help": "The ratio of the number of masked tokens in the input sequence."} + ) + norm_pix_loss: bool = field( + default=True, metadata={"help": "Whether or not to train with normalized pixel values as target."} + ) + + +@dataclass +class CustomTrainingArguments(TrainingArguments): + base_learning_rate: float = field( + default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."} + ) + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + return {"pixel_values": pixel_values} + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_mae", model_args, data_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Initialize our dataset. + ds = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + data_files=data_args.data_files, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # If we don't have a validation split, split off a percentage of train as validation. + data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split + if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: + split = ds["train"].train_test_split(data_args.train_val_split) + ds["train"] = split["train"] + ds["validation"] = split["test"] + + # Load pretrained model and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if model_args.config_name: + config = ViTMAEConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = ViTMAEConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + config = ViTMAEConfig() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") + + # adapt config + config.update( + { + "mask_ratio": model_args.mask_ratio, + "norm_pix_loss": model_args.norm_pix_loss, + } + ) + + # create feature extractor + if model_args.feature_extractor_name: + feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs) + elif model_args.model_name_or_path: + feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + feature_extractor = ViTFeatureExtractor() + + # create model + if model_args.model_name_or_path: + model = ViTMAEForPreTraining.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + else: + logger.info("Training new model from scratch") + model = ViTMAEForPreTraining(config) + + if training_args.do_train: + column_names = ds["train"].column_names + else: + column_names = ds["validation"].column_names + + if data_args.image_column_name is not None: + image_column_name = data_args.image_column_name + elif "image" in column_names: + image_column_name = "image" + elif "img" in column_names: + image_column_name = "img" + else: + image_column_name = column_names[0] + + # transformations as done in original MAE paper + # source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py + transforms = Compose( + [ + Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + RandomResizedCrop(feature_extractor.size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC), + RandomHorizontalFlip(), + ToTensor(), + Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std), + ] + ) + + def preprocess_images(examples): + """Preprocess a batch of images by applying transforms.""" + + examples["pixel_values"] = [transforms(image) for image in examples[image_column_name]] + return examples + + if training_args.do_train: + if "train" not in ds: + raise ValueError("--do_train requires a train dataset") + if data_args.max_train_samples is not None: + ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) + # Set the training transforms + ds["train"].set_transform(preprocess_images) + + if training_args.do_eval: + if "validation" not in ds: + raise ValueError("--do_eval requires a validation dataset") + if data_args.max_eval_samples is not None: + ds["validation"] = ( + ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) + ) + # Set the validation transforms + ds["validation"].set_transform(preprocess_images) + + # Compute absolute learning rate + total_train_batch_size = ( + training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size + ) + if training_args.base_learning_rate is not None: + training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256 + + # Initialize our trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ds["train"] if training_args.do_train else None, + eval_dataset=ds["validation"] if training_args.do_eval else None, + tokenizer=feature_extractor, + data_collator=collate_fn, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Write model card and (optionally) push to hub + kwargs = { + "tasks": "masked-auto-encoding", + "dataset": data_args.dataset_name, + "tags": ["masked-auto-encoding"], + } + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +if __name__ == "__main__": + main() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() \ No newline at end of file diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/run_mim.py b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/run_mim.py new file mode 100644 index 0000000000..e4f8b84af2 --- /dev/null +++ b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/scripts/run_mim.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from datasets import load_dataset +from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor + +import transformers +from transformers import ( + CONFIG_MAPPING, + FEATURE_EXTRACTOR_MAPPING, + MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, + AutoConfig, + AutoFeatureExtractor, + AutoModelForMaskedImageModeling, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + + +""" Pre-training a 🤗 Transformers model for simple masked image modeling (SimMIM). +Any model supported by the AutoModelForMaskedImageModeling API can be used. +""" + +logger = logging.getLogger(__name__) + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.21.0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + Using `HfArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + dataset_name: Optional[str] = field( + default="cifar10", metadata={"help": "Name of a dataset from the datasets package"} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + image_column_name: Optional[str] = field( + default=None, + metadata={"help": "The column name of the images in the files. If not set, will try to use 'image' or 'img'."}, + ) + train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."}) + validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."}) + train_val_split: Optional[float] = field( + default=0.15, metadata={"help": "Percent to split off of train for validation."} + ) + mask_patch_size: int = field(default=32, metadata={"help": "The size of the square patches to use for masking."}) + mask_ratio: float = field( + default=0.6, + metadata={"help": "Percentage of patches to mask."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + + def __post_init__(self): + data_files = dict() + if self.train_dir is not None: + data_files["train"] = self.train_dir + if self.validation_dir is not None: + data_files["val"] = self.validation_dir + self.data_files = data_files if data_files else None + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/feature extractor we are going to pre-train. + """ + + model_name_or_path: str = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a " + "checkpoint identifier on the hub. " + "Don't set if you want to train a model from scratch." + ) + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + ) + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store (cache) the pretrained models/datasets downloaded from the hub"}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."}) + use_auth_token: bool = field( + default=False, + metadata={ + "help": ( + "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + ) + }, + ) + image_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The size (resolution) of each image. If not specified, will use `image_size` of the configuration." + ) + }, + ) + patch_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration." + ) + }, + ) + encoder_stride: Optional[int] = field( + default=None, + metadata={"help": "Stride to use for the encoder."}, + ) + + +class MaskGenerator: + """ + A class to generate boolean masks for the pretraining task. + A mask is a 1D tensor of shape (model_patch_size**2,) where the value is either 0 or 1, + where 1 indicates "masked". + """ + + def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + if self.input_size % self.mask_patch_size != 0: + raise ValueError("Input size must be divisible by mask patch size") + if self.mask_patch_size % self.model_patch_size != 0: + raise ValueError("Mask patch size must be divisible by model patch size") + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size**2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def __call__(self): + mask_idx = np.random.permutation(self.token_count)[: self.mask_count] + mask = np.zeros(self.token_count, dtype=int) + mask[mask_idx] = 1 + + mask = mask.reshape((self.rand_size, self.rand_size)) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + + return torch.tensor(mask.flatten()) + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + mask = torch.stack([example["mask"] for example in examples]) + return {"pixel_values": pixel_values, "bool_masked_pos": mask} + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_mim", model_args, data_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Initialize our dataset. + ds = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + data_files=data_args.data_files, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # If we don't have a validation split, split off a percentage of train as validation. + data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split + if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: + split = ds["train"].train_test_split(data_args.train_val_split) + ds["train"] = split["train"] + ds["validation"] = split["test"] + + # Create config + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if model_args.config_name_or_path: + config = AutoConfig.from_pretrained(model_args.config_name_or_path, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") + + # make sure the decoder_type is "simmim" (only relevant for BEiT) + if hasattr(config, "decoder_type"): + config.decoder_type = "simmim" + + # adapt config + model_args.image_size = model_args.image_size if model_args.image_size is not None else config.image_size + model_args.patch_size = model_args.patch_size if model_args.patch_size is not None else config.patch_size + model_args.encoder_stride = ( + model_args.encoder_stride if model_args.encoder_stride is not None else config.encoder_stride + ) + + config.update( + { + "image_size": model_args.image_size, + "patch_size": model_args.patch_size, + "encoder_stride": model_args.encoder_stride, + } + ) + + # create feature extractor + if model_args.feature_extractor_name: + feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs) + elif model_args.model_name_or_path: + feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + FEATURE_EXTRACTOR_TYPES = { + conf.model_type: feature_extractor_class + for conf, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items() + } + feature_extractor = FEATURE_EXTRACTOR_TYPES[model_args.model_type]() + + # create model + if model_args.model_name_or_path: + model = AutoModelForMaskedImageModeling.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForMaskedImageModeling.from_config(config) + + if training_args.do_train: + column_names = ds["train"].column_names + else: + column_names = ds["validation"].column_names + + if data_args.image_column_name is not None: + image_column_name = data_args.image_column_name + elif "image" in column_names: + image_column_name = "image" + elif "img" in column_names: + image_column_name = "img" + else: + image_column_name = column_names[0] + + # transformations as done in original SimMIM paper + # source: https://github.com/microsoft/SimMIM/blob/main/data/data_simmim.py + transforms = Compose( + [ + Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + RandomResizedCrop(model_args.image_size, scale=(0.67, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), + RandomHorizontalFlip(), + ToTensor(), + Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std), + ] + ) + + # create mask generator + mask_generator = MaskGenerator( + input_size=model_args.image_size, + mask_patch_size=data_args.mask_patch_size, + model_patch_size=model_args.patch_size, + mask_ratio=data_args.mask_ratio, + ) + + def preprocess_images(examples): + """Preprocess a batch of images by applying transforms + creating a corresponding mask, indicating + which patches to mask.""" + + examples["pixel_values"] = [transforms(image) for image in examples[image_column_name]] + examples["mask"] = [mask_generator() for i in range(len(examples[image_column_name]))] + + return examples + + if training_args.do_train: + if "train" not in ds: + raise ValueError("--do_train requires a train dataset") + if data_args.max_train_samples is not None: + ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) + # Set the training transforms + ds["train"].set_transform(preprocess_images) + + if training_args.do_eval: + if "validation" not in ds: + raise ValueError("--do_eval requires a validation dataset") + if data_args.max_eval_samples is not None: + ds["validation"] = ( + ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) + ) + # Set the validation transforms + ds["validation"].set_transform(preprocess_images) + + # Initialize our trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ds["train"] if training_args.do_train else None, + eval_dataset=ds["validation"] if training_args.do_eval else None, + tokenizer=feature_extractor, + data_collator=collate_fn, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Write model card and (optionally) push to hub + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "tasks": "masked-image-modeling", + "dataset": data_args.dataset_name, + "tags": ["masked-image-modeling"], + } + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +if __name__ == "__main__": + main() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() \ No newline at end of file diff --git a/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/vision-transformer-p4-fp32.ipynb b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/vision-transformer-p4-fp32.ipynb new file mode 100644 index 0000000000..98d748ea49 --- /dev/null +++ b/sagemaker-training-compiler/huggingface/pytorch_multiple_gpu_single_node/vision_transformer/vision-transformer-p4-fp32.ipynb @@ -0,0 +1,542 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compile and Train a Vision Transformer Model on the MNIST Dataset using Multi Node Distributed Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. [Introduction](#Introduction) \n", + "2. [Development Environment and Permissions](#Development-Environment-and-Permissions)\n", + " 1. [Installation](#Installation) \n", + " 2. [SageMaker environment](#SageMaker-environment)\n", + "3. [Processing](#Preprocessing) \n", + " 1. [Tokenization](#Tokenization) \n", + " 2. [Uploading data to sagemaker_session_bucket](#Uploading-data-to-sagemaker_session_bucket) \n", + "4. [SageMaker Training Job](#SageMaker-Training-Job) \n", + " 1. [Training with Native PyTorch](#Training-with-Native-PyTorch) \n", + " 2. [Training with Optimized PyTorch](#Training-with-Optimized-PyTorch) \n", + " 3. [Analysis](#Analysis) \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SageMaker Training Compiler Overview\n", + "\n", + "SageMaker Training Compiler is a capability of SageMaker that makes these hard-to-implement optimizations to reduce training time on GPU instances. The compiler optimizes DL models to accelerate training by more efficiently using SageMaker machine learning (ML) GPU instances. SageMaker Training Compiler is available at no additional charge within SageMaker and can help reduce total billable time as it accelerates training. \n", + "\n", + "SageMaker Training Compiler is integrated into the AWS Deep Learning Containers (DLCs). Using the SageMaker Training Compiler enabled AWS DLCs, you can compile and optimize training jobs on GPU instances with minimal changes to your code. Bring your deep learning models to SageMaker and enable SageMaker Training Compiler to accelerate the speed of your training job on SageMaker ML instances for accelerated computing. \n", + "\n", + "For more information, see [SageMaker Training Compiler](https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html) in the *Amazon SageMaker Developer Guide*.\n", + "\n", + "## Introduction\n", + "\n", + "In this demo, you'll use Hugging Face's `transformers` and `datasets` libraries with Amazon SageMaker Training Compiler to train the `RoBERTa` model on the `Stanford Sentiment Treebank v2 (SST2)` dataset. To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on. \n", + "\n", + "**NOTE:** You can run this demo in SageMaker Studio, SageMaker notebook instances, or your local machine with AWS CLI set up. If using SageMaker Studio or SageMaker notebook instances, make sure you choose one of the PyTorch-based kernels, `Python 3 (PyTorch x.y Python 3.x CPU Optimized)` or `conda_pytorch_p36` respectively.\n", + "\n", + "**NOTE:** This notebook uses two `ml.p3.2xlarge` instances that have single GPU. If you don't have enough quota, see [Request a service quota increase for SageMaker resources](https://docs.aws.amazon.com/sagemaker/latest/dg/regions-quotas.html#service-limit-increase-request-procedure). " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Development Environment " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "This example notebook requires the **SageMaker Python SDK v2.70.0** and **transformers v4.11.0**." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n", + "Requirement already satisfied: sagemaker in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (2.103.0)\n", + "Requirement already satisfied: botocore in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (1.24.19)\n", + "Collecting botocore\n", + " Downloading botocore-1.27.52-py3-none-any.whl (9.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.0/9.0 MB\u001b[0m \u001b[31m84.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: boto3 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (1.24.48)\n", + "Collecting boto3\n", + " Downloading boto3-1.24.52-py3-none-any.whl (132 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.5/132.5 KB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: awscli in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (1.25.48)\n", + "Collecting awscli\n", + " Downloading awscli-1.25.52-py3-none-any.whl (3.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.9/3.9 MB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: attrs<22,>=20.3.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (21.2.0)\n", + "Requirement already satisfied: numpy<2.0,>=1.9.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (1.21.2)\n", + "Requirement already satisfied: smdebug-rulesconfig==1.0.1 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (1.0.1)\n", + "Requirement already satisfied: pathos in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (0.2.8)\n", + "Requirement already satisfied: google-pasta in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (0.2.0)\n", + "Requirement already satisfied: pandas in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (1.3.4)\n", + "Requirement already satisfied: protobuf3-to-dict<1.0,>=0.1.5 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (0.1.5)\n", + "Requirement already satisfied: importlib-metadata<5.0,>=1.4.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (4.8.2)\n", + "Requirement already satisfied: protobuf<4.0,>=3.1 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (3.19.4)\n", + "Requirement already satisfied: packaging>=20.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sagemaker) (21.3)\n", + "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from botocore) (2.8.2)\n", + "Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from botocore) (1.26.8)\n", + "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from botocore) (0.10.0)\n", + "Requirement already satisfied: s3transfer<0.7.0,>=0.6.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from boto3) (0.6.0)\n", + "Requirement already satisfied: colorama<0.4.5,>=0.2.5 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from awscli) (0.4.3)\n", + "Requirement already satisfied: docutils<0.17,>=0.10 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from awscli) (0.15.2)\n", + "Requirement already satisfied: PyYAML<5.5,>=3.10 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from awscli) (5.4.1)\n", + "Requirement already satisfied: rsa<4.8,>=3.1.2 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from awscli) (4.7.2)\n", + "Requirement already satisfied: zipp>=0.5 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from importlib-metadata<5.0,>=1.4.0->sagemaker) (3.6.0)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from packaging>=20.0->sagemaker) (3.0.6)\n", + "Requirement already satisfied: six in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from protobuf3-to-dict<1.0,>=0.1.5->sagemaker) (1.16.0)\n", + "Requirement already satisfied: pyasn1>=0.1.3 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from rsa<4.8,>=3.1.2->awscli) (0.4.8)\n", + "Requirement already satisfied: pytz>=2017.3 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from pandas->sagemaker) (2021.3)\n", + "Requirement already satisfied: dill>=0.3.4 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from pathos->sagemaker) (0.3.4)\n", + "Requirement already satisfied: multiprocess>=0.70.12 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from pathos->sagemaker) (0.70.12.2)\n", + "Requirement already satisfied: pox>=0.3.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from pathos->sagemaker) (0.3.0)\n", + "Requirement already satisfied: ppft>=1.6.6.4 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from pathos->sagemaker) (1.6.6.4)\n", + "Installing collected packages: botocore, boto3, awscli\n", + " Attempting uninstall: botocore\n", + " Found existing installation: botocore 1.24.19\n", + " Uninstalling botocore-1.24.19:\n", + " Successfully uninstalled botocore-1.24.19\n", + " Attempting uninstall: boto3\n", + " Found existing installation: boto3 1.24.48\n", + " Uninstalling boto3-1.24.48:\n", + " Successfully uninstalled boto3-1.24.48\n", + " Attempting uninstall: awscli\n", + " Found existing installation: awscli 1.25.48\n", + " Uninstalling awscli-1.25.48:\n", + " Successfully uninstalled awscli-1.25.48\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "aiobotocore 2.0.1 requires botocore<1.22.9,>=1.22.8, but you have botocore 1.27.52 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed awscli-1.25.52 boto3-1.24.52 botocore-1.27.52\n", + "\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.2.2 is available.\n", + "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p38/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install sagemaker botocore boto3 awscli --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n", + "Collecting transformers\n", + " Downloading transformers-4.21.1-py3-none-any.whl (4.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m52.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting datasets\n", + " Downloading datasets-2.4.0-py3-none-any.whl (365 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m365.7/365.7 KB\u001b[0m \u001b[31m68.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from transformers) (21.3)\n", + "Requirement already satisfied: requests in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from transformers) (2.26.0)\n", + "Requirement already satisfied: numpy>=1.17 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from transformers) (1.21.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from transformers) (2021.11.10)\n", + "Collecting tokenizers!=0.11.3,<0.13,>=0.11.1\n", + " Downloading tokenizers-0.12.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.6/6.6 MB\u001b[0m \u001b[31m49.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from transformers) (4.62.3)\n", + "Requirement already satisfied: pyyaml>=5.1 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from transformers) (5.4.1)\n", + "Collecting huggingface-hub<1.0,>=0.1.0\n", + " Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.5/101.5 KB\u001b[0m \u001b[31m23.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: filelock in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from transformers) (3.4.0)\n", + "Requirement already satisfied: multiprocess in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from datasets) (0.70.12.2)\n", + "Requirement already satisfied: aiohttp in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from datasets) (3.8.1)\n", + "Collecting responses<0.19\n", + " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", + "Requirement already satisfied: pandas in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from datasets) (1.3.4)\n", + "Requirement already satisfied: fsspec[http]>=2021.11.1 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from datasets) (2021.11.1)\n", + "Requirement already satisfied: pyarrow>=6.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from datasets) (7.0.0)\n", + "Requirement already satisfied: dill<0.3.6 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from datasets) (0.3.4)\n", + "Collecting xxhash\n", + " Downloading xxhash-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.1/212.1 KB\u001b[0m \u001b[31m52.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.0.0)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from packaging>=20.0->transformers) (3.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from requests->transformers) (2021.10.8)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from requests->transformers) (2.0.7)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from requests->transformers) (1.26.8)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from requests->transformers) (3.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from aiohttp->datasets) (1.2.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from aiohttp->datasets) (21.2.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from aiohttp->datasets) (5.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from aiohttp->datasets) (1.2.0)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from aiohttp->datasets) (1.7.2)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2017.3 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from pandas->datasets) (2021.3)\n", + "Requirement already satisfied: six>=1.5 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.16.0)\n", + "Installing collected packages: tokenizers, xxhash, responses, huggingface-hub, transformers, datasets\n", + "Successfully installed datasets-2.4.0 huggingface-hub-0.8.1 responses-0.18.0 tokenizers-0.12.1 transformers-4.21.1 xxhash-3.0.0\n", + "\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.2.2 is available.\n", + "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p38/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -U transformers datasets --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sagemaker: 2.103.0\n", + "transformers: 4.21.1\n" + ] + } + ], + "source": [ + "import botocore\n", + "import boto3\n", + "import sagemaker\n", + "import transformers\n", + "import pandas as pd\n", + "\n", + "print(f\"sagemaker: {sagemaker.__version__}\")\n", + "print(f\"transformers: {transformers.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copy and run the following code if you need to upgrade ipywidgets for `datasets` library and restart kernel. This is only needed when prerpocessing is done in the notebook.\n", + "\n", + "```python\n", + "%%capture\n", + "import IPython\n", + "!conda install -c conda-forge ipywidgets -y\n", + "# has to restart kernel for the updates to be applied\n", + "IPython.Application.instance().kernel.do_shutdown(True) \n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SageMaker environment " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sagemaker role arn: arn:aws:iam::875423407011:role/SageMakerRole\n", + "sagemaker bucket: sagemaker-us-west-2-875423407011\n", + "sagemaker session region: us-west-2\n" + ] + } + ], + "source": [ + "import sagemaker\n", + "\n", + "sess = sagemaker.Session()\n", + "\n", + "# SageMaker session bucket -> used for uploading data, models and logs\n", + "# SageMaker will automatically create this bucket if it does not exist\n", + "sagemaker_session_bucket = None\n", + "if sagemaker_session_bucket is None and sess is not None:\n", + " # set to default bucket if a bucket name is not given\n", + " sagemaker_session_bucket = sess.default_bucket()\n", + "\n", + "role = sagemaker.get_execution_role()\n", + "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n", + "\n", + "print(f\"sagemaker role arn: {role}\")\n", + "print(f\"sagemaker bucket: {sess.default_bucket()}\")\n", + "print(f\"sagemaker session region: {sess.boto_region_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SageMaker Training Job\n", + "\n", + "To create a SageMaker training job, we use a `HuggingFace` estimator. Using the estimator, you can define which fine-tuning script should SageMaker use through `entry_point`, which `instance_type` to use for training, which `hyperparameters` to pass, and so on.\n", + "\n", + "When a SageMaker training job starts, SageMaker takes care of starting and managing all the required machine learning instances, picks up the `HuggingFace` Deep Learning Container, uploads your training script, and downloads the data from `sagemaker_session_bucket` into the container at `/opt/ml/input/data`.\n", + "\n", + "In the following section, you learn how to set up two versions of the SageMaker `HuggingFace` estimator, a native one without the compiler and an optimized one with the compiler." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up an option for fine-tuning or full training. Set `FINE_TUNING = 1` for fine-tuning and using `fine_tune_with_huggingface.py`. Set `FINE_TUNING = 0` for full training and using `full_train_roberta_with_huggingface.py`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "EPOCHS = 1\n", + "\n", + "# SageMaker Training Compiler currently only supports training on GPU\n", + "# Select Instance type for training\n", + "INSTANCE_TYPE = \"ml.p4d.24xlarge\"\n", + "NUM_GPUS = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training with Native PyTorch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `train_batch_size` in the following code cell is the maximum batch that can fit into the memory of an `ml.g4dn.2xlarge` instance. If you change the model, instance type, and other parameters, you need to do some experiments to find the largest batch size that will fit into GPU memory." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.huggingface import HuggingFace\n", + "\n", + "kwargs = dict(\n", + " source_dir=\"scripts\",\n", + " instance_type=INSTANCE_TYPE,\n", + " role=role,\n", + " py_version=\"py38\",\n", + " disable_profiler=True,\n", + " debugger_hook_config=False,\n", + " volume_size=60,\n", + ")\n", + "\n", + "PER_DEVICE_BATCH_SIZE=248\n", + "cluster_size=1\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "200 pr-huggingface-pytorch-training-2022-08-23-21-10-36-273\n", + "208 pr-huggingface-pytorch-training-2022-08-23-21-10-36-995\n", + "216 pr-huggingface-pytorch-training-2022-08-23-21-10-40-499\n", + "224 pr-huggingface-pytorch-training-2022-08-23-21-10-41-041\n", + "232 pr-huggingface-pytorch-training-2022-08-23-21-10-44-361\n", + "240 pr-huggingface-pytorch-training-2022-08-23-21-10-45-961\n", + "248 pr-huggingface-pytorch-training-2022-08-23-21-10-47-342\n", + "256 pr-huggingface-pytorch-training-2022-08-23-21-10-49-527\n", + "264 pr-huggingface-pytorch-training-2022-08-23-21-10-53-981\n", + "272 pr-huggingface-pytorch-training-2022-08-23-21-10-54-513\n" + ] + } + ], + "source": [ + "from sagemaker.huggingface import HuggingFace\n", + "\n", + "\n", + "# The original LR was set for a batch of 8. Here we are scaling learning rate with batch size.\n", + "GLOBAL_BATCH_SIZE = PER_DEVICE_BATCH_SIZE * NUM_GPUS * cluster_size\n", + "LEARNING_RATE = float(\"2e-5\") / 8 * GLOBAL_BATCH_SIZE\n", + "\n", + "# configure the training job\n", + "huggingface_estimator = HuggingFace(\n", + " image_uri=\"669063966089.dkr.ecr.us-west-2.amazonaws.com/pr-huggingface-pytorch-training:1.11.0-transformers4.21.1-gpu-py38-cu113-ubuntu20.04-pr-1824-2022-08-08-10-57-02\",\n", + " instance_count=cluster_size,\n", + " entry_point='run_mim.py',\n", + " hyperparameters={\n", + " 'model_type': 'vit',\n", + " 'dataset_name': 'mnist',\n", + " 'output_dir': '/opt/ml/model',\n", + " 'overwrite_output_dir': True,\n", + " 'remove_unused_columns': 'False',\n", + " 'label_names' : 'bool_masked_pos',\n", + " 'do_train': True,\n", + " 'do_eval': False,\n", + " 'learning_rate': LEARNING_RATE,\n", + " 'weight_decay': 0.05,\n", + " 'num_train_epochs': EPOCHS,\n", + " 'per_device_train_batch_size': PER_DEVICE_BATCH_SIZE,\n", + " 'per_device_eval_batch_size': PER_DEVICE_BATCH_SIZE,\n", + " 'logging_strategy': 'epoch',\n", + " 'evaluation_strategy': 'no',\n", + " 'save_strategy': 'no',\n", + " 'save_total_limit': 3,\n", + " },\n", + " distribution={'smdistributed': {'dataparallel': {'enabled': True}}},\n", + " **kwargs,\n", + ")\n", + "\n", + "# start training with our uploaded datasets as input\n", + "huggingface_estimator.fit(wait=False)\n", + "\n", + "# The name of the training job.\n", + "print(huggingface_estimator.latest_training_job.name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training with Optimized PyTorch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compilation through Training Compiler changes the memory footprint of the model. Most commonly, this manifests as a reduction in memory utilization and a consequent increase in the largest batch size that can fit on the GPU. Note that if you want to change the batch size, you must adjust the learning rate appropriately.\n", + "\n", + "**Note:** We recommend you to turn the SageMaker Debugger's profiling and debugging tools off when you use compilation to avoid additional overheads." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "248 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-40-712\n", + "256 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-41-485\n", + "264 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-44-498\n", + "272 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-46-143\n", + "280 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-46-682\n", + "288 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-51-186\n", + "296 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-52-597\n", + "304 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-53-330\n", + "312 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-56-047\n", + "320 pr-huggingface-pytorch-trcomp-training-2022-08-23-21-45-56-676\n" + ] + } + ], + "source": [ + "from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig\n", + "TrainingCompilerConfig.validate = lambda *args, **kwargs:None\n", + "\n", + "NEW_PER_DEVICE_BATCH_SIZE=248\n", + "cluster_size=1\n", + "\n", + "# The original LR was set for a batch of 8. Here we are scaling learning rate with batch size.\n", + "GLOBAL_BATCH_SIZE = NEW_PER_DEVICE_BATCH_SIZE * NUM_GPUS * cluster_size\n", + "LEARNING_RATE = float(\"2e-5\") / 8 * GLOBAL_BATCH_SIZE\n", + "\n", + "# configure the training job\n", + "optimized_estimator = HuggingFace(\n", + " image_uri=\"669063966089.dkr.ecr.us-west-2.amazonaws.com/pr-huggingface-pytorch-trcomp-training:1.11.0-transformers4.21.1-gpu-py38-cu113-ubuntu20.04-pr-2032-2022-08-19-18-27-39\",\n", + " compiler_config=TrainingCompilerConfig(),\n", + " instance_count=cluster_size,\n", + " entry_point='run_mim.py',\n", + " hyperparameters={\n", + " 'model_type': 'vit',\n", + " 'dataset_name': 'mnist',\n", + " 'output_dir': '/opt/ml/model',\n", + " 'overwrite_output_dir': True,\n", + " 'remove_unused_columns': 'False',\n", + " 'label_names' : 'bool_masked_pos',\n", + " 'do_train': True,\n", + " 'do_eval': False,\n", + " 'learning_rate': LEARNING_RATE,\n", + " 'weight_decay': 0.05,\n", + " 'num_train_epochs': EPOCHS,\n", + " 'per_device_train_batch_size': NEW_PER_DEVICE_BATCH_SIZE,\n", + " 'per_device_eval_batch_size': PER_DEVICE_BATCH_SIZE,\n", + " 'logging_strategy': 'epoch',\n", + " 'evaluation_strategy': 'no',\n", + " 'save_strategy': 'no',\n", + " 'save_total_limit': 3,\n", + " 'sagemaker_pytorch_xla_multi_worker_enabled': True,\n", + " },\n", + " **kwargs,\n", + ")\n", + "\n", + "# start training with our uploaded datasets as input\n", + "optimized_estimator.fit(wait=False)\n", + "\n", + "# The name of the training job.\n", + "print(optimized_estimator.latest_training_job.name)" + ] + } + ], + "metadata": { + "instance_type": "ml.t3.medium", + "interpreter": { + "hash": "c281c456f1b8161c8906f4af2c08ed2c40c50136979eaae69688b01f70e9f4a9" + }, + "kernelspec": { + "display_name": "conda_pytorch_p38", + "language": "python", + "name": "conda_pytorch_p38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}