-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Generate base class for better integration of distributed inference #1355
[WIP] Generate base class for better integration of distributed inference #1355
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1355
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 144c81c with merge base 2fcc37c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename SingleGpuGenerator -> LocalGenerator, or some other name that reflects the broader use of this class, please?
Hi @Jack-Khuu would be great to get your opinion on this PR as well because it does some more aggressive refactoring to integrate distributed inference. Its only a step into the direction to keep PR size small. The final goal is to have LocalGenerator and DistributedGenerator use a unified interface. Then we can reuse the code base for preparing the data etc by sharing the chat() method which would eventually also be raise into the base class as well. This branch contains the next PR where I aligned the generate() interface. Next step is to refactor chat() to support both Local- and DistributedGenerator. |
Thanks for putting up the PR @mreso, conceptually this PR makes sense and will make the Distributed integration smoother, though I have ideas to refactor the overall Generator architecture during Q1. Main idea is that I want to avoid abstractions classes wherever possible since it hurts "copy and pastability". To temporarily avoid code duplication and move the distributed integration along, we can ride with the LocalGenerator concept for now. The CI is currently borked from what I suspect is a Nvidia bug: huggingface/diffusers#9704 But I'll force land this PR if I find the CI fix non-trivial |
CI is fixed The changes with LocalGenerator look intuitive to me and I see that generate is going to be pushed up a layer into Generator in your next PR which is great. Left one question about |
def _gen_model_input( | ||
self, | ||
prompt: Union[str | List[Any]], | ||
image_prompts: Optional[List[str | Image.Image]] = None, | ||
max_new_tokens: Optional[int] = None, | ||
max_seq_len: Optional[int] = 2048, | ||
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something similar to generator was something we were looking into for Q4/H1, so I'm glad to see the intial work here lining up with my mental model.
Is _gen_model_input
used in the DistributedGenerator
? If not let's keep this in the LocalGenerator
so that Generator
is more succinct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can be more aggressive with leaving things in LocalGenerator and Generator being thin, but that's something we can do later (but gen_model_input should be moved unless there's a need)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing and feedback. The "copy and pastability" design aspect of TorchChat got to me through the recent discussions. Regarding _gen_model_input, I was planning to only move things into Generator that I could reuse in distributed. In my dev branch I actually pulled chat into the base class to reuse it so _gen_model_input would need to live in Generator as well. I think in the end you could end up with only load_model, decode_one_token and decode_n_tokens, prefill and some model properties in the Local/DistributedGenerator which in some sense would make it "copy and pastability" again as everything lives in Generator.
Lets hold on merging this, and let me look into a more "copy and pastabile" approach. As we'll end up with only model specifics in the Local/DistributedGenerator we might get away with just touching the model level. |
Feel free to spin up a RFC/doc or notes. We'll gladly take a look since it's something we've been thinking about doing for a while now cc: @Gasoonjia |
Closing this in favor of #1381 |
This PR introduces a new Generator base class to better integrate distributed inference with the already present infra