Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use FA's CrossEntropyLoss for metrics calculation #3214

Closed
wants to merge 13 commits into from
10 changes: 9 additions & 1 deletion composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,15 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.ignore_index = ignore_index
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss
self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
except:
log.debug(
'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' +
'to compute LanguageCrossEntropy metric, which will be slower.',
)
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum')

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# build requirements
[build-system]
requires = ["setuptools < 68.0.0"]
requires = ["setuptools < 68.0.0", "packaging >= 21.3.0, < 24.1"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gets around some of the installation errors, but torch is also a dependency, and then we go round and round in this circle trying to get build isolation working correctly. so i'm closing this PR, it should live foundry side.

build-backend = "setuptools.build_meta"

# iSort
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ def package_files(prefix: str, directory: str, extension: str):
'mlflow>=2.11.1,<3.0',
]

extra_deps['flash-attn'] = [
'flash-attn==2.5.8',
]

extra_deps['pandas'] = ['pandas>=2.0.0,<3.0']

extra_deps['databricks'] = ['databricks-sdk==0.25.1']
Expand Down
Loading