Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gpt-j inference issue #3639

Merged
merged 8 commits into from
Jun 7, 2023
Merged

Conversation

RezaYazdaniAminabadi
Copy link
Contributor

@RezaYazdaniAminabadi RezaYazdaniAminabadi commented May 30, 2023

This PR fixes the issue with GPT-J inference which runs into this error on master:

File "/vc_data_1/users/bapatra/ds-master/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py", line 68, in forward
    output, residual_add = self.mlp_gemm_func(
TypeError: mlp_gemm_fp16(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor, arg5: torch.Tensor, arg6: torch.Tensor, arg7: torch.Tensor, arg8: float, arg9: bool, arg10: bool, arg11: torch.Tensor, arg12: torch.Tensor, arg13: bool, arg14: int, arg15: bool) -> List[torch.Tensor]

Invoked with: tensor([[[-0.0549, -0.4070,  0.1932,  ...,  0.1614, -0.3074, -0.0302],
         [ 0.0147, -0.4172,  0.2135,  ...,  0.0264, -0.2727, -0.2778],
         [ 0.0952, -0.0844,  0.0266,  ..., -0.0730, -0.1125, -0.0426],
         ...,
         [ 0.3167, -0.0099,  0.2551,  ..., -0.3394, -0.0636,  0.0848],
         [-0.1514, -0.3774, -0.1428,  ..., -0.7607, -0.2854, -0.3149],
         [-0.1008, -0.0690,  0.0038,  ..., -0.3818, -0.1300, -0.2012]]],
       device='cuda:0', dtype=torch.float16), tensor([[[ 0.0086,  0.0225, -0.0049,  ..., -0.0125,  0.0019, -0.0077],
         [-0.0086, -0.0046, -0.0257,  ..., -0.0141, -0.0074, -0.0209],
         [ 0.0147,  0.0043,  0.0167,  ..., -0.0083,  0.0008,  0.0096],
         ...,
         [-0.0099,  0.0013,  0.0294,  ..., -0.0185,  0.0043, -0.0280],
         [ 0.0038,  0.0182,  0.0184,  ..., -0.0090, -0.0004, -0.0129],
         [ 0.0175,  0.0097,  0.0035,  ...,  0.0085, -0.0027, -0.0117]]],
       device='cuda:0', dtype=torch.float16), None, Parameter containing:
tensor([[-1.6632e-02,  2.8137e-02,  2.7332e-03,  ...,  1.8406e-03,
          9.3765e-03, -1.3969e-02],
        [-2.0233e-02, -1.3191e-02,  1.2337e-02,  ..., -1.7670e-02,
          5.6915e-03,  5.6648e-03],
        [-1.3817e-02, -7.0870e-05,  1.4000e-03,  ..., -1.3336e-02,
         -1.5976e-02, -5.7640e-03],
        ...,
        [ 4.2076e-03,  2.7115e-02,  7.9651e-03,  ..., -5.3024e-03,
         -4.6806e-03,  1.3374e-02],
        [-1.7128e-03,  6.7635e-03,  9.7351e-03,  ...,  1.0294e-04,
          1.8530e-03,  9.9258e-03],
        [-7.0534e-03,  3.6438e-02, -1.5087e-03,  ..., -2.2221e-03,
         -3.5973e-03,  6.6566e-04]], device='cuda:0', dtype=torch.float16), Parameter containing:
tensor([[-2.6798e-03,  2.1019e-03, -1.0033e-02,  ...,  8.4763e-03,
          1.2566e-02,  1.0633e-03],
        [ 5.3024e-03,  1.3947e-02,  1.8082e-02,  ...,  1.6922e-02,
          1.0347e-03,  1.2619e-02],
        [-3.6278e-03, -3.1681e-03, -4.5166e-03,  ...,  6.3972e-03,
          7.2975e-03,  3.9062e-03],
        ...,
        [-1.6251e-03,  2.3743e-02, -1.0262e-02,  ..., -7.4983e-05,
          3.5763e-07,  9.3842e-03],
        [-5.2719e-03, -4.3068e-03, -4.6654e-03,  ..., -8.1940e-03,
          1.7838e-02, -2.2182e-03],
        [-4.9925e-04,  1.2421e-02,  1.7490e-03,  ..., -9.2850e-03,
         -6.4240e-03,  1.7080e-03]], device='cuda:0', dtype=torch.float16), Parameter containing:
tensor([-0.0587, -0.0428, -0.0135,  ..., -0.0426, -0.0233, -0.0378],
       device='cuda:0', dtype=torch.float16), Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', dtype=torch.float16), Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', dtype=torch.float16), 1e-05, True, False, tensor([2.3694e-38]), tensor([2.3694e-38]), False, <ActivationFuncType.GELU: 1>, False

Which is because of choosing wrong kernel to run the mlp function for this model. As GPT-J has only one LayerNorm, we should have called fused_gemm_gelu here, however, since this parameter was not set correctly in the base container, we run into such issue.
Also the unit test for the gpt-j is skipped and that's why we did not catch this error before! (cc: @jeffra / @mrwyattii)

Fixes #3604

@mrwyattii
Copy link
Contributor

#3618 fixes the bug that caused GPT-J tests to get skipped!

@Yard1
Copy link

Yard1 commented Jun 1, 2023

While this PR fixes the GPT-J model, it breaks Pythia models (which are based on GPT-J):

import torch
import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-12b")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-12b")
model = deepspeed.init_inference(
    model,
    mp_size=2,
    dtype=torch.half,
    replace_with_kernel_inject=True
)

batch = tokenizer(
    "This is a test prompt",
    return_tensors="pt", 
    add_special_tokens=False
)
batch = {k: v.cuda() for k, v in batch.items()}
generated = model.generate(**batch, max_length=100)
print(tokenizer.decode(generated[0]))

Output using this PR:

This is a test promptmedscfS1C1S1C1S1C1S1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1MA1C1MA1CThis is a test promptmedscfS1C1S1C1S1C1S1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1M1C1MA1C1MA1C

Output using current master (da8f4e0):

This is a test prompt.

```js
var test = require('tap').test
var fs = require('fs')
var path = require('path')

var read = fs.readFile
var write = fs.writeFile

test('should get test', function (t) {
  read(
    path.join(
      path.dirname(
        path.resolve(__dirname, '..', 'test.js')
      ),This is a test prompt.

```js
var test = require('tap').test
var fs = require('fs')
var path = require('path')

var read = fs.readFile
var write = fs.writeFile

test('should get test', function (t) {
  read(
    path.join(
      path.dirname(
        path.resolve(__dirname, '..', 'test.js')
      ),

Ran this on A10s.

ds_report:

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/ray/anaconda3/lib/python3.10/site-packages/torch']
torch version .................... 2.0.1+cu118
deepspeed install path ........... ['/home/ray/anaconda3/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.9.3+96152df1, 96152df1, ds-inference/fix-gptj
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.8

@RezaYazdaniAminabadi
Copy link
Contributor Author

thanks @Yard1 for helping me verifying the solution for other models, I have push some new changes to fix this now, please give it a try when you get a chance. thanks :)

@Yard1
Copy link

Yard1 commented Jun 5, 2023

@RezaYazdaniAminabadi Thanks, seems to be fixed now 👍

@RezaYazdaniAminabadi RezaYazdaniAminabadi enabled auto-merge (squash) June 5, 2023 19:16
@RezaYazdaniAminabadi RezaYazdaniAminabadi merged commit 34a9fbf into master Jun 7, 2023
@mrwyattii mrwyattii deleted the ds-inference/fix-gptj branch June 7, 2023 21:50
molly-smith pushed a commit that referenced this pull request Jun 23, 2023
* fix gpt-j inference issue for mlp_gemm_func call

* bring back the gpt-j inference-test

* fix formatting

* fix the neox and pythia injection issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] GPT-J inference with kernel inject fails on master (7667988)
4 participants