Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUBLAS_STATUS_EXECUTION_FAILED when calling cublasGemmEx when using deepspeed tranformer kernel #294

Closed
tyler-romero opened this issue Jul 17, 2020 · 11 comments
Assignees

Comments

@tyler-romero
Copy link

Hi, I'm running into the following error when attempting to train with the deepspeed transformer kernel.

This error occurs during the forward pass of the first training step.
!!!! kernel execution error. is printed for 80 lines, followed by this traceback:

Traceback (most recent call last):
  File "RunPretrain.py", line 127, in <module>
    deepspeed_train.main(args)
  File "[REDACTED]/deepspeed_train.py", line 751, in main
    run(args, model, optimizer, start_epoch)
  File "[REDACTED]/deepspeed_train.py", line 700, in run
    train(args, index, model, optimizer)
  File "[REDACTED]/deepspeed_train.py", line 329, in train
    loss = model.network(batch)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/pt/deepspeed_light.py", line 689, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "[REDACTED]/nvidia/modeling.py", line 1110, in forward
    checkpoint_activations=checkpoint_activations)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "[REDACTED]/nvidia/modeling.py", line 1025, in forward
    pooled_output = self.pooler(sequence_output)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "[REDACTED]/nvidia/modeling.py", line 635, in forward
    pooled_output = self.dense_act(first_token_tensor)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "[REDACTED]/nvidia/modeling.py", line 207, in forward
    return bias_tanh(self.bias, F.linear(input, self.weight, None))
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1371, in linear
    output = input.matmul(weight.t())
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, &fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)`

If I disable the use of the deepspeed transformer kernel, everything works just fine.

I'm using a slightly modified version of the provided dockerfile:

FROM nvidia/cuda:10.0-devel-ubuntu18.04

##############################################################################
# Installation/Basic Utilities
##############################################################################
RUN apt-get update && \
    apt-get install -y --no-install-recommends \
    software-properties-common \
    openssh-client openssh-server \
    pdsh curl sudo net-tools \
    vim iputils-ping wget

##############################################################################
# Installation Latest Git
##############################################################################
RUN add-apt-repository ppa:git-core/ppa -y && \
    apt-get update && \
    apt-get install -y git && \
    git --version

##############################################################################
# Python
##############################################################################
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHON_VERSION=3
RUN apt-get install -y python3 python3-dev && \
    rm -f /usr/bin/python && \
    ln -s /usr/bin/python3 /usr/bin/python && \
    curl -O https://bootstrap.pypa.io/get-pip.py && \
        python get-pip.py && \
        rm get-pip.py && \
    pip install --upgrade pip && \
    # Print python an pip version
    python -V && pip -V

##############################################################################
# MXNet
##############################################################################
ENV MXNET_VERSION=1.5.0
RUN pip install mxnet-cu100==${MXNET_VERSION}

##############################################################################
# TensorFlow
##############################################################################
ENV TENSORFLOW_VERSION=1.15.2
RUN pip install tensorflow-gpu==${TENSORFLOW_VERSION}

##############################################################################
# PyTorch
##############################################################################
ENV PYTORCH_VERSION=1.2.0
ENV TORCHVISION_VERSION=0.4.0
ENV TENSORBOARDX_VERSION=1.8
RUN pip install torch==${PYTORCH_VERSION}
RUN pip install torchvision==${TORCHVISION_VERSION}
RUN pip install tensorboardX==${TENSORBOARDX_VERSION}

##############################################################################
# Temporary Installation Directory
##############################################################################
ENV STAGE_DIR=/tmp
RUN mkdir -p ${STAGE_DIR}

##############################################################################
# Mellanox OFED
##############################################################################
ENV MLNX_OFED_VERSION=4.6-1.0.1.1
RUN apt-get install -y libnuma-dev
RUN cd ${STAGE_DIR} && \
    wget -q -O - http://www.mellanox.com/downloads/ofed/MLNX_OFED-${MLNX_OFED_VERSION}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64.tgz | tar xzf - && \
    cd MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64 && \
    ./mlnxofedinstall --user-space-only --without-fw-update --all -q && \
    cd ${STAGE_DIR} && \
    rm -rf ${STAGE_DIR}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64*

##############################################################################
# nv_peer_mem
##############################################################################
RUN mkdir -p ${STAGE_DIR} && \
    git clone https://github.com/Mellanox/nv_peer_memory.git ${STAGE_DIR}/nv_peer_memory && \
    cd ${STAGE_DIR}/nv_peer_memory && \
    ./build_module.sh && \
    cd ${STAGE_DIR} && \
    tar xzf ${STAGE_DIR}/nvidia-peer-memory_1.0.orig.tar.gz && \
    cd ${STAGE_DIR}/nvidia-peer-memory-1.0 && \
    apt-get install -y dkms && \
    dpkg-buildpackage -us -uc && \
    dpkg -i ${STAGE_DIR}/nvidia-peer-memory_1.0-9_all.deb

##############################################################################
# Install OpenMPI
##############################################################################
RUN mkdir -p ${STAGE_DIR}/openmpi && \
    cd ${STAGE_DIR}/openmpi && \
    wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.1.tar.gz && \
    tar zxf openmpi-4.0.1.tar.gz && \
    cd openmpi-4.0.1 && \
    ./configure --enable-orterun-prefix-by-default && \
    make -j $(nproc) all && \
    make install && \
    ldconfig && \
    rm -rf ${STAGE_DIR}/openmpi

##############################################################################
# Ucomment and set SSH Daemon port
##############################################################################
RUN mkdir -p /var/run/sshd
RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \
    echo "    StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \
    mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config

ENV SSH_PORT=2222
RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \
    sed "0,/^#Port 22/s//Port ${SSH_PORT}/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config

##############################################################################
# DeepSpeed
##############################################################################
RUN git clone https://github.com/microsoft/DeepSpeed.git ${STAGE_DIR}/DeepSpeed
RUN cd ${STAGE_DIR}/DeepSpeed && \
    git checkout . && \
    git checkout master && \
    ./install.sh --allow_sudo --pip_sudo
RUN rm -rf ${STAGE_DIR}/DeepSpeed
RUN python -c "import deepspeed; print(deepspeed.__version__)"

##############################################################################
# Install Additional Python Libs
##############################################################################
RUN pip install future typing
RUN pip install numpy scipy pandas h5py tqdm \
    scikit-learn pytest boto3 filelock \
    tokenizers requests regex mpi4py dill

##############################################################################
# Install Horovod
##############################################################################
RUN ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs && \
    HOROVOD_GPU_ALLREDUCE=NCCL \
    HOROVOD_GPU_BROADCAST=NCCL \
    HOROVOD_WITH_TENSORFLOW=1 \
    HOROVOD_WITH_PYTORCH=1 \
    HOROVOD_WITH_MXNET=1          \
    pip install --no-cache-dir horovod && \
    ldconfig

##############################################################################
# Add-ons
##############################################################################
RUN pip install fastparquet
RUN pip install --no-cache-dir azureml-defaults

SHELL [ "/bin/bash", "-cu" ]

My deepspeed config looks like this:

{
  "train_batch_size": 16384,
  "train_micro_batch_size_per_gpu": 32,
  "steps_per_print": 1000,
  "prescale_gradients": false,
  "optimizer": {
    "type": "Lamb",
    "params": {
      "lr": 11e-3,
      "weight_decay": 0.01,
      "bias_correction": false,
      "max_coeff": 0.3,
      "min_coeff": 0.01
    }
  },
  "gradient_clipping": 1.0,
  "wall_clock_breakdown": false,
  "fp16": {
    "enabled": true,
    "loss_scale": 0
  }
}

This issue seems to indicate it may be a bug in the versions of Cuda and PyTorch that are used:
pytorch/pytorch#24018

And this one indicates that it may have to do with fp16 casting:
NVIDIA/apex#580

Any help would be appriciated! Thanks.

@tyler-romero
Copy link
Author

If I disable fp16, I get this error after the same python stack trace:

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`

@RezaYazdaniAminabadi
Copy link
Contributor

RezaYazdaniAminabadi commented Jul 25, 2020

Hi @tyler-romero,

Could you please try running one of the test units that we have for the kernels, like "DeepSpeed/tests/unit/test_cuda_forward.py"? This files includes several tests for our kernel with different batch, sequence and hidden-dim and both for fp16 and fp32. From the result of this test, we can see if the problem is on the kernel side. Since from your log, I am seeing that the error comes from a matmul initiated in line 207 of your modeling file! Is this executed after calling the transformer kernel? I wonder if there might be any size mismatch when calling the Cublas library. You can also try printing the shape of the tensors, such as input and self.weight.
Also, could you please send me the modeling file that you are using, so that I can try repro this issue on my side? I assume this is not one of our DeepSpeed examples.
Thanks.

Best regards,
Reza

@yselivonchyk
Copy link

Encountered same issue after a few thousand iterations:

Iteration:  31%|███▏      | 490/1559 [04:34<07:11,  2.48it/s][1,0]<stdout>:bing_bert_progress: step=3599, loss=1.77734375, lr=0.003918867386312817, sample_count=235864064, step_time=0.3878452777862549
Iteration:  31%|███▏      | 491/1559 [04:35<07:28,  2.38it/s][1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stderr>:!!!! kernel execution error.
[1,290]<stdout>:Traceback (most recent call last):
[1,290]<stdout>:  File "deepspeed_train_herring.py", line 520, in <module>
[1,290]<stdout>:    main()
[1,290]<stdout>:  File "deepspeed_train_herring.py", line 513, in main
[1,290]<stdout>:    run(args, model, optimizer, start_epoch)
[1,290]<stdout>:  File "deepspeed_train_herring.py", line 478, in run
[1,290]<stdout>:    train(args, index, model, optimizer, pretrain_dataset_provider)
[1,290]<stdout>:  File "deepspeed_train_herring.py", line 166, in train
[1,290]<stdout>:    loss = model.network(batch)
[1,290]<stdout>:  File "/shared/pytorch_efa/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
[1,290]<stdout>:    result = self.forward(*input, **kwargs)
[1,290]<stdout>:  File "/shared/pytorch_efa/lib/python3.7/site-packages/deepspeed/pt/deepspeed_light.py", line 741, in forward
[1,290]<stdout>:    loss = self.module(*inputs, **kwargs)
[1,290]<stdout>:  File "/shared/pytorch_efa/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
[1,290]<stdout>:    result = self.forward(*input, **kwargs)
[1,290]<stdout>:  File "/shared/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 1056, in forward
[1,290]<stdout>:    checkpoint_activations=checkpoint_activations)
[1,290]<stdout>:  File "/shared/pytorch_efa/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
[1,290]<stdout>:    result = self.forward(*input, **kwargs)
[1,290]<stdout>:  File "/shared/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 979, in forward
[1,290]<stdout>:    pooled_output = self.pooler(sequence_output)
[1,290]<stdout>:  File "/shared/pytorch_efa/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
[1,290]<stdout>:    result = self.forward(*input, **kwargs)
[1,290]<stdout>:  File "/shared/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 633, in forward
[1,290]<stdout>:    pooled_output = self.dense_act(first_token_tensor)
[1,290]<stdout>:  File "/shared/pytorch_efa/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
[1,290]<stdout>:    result = self.forward(*input, **kwargs)
[1,290]<stdout>:  File "/shared/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 207, in forward
[1,290]<stdout>:    return bias_tanh(self.bias, F.linear(input, self.weight, None))
[1,290]<stdout>:  File "/shared/pytorch_efa/lib/python3.7/site-packages/torch/nn/functional.py", line 1612, in linear
[1,290]<stdout>:    output = input.matmul(weight.t())
[1,290]<stdout>:RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, &fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)`

@RezaYazdaniAminabadi
Copy link
Contributor

Hi yselivonchyk,

That's so odd that this happens after many iterations!
Can you tell me what batch size and sequence-length you are using for this test? Also, are you using fp16?
I want to see if I can reproduce the same issue as you mentioned. So, if that's possible also provide your train-script too, so that I can run on my end.
Thanks.

Best regards,
Reza

@tyler-romero
Copy link
Author

@RezaYazdaniAminabadi true that the RuntumeError in the stack trace occurs at the matmul, but the !!!! kernel execution error. printout prior to the error comes from DeepSpeed transformer code:

https://github.com/microsoft/DeepSpeed/blob/8353c5949d395d27c898c78fd3ae72ccfe878c26/csrc/transformer/cublas_wrappers.cu

I'm pretty occupued at work right now, but hopefully in the somewhat near future I can try printing those tensor shapes and getting a repro. The modeling file I'm using is a very slightly modified version of the one in the BingBertSquad example.

@RezaYazdaniAminabadi
Copy link
Contributor

Hi Tyler,

We have seen the same issues with some other benchmarks too! Normally, this happens when there is high pressure on the memory and the cublas gemm crashes due to this error: CUBLAS_STATUS_EXECUTION_FAILED=13. I wonder if you can verify this on your end by printing the error number at the same file you've mentioned. Also, can you tell me what memory consumption you will have when running the training? If that's close to the peak of the total memory on GPU, there might be the risk of crashing due to memory allocation issues.
Please let me know when you have time to fix this issue together.
Thanks.

Best regards,
Reza

@tyler-romero
Copy link
Author

tyler-romero commented Oct 6, 2020

I've since updated to Deepspeed V0.3, so now the printout is a bit better. It does seem that the error=13. I was using a 24 transformer block model, with a batch size of 32 on each gpu. I cut the number of layers to 6 and the batch size to 16, but I still see the same error.

I dont have the % of memory calculated, but the 24 layer 32 batch size params worked just fine with the pytorch implementation of a transformer block that exists in bing_bert. I would think that cutting the number of layers by 4 and the batch size by 2 would take away the memory pressure.

Is there any other reason why error=13 might be thrown?

!!!! kernel execution error. (m: 128, n: 128, k: 64, error: 13) 
!!!! kernel execution error. (m: 64, n: 128, k: 128, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 4096, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 4096, error: 13) 
!!!! kernel execution error. (m: 3072, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 128, n: 128, k: 64, error: 13) 
!!!! kernel execution error. (m: 64, n: 128, k: 128, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 4096, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 4096, error: 13) 
!!!! kernel execution error. (m: 3072, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 128, n: 128, k: 64, error: 13) 
!!!! kernel execution error. (m: 64, n: 128, k: 128, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 4096, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 4096, error: 13) 
!!!! kernel execution error. (m: 3072, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 128, n: 128, k: 64, error: 13) 
!!!! kernel execution error. (m: 64, n: 128, k: 128, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 4096, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 4096, error: 13) 
!!!! kernel execution error. (m: 3072, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 128, n: 128, k: 64, error: 13) 
!!!! kernel execution error. (m: 64, n: 128, k: 128, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 4096, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 4096, error: 13) 
!!!! kernel execution error. (m: 3072, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 128, n: 128, k: 64, error: 13) 
!!!! kernel execution error. (m: 64, n: 128, k: 128, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 4096, n: 2048, k: 1024, error: 13) 
!!!! kernel execution error. (m: 1024, n: 2048, k: 4096, error: 13)

Edit:
This is the printout from torch.cuda.memory_summary from one of the gpus, after allocating a model with only 2 transformer blocks (still fails with same error).

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  394930 KB |  564130 KB |  651133 KB |  256202 KB |
|       from large pool |  393340 KB |  561278 KB |  618500 KB |  225160 KB |
|       from small pool |    1590 KB |    2852 KB |   32633 KB |   31042 KB |
|---------------------------------------------------------------------------|
| Active memory         |  394930 KB |  564130 KB |  651133 KB |  256202 KB |
|       from large pool |  393340 KB |  561278 KB |  618500 KB |  225160 KB |
|       from small pool |    1590 KB |    2852 KB |   32633 KB |   31042 KB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |  573440 KB |  573440 KB |  573440 KB |       0 B  |
|       from large pool |  569344 KB |  569344 KB |  569344 KB |       0 B  |
|       from small pool |    4096 KB |    4096 KB |    4096 KB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |   33101 KB |   50417 KB |  363121 KB |  330020 KB |
|       from large pool |   32644 KB |   50116 KB |  331050 KB |  298406 KB |
|       from small pool |     457 KB |    1969 KB |   32071 KB |   31614 KB |
|---------------------------------------------------------------------------|
| Allocations           |     192    |     270    |    2464    |    2272    |
|       from large pool |      44    |      65    |      75    |      31    |
|       from small pool |     148    |     205    |    2389    |    2241    |
|---------------------------------------------------------------------------|
| Active allocs         |     192    |     270    |    2464    |    2272    |
|       from large pool |      44    |      65    |      75    |      31    |
|       from small pool |     148    |     205    |    2389    |    2241    |
|---------------------------------------------------------------------------|
| GPU reserved segments |      35    |      35    |      35    |       0    |
|       from large pool |      33    |      33    |      33    |       0    |
|       from small pool |       2    |       2    |       2    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      13    |      14    |    1220    |    1207    |
|       from large pool |       6    |       7    |      28    |      22    |
|       from small pool |       7    |       8    |    1192    |    1185    |
|===========================================================================|

I'm running this on P100s, so 16gb gpu memory each. Seems like only 0.573gb are being used here.

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @tyler-romero,

Thanks for trying this again. I see that the dimensions are all okay and it is very strange why the Cublas GeMM library is giving this error. I wonder if this is the only error you are getting or is there anything else maybe hidden in your log. Regarding the error 13 for the cublas GEMM, the only explanation I could find is the bellow one, which is not so much helpful!

CUBLAS_STATUS_EXECUTION_FAILED The GPU program failed to execute. This is often caused by a launch failure of the kernel on the GPU, which can be caused by multiple reasons. To correct: check that the hardware, an appropriate version of the driver, and the cuBLAS library are correctly installed.

I have also made a PR to fix a similar issue in some of our examples for BingBertSquad: microsoft/DeepSpeedExamples#58

I wonder if you pass the local_rank argument when running with the transformer kernel? Also, I have seen that this error is sometimes due to having the parameters in FP16 and the input is passed in as FP32. I think the P100 architecture does not support fp16. Could you please check this?

If everything is still fine with your test environment, I wonder if you have the option to run this test on a different GPU hardware, just to rule out the hardware issue!

Thanks,
Reza

@tyler-romero
Copy link
Author

Hi, thanks for the response Reza,

I am passing in local_rank to the transformer kernel (was following the modeling example from bing_bert). Also I double checked the docs and can confirm that P100 does support fp16. I will double check that the input matches the parameter precision soon.

I will also double check my docker file to see if things are installed correctly.

I am also doing something a bit unusual when launching training. I am a Microsoft employee, so could we discuss this method of launching deepspeed offline so I can share my code?

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @tyler-romero

Thanks in advance for checking the configuration and also the data types.
Of course, we can discuss this offline so that I learn more about the issue and hopefully fix it soon :)
Please let me know when you want to meet and discuss this further.
Thanks.

Reza

@tyler-romero
Copy link
Author

Issue fixed offline. Problem was with the specific GPU architecture. After more testing we noticed it worked fine on V100s, but not P100s. This PR contains the fix.

Thanks Reza and Jeff!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants