Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally use flash-attn's CE loss for metrics #3394

Merged
merged 16 commits into from
Jun 17, 2024
2 changes: 1 addition & 1 deletion .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.11-2.3
container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-doctest
Expand Down
3 changes: 3 additions & 0 deletions composer/devices/device_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.backends.cudnn
import torch.cuda
import torch.cuda.amp
import torch.distributed as torch_dist
import torch.utils.data

from composer.devices.device import Device
Expand Down Expand Up @@ -42,6 +43,8 @@ def __init__(
):
if not torch.cuda.is_available():
raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.')
if torch_dist.is_gloo_available():
DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo'
if device_id is None:
device_id = dist.get_local_rank()
self._device = torch.device(f'cuda:{device_id}')
Expand Down
22 changes: 20 additions & 2 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,21 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.ignore_index = ignore_index
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
self.flash_loss_fn = None
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss
log.debug(
'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' +
'to compute LanguageCrossEntropy metric for CUDA tensors, which will be faster.',
)
self.flash_loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
except ImportError:
if torch.cuda.is_available():
log.debug(
'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' +
'to compute LanguageCrossEntropy metric for CUDA tensors, which will be slower.',
)
self.torch_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum')

Expand All @@ -104,7 +118,11 @@ def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None:

target = target.view(-1)
logits = logits.view(target.shape[0], -1)
losses = self.loss_fn(logits, target)
# Use Flash attn's CE loss function, if available, if inputs are both CUDA tensors.
if self.flash_loss_fn is not None and target.is_cuda and logits.is_cuda:
losses = self.flash_loss_fn(logits, target)
else:
losses = self.torch_loss_fn(logits, target)

total_items = (target != self.ignore_index).sum()
self.total_items += total_items #type: ignore (third-party)
Expand Down
6 changes: 5 additions & 1 deletion tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
import torch
import torch.distributed as torch_dist
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import adam
Expand Down Expand Up @@ -530,7 +531,10 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz
assert 'model_name' in metadata_sd

assert 'dist_backend' in metadata_sd
assert metadata_sd['dist_backend'] == 'nccl'
if torch_dist.is_gloo_available():
assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo'
else:
assert metadata_sd['dist_backend'] == 'nccl'


@pytest.mark.filterwarnings('ignore:SWA has')
Expand Down
89 changes: 89 additions & 0 deletions tests/metrics/test_nlp_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
LanguagePerplexity,
MaskedAccuracy,
)
from tests.common import device


@pytest.mark.parametrize('ignore_index', [-100])
Expand Down Expand Up @@ -50,12 +51,100 @@ def test_masked_accuracy(ignore_index, num_classes):
assert abs(final_acc - (1.0 / num_classes)) < 0.02


@device('cpu', 'gpu')
@pytest.mark.parametrize('ignore_index', [-100])
@pytest.mark.parametrize('batch_size', [1e2, 1e3])
@pytest.mark.parametrize('sequence_length', [128])
@pytest.mark.parametrize('num_classes', [2, 10])
@pytest.mark.parametrize('minibatch_size', [56, 256, 768])
@pytest.mark.parametrize('tensor_device', ['cpu', 'gpu'])
def test_cross_entropy(
device: str,
batch_size: float,
ignore_index: Optional[int],
sequence_length: int,
num_classes: int,
minibatch_size: int,
tensor_device: str,
):
"""Sanity check to make sure that batched CrossEntropyLoss matches the expected performance.

Generates a predicted distribution from a normal distribution, and a ground truth from a normal distribution.
Verifies Cross Entropy Loss against the baseline performance.

Args:
device (str): the device to run the test on
batch_size (int): how many samples are in each batch
ignore_index (Optional[int]): if present, the class index to ignore in accuracy calculations.
sequence_length (int): the length of the generated sequence
num_classes (int): the number of classes in the classification task
minibatch_size (int): the minibatch size to simulate for model predictions
tensor_device (str): which device the input tensors to the metric are on
"""

if device == 'cpu' and tensor_device == 'gpu':
pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.')

batch_size = int(batch_size)
generated_preds = torch.randn((batch_size, sequence_length, num_classes))
generated_true = torch.randint(low=0, high=num_classes, size=(batch_size, sequence_length))

assert ignore_index is not None
torchmetrics_xent = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index)
ce_with_keys_metric = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index)

if tensor_device == 'cpu':
torchmetrics_xent = torchmetrics_xent.to('cpu')
ce_with_keys_metric = ce_with_keys_metric.to('cpu')
elif tensor_device == 'gpu':
torchmetrics_xent = torchmetrics_xent.to('cuda')
ce_with_keys_metric = ce_with_keys_metric.to('cuda')

if device == 'gpu':
assert torchmetrics_xent.flash_loss_fn is not None

labels_mask = torch.rand((batch_size, sequence_length))
labels_mask[labels_mask > 0.8] = 1
labels_mask[labels_mask <= 0.8] = 0
labels_mask = labels_mask.bool()
generated_true[labels_mask] = ignore_index

num_batches = math.ceil(batch_size / minibatch_size)
for batch_idx in range(num_batches):
begin_idx = (batch_idx * minibatch_size)
end_idx = ((batch_idx + 1) * minibatch_size)
preds_subset = generated_preds[begin_idx:end_idx]
true_subset = generated_true[begin_idx:end_idx]

if tensor_device == 'cpu':
preds_subset = preds_subset.cpu()
true_subset = true_subset.cpu()
elif tensor_device == 'gpu':
preds_subset = preds_subset.cuda()
true_subset = true_subset.cuda()

torchmetrics_xent.update(preds_subset, true_subset)
ce_with_keys_metric.update(
{
'logits': preds_subset.view(-1, num_classes),
'loss': cross_entropy(preds_subset.view(-1, num_classes), true_subset.view(-1)),
},
true_subset.view(-1),
)

torchmetrics_loss = torchmetrics_xent.compute()
ce_with_keys_loss = ce_with_keys_metric.compute()
correct_loss = cross_entropy(generated_preds.view(-1, num_classes), generated_true.view(-1))
assert torchmetrics_loss == ce_with_keys_loss
assert torch.isclose(correct_loss, torchmetrics_loss)


@pytest.mark.parametrize('ignore_index', [-100])
@pytest.mark.parametrize('batch_size', [1e2, 1e3])
@pytest.mark.parametrize('sequence_length', [128])
@pytest.mark.parametrize('num_classes', [2, 10])
@pytest.mark.parametrize('minibatch_size', [56, 256, 768])
def test_torch_cpu_cross_entropy(
batch_size: float,
ignore_index: Optional[int],
sequence_length: int,
Expand Down
Loading