From 2cf9262e988c7cc4ee107259b98efec0298c5017 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Mon, 17 Jun 2024 12:09:55 -0700 Subject: [PATCH] Optionally use `flash-attn`'s CE loss for metrics (#3394) * yo * slam * cuda * cuda checks * test * fix_test * gloo * gloo * lint * lint --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> Co-authored-by: Mihir Patel --- .github/workflows/pr-cpu.yaml | 2 +- composer/devices/device_gpu.py | 3 + composer/metrics/nlp.py | 22 ++++++- tests/checkpoint/test_state_dict.py | 6 +- tests/metrics/test_nlp_metrics.py | 89 +++++++++++++++++++++++++++++ 5 files changed, 118 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 1bdb383823..12f471749e 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -22,7 +22,7 @@ jobs: markers: not daily and not remote and not gpu and not doctest pytest_command: coverage run -m pytest - name: cpu-3.11-2.3 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04 markers: not daily and not remote and not gpu and not doctest pytest_command: coverage run -m pytest - name: cpu-doctest diff --git a/composer/devices/device_gpu.py b/composer/devices/device_gpu.py index 19cb0a774a..401368576e 100644 --- a/composer/devices/device_gpu.py +++ b/composer/devices/device_gpu.py @@ -12,6 +12,7 @@ import torch.backends.cudnn import torch.cuda import torch.cuda.amp +import torch.distributed as torch_dist import torch.utils.data from composer.devices.device import Device @@ -42,6 +43,8 @@ def __init__( ): if not torch.cuda.is_available(): raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.') + if torch_dist.is_gloo_available(): + DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo' if device_id is None: device_id = dist.get_local_rank() self._device = torch.device(f'cuda:{device_id}') diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index e6877292cf..c1562e5936 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -83,7 +83,21 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index - self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + self.flash_loss_fn = None + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss + log.debug( + 'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' + + 'to compute LanguageCrossEntropy metric for CUDA tensors, which will be faster.', + ) + self.flash_loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + except ImportError: + if torch.cuda.is_available(): + log.debug( + 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + + 'to compute LanguageCrossEntropy metric for CUDA tensors, which will be slower.', + ) + self.torch_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum') @@ -104,7 +118,11 @@ def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None: target = target.view(-1) logits = logits.view(target.shape[0], -1) - losses = self.loss_fn(logits, target) + # Use Flash attn's CE loss function, if available, if inputs are both CUDA tensors. + if self.flash_loss_fn is not None and target.is_cuda and logits.is_cuda: + losses = self.flash_loss_fn(logits, target) + else: + losses = self.torch_loss_fn(logits, target) total_items = (target != self.ignore_index).sum() self.total_items += total_items #type: ignore (third-party) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index af0ca34961..bd14154dc9 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -7,6 +7,7 @@ import pytest import torch +import torch.distributed as torch_dist from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim import adam @@ -530,7 +531,10 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz assert 'model_name' in metadata_sd assert 'dist_backend' in metadata_sd - assert metadata_sd['dist_backend'] == 'nccl' + if torch_dist.is_gloo_available(): + assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo' + else: + assert metadata_sd['dist_backend'] == 'nccl' @pytest.mark.filterwarnings('ignore:SWA has') diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 7fe854bd96..9b198003d3 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -14,6 +14,7 @@ LanguagePerplexity, MaskedAccuracy, ) +from tests.common import device @pytest.mark.parametrize('ignore_index', [-100]) @@ -50,12 +51,100 @@ def test_masked_accuracy(ignore_index, num_classes): assert abs(final_acc - (1.0 / num_classes)) < 0.02 +@device('cpu', 'gpu') @pytest.mark.parametrize('ignore_index', [-100]) @pytest.mark.parametrize('batch_size', [1e2, 1e3]) @pytest.mark.parametrize('sequence_length', [128]) @pytest.mark.parametrize('num_classes', [2, 10]) @pytest.mark.parametrize('minibatch_size', [56, 256, 768]) +@pytest.mark.parametrize('tensor_device', ['cpu', 'gpu']) def test_cross_entropy( + device: str, + batch_size: float, + ignore_index: Optional[int], + sequence_length: int, + num_classes: int, + minibatch_size: int, + tensor_device: str, +): + """Sanity check to make sure that batched CrossEntropyLoss matches the expected performance. + + Generates a predicted distribution from a normal distribution, and a ground truth from a normal distribution. + Verifies Cross Entropy Loss against the baseline performance. + + Args: + device (str): the device to run the test on + batch_size (int): how many samples are in each batch + ignore_index (Optional[int]): if present, the class index to ignore in accuracy calculations. + sequence_length (int): the length of the generated sequence + num_classes (int): the number of classes in the classification task + minibatch_size (int): the minibatch size to simulate for model predictions + tensor_device (str): which device the input tensors to the metric are on + """ + + if device == 'cpu' and tensor_device == 'gpu': + pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.') + + batch_size = int(batch_size) + generated_preds = torch.randn((batch_size, sequence_length, num_classes)) + generated_true = torch.randint(low=0, high=num_classes, size=(batch_size, sequence_length)) + + assert ignore_index is not None + torchmetrics_xent = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) + ce_with_keys_metric = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) + + if tensor_device == 'cpu': + torchmetrics_xent = torchmetrics_xent.to('cpu') + ce_with_keys_metric = ce_with_keys_metric.to('cpu') + elif tensor_device == 'gpu': + torchmetrics_xent = torchmetrics_xent.to('cuda') + ce_with_keys_metric = ce_with_keys_metric.to('cuda') + + if device == 'gpu': + assert torchmetrics_xent.flash_loss_fn is not None + + labels_mask = torch.rand((batch_size, sequence_length)) + labels_mask[labels_mask > 0.8] = 1 + labels_mask[labels_mask <= 0.8] = 0 + labels_mask = labels_mask.bool() + generated_true[labels_mask] = ignore_index + + num_batches = math.ceil(batch_size / minibatch_size) + for batch_idx in range(num_batches): + begin_idx = (batch_idx * minibatch_size) + end_idx = ((batch_idx + 1) * minibatch_size) + preds_subset = generated_preds[begin_idx:end_idx] + true_subset = generated_true[begin_idx:end_idx] + + if tensor_device == 'cpu': + preds_subset = preds_subset.cpu() + true_subset = true_subset.cpu() + elif tensor_device == 'gpu': + preds_subset = preds_subset.cuda() + true_subset = true_subset.cuda() + + torchmetrics_xent.update(preds_subset, true_subset) + ce_with_keys_metric.update( + { + 'logits': preds_subset.view(-1, num_classes), + 'loss': cross_entropy(preds_subset.view(-1, num_classes), true_subset.view(-1)), + }, + true_subset.view(-1), + ) + + torchmetrics_loss = torchmetrics_xent.compute() + ce_with_keys_loss = ce_with_keys_metric.compute() + correct_loss = cross_entropy(generated_preds.view(-1, num_classes), generated_true.view(-1)) + assert torchmetrics_loss == ce_with_keys_loss + assert torch.isclose(correct_loss, torchmetrics_loss) + + +@pytest.mark.parametrize('ignore_index', [-100]) +@pytest.mark.parametrize('batch_size', [1e2, 1e3]) +@pytest.mark.parametrize('sequence_length', [128]) +@pytest.mark.parametrize('num_classes', [2, 10]) +@pytest.mark.parametrize('minibatch_size', [56, 256, 768]) +def test_torch_cpu_cross_entropy( batch_size: float, ignore_index: Optional[int], sequence_length: int,