diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index 969ef1dffa..fe81f8edac 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -102,7 +102,15 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index - self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss + self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + except: + log.debug( + 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + + 'to compute LanguageCrossEntropy metric, which will be slower.', + ) + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum') diff --git a/pyproject.toml b/pyproject.toml index 3a6f3bb0e8..0f12f3b148 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ # build requirements [build-system] -requires = ["setuptools < 68.0.0"] +requires = ["setuptools < 68.0.0", "packaging >= 21.3.0, < 24.1"] build-backend = "setuptools.build_meta" # iSort diff --git a/setup.py b/setup.py index 4568ceed1e..7d045e9180 100644 --- a/setup.py +++ b/setup.py @@ -225,6 +225,10 @@ def package_files(prefix: str, directory: str, extension: str): 'mlflow>=2.11.1,<3.0', ] +extra_deps['flash-attn'] = [ + 'flash-attn==2.5.8', +] + extra_deps['pandas'] = ['pandas>=2.0.0,<3.0'] extra_deps['databricks'] = ['databricks-sdk==0.25.1']