-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optionally use flash-attn
's CE loss for metrics
#3394
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Holding review until after freeze
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we offer a flag to gate as well? IIRC there are occasionally numerics issues for long seq...
@ShashankMosaicML do u remember
Flash attention fixed the long seq issue in this PR: Dao-AILab/flash-attention@c79de85 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will review once CI passes
…into saaketh/fa_ce_loss merging origin
Seeing the error below on CPU tests:
So i'm gonna add a check for |
jk. The torch 3.11 cpu tests were using the cuda image on accident, causing this problem. It was only the torch 3.11 tests too. Fixed that in this PR as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unit tests for this before merging please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
holding till offline discussion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Noting that changing the GPU backend to provide both mirrors PyTorch's default behavior, which initializes both a GPU and a CPU dist backend.
Added manual test names to PR description |
This reverts commit 2cf9262. revert dat boi
* yo * slam * cuda * cuda checks * test * fix_test * gloo * gloo * lint * lint --------- Co-authored-by: Daniel King <[email protected]> Co-authored-by: Mihir Patel <[email protected]>
)" (mosaicml#3408) This reverts commit 2cf9262. revert dat boi
* Revert "Optionally use `flash-attn`'s CE loss for metrics (mosaicml#3394)" This reverts commit 2cf9262. revert dat boi * remove * slamm
* yo * slam * cuda * cuda checks * test * fix_test * gloo * gloo * lint * lint --------- Co-authored-by: Daniel King <[email protected]> Co-authored-by: Mihir Patel <[email protected]>
What does this PR do?
Resubmission of #3214 -- using FA's CE Loss results in lower peak reserved memory usage and higher throughput. We are not adding flash attention as an optional dependency to composer since this makes installs and correct builds messy & take a lot longer.
Fixed a small typo where the torch 3.11 CPU tests were using the GPU image with flash attn installed by accident.
Also modified
DeviceGPU
class so that it instantiates agloo
backend for CPU tensors, ifgloo
is available. This handles cases where users may want to perform distributed operations with tensors present on CPU even if they are using GPUs.Manual tests:
13b-dense-fsdp-fullshard-hsdp-adam-shardedckpt-start-5PtEdK
), resumed with this branch (13b-dense-fsdp-fullshard-hsdp-adam-shardedckpt-resume-E5SieL
)13b-dense-fsdp-fullshard-hsdp-adam-shardedckpt-start-0g8uD4
), resumed with dev branch (13b-dense-fsdp-fullshard-hsdp-adam-shardedckpt-resume-TSGoUC
)4th time's the charm :0
Run with torch CE loss (green): tiny-sp-dtms1-32h-wCFWfa
Run with FA CE loss (tan): tiny-sp-dtms1-32h-jOfIPL
What issue(s) does this change relate to?
Before submitting
pre-commit
on your change? (see thepre-commit
section of prerequisites)