-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RAG] Fix various generator issues #590
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The top_k
fix seems good and straightforward.
I have some doubts about the GPU fix though. With removing the .to(self.device)
calls, the tensors should still reside on the CPU. Haven't tested this, but I suppose this solves the error msg in #587 by not utilizing the GPU at all anymore. We should instead make sure that all required tensors are on the GPU (= adding more to(self.device)
calls). Let me know if I should jump in and help to code this up.
Yes please. I tried to debug but couldn't as I don't have GPU. But I would suggest to create another PR to enable GPU support for generator |
Ok then let's keep this one limited to the top_k fix (and remove the GPU changes). I will raise a new one for GPU support and fixing #587. |
I would suggest to keep these changes as well because they will allow people to use generator on CPU and GPU without error (even with low speed). |
Well, isn't there only an error if the user set |
Agree. I will remove GPU related changes and raise exception if user try generator with |
Now exception will be raised if user use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
Resolve #587 and #585
for
num_return_sequences
refer.https://github.com/huggingface/transformers/blob/a1bbcf3f6c20e15fe799a8659d6b7bd36fdf11ed/src/transformers/modeling_rag.py#L852
This should not be greater than
num_beams
hence when use passtop_k
value greater thannum_beams
we reset it tonum_beams
and print warning.