Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda 11 cannot be supported #536

Closed
hujian233 opened this issue Nov 18, 2020 · 15 comments · Fixed by #570
Closed

Cuda 11 cannot be supported #536

hujian233 opened this issue Nov 18, 2020 · 15 comments · Fixed by #570

Comments

@hujian233
Copy link

hujian233 commented Nov 18, 2020

I wanted to run Deepdpeed on RTX 3090, cudA 11 only on 3090, and in your docker release I updated Pytorch version to 1.7.0 and ran an error:

building GPT2 model ...
Traceback (most recent call last):
  File "pretrain_gpt2.py", line 711, in <module>
    main()
  File "pretrain_gpt2.py", line 659, in main
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
  File "pretrain_gpt2.py", line 158, in setup_model_and_optimizer
    model = get_model(args)
  File "pretrain_gpt2.py", line 69, in get_model
    parallel_output=True)
  File "/data/Megatron-LM/model/gpt2_modeling.py", line 81, in __init__
    checkpoint_num_layers)
  File "/data/Megatron-LM/mpu/transformer.py", line 384, in __init__
    [get_layer() for _ in range(num_layers)])
  File "/data/Megatron-LM/mpu/transformer.py", line 384, in <listcomp>
    [get_layer() for _ in range(num_layers)])
  File "/data/Megatron-LM/mpu/transformer.py", line 380, in get_layer
    output_layer_init_method=output_layer_init_method)
  File "/data/Megatron-LM/mpu/transformer.py", line 259, in __init__
    self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
  File "/usr/local/lib/python3.6/dist-packages/apex/normalization/fused_layer_norm.py", line 133, in __init__
    fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
  File "/usr/lib/python3.6/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 994, in _gcd_import
  File "<frozen importlib._bootstrap>", line 971, in _find_and_load
  File "<frozen importlib._bootstrap>", line 955, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 658, in _load_unlocked
  File "<frozen importlib._bootstrap>", line 571, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 922, in create_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: /usr/local/lib/python3.6/dist-packages/fused_layer_norm_cuda.cpython-36m-x86_64-linux-gnu.so: undefined symbol: _ZN6caffe26detail37_typeMetaDataInstance_preallocated_32E

how can I use it in 3090?

@jeffra
Copy link
Collaborator

jeffra commented Nov 18, 2020

We're in the process of updating our docker images. We're also still in the process of fully testing torch 1.7 + cuda 11. However, can you try this docker image we just pushed to docker hub? deepspeed/deepspeed:v031_torch17_cuda11

The core error I am seeing here seems to be coming from the Megatron example code which is trying to load an extension from Apex for fused-layer-norm. You might try re-installing apex via these steps: https://github.com/nvidia/apex#linux

@jeffra
Copy link
Collaborator

jeffra commented Nov 18, 2020

The above image I linked does not include apex since deepspeed core does not require it anymore. However, the Megatron example does rely on a few cuda/cpp extensions that are included in apex. I've pushed another docker image that installs the latest apex as well, you can grab it here: deepspeed/deepspeed:v031_torch17_cuda11_apex

@jeffra
Copy link
Collaborator

jeffra commented Nov 18, 2020

The dockerfiles for these two images are below (the 2nd image just adds the apex lines at the end):

FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel

RUN apt-get update && \
    apt-get install -y --no-install-recommends \
    software-properties-common pdsh ninja-build \
    llvm-9-dev cmake git && \
    rm -rf /var/lib/apt/lists/* && \
    apt-get purge --auto-remove && \
    apt-get clean

RUN DS_BUILD_OPS=1 pip install --no-cache -v deepspeed && \
    ds_report

# apex
RUN mkdir /tmp/stage && \
    git clone https://github.com/NVIDIA/apex /tmp/stage/apex && \
    cd /tmp/stage/apex && \
    pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ && \
    cd - && \
    rm -rf /tmp/stage

@hujian233
Copy link
Author

Thank you for your answer, I will try it and look forward to a good result.

@hujian233
Copy link
Author

hujian233 commented Nov 18, 2020

@jeffra Hello, I used the docker image you gave me above, and when testing the Megatron Example, I reported an error when Running bash scripts/pretrain_gpt2.sh for a single GPU:

Partition Activations False and Correctness Check False
Traceback (most recent call last):
  File "pretrain_gpt2.py", line 711, in <module>
    main()
  File "pretrain_gpt2.py", line 684, in main
    iteration, skipped = train(model, optimizer,
  File "pretrain_gpt2.py", line 410, in train
    lm_loss, skipped_iter = train_step(train_data_iterator, model, optimizer, lr_scheduler, args, timers)
  File "pretrain_gpt2.py", line 365, in train_step
    lm_loss = forward_step(data_iterator, model, args, timers)
  File "pretrain_gpt2.py", line 286, in forward_step
    output = model(tokens, position_ids, attention_mask)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/Megatron-LM/model/distributed.py", line 78, in forward
    return self.module(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/Megatron-LM/fp16/fp16.py", line 65, in forward
    return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/Megatron-LM/model/gpt2_modeling.py", line 94, in forward
    transformer_output = self.transformer(embeddings, attention_mask)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/Megatron-LM/mpu/transformer.py", line 411, in forward
    hidden_states = checkpoint(custom(l, l+chunk_length),
  File "/data/Megatron-LM/mpu/random.py", line 378, in checkpoint
    return CheckpointFunction.apply(function, *args)
  File "/data/Megatron-LM/mpu/random.py", line 314, in forward
    outputs = run_function(*inputs_cuda)
  File "/data/Megatron-LM/mpu/transformer.py", line 402, in custom_forward
    x_ = layer(x_, inputs[1])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/Megatron-LM/mpu/transformer.py", line 294, in forward
    mlp_output = self.mlp(layernorm_output)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/Megatron-LM/mpu/transformer.py", line 209, in forward
    intermediate_parallel = gelu(intermediate_parallel)
  File "/data/Megatron-LM/mpu/transformer.py", line 166, in gelu
    return gelu_impl(x)
RuntimeError: nvrtc: error: invalid value for --gpu-architecture (-arch)

When I use the 'bash scripts/ds_zero-offload_10B_pretrain_gpt2_model_parallel.sh' I got another error:

RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1603729096996/work/torch/lib/c10d/ProcessGroupNCCL.cpp:784, unhandled system error, NCCL n 2.7.8
    initialize_distributed(args)
  File "pretrain_gpt2.py", line 555, in initialize_distributed
    torch.distributed.init_process_group(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 442, in init_process_group
    torch.distributed.init_process_group(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 442, in init_process_group
    work = _default_pg.barrier()

I think it should be some version related problems. I don't know how to adapt to the version related problems. Could you give me some Suggestions?Thank you very much

@jeffra
Copy link
Collaborator

jeffra commented Nov 20, 2020

Hi @hujian233, I see. Just realized this is compute capability 8.0 (ampere).

The previous image I linked pre-builds our cpp/cuda extensions. Let's try an image where the ops will be built just-in-time (JIT) instead. I've pushed it here: deepspeed/deepspeed:v034-jitops-torch170-cuda11. Alternatively if you don't want to re-pull a new image you can pip uninstall deepspeed; pip install deepspeed this should remove the old ops and install the latest version from PyPI.

We're still doing initial testing on ampere with DeepSpeed. I just tested the previous image on an A100 and saw a sort of similar (but not exact same) error. After switching to the JIT compiled version it worked.

@hujian233
Copy link
Author

hujian233 commented Nov 20, 2020

Hi @jeffra ,thanks for your answer. I actually solved the above two problems yesterday, although I don't know why. I just do this:
the first problem, I noticed that the last error was in def gelu_impl(x),then I comment out the annotation @torch.jit.script,because I don‘t need C++ implementation now, and then it worked.

error:
     return gelu_impl(x)
RuntimeError: nvrtc: error: invalid value for --gpu-architecture (-arch)

operation:
# @torch.jit.script
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))

the second problem ,I set the docker shm-size 2048m and then the error is gone

error:
NCCL error in: /opt/conda/conda-bld/pytorch_1603729096996/work/torch/lib/c10d/ProcessGroupNCCL.cpp:784, unhandled system error, NCCL n 2.7.8
    initialize_distributed(args)

@hujian233
Copy link
Author

@jeffra hello, I installed the latest version of DeepSpeed 0.3.4 and run the ds_zero-offload_10B_pretrain_gpt2_model_parallel.sh, I get an error:

 ImportError: No module named 'cpu_adam'
Loading extension module cpu_adam...
Traceback (most recent call last):
  File "pretrain_gpt2.py", line 711, in <module>
    main()
  File "pretrain_gpt2.py", line 659, in main
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
  File "pretrain_gpt2.py", line 165, in setup_model_and_optimizer
    model, optimizer, _, lr_scheduler = deepspeed.initialize(
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/__init__.py", line 109, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 175, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 585, in _configure_optimizer
    basic_optimizer = self._configure_basic_optimizer(model_parameters)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 650, in _configure_basic_optimizer
    optimizer = DeepSpeedCPUAdam(model_parameters,
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/ops/adam/cpu_adam.py", line 70, in __init__
    self.ds_opt_adam = CPUAdamBuilder().load()
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/ops/op_builder/builder.py", line 156, in load
    return self.jit_load(verbose)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/ops/op_builder/builder.py", line 184, in jit_load
    op_module = load(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 969, in load
    return _jit_compile(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1196, in _jit_compile
    return _import_module_from_library(name, build_directory, is_python_module)
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1543, in _import_module_from_library
    file, path, description = imp.find_module(module_name, [path])
  File "/opt/conda/lib/python3.8/imp.py", line 296, in find_module
    raise ImportError(_ERR_MSG.format(name), name=name)
ImportError: No module named 'cpu_adam'

When This error occurs after I upgrade, I go back to 0.3.2 and still have this error, as if I can't do PIP directly.

@jeffra
Copy link
Collaborator

jeffra commented Nov 20, 2020

Hi @hujian233, in some of my local testing with this image I am also seeing strange issues with PIP and conda. I believe my issue was related to the docker build installing deepspeed with root and then at runtime (in my environment at least) I am running with a user without write permissions to /opt/conda/lib/python3.8/site-packages. I apologize, this is a different base image than we have been using for our production scenarios so it's not as well tested. I suspect you won't have the same issue (I hope?) if you try and use the new docker image, can you give it a try? Here's the image name: deepspeed/deepspeed:v034-jitops-torch170-cuda11

@hujian233
Copy link
Author

hujian233 commented Nov 23, 2020

@jeffra , I can't use the deepspeed/deepspeed:v034-jitops-torch170-cuda11 with the import error ImportError: No module named 'cpu_adam' which is I said above. I'm sure it's almost inevitable.

this is the log before:

Using /root/.cache/torch_extensions as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/cpu_adam...
Using /root/.cache/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -I/opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode arch=compute_86,code=compute_86 -c /opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/adam/custom_cuda_kernel.cu -o custom_cuda_kernel.cuda.o 
FAILED: custom_cuda_kernel.cuda.o 
/usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -I/opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode arch=compute_86,code=compute_86 -c /opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/adam/custom_cuda_kernel.cu -o custom_cuda_kernel.cuda.o 
nvcc fatal   : Unsupported gpu architecture 'compute_86'

@hujian233
Copy link
Author

@jeffra Hello, I now have some idea of the reason. I can use deepspeed/deepspeed:v031_torch17_cuda11_apex directly on my 3090 by the image of pre-builds CPP/Cuda Extensions you gave me before, But I can't use the just-in-time (JIT) feature with the architecture incompatibility :
like these "RuntimeError: nvrtc: error: invalid value for --gpu-architecture (-arch)" logs.

The following is a demonstration of two different versions of environment variables:
image

I guess,maybe the torch1.7.0,cuda11, CPP/Cuda Extensions can be build in your V100 gpu,but can't build in my 3090 gpu. I'll find out why, but could you first send me a docker image of the latest version that contains the pre-builds CPP/Cuda extension. I also recommend that you pre-builds the extensions when you package the images so that people can choose whether or not they want to recompile by JIT. Looking forward to your reply. Thank you very much.

@jeffra
Copy link
Collaborator

jeffra commented Dec 2, 2020

@hujian233 there's a recent PR from @stas00 that might help here as well. Can you give it a try?

@jeffra
Copy link
Collaborator

jeffra commented Dec 3, 2020

Actually I don't believe the previous linked PR is related here. However, PR #572 should fix your issues I believe.

@hujian233
Copy link
Author

@jeffra ,Hi,sounds great, I had tried the same thing last night using the cuda 11.1 and torch 1.8.0 without version check. It compiled successs. I will try the latest code, thank you very much.

@hujian233
Copy link
Author

@jeffra ,I am excited. Yeah, it worked very well in th RTX3090 with cuda11.0,pytorch1.8.0. or pytorch1.7.0 This question can be closed. Thanks again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants