-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use FA's CrossEntropyLoss for metrics calculation #3214
Conversation
#2987 this is the third time we've had this PR 😂 once in foundry, once in composer, and now again. @mvpatel2000 last time I thought we should put it in foundry. Do you want to take FA as a Composer dependency? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea Im fine adding it. @snarayan21 please add as an optional dependency
added as optional dep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
setup.py
Outdated
extra_deps['gpu-flash2'] = [ | ||
'flash-attn==2.5.0', | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why gpu-
instead of jsut flash2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should just be flash-attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wbout flash-attn
since its the package name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to flash-attn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya sounds good
…sed_ce_eval_metrics merging main
…l/composer into saaketh/fused_ce_eval_metrics hello
@@ -1,6 +1,6 @@ | |||
# build requirements | |||
[build-system] | |||
requires = ["setuptools < 68.0.0"] | |||
requires = ["setuptools < 68.0.0", "packaging >= 21.3.0, < 24.1"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gets around some of the installation errors, but torch is also a dependency, and then we go round and round in this circle trying to get build isolation working correctly. so i'm closing this PR, it should live foundry side.
Closed this PR since adding |
What does this PR do?
If available, we should use FA's CrossEntropyLoss for calculating the CE metric in composer. We already use it in foundry, and it's simply faster. Below we can see that the CE loss metric is the same, and there's a nice MFU boost as well.
Pink: using FA's CrossEntropyLoss
Green: using torch's CrossEntropyLoss
What issue(s) does this change relate to?
Before submitting
pre-commit
on your change? (see thepre-commit
section of prerequisites)