-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix BertScore on GPU #912
Fix BertScore on GPU #912
Conversation
Codecov Report
@@ Coverage Diff @@
## master #912 +/- ##
=====================================
Coverage 95% 95%
=====================================
Files 173 173
Lines 7344 7344
=====================================
+ Hits 6946 6963 +17
+ Misses 398 381 -17 |
@SkafteNicki @Borda Won't be there any problem with GPU memory when the embeddings are going to be stored in its RAM? |
@stancld true, i guess the we need to move to GPU at some later point in the computation? |
Hi @SkafteNicki, I'd rather suggest the following :]. I can see that out = torch.einsum("blsd, bs -> blsd", out, processed_attention_mask)
embeddings_list.append(out.cpu())
# Calculate weighted (w.r.t. sentence length) input_ids IDF matrix
input_ids_idf = (
batch["input_ids_idf"] * processed_attention_mask if idf else processed_attention_mask.type(out.dtype)
)
input_ids_idf /= input_ids_idf.sum(-1, keepdim=True)
- idf_scale_list.append(input_ids_idf)
+ idf_scale_list.append(input_ids_idf.cpu()) |
@SkafteNicki @Borda Please, have a look at minor change I made. I believe it can be marked as ready for review :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a test to cover this case? 🐰
Should be tested thanks to a test parameter added by @SkafteNicki: @pytest.mark.parametrize("device", ["cpu", "cuda"]) Tests fail without our code change and with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Haven't looked at the implementation, but I assume it relies on .to(device=device)
, where we should also test for no explicit device
Co-authored-by: Justus Schock <[email protected]>
What does this PR do?
Fixes #909
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃