-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: data parallel inference examples #2805
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py 2024-05-02 00:29:27.054073+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py 2024-05-02 00:31:18.785078+00:00
@@ -13,12 +13,26 @@
distributed_state = PartialState()
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)
-model.forward = torch.compile(model.forward, backend="torch_tensorrt", options={"truncate_long_and_double": True, "enabled_precisions": {torch.float16}, "debug": True}, dynamic=False,)
+model.forward = torch.compile(
+ model.forward,
+ backend="torch_tensorrt",
+ options={
+ "truncate_long_and_double": True,
+ "enabled_precisions": {torch.float16},
+ "debug": True,
+ },
+ dynamic=False,
+)
with distributed_state.split_between_processes([input_id1, input_id2]) as prompt:
cur_input = torch.clone(prompt[0]).to(distributed_state.device)
- gen_tokens = model.generate(cur_input, do_sample=True, temperature=0.9, max_length=100,)
+ gen_tokens = model.generate(
+ cur_input,
+ do_sample=True,
+ temperature=0.9,
+ max_length=100,
+ )
gen_text = tokenizer.batch_decode(gen_tokens)[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Need a requirements.txt
- Annotate the script with description of whats happening https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/torch_compile_advanced_usage.py
- Add a reference to index.rst so that it gets rendered in the docs:
Line 113 in 12e885a
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
4bc05b7
to
dfbf6ea
Compare
dfbf6ea
to
7b4b504
Compare
@bowang007 You didn't properly clean up the merge conflicts, therefore db24b3b had |
Description
This PR shows a simple example about using
accelerate
library for data parallel inference.Checklist: