-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix gpt-j inference issue #3639
Conversation
#3618 fixes the bug that caused GPT-J tests to get skipped! |
While this PR fixes the GPT-J model, it breaks Pythia models (which are based on GPT-J): import torch
import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-12b")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-12b")
model = deepspeed.init_inference(
model,
mp_size=2,
dtype=torch.half,
replace_with_kernel_inject=True
)
batch = tokenizer(
"This is a test prompt",
return_tensors="pt",
add_special_tokens=False
)
batch = {k: v.cuda() for k, v in batch.items()}
generated = model.generate(**batch, max_length=100)
print(tokenizer.decode(generated[0])) Output using this PR:
Output using current master (da8f4e0):
Ran this on A10s. ds_report:
|
thanks @Yard1 for helping me verifying the solution for other models, I have push some new changes to fix this now, please give it a try when you get a chance. thanks :) |
@RezaYazdaniAminabadi Thanks, seems to be fixed now 👍 |
* fix gpt-j inference issue for mlp_gemm_func call * bring back the gpt-j inference-test * fix formatting * fix the neox and pythia injection issue
This PR fixes the issue with GPT-J inference which runs into this error on master:
Which is because of choosing wrong kernel to run the mlp function for this model. As GPT-J has only one LayerNorm, we should have called
fused_gemm_gelu
here, however, since this parameter was not set correctly in the base container, we run into such issue.Also the unit test for the gpt-j is skipped and that's why we did not catch this error before! (cc: @jeffra / @mrwyattii)
Fixes #3604