Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Shard embedding hooks #80

Merged
merged 2 commits into from
Mar 15, 2023
Merged

[Bugfix] Shard embedding hooks #80

merged 2 commits into from
Mar 15, 2023

Conversation

comaniac
Copy link
Contributor

Description

When sharding word embedding, we use forward pre-hook to mask inputs for different devices in a TP group. Accordingly, we have to mask the invalid output using the same mask in forward post-hook. Since the input mask generated in pre-hook cannot be saved, we currently re-generate the mask in post-hook. However, at this moment the input is already masked, so we cannot re-generate a mask using the same logic.

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

cc @chhzh123 @zarzen

Copy link
Contributor

@chhzh123 chhzh123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This bug is pretty subtle. Thanks for fixing.

@comaniac comaniac merged commit 07e474a into awslabs:main Mar 15, 2023
@comaniac
Copy link
Contributor Author

Thanks @chhzh123

@comaniac comaniac deleted the fix_shard branch March 15, 2023 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants