diff --git a/.gitignore b/.gitignore index 92252a2c51f3..84340857f802 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ build/ dist/ *.so deepspeed.egg-info/ +build.txt # Website docs/_site/ @@ -23,3 +24,7 @@ docs/code-docs/build # Testing data tests/unit/saved_checkpoint/ + +# Dev/IDE data +.vscode +.theia diff --git a/.gitmodules b/.gitmodules index 1257dc13e0f4..37adb6f39e5c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "third_party/apex"] - path = third_party/apex - url = https://github.com/NVIDIA/apex.git [submodule "DeepSpeedExamples"] path = DeepSpeedExamples url = https://github.com/microsoft/DeepSpeedExamples diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000000..f974c3971738 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +global-include *.cpp *.h *.cu *.tr *.cuh *.cc *.txt diff --git a/README.md b/README.md index a3c78bb16a36..42c91f288cab 100755 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ [![Build Status](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_apis/build/status/microsoft.DeepSpeed?branchName=master)](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_build/latest?definitionId=1&branchName=master) +[![PyPI version](https://badge.fury.io/py/deepspeed.svg)](https://badge.fury.io/py/deepspeed) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest) [![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE) [![Docker Pulls](https://img.shields.io/docker/pulls/deepspeed/deepspeed)](https://hub.docker.com/r/deepspeed/deepspeed) @@ -31,29 +32,25 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale) # News -* [2020/09/10] [DeepSpeed: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/) +* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation) +* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html) +* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/) * [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html) * [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html) * [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html) * [10x bigger model training on a single GPU with ZeRO-Offload](https://www.deepspeed.ai/news/2020/09/08/ZeRO-Offload.html) * [2020/08/07] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) is now available on-demand -* [2020/07/24] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) on August 6th, 2020 - [![DeepSpeed webinar](docs/assets/images/webinar-aug2020.png)](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-Live.html) -* [2020/05/19] [ZeRO-2 & DeepSpeed: Shattering Barriers of Deep Learning Speed & Scale](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/) -* [2020/05/19] [An Order-of-Magnitude Larger and Faster Training with ZeRO-2](https://www.deepspeed.ai/news/2020/05/18/zero-stage2.html) -* [2020/05/19] [The Fastest and Most Efficient BERT Training through Optimized Transformer Kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html) -* [2020/02/13] [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/) -* [2020/02/13] [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/) # Table of Contents | Section | Description | | --------------------------------------- | ------------------------------------------- | | [Why DeepSpeed?](#why-deepspeed) | DeepSpeed overview | -| [Features](#features) | DeepSpeed features | -| [Further Reading](#further-reading) | DeepSpeed documentation, tutorials, etc. | -| [Contributing](#contributing) | Instructions for contributing to DeepSpeed | -| [Publications](#publications) | DeepSpeed publications | +| [Install](#installation) | Installation details | +| [Features](#features) | Feature list and overview | +| [Further Reading](#further-reading) | Documentation, tutorials, etc. | +| [Contributing](#contributing) | Instructions for contributing | +| [Publications](#publications) | Publications related to DeepSpeed | # Why DeepSpeed? Training advanced deep learning models is challenging. Beyond model design, @@ -65,8 +62,32 @@ a large model easily runs out of memory with pure data parallelism and it is difficult to use model parallelism. DeepSpeed addresses these challenges to accelerate model development *and* training. -# Features +# Installation + +The quickest way to get started with DeepSpeed is via pip, this will install +the latest release of DeepSpeed which is not tied to specific PyTorch or CUDA +versions. DeepSpeed includes several C++/CUDA extensions that we commonly refer +to as our 'ops'. By default, all of these extensions/ops will be built +just-in-time (JIT) using [torch's JIT C++ extension loader that relies on +ninja](https://pytorch.org/docs/stable/cpp_extension.html) to build and +dynamically link them at runtime. + +```bash +pip install deepspeed +``` + +After installation you can validate your install and see which extensions/ops +your machine is compatible with via the DeepSpeed environment report. +```bash +ds_report +``` + +If you would like to pre-install any of the DeepSpeed extensions/ops (instead +of JIT compiling) or install pre-compiled ops via PyPI please see our [advanced +installation instructions](https://www.deepspeed.ai/tutorials/advanced-install/). + +# Features Below we provide a brief feature list, see our detailed [feature overview](https://www.deepspeed.ai/features/) for descriptions and usage. diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 16d4d8c26501..f9ddfe606670 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -43,7 +43,6 @@ jobs: conda install -q --yes conda conda install -q --yes pip conda install -q --yes gxx_linux-64 - if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi echo "PATH=$PATH, LD_LIBRARY_PATH=$LD_LIBRARY_PATH" displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)' @@ -51,9 +50,8 @@ jobs: - script: | source activate $(conda_env) pip install --progress-bar=off torch==$(pytorch.version) torchvision==$(torchvision.version) - #-f https://download.pytorch.org/whl/torch_stable.html - ./install.sh --local_only - #python -I basic_install_test.py + pip install .[dev] + ds_report displayName: 'Install DeepSpeed' - script: | @@ -71,7 +69,8 @@ jobs: - script: | source activate $(conda_env) - pytest --durations=0 --forked --verbose -x tests/unit/ + if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/ displayName: 'Unit tests' # - script: | diff --git a/basic_install_test.py b/basic_install_test.py deleted file mode 100644 index 2090337db885..000000000000 --- a/basic_install_test.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import warnings -import importlib -import warnings - -GREEN = '\033[92m' -RED = '\033[91m' -YELLOW = '\033[93m' -END = '\033[0m' -SUCCESS = f"{GREEN} [SUCCESS] {END}" -WARNING = f"{YELLOW} [WARNING] {END}" -FAIL = f'{RED} [FAIL] {END}' -INFO = ' [INFO]' - -try: - import deepspeed - print(f"{SUCCESS} deepspeed successfully imported.") -except ImportError as err: - raise err - -print(f"{INFO} torch install path: {torch.__path__}") -print(f"{INFO} torch version: {torch.__version__}, torch.cuda: {torch.version.cuda}") -print(f"{INFO} deepspeed install path: {deepspeed.__path__}") -print( - f"{INFO} deepspeed info: {deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}" -) - -try: - apex_C = importlib.import_module('apex_C') - print(f"{SUCCESS} apex extensions successfully installed") -except Exception as err: - print(f'{WARNING} apex extensions are not installed') - -try: - from apex.optimizers import FP16_Optimizer - print(f"{INFO} using old-style apex") -except ImportError: - print(f"{INFO} using new-style apex") - -try: - importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda') - print(f'{SUCCESS} fused lamb successfully installed.') -except Exception as err: - print(f"{WARNING} fused lamb is NOT installed.") - -try: - importlib.import_module('deepspeed.ops.transformer.transformer_cuda') - print(f'{SUCCESS} transformer kernels successfully installed.') -except Exception as err: - print(f'{WARNING} transformer kernels are NOT installed.') - -try: - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils') - import triton - print(f'{SUCCESS} sparse attention successfully installed.') -except ImportError: - print(f'{WARNING} sparse attention is NOT installed.') - -try: - importlib.import_module('deepspeed.ops.adam.cpu_adam_op') - print(f'{SUCCESS} cpu-adam (used by ZeRO-offload) successfully installed.') -except ImportError: - print(f'{WARNING} cpu-adam (used by ZeRO-offload) is NOT installed.') diff --git a/bin/ds_report b/bin/ds_report new file mode 100644 index 000000000000..c03a95645eae --- /dev/null +++ b/bin/ds_report @@ -0,0 +1,6 @@ +#!/usr/bin/env python + +from deepspeed.env_report import main + +if __name__ == '__main__': + main() diff --git a/csrc/adam/compat.h b/csrc/adam/compat.h new file mode 100644 index 000000000000..86f84a85065c --- /dev/null +++ b/csrc/adam/compat.h @@ -0,0 +1,14 @@ +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/csrc/adam/custom_cuda_kernel.cu b/csrc/adam/custom_cuda_kernel.cu old mode 100644 new mode 100755 index 8f8d2c826771..2f282aff1aca --- a/csrc/adam/custom_cuda_kernel.cu +++ b/csrc/adam/custom_cuda_kernel.cu @@ -4,30 +4,15 @@ __global__ void param_update_kernel(const float* input, __half* output, int size) { - const float4* input_cast = reinterpret_cast(input); - float2* output_cast = reinterpret_cast(output); - int id = blockIdx.x * blockDim.x + threadIdx.x; - if (id < size) { - float4 data = input_cast[id]; - float2 cast_data; - __half* output_h = reinterpret_cast<__half*>(&cast_data); - - output_h[0] = (__half)data.x; - output_h[1] = (__half)data.y; - output_h[2] = (__half)data.z; - output_h[3] = (__half)data.w; - - output_cast[id] = cast_data; - } + if (id < size) { output[id] = (__half)input[id]; } } void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream) { - int threads = 512; + int threads = 1024; - size /= 4; dim3 grid_dim((size - 1) / threads + 1); dim3 block_dim(threads); diff --git a/csrc/adam/fused_adam_frontend.cpp b/csrc/adam/fused_adam_frontend.cpp new file mode 100644 index 000000000000..b06531c53002 --- /dev/null +++ b/csrc/adam/fused_adam_frontend.cpp @@ -0,0 +1,20 @@ +#include + +void multi_tensor_adam_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("multi_tensor_adam", + &multi_tensor_adam_cuda, + "Compute and apply gradient update to parameters for Adam optimizer"); +} diff --git a/csrc/adam/multi_tensor_adam.cu b/csrc/adam/multi_tensor_adam.cu new file mode 100644 index 000000000000..3cb9763befce --- /dev/null +++ b/csrc/adam/multi_tensor_adam.cu @@ -0,0 +1,163 @@ +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +typedef enum { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) +} adamMode_t; + +using MATH_T = float; + +template +struct AdamFunctor { + __device__ __forceinline__ void operator()(int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<4>& tl, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float epsilon, + const float lr, + adamMode_t mode, + const float decay) + { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* g = (T*)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T* p = (T*)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T* m = (T*)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T* v = (T*)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } +}; + +void multi_tensor_adam_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay) +{ + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/csrc/adam/multi_tensor_apply.cuh b/csrc/adam/multi_tensor_apply.cuh new file mode 100644 index 000000000000..13af4b7578f6 --- /dev/null +++ b/csrc/adam/multi_tensor_apply.cuh @@ -0,0 +1,127 @@ +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include +#include +#include +#include "compat.h" + +#include + +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template +struct TensorListMetadata { + void* addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int start_tensor_this_launch; +}; + +template +__global__ void multi_tensor_apply_kernel(int chunk_size, + volatile int* noop_flag, + T tl, + U callable, + ArgTypes... args) +{ + // Hand the chunk information to the user-supplied functor to process however it likes. + callable(chunk_size, noop_flag, tl, args...); +} + +template +void multi_tensor_apply(int block_size, + int chunk_size, + const at::Tensor& noop_flag, + const std::vector>& tensor_lists, + T callable, + ArgTypes... args) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = (contiguous_memory || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/csrc/utils/flatten_unflatten.cpp b/csrc/utils/flatten_unflatten.cpp new file mode 100644 index 000000000000..7d16c5c00a5f --- /dev/null +++ b/csrc/utils/flatten_unflatten.cpp @@ -0,0 +1,25 @@ +/* + Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h + +at::Tensor flatten(std::vector tensors) +{ + return torch::utils::flatten_dense_tensors(tensors); +} + +std::vector unflatten(at::Tensor flat, std::vector tensors) +{ + return torch::utils::unflatten_dense_tensors(flat, tensors); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("flatten", &flatten, "Flatten dense tensors"); + m.def("unflatten", &unflatten, "Unflatten dense tensors"); +} diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 61a6f730fb17..35cc1a4a26bd 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -17,24 +17,19 @@ from .pipe import PipelineModule -try: - from .git_version_info import version, git_hash, git_branch -except ImportError: - version = "0.0.0+unknown" - git_hash = None - git_branch = None +from .git_version_info import version, git_hash, git_branch + + +def _parse_version(version_str): + '''Parse a version string and extract the major, minor, and patch versions.''' + import re + matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str) + return int(matched.group(1)), int(matched.group(2)), int(matched.group(3)) + # Export version information -version, __version_tag__ = version.split('+') -__version_major__ = int(version.split('.')[0]) -__version_minor__ = int(version.split('.')[1]) -__version_patch__ = int(version.split('.')[2]) -__version__ = '.'.join( - map(str, - [__version_major__, - __version_minor__, - __version_patch__])) -__version__ = f"{__version__}+{__version_tag__}" +__version__ = version +__version_major__, __version_minor__, __version_patch__ = _parse_version(__version__) __git_hash__ = git_hash __git_branch__ = git_branch diff --git a/deepspeed/env_report.py b/deepspeed/env_report.py new file mode 100644 index 000000000000..36ad4789e69d --- /dev/null +++ b/deepspeed/env_report.py @@ -0,0 +1,107 @@ +import torch +import deepspeed +import subprocess +from .ops.op_builder import ALL_OPS +from .git_version_info import installed_ops, torch_info +from .ops import __compatible_ops__ as compatible_ops + +GREEN = '\033[92m' +RED = '\033[91m' +YELLOW = '\033[93m' +END = '\033[0m' +SUCCESS = f"{GREEN} [SUCCESS] {END}" +OKAY = f"{GREEN}[OKAY]{END}" +WARNING = f"{YELLOW}[WARNING]{END}" +FAIL = f'{RED}[FAIL]{END}' +INFO = '[INFO]' + +color_len = len(GREEN) + len(END) +okay = f"{GREEN}[OKAY]{END}" +warning = f"{YELLOW}[WARNING]{END}" + + +def op_report(): + max_dots = 23 + max_dots2 = 11 + h = ["op name", "installed", "compatible"] + print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) + print("DeepSpeed C++/CUDA extension op report") + print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) + + print("NOTE: Ops not installed will be just-in-time (JIT) compiled at\n" + " runtime if needed. Op compatibility means that your system\n" + " meet the required dependencies to JIT install the op.") + + print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) + print("JIT compiled ops requires ninja") + ninja_status = OKAY if ninja_installed() else FAIL + print('ninja', "." * (max_dots - 5), ninja_status) + print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) + print(h[0], "." * (max_dots - len(h[0])), h[1], "." * (max_dots2 - len(h[1])), h[2]) + print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) + installed = f"{GREEN}[YES]{END}" + no = f"{YELLOW}[NO]{END}" + for op_name, builder in ALL_OPS.items(): + dots = "." * (max_dots - len(op_name)) + is_compatible = OKAY if builder.is_compatible() else no + is_installed = installed if installed_ops[op_name] else no + dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - + (len(is_installed) - color_len)) + print(op_name, dots, is_installed, dots2, is_compatible) + print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) + + +def ninja_installed(): + try: + import ninja + except ImportError: + return False + return True + + +def nvcc_version(): + import torch.utils.cpp_extension + cuda_home = torch.utils.cpp_extension.CUDA_HOME + try: + output = subprocess.check_output([cuda_home + "/bin/nvcc", + "-V"], + universal_newlines=True) + except FileNotFoundError: + return f"{RED} [FAIL] nvcc missing {END}" + output_split = output.split() + release_idx = output_split.index("release") + release = output_split[release_idx + 1].replace(',', '').split(".") + return ".".join(release) + + +def debug_report(): + max_dots = 33 + report = [ + ("torch install path", + torch.__path__), + ("torch version", + torch.__version__), + ("torch cuda version", + torch.version.cuda), + ("nvcc version", + nvcc_version()), + ("deepspeed install path", + deepspeed.__path__), + ("deepspeed info", + f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}" + ), + ("deepspeed wheel compiled w.", + f"torch {torch_info['version']}, cuda {torch_info['cuda_version']}"), + ] + print("DeepSpeed general environment info:") + for name, value in report: + print(name, "." * (max_dots - len(name)), value) + + +def main(): + op_report() + debug_report() + + +if __name__ == "__main__": + main() diff --git a/deepspeed/git_version_info.py b/deepspeed/git_version_info.py index 82f60a86a6f6..d17948ae41a7 100644 --- a/deepspeed/git_version_info.py +++ b/deepspeed/git_version_info.py @@ -3,12 +3,11 @@ from .git_version_info_installed import * except ModuleNotFoundError: # Will be missing from checkouts that haven't been installed (e.g., readthedocs) - version = '0.3.0+[none]' + version = open('version.txt', 'r').read().strip() git_hash = '[none]' git_branch = '[none]' - installed_ops = { - 'lamb': False, - 'transformer': False, - 'sparse-attn': False, - 'cpu-adam': False - } + + from .ops.op_builder import ALL_OPS + installed_ops = dict.fromkeys(ALL_OPS.keys(), False) + compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) + torch_info = {'version': "0.0", "cuda_version": "0.0"} diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py index 6c4187415aae..8aec76267ed3 100644 --- a/deepspeed/ops/__init__.py +++ b/deepspeed/ops/__init__.py @@ -1,7 +1,6 @@ -from ..git_version_info import installed_ops as __installed_ops__ +from . import adam from . import lamb +from . import sparse_attention from . import transformer -if __installed_ops__['sparse-attn']: - from . import sparse_attention -if __installed_ops__['cpu-adam']: - from . import adam + +from ..git_version_info import compatible_ops as __compatible_ops__ diff --git a/deepspeed/ops/adam/__init__.py b/deepspeed/ops/adam/__init__.py index 1d8844409374..6e620b36bd8e 100755 --- a/deepspeed/ops/adam/__init__.py +++ b/deepspeed/ops/adam/__init__.py @@ -1 +1,2 @@ from .cpu_adam import DeepSpeedCPUAdam +from .fused_adam import FusedAdam diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index ec2dc73fe8b7..1d1ff8a1ac5d 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -4,9 +4,9 @@ import math import torch -import importlib - -ds_opt_adam = None +import time +from pathlib import Path +from ..op_builder import CPUAdamBuilder class DeepSpeedCPUAdam(torch.optim.Optimizer): @@ -67,15 +67,15 @@ def __init__(self, self.opt_id = DeepSpeedCPUAdam.optimizer_id DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 - global ds_opt_adam - ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op') - ds_opt_adam.create_adam(self.opt_id, - lr, - betas[0], - betas[1], - eps, - weight_decay, - adamw_mode) + self.ds_opt_adam = CPUAdamBuilder().load() + + self.ds_opt_adam.create_adam(self.opt_id, + lr, + betas[0], + betas[1], + eps, + weight_decay, + adamw_mode) def __setstate__(self, state): super(DeepSpeedCPUAdam, self).__setstate__(state) @@ -101,18 +101,20 @@ def step(self, closure=None, fp16_param_groups=None): print(f'group {group_id} param {param_id} = {p.numel()}') state['step'] = 0 # gradient momentums - state['exp_avg'] = torch.zeros_like( - p.data, - memory_format=torch.preserve_format) + state['exp_avg'] = torch.zeros_like(p.data, + dtype=p.dtype, + device='cpu') + #memory_format=torch.preserve_format) # gradient variances - state['exp_avg_sq'] = torch.zeros_like( - p.data, - memory_format=torch.preserve_format) + state['exp_avg_sq'] = torch.zeros_like(p.data, + dtype=p.dtype, + device='cpu') + #memory_format=torch.preserve_format) state['step'] += 1 if fp16_param_groups is not None: - ds_opt_adam.adam_update_copy( + self.ds_opt_adam.adam_update_copy( self.opt_id, state['step'], group['lr'], @@ -122,11 +124,11 @@ def step(self, closure=None, fp16_param_groups=None): state['exp_avg_sq'], fp16_param_groups[group_id][param_id].data) else: - ds_opt_adam.adam_update(self.opt_id, - state['step'], - group['lr'], - p.data, - p.grad.data, - state['exp_avg'], - state['exp_avg_sq']) + self.ds_opt_adam.adam_update(self.opt_id, + state['step'], + group['lr'], + p.data, + p.grad.data, + state['exp_avg'], + state['exp_avg_sq']) return loss diff --git a/deepspeed/ops/adam/fused_adam.py b/deepspeed/ops/adam/fused_adam.py new file mode 100644 index 000000000000..ae7c5fac88f0 --- /dev/null +++ b/deepspeed/ops/adam/fused_adam.py @@ -0,0 +1,182 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team + +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +''' + +import torch +import importlib +from .multi_tensor_apply import MultiTensorApply +multi_tensor_applier = MultiTensorApply(2048 * 32) +from ..op_builder import FusedAdamBuilder + + +class FusedAdam(torch.optim.Optimizer): + """Implements Adam algorithm. + + Currently GPU-only. + + This version of fused Adam implements 2 fusions. + + * Fusion of the Adam update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in FusedAdam! + adam_w_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + set_grad_none (bool, optional): whether set grad to None when zero_grad() + method is called. (default: True) + + .. _Adam - A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, + 0.999), + eps=1e-8, + adam_w_mode=True, + weight_decay=0., + amsgrad=False, + set_grad_none=True): + + if amsgrad: + raise RuntimeError('FusedAdam does not support the AMSGrad variant.') + defaults = dict(lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay) + super(FusedAdam, self).__init__(params, defaults) + self.adam_w_mode = 1 if adam_w_mode else 0 + self.set_grad_none = set_grad_none + + fused_adam_cuda = FusedAdamBuilder().load() + # Skip buffer + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + self.multi_tensor_adam = fused_adam_cuda.multi_tensor_adam + + def zero_grad(self): + if self.set_grad_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(FusedAdam, self).zero_grad() + + def step(self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.' + ) + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + # create lists for multi-tensor apply + g_16, p_16, m_16, v_16 = [], [], [], [] + g_32, p_32, m_32, v_32 = [], [], [], [] + + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + 'FusedAdam does not support sparse gradients, please consider SparseAdam instead' + ) + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + if p.dtype == torch.float16: + g_16.append(p.grad.data) + p_16.append(p.data) + m_16.append(state['exp_avg']) + v_16.append(state['exp_avg_sq']) + elif p.dtype == torch.float32: + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state['exp_avg']) + v_32.append(state['exp_avg_sq']) + else: + raise RuntimeError('FusedAdam only support fp16 and fp32.') + + if (len(g_16) > 0): + multi_tensor_applier(self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, + p_16, + m_16, + v_16], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self.adam_w_mode, + bias_correction, + group['weight_decay']) + if (len(g_32) > 0): + multi_tensor_applier(self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, + p_32, + m_32, + v_32], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self.adam_w_mode, + bias_correction, + group['weight_decay']) + + return loss diff --git a/deepspeed/ops/adam/multi_tensor_apply.py b/deepspeed/ops/adam/multi_tensor_apply.py new file mode 100644 index 000000000000..173288ee76ab --- /dev/null +++ b/deepspeed/ops/adam/multi_tensor_apply.py @@ -0,0 +1,15 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team + +Copyright NVIDIA/apex +This file is adapted from NVIDIA/apex, commit a109f85 +''' +import torch + + +class MultiTensorApply(object): + def __init__(self, chunk_size): + self.chunk_size = chunk_size + + def __call__(self, op, noop_flag_buffer, tensor_lists, *args): + return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) diff --git a/deepspeed/ops/csrc b/deepspeed/ops/csrc new file mode 120000 index 000000000000..338652302505 --- /dev/null +++ b/deepspeed/ops/csrc @@ -0,0 +1 @@ +../../csrc \ No newline at end of file diff --git a/deepspeed/ops/lamb/__init__.py b/deepspeed/ops/lamb/__init__.py index 128f0d66195c..9d3448dd68ea 100644 --- a/deepspeed/ops/lamb/__init__.py +++ b/deepspeed/ops/lamb/__init__.py @@ -1 +1 @@ -from deepspeed.ops.lamb.fused_lamb import FusedLamb +from .fused_lamb import FusedLamb diff --git a/deepspeed/ops/lamb/fused_lamb.py b/deepspeed/ops/lamb/fused_lamb.py index 8117cdc7b0d0..e9210cdda9bc 100644 --- a/deepspeed/ops/lamb/fused_lamb.py +++ b/deepspeed/ops/lamb/fused_lamb.py @@ -5,8 +5,8 @@ This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer ''' import types -import importlib import torch +from ..op_builder import FusedLambBuilder class FusedLamb(torch.optim.Optimizer): @@ -48,15 +48,7 @@ def __init__(self, max_coeff=10.0, min_coeff=0.01, amsgrad=False): - global fused_lamb_cuda - try: - fused_lamb_cuda = importlib.import_module( - "deepspeed.ops.lamb.fused_lamb_cuda") - except ImportError as err: - print( - "Unable to import Lamb cuda extension, please build DeepSpeed with cuda/cpp extensions." - ) - raise err + self.fused_lamb_cuda = FusedLambBuilder().load() if amsgrad: raise RuntimeError('FusedLamb does not support the AMSGrad variant.') @@ -173,22 +165,22 @@ def step(self, out_p = torch.tensor( [], dtype=torch.float) if output_param is None else output_param - lamb_coeff = fused_lamb_cuda.lamb(p.data, - out_p, - exp_avg, - exp_avg_sq, - grad, - group['lr'], - beta1, - beta2, - max_coeff, - min_coeff, - group['eps'], - combined_scale, - state['step'], - self.eps_mode, - bias_correction, - group['weight_decay']) + lamb_coeff = self.fused_lamb_cuda.lamb(p.data, + out_p, + exp_avg, + exp_avg_sq, + grad, + group['lr'], + beta1, + beta2, + max_coeff, + min_coeff, + group['eps'], + combined_scale, + state['step'], + self.eps_mode, + bias_correction, + group['weight_decay']) self.lamb_coeffs.append(lamb_coeff) return loss diff --git a/deepspeed/ops/op_builder b/deepspeed/ops/op_builder new file mode 120000 index 000000000000..db4f9c335065 --- /dev/null +++ b/deepspeed/ops/op_builder @@ -0,0 +1 @@ +../../op_builder \ No newline at end of file diff --git a/deepspeed/ops/sparse_attention/matmul.py b/deepspeed/ops/sparse_attention/matmul.py index d60f967d38f3..db5b774c3243 100644 --- a/deepspeed/ops/sparse_attention/matmul.py +++ b/deepspeed/ops/sparse_attention/matmul.py @@ -2,13 +2,12 @@ # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py import importlib import warnings -try: - import triton -except ImportError: - warnings.warn("Unable to import triton, sparse attention will not be accessible") import torch import math -from deepspeed.ops.sparse_attention.trsrc import matmul +from .trsrc import matmul +from ..op_builder import SparseAttnBuilder + +triton = None ############## @@ -27,6 +26,9 @@ class _sparse_matmul(torch.autograd.Function): # between `seg_size` elements @staticmethod def load_balance(sizes, block): + global triton + if triton is None: + triton = importlib.import_module('triton') # segment size # heuristics taken from OpenAI blocksparse code # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95 @@ -83,11 +85,18 @@ def get_locks(size, dev): ########################## # SPARSE = DENSE x DENSE # ########################## - cpp_utils = importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils') - sdd_segment = cpp_utils.sdd_segment + cpp_utils = None + sdd_segment = None + + @staticmethod + def _load_utils(): + if _sparse_matmul.cpp_utils is None: + _sparse_matmul.cpp_utils = SparseAttnBuilder().load() + _sparse_matmul.sdd_segment = _sparse_matmul.cpp_utils.sdd_segment @staticmethod def make_sdd_lut(layout, block, dtype, device): + _sparse_matmul._load_utils() start_width = 64 // block segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width) luts, widths, packs = [], [], [] @@ -118,6 +127,10 @@ def _sdd_matmul(a, packs, bench, time): + global triton + if triton is None: + triton = importlib.import_module('triton') + if trans_c: a, b = b, a trans_a, trans_b = not trans_b, not trans_a @@ -332,6 +345,10 @@ def _dds_matmul(a, packs, bench, time): + global triton + if triton is None: + triton = importlib.import_module('triton') + # shapes / dtypes AS0 = a.size(0) AS1 = a.size(1) @@ -413,6 +430,10 @@ def _dsd_matmul(a, packs, bench, time): + global triton + if triton is None: + triton = importlib.import_module('triton') + # shapes / dtypes AS0 = spdims[0] AS1 = block * spdims[2 if trans_a else 1] diff --git a/deepspeed/ops/sparse_attention/softmax.py b/deepspeed/ops/sparse_attention/softmax.py index 41267298a0a4..cd18fbcae71f 100644 --- a/deepspeed/ops/sparse_attention/softmax.py +++ b/deepspeed/ops/sparse_attention/softmax.py @@ -2,17 +2,17 @@ # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py import warnings -try: - import triton -except ImportError: - warnings.warn("Unable to import triton, sparse attention will not be accessible") +import importlib import torch import math -from deepspeed.ops.sparse_attention.trsrc import softmax_fwd, softmax_bwd +from .trsrc import softmax_fwd, softmax_bwd fwd_kernels = dict() bwd_kernels = dict() +# Delay importing triton unless we need it +triton = None + class _sparse_softmax(torch.autograd.Function): @@ -52,6 +52,10 @@ def make_kernel(cache, apply_attn_mask, kp_mask_mode, attn_mask_mode): + global triton + if triton is None: + triton = importlib.import_module('triton') + if max_k >= 32768: raise NotImplementedError('Reductions larger than 32768 elements '\ 'are not yet implemented') @@ -112,6 +116,10 @@ def forward(ctx, maxlut, bench, time): + global triton + if triton is None: + triton = importlib.import_module('triton') + apply_scale = False if scale == 1.0 else True # handle None rpe @@ -180,6 +188,10 @@ def forward(ctx, @staticmethod def backward(ctx, dx): + global triton + if triton is None: + triton = importlib.import_module('triton') + # retrieve from context x, lut = ctx.saved_tensors # run kernel diff --git a/deepspeed/ops/transformer/__init__.py b/deepspeed/ops/transformer/__init__.py index 4a056762beac..63c5938bb9e7 100644 --- a/deepspeed/ops/transformer/__init__.py +++ b/deepspeed/ops/transformer/__init__.py @@ -1 +1 @@ -from deepspeed.ops.transformer.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig +from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index c2c7b28f64a4..a91e5ce6f08b 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -8,6 +8,8 @@ from torch import nn from torch.autograd import Function +from ..op_builder import TransformerBuilder, StochasticTransformerBuilder + # Cuda modules will be imported if needed transformer_cuda_module = None stochastic_transformer_cuda_module = None @@ -483,19 +485,12 @@ def __init__(self, layer_id, config, initial_weights=None, initial_biases=None): self.norm_w = initial_weights[7] self.norm_b = initial_biases[7] - # Import cuda modules if needed + # Load cuda modules if needed global transformer_cuda_module, stochastic_transformer_cuda_module - if transformer_cuda_module is None or stochastic_transformer_cuda_module is None: - try: - transformer_cuda_module = importlib.import_module( - "deepspeed.ops.transformer.transformer_cuda") - stochastic_transformer_cuda_module = importlib.import_module( - "deepspeed.ops.transformer.stochastic_transformer_cuda") - except ImportError as err: - print( - "Unable to import transformer cuda extension, please build DeepSpeed with cuda/cpp extensions." - ) - raise err + if transformer_cuda_module is None and not self.config.stochastic_mode: + transformer_cuda_module = TransformerBuilder().load() + if stochastic_transformer_cuda_module is None and self.config.stochastic_mode: + stochastic_transformer_cuda_module = StochasticTransformerBuilder().load() # create the layer in cuda kernels. cuda_module = stochastic_transformer_cuda_module if self.config.stochastic_mode else transformer_cuda_module diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 29bac42ab108..52cfda31b730 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -7,8 +7,6 @@ import warnings import torch.distributed as dist -import apex -from apex import amp from torch.nn.modules import Module from torch.distributed.distributed_c10d import _get_global_rank from tensorboardX import SummaryWriter @@ -36,22 +34,17 @@ from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop from .utils import ensure_directory_exists +from ..ops.op_builder import UtilsBuilder +from ..ops.adam import DeepSpeedCPUAdam +from ..ops.adam import FusedAdam MEMORY_OPT_ALLREDUCE_SIZE = 500000000 try: - from apex_C import flatten - from apex_C import unflatten + from apex import amp except ImportError: - try: - _ = warned_flatten - except NameError: - logger.warning( - "Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten." - ) - warned_flatten = True - from torch._utils import _flatten_dense_tensors as flatten - from torch._utils import _unflatten_dense_tensors as unflatten + # Fail silently so we don't spam logs unnecessarily if user isn't using amp + pass def split_half_float_double_csr(tensors): @@ -201,6 +194,11 @@ def __init__(self, if self.dump_state(): print_configuration(self, 'DeepSpeedEngine') + # Load pre-installed or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + def _mpi_check(self, args, dist_init_required): if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: from mpi4py import MPI @@ -558,6 +556,12 @@ def _configure_optimizer(self, client_optimizer, model_parameters): amp_params = self.amp_params() if self.global_rank == 0: logger.info(f"Initializing AMP with these params: {amp_params}") + try: + logger.info("Initializing Apex amp from: {}".format(amp.__path__)) + except NameError: + # If apex/amp is available it will be imported above + raise RuntimeError( + "Unable to import apex/amp, please make sure it is installed") self.module, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params) self._broadcast_model() elif self.fp16_enabled(): @@ -584,17 +588,18 @@ def _configure_basic_optimizer(self, model_parameters): # T|F T F torch.optim.Adam # T F T|F DeepSpeedCPUAdam(adam_w_mode) # F F T|F FusedAdam(adam_w_mode) - if torch_adam and adam_w_mode: - optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters) - elif torch_adam and not adam_w_mode: - optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) - elif self.zero_cpu_offload() and not torch_adam: - from deepspeed.ops.adam import DeepSpeedCPUAdam + if torch_adam: + if adam_w_mode: + optimizer = torch.optim.AdamW(model_parameters, + **optimizer_parameters) + else: + optimizer = torch.optim.Adam(model_parameters, + **optimizer_parameters) + elif self.zero_cpu_offload(): optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters, adamw_mode=adam_w_mode) - elif not self.zero_cpu_offload() and not torch_adam: - from apex.optimizers.fused_adam import FusedAdam + else: optimizer_parameters[ADAM_W_MODE_PARAM] = adam_w_mode optimizer = FusedAdam(model_parameters, **optimizer_parameters) @@ -614,8 +619,7 @@ def _configure_fp16_optimizer(self, optimizer): dynamic_loss_args = self.dynamic_loss_scale_args() clip_grad = self.gradient_clipping() if isinstance(optimizer, - apex.optimizers.FusedAdam) or self.optimizer_name( - ) == ONEBIT_ADAM_OPTIMIZER: + FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER: if self.dynamic_loss_scale(): logger.info('Creating fp16 optimizer with dynamic loss scale') timers = self.timers if self.wall_clock_breakdown() else None @@ -1072,7 +1076,7 @@ def _report_progress(self, step): ranks=[0]) def allreduce_bucket(self, bucket): - tensor = flatten(bucket) + tensor = self.flatten(bucket) tensor_to_allreduce = tensor @@ -1100,7 +1104,7 @@ def allreduce_bucket(self, bucket): def allreduce_and_copy(self, small_bucket): allreduced = self.allreduce_bucket(small_bucket) - for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) def allreduce_no_retain(self, bucket, numel_per_bucket=500000000): diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 2eda764c8e3d..544458e41c75 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -15,26 +15,15 @@ from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS +from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.utils import logger +from ...ops.op_builder import UtilsBuilder + #Toggle this to true to enable correctness test #with gradient partitioning and without pg_correctness_test = False -try: - from apex_C import flatten - from apex_C import unflatten -except ImportError: - try: - _ = warned_flatten - except NameError: - logger.warning( - "apex was installed without --cpp_ext. Falling back to Python flatten and unflatten." - ) - warned_flatten = True - from torch._utils import _flatten_dense_tensors as flatten - from torch._utils import _unflatten_dense_tensors as unflatten - def input(msg): return @@ -132,6 +121,11 @@ def __init__(self, gradient_predivide_factor=1.0, gradient_accumulation_steps=1): + # Load pre-installed or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") @@ -1053,7 +1047,7 @@ def set_none_gradients_to_zero(self, i, partition_id): def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): rank = None - tensor = flatten(bucket) + tensor = self.flatten(bucket) tensor_to_allreduce = tensor @@ -1095,7 +1089,7 @@ def allreduce_and_copy(self, small_bucket, rank=None, log=None): with torch.cuda.stream(stream): allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) if rank is None or rank == dist.get_rank(group=self.dp_process_group): - for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) def allreduce_no_retain(self, diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index e51ad7c1eca1..bdada1b4989a 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -1,8 +1,8 @@ import torch import torch.distributed as dist -import apex from deepspeed.utils import logger from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.adam import FusedAdam def _initialize_parameter_parallel_groups(parameter_parallel_size=None): @@ -23,11 +23,14 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None): return my_group -ZERO_SUPPORTED_OPTIMIZERS = [ - torch.optim.Adam, - apex.optimizers.FusedAdam, - DeepSpeedCPUAdam -] +ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, FusedAdam, DeepSpeedCPUAdam] + +# Add apex FusedAdam to supported list if apex is installed +try: + import apex + ZERO_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam) +except ImportError: + pass def is_zero_supported_optimizer(optimizer): diff --git a/docs/_config.yml b/docs/_config.yml index 6983f5567e3d..4d64e8caf52f 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -30,6 +30,7 @@ collections: output: true permalink: /:collection/:path/ order: + - advanced-install.md - getting-started.md - azure.md - cifar-10.md diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md new file mode 100644 index 000000000000..4d8061e4c987 --- /dev/null +++ b/docs/_tutorials/advanced-install.md @@ -0,0 +1,86 @@ +--- +title: "Installation Details" +date: 2020-10-28 +--- + +The quickest way to get started with DeepSpeed is via pip, this will install +the latest release of DeepSpeed which is not tied to specific PyTorch or CUDA +versions. DeepSpeed includes several C++/CUDA extensions that we commonly refer +to as our 'ops'. By default, all of these extensions/ops will be built +just-in-time (JIT) using [torch's JIT C++ extension loader that relies on +ninja](https://pytorch.org/docs/stable/cpp_extension.html) to build and +dynamically link them at runtime. + +```bash +pip install deepspeed +``` + +After installation you can validate your install and see which ops your machine +is compatible with via the DeepSpeed environment report with `ds_report` or +`python -m deepspeed.env_report`. We've found this report useful when debugging +DeepSpeed install or compatibility issues. + +```bash +ds_report +``` + +## Install DeepSpeed from source + +After cloning the DeepSpeed repo from github you can install DeepSpeed in +JIT mode via pip (see below). This install should complete +quickly since it is not compiling any C++/CUDA source files. + +```bash +pip install . +``` + +For installs spanning multiple nodes we find it useful to install DeepSpeed +using the +[install.sh](https://github.com/microsoft/DeepSpeed/blob/master/install.sh) +script in the repo. This will build a python wheel locally and copy it to all +the nodes listed in your hostfile (either given via --hostfile, or defaults to +/job/hostfile). + +## Pre-install DeepSpeed Ops + +Sometimes we have found it useful to pre-install either some or all DeepSpeed +C++/CUDA ops instead of using the JIT compiled path. In order to support +pre-installation we introduce build environment flags to turn on/off building +specific ops. + +You can indicate to our installer (either install.sh or pip install) that you +want to attempt to install all of our ops by setting the `DS_BUILD_OPS` +environment variable to 1, for example: + +```bash +DS_BUILD_OPS=1 pip install . +``` + +We will only install any ops that are compatible with your machine, for more +details on which ops are compatible with your system please try our `ds_report` +tool described above. + +If you want to install only a specific op (e.g., FusedLamb) you can view the op +specific build environment variable (set as `BUILD_VAR`) in the corresponding +op builder class in the +[op\_builder](https://github.com/microsoft/DeepSpeed/tree/master/op_builder) +directory. For example to install only the Fused Lamb op you would install via: + +```bash +DS_BUILD_FUSED_LAMB=1 pip install . +``` + +## Feature specific dependencies + +Some DeepSpeed features require specific dependencies outside of the general +dependencies of DeepSpeed. + +* Python package dependencies per feature/op please +see our [requirements +directory](https://github.com/microsoft/DeepSpeed/tree/master/requirements). + +* We attempt to keep the system level dependencies to a minimum, however some features do require special system-level packages. Please see our `ds_report` tool output to see if you are missing any system-level packages for a given feature. + +## Pre-compiled DeepSpeed builds from PyPI + +Coming soon diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index c62eef569a1d..c8ff47331f78 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -7,9 +7,9 @@ date: 2020-05-15 ## Installation +* Installing is as simple as `pip install deepspeed`, [see more details](/tutorials/advanced-install/). * Please see our [Azure tutorial](/tutorials/azure/) to get started with DeepSpeed on Azure! * If you're not on Azure, we recommend using our docker image via `docker pull deepspeed/deepspeed:latest` which contains a pre-installed version of DeepSpeed and all the necessary dependencies. -* If you want to install DeepSpeed manually, we provide an install script `install.sh` to help install on a local machine or across an entire cluster. ## Writing DeepSpeed Models DeepSpeed model training is accomplished using the DeepSpeed engine. The engine diff --git a/docs/index.md b/docs/index.md index 13d1ff89873b..be265dc70a0b 100755 --- a/docs/index.md +++ b/docs/index.md @@ -28,8 +28,9 @@ initiative to enable next-generation AI capabilities at scale, where you can fin information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale). # What's New? -* [2020/10/28] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html) - * [DeepSpeed: Extreme-scale model training for everyone]({{ site.press_release_v3 }}) +* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation) +* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html) +* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone]({{ site.press_release_v3 }}) * [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html) * [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html) * [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html) diff --git a/install.sh b/install.sh index bc390a3c45ba..b027d319cdd6 100755 --- a/install.sh +++ b/install.sh @@ -15,16 +15,13 @@ By default will install deepspeed and all third party dependecies accross all ma hostfile (hostfile: /job/hostfile). If no hostfile exists, will only install locally [optional] - -d, --deepspeed_only Install only deepspeed and no third party dependencies - -t, --third_party_only Install only third party dependencies and not deepspeed -l, --local_only Install only on local machine -s, --pip_sudo Run pip install with sudo (default: no sudo) -r, --allow_sudo Allow script to be run by root (probably don't want this, instead use --pip_sudo) -n, --no_clean Do not clean prior build state, by default prior build files are removed before building wheels -m, --pip_mirror Use the specified pip mirror (default: the default pip mirror) -H, --hostfile Path to MPI-style hostfile (default: /job/hostfile) - -a, --apex_commit Install a specific commit hash of apex, instead of the one deepspeed points to - -k, --skip_requirements Skip installing DeepSpeed requirements + -v, --verbose Verbose logging -h, --help This help text """ } @@ -42,27 +39,12 @@ apex_commit="" skip_requirements=0 allow_sudo=0 no_clean=0 +verbose=0 while [[ $# -gt 0 ]] do key="$1" case $key in - -d|--deepspeed_only) - deepspeed_install=1; - third_party_install=0; - ds_only=1; - shift - ;; - -t|--third_party_only) - deepspeed_install=0; - third_party_install=1; - tp_only=1; - shift - ;; - -l|--local_only) - local_only=1; - shift - ;; -s|--pip_sudo) pip_sudo=1; shift @@ -72,13 +54,8 @@ case $key in shift shift ;; - -a|--apex_commit) - apex_commit=$2; - shift - shift - ;; - -k|--skip_requirements) - skip_requirements=1; + -v|--verbose) + verbose=1; shift ;; -r|--allow_sudo) @@ -126,12 +103,18 @@ if [ "$ds_only" == "1" ] && [ "$tp_only" == "1" ]; then exit 1 fi +if [ "$verbose" == "1" ]; then + VERBOSE="-v" +else + VERBOSE="" +fi + rm_if_exist() { echo "Attempting to remove $1" if [ -f $1 ]; then - rm -v $1 + rm $VERBOSE $1 elif [ -d $1 ]; then - rm -vr $1 + rm -r $VERBOSE $1 fi } @@ -141,10 +124,6 @@ if [ "$no_clean" == "0" ]; then rm_if_exist dist rm_if_exist build rm_if_exist deepspeed.egg-info - # remove apex build files - rm_if_exist third_party/apex/dist - rm_if_exist third_party/apex/build - rm_if_exist third_party/apex/apex.egg-info fi if [ "$pip_sudo" == "1" ]; then @@ -154,60 +133,25 @@ else fi if [ "$pip_mirror" != "" ]; then - PIP_INSTALL="pip install -v -i $pip_mirror" + PIP_INSTALL="pip install $VERBOSE -i $pip_mirror" else - PIP_INSTALL="pip install -v" + PIP_INSTALL="pip install $VERBOSE" fi + if [ ! -f $hostfile ]; then echo "No hostfile exists at $hostfile, installing locally" local_only=1 fi -if [ "$skip_requirements" == "0" ]; then - # Ensure dependencies are installed locally - $PIP_SUDO $PIP_INSTALL -r requirements/requirements.txt -fi - -# Build wheels -if [ "$third_party_install" == "1" ]; then - echo "Checking out sub-module(s)" - git submodule update --init --recursive - - echo "Building apex wheel" - cd third_party/apex - - if [ "$apex_commit" != "" ]; then - echo "Installing a non-standard version of apex at commit: $apex_commit" - git fetch - git checkout $apex_commit - fi - - python setup.py -v --cpp_ext --cuda_ext bdist_wheel - cd - - - echo "Installing apex locally so that deepspeed will build" - $PIP_SUDO pip uninstall -y apex - $PIP_SUDO $PIP_INSTALL third_party/apex/dist/apex*.whl -fi -if [ "$deepspeed_install" == "1" ]; then - echo "Building deepspeed wheel" - python setup.py -v bdist_wheel -fi +echo "Building deepspeed wheel" +python setup.py $VERBOSE bdist_wheel if [ "$local_only" == "1" ]; then - if [ "$deepspeed_install" == "1" ]; then - echo "Installing deepspeed" - $PIP_SUDO pip uninstall -y deepspeed - $PIP_SUDO $PIP_INSTALL dist/deepspeed*.whl - # -I to exclude local directory files - python -I basic_install_test.py - if [ $? == 0 ]; then - echo "Installation is successful" - else - echo "Installation failed" - fi - fi + echo "Installing deepspeed" + $PIP_SUDO pip uninstall -y deepspeed + $PIP_SUDO $PIP_INSTALL dist/deepspeed*.whl + ds_report else local_path=`pwd` if [ -f $hostfile ]; then @@ -216,28 +160,16 @@ else echo "hostfile not found, cannot proceed" exit 1 fi - export PDSH_RCMD_TYPE=ssh; + export PDSH_RCMD_TYPE=ssh tmp_wheel_path="/tmp/deepspeed_wheels" pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi" pdcp -w $hosts requirements/requirements.txt ${tmp_wheel_path}/ - if [ "$skip_requirements" == "0" ]; then - pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt" - fi - if [ "$third_party_install" == "1" ]; then - pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex" - pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/ - pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/apex*.whl" - pdsh -w $hosts 'python -c "import apex"' - fi - if [ "$deepspeed_install" == "1" ]; then - echo "Installing deepspeed" - pdsh -w $hosts "$PIP_SUDO pip uninstall -y deepspeed" - pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/ - pdcp -w $hosts basic_install_test.py $tmp_wheel_path/ - pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl" - pdsh -w $hosts "python $tmp_wheel_path/basic_install_test.py" - echo "Installation is successful" - fi - pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl $tmp_wheel_path/basic_install_test.py; rmdir $tmp_wheel_path; fi" + + echo "Installing deepspeed" + pdsh -w $hosts "$PIP_SUDO pip uninstall -y deepspeed" + pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/ + pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl" + pdsh -w $hosts "ds_report" + pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rmdir $tmp_wheel_path; fi" fi diff --git a/op_builder/__init__.py b/op_builder/__init__.py new file mode 100644 index 000000000000..40ebebc5685f --- /dev/null +++ b/op_builder/__init__.py @@ -0,0 +1,20 @@ +from .cpu_adam import CPUAdamBuilder +from .fused_adam import FusedAdamBuilder +from .fused_lamb import FusedLambBuilder +from .sparse_attn import SparseAttnBuilder +from .transformer import TransformerBuilder +from .stochastic_transformer import StochasticTransformerBuilder +from .utils import UtilsBuilder + +# TODO: infer this list instead of hard coded +# List of all available ops +__op_builders__ = [ + CPUAdamBuilder(), + FusedAdamBuilder(), + FusedLambBuilder(), + SparseAttnBuilder(), + TransformerBuilder(), + StochasticTransformerBuilder(), + UtilsBuilder() +] +ALL_OPS = {op.name: op for op in __op_builders__} diff --git a/op_builder/builder.py b/op_builder/builder.py new file mode 100644 index 000000000000..c1116fad007a --- /dev/null +++ b/op_builder/builder.py @@ -0,0 +1,245 @@ +import os +import time +import torch +import importlib +from pathlib import Path +import subprocess +from abc import ABC, abstractmethod + +YELLOW = '\033[93m' +END = '\033[0m' +WARNING = f"{YELLOW} [WARNING] {END}" + +DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" + + +def assert_no_cuda_mismatch(): + import torch.utils.cpp_extension + cuda_home = torch.utils.cpp_extension.CUDA_HOME + assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" + # Ensure there is not a cuda version mismatch between torch and nvcc compiler + output = subprocess.check_output([cuda_home + "/bin/nvcc", + "-V"], + universal_newlines=True) + output_split = output.split() + release_idx = output_split.index("release") + release = output_split[release_idx + 1].replace(',', '').split(".") + # Ignore patch versions, only look at major + minor + installed_cuda_version = ".".join(release[:2]) + torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + # This is a show-stopping error, should probably not proceed past this + if installed_cuda_version != torch_cuda_version: + raise Exception( + f"Installed CUDA version {installed_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}, unable to compile " + "cuda/cpp extensions without a matching cuda version.") + + +def assert_torch_info(torch_info): + install_torch_version = torch_info['version'] + install_cuda_version = torch_info['cuda_version'] + + current_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + current_torch_version = ".".join(torch.__version__.split('.')[:2]) + + if install_cuda_version != current_cuda_version or install_torch_version != current_torch_version: + raise RuntimeError( + "PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. Please re-install " + f"DeepSpeed or switch torch versions. DeepSpeed install versions: " + f"torch={install_torch_version}, cuda={install_cuda_version}, runtime versions:" + f"torch={current_torch_version}, cuda={current_cuda_version}") + + +class OpBuilder(ABC): + def __init__(self, name): + self.name = name + self.jit_mode = False + + @abstractmethod + def absolute_name(self): + ''' + Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam + will be installed as something like: deepspeed/ops/adam/cpu_adam.so + ''' + pass + + @abstractmethod + def sources(self): + ''' + Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) + ''' + pass + + def include_paths(self): + ''' + Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) + ''' + return [] + + def nvcc_args(self): + ''' + Returns optional list of compiler flags to forward to nvcc when building CUDA sources + ''' + return [] + + def cxx_args(self): + ''' + Returns optional list of compiler flags to forward to the build + ''' + return [] + + def is_compatible(self): + ''' + Check if all non-python dependencies are satisfied to build this op + ''' + return True + + def python_requirements(self): + ''' + Override if op wants to define special dependencies, otherwise will + take self.name and load requirements-.txt if it exists. + ''' + path = f'requirements/requirements-{self.name}.txt' + requirements = [] + if os.path.isfile(path): + with open(path, 'r') as fd: + requirements = [r.strip() for r in fd.readlines()] + return requirements + + def command_exists(self, cmd): + if '|' in cmd: + cmds = cmd.split("|") + else: + cmds = [cmd] + valid = False + for cmd in cmds: + result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + valid = valid or result.wait() == 0 + + if not valid and len(cmds) > 1: + print( + f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!" + ) + elif not valid and len(cmds) == 1: + print( + f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!" + ) + return valid + + def warning(self, msg): + print(f"{WARNING} {msg}") + + def deepspeed_src_path(self, code_path): + if os.path.isabs(code_path): + return code_path + else: + return os.path.join(Path(__file__).parent.parent.absolute(), code_path) + + def builder(self): + from torch.utils.cpp_extension import CppExtension + return CppExtension(name=self.absolute_name(), + sources=self.sources(), + include_dirs=self.include_paths(), + extra_compile_args={'cxx': self.cxx_args()}) + + def load(self, verbose=True): + from ...git_version_info import installed_ops, torch_info + if installed_ops[self.name]: + # Ensure the op we're about to load was compiled with the same + # torch/cuda versions we are currently using at runtime. + if isinstance(self, CUDAOpBuilder): + assert_torch_info(torch_info) + + return importlib.import_module(self.absolute_name()) + else: + return self.jit_load(verbose) + + def jit_load(self, verbose=True): + if not self.is_compatible(): + raise RuntimeError( + f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue." + ) + try: + import ninja + except ImportError: + raise RuntimeError( + f"Unable to JIT load the {self.name} op due to ninja not being installed." + ) + + if isinstance(self, CUDAOpBuilder): + assert_no_cuda_mismatch() + + self.jit_mode = True + from torch.utils.cpp_extension import load + + # Ensure directory exists to prevent race condition in some cases + ext_path = os.path.join( + os.environ.get('TORCH_EXTENSIONS_DIR', + DEFAULT_TORCH_EXTENSION_PATH), + self.name) + os.makedirs(ext_path, exist_ok=True) + + start_build = time.time() + op_module = load( + name=self.name, + sources=[self.deepspeed_src_path(path) for path in self.sources()], + extra_include_paths=[ + self.deepspeed_src_path(path) for path in self.include_paths() + ], + extra_cflags=self.cxx_args(), + extra_cuda_cflags=self.nvcc_args(), + verbose=verbose) + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") + return op_module + + +class CUDAOpBuilder(OpBuilder): + def compute_capability_args(self, cross_compile_archs=['60', '61', '70']): + args = [] + if self.jit_mode: + # Compile for underlying architecture since we know it at runtime + CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability() + compute_capability = f"{CC_MAJOR}{CC_MINOR}" + args.append('-gencode') + args.append( + f'arch=compute_{compute_capability},code=compute_{compute_capability}') + else: + # Cross-compile mode, compile for various architectures + for compute_capability in cross_compile_archs: + args.append('-gencode') + args.append( + f'arch=compute_{compute_capability},code=compute_{compute_capability}' + ) + return args + + def version_dependent_macros(self): + # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + version_ge_1_1 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): + version_ge_1_1 = ['-DVERSION_GE_1_1'] + version_ge_1_3 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + version_ge_1_3 = ['-DVERSION_GE_1_3'] + version_ge_1_5 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): + version_ge_1_5 = ['-DVERSION_GE_1_5'] + return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 + + def is_compatible(self): + return super().is_compatible() + + def builder(self): + from torch.utils.cpp_extension import CUDAExtension + assert_no_cuda_mismatch() + return CUDAExtension(name=self.absolute_name(), + sources=self.sources(), + include_dirs=self.include_paths(), + extra_compile_args={ + 'cxx': self.cxx_args(), + 'nvcc': self.nvcc_args() + }) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py new file mode 100644 index 000000000000..d391301897e3 --- /dev/null +++ b/op_builder/cpu_adam.py @@ -0,0 +1,77 @@ +import os +import torch +import warnings +from .builder import CUDAOpBuilder + + +class CPUAdamBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/custom_cuda_kernel.cu'] + + def include_paths(self): + CUDA_INCLUDE = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include") + return ['csrc/includes', CUDA_INCLUDE] + + def available_vector_instructions(self): + try: + import cpufeature + except ImportError: + warnings.warn( + f'import cpufeature failed - CPU vector optimizations are not available for CPUAdam' + ) + return {} + + cpu_vector_instructions = {} + try: + cpu_vector_instructions = cpufeature.CPUFeature + except _: + warnings.warn( + f'cpufeature.CPUFeature failed - CPU vector optimizations are not available for CPUAdam' + ) + return {} + + return cpu_vector_instructions + + def cxx_args(self): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") + cpu_info = self.available_vector_instructions() + SIMD_WIDTH = '' + if 'Intel' in cpu_info.get('VendorId', ''): + if cpu_info.get('AVX512f', False): + SIMD_WIDTH = '-D__AVX512__' + elif cpu_info.get('AVX2', False): + SIMD_WIDTH = '-D__AVX256__' + + return [ + '-O3', + '-std=c++14', + f'-L{CUDA_LIB64}', + '-lcudart', + '-lcublas', + '-g', + '-Wno-reorder', + '-march=native', + '-fopenmp', + SIMD_WIDTH + ] + + def nvcc_args(self): + args = [ + '-O3', + '--use_fast_math', + '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__' + ] + args += self.compute_capability_args() + return args diff --git a/op_builder/fused_adam.py b/op_builder/fused_adam.py new file mode 100644 index 000000000000..4b43ff7f1cea --- /dev/null +++ b/op_builder/fused_adam.py @@ -0,0 +1,25 @@ +import torch +from .builder import CUDAOpBuilder + + +class FusedAdamBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/fused_adam_frontend.cpp', 'csrc/adam/multi_tensor_adam.cu'] + + def include_paths(self): + return ['csrc/includes'] + + def cxx_args(self): + return ['-O3'] + self.version_dependent_macros() + + def nvcc_args(self): + return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros() diff --git a/op_builder/fused_lamb.py b/op_builder/fused_lamb.py new file mode 100644 index 000000000000..272a9772abc3 --- /dev/null +++ b/op_builder/fused_lamb.py @@ -0,0 +1,25 @@ +import torch +from .builder import CUDAOpBuilder + + +class FusedLambBuilder(CUDAOpBuilder): + BUILD_VAR = 'DS_BUILD_FUSED_LAMB' + NAME = "fused_lamb" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lamb.{self.NAME}_op' + + def sources(self): + return ['csrc/lamb/fused_lamb_cuda.cpp', 'csrc/lamb/fused_lamb_cuda_kernel.cu'] + + def include_paths(self): + return ['csrc/includes'] + + def cxx_args(self): + return ['-O3'] + self.version_dependent_macros() + + def nvcc_args(self): + return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros() diff --git a/op_builder/sparse_attn.py b/op_builder/sparse_attn.py new file mode 100644 index 000000000000..4c716f859970 --- /dev/null +++ b/op_builder/sparse_attn.py @@ -0,0 +1,36 @@ +import torch +import warnings +from .builder import OpBuilder + + +class SparseAttnBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_SPARSE_ATTN" + NAME = "sparse_attn" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.sparse_attention.{self.NAME}_op' + + def sources(self): + return ['csrc/sparse_attention/utils.cpp'] + + def cxx_args(self): + return ['-O2', '-fopenmp'] + + def is_compatible(self): + # Check to see if llvm and cmake are installed since they are dependencies + required_commands = ['llvm-config|llvm-config-9', 'cmake'] + command_status = list(map(self.command_exists, required_commands)) + deps_compatible = all(command_status) + + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5 + if not torch_compatible: + self.warning( + f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}' + ) + + return super().is_compatible() and deps_compatible and torch_compatible diff --git a/op_builder/stochastic_transformer.py b/op_builder/stochastic_transformer.py new file mode 100644 index 000000000000..50dfea7c6698 --- /dev/null +++ b/op_builder/stochastic_transformer.py @@ -0,0 +1,18 @@ +import torch +from .transformer import TransformerBuilder + + +class StochasticTransformerBuilder(TransformerBuilder): + BUILD_VAR = "DS_BUILD_STOCHASTIC_TRANSFORMER" + NAME = "stochastic_transformer" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.transformer.{self.NAME}_op' + + def nvcc_args(self): + args = super().nvcc_args() + args.append('-D__STOCHASTIC_MODE__') + return args diff --git a/op_builder/transformer.py b/op_builder/transformer.py new file mode 100644 index 000000000000..2735b078fb98 --- /dev/null +++ b/op_builder/transformer.py @@ -0,0 +1,44 @@ +import torch +from .builder import CUDAOpBuilder + + +class TransformerBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_TRANSFORMER" + NAME = "transformer" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.transformer.{self.NAME}_op' + + def sources(self): + return [ + 'csrc/transformer/ds_transformer_cuda.cpp', + 'csrc/transformer/cublas_wrappers.cu', + 'csrc/transformer/transform_kernels.cu', + 'csrc/transformer/gelu_kernels.cu', + 'csrc/transformer/dropout_kernels.cu', + 'csrc/transformer/normalize_kernels.cu', + 'csrc/transformer/softmax_kernels.cu', + 'csrc/transformer/general_kernels.cu' + ] + + def include_paths(self): + return ['csrc/includes'] + + def nvcc_args(self): + args = [ + '-O3', + '--use_fast_math', + '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__' + ] + + return args + self.compute_capability_args() + + def cxx_args(self): + return ['-O3', '-std=c++14', '-g', '-Wno-reorder'] diff --git a/op_builder/utils.py b/op_builder/utils.py new file mode 100644 index 000000000000..1631a2cf18b2 --- /dev/null +++ b/op_builder/utils.py @@ -0,0 +1,15 @@ +from .builder import OpBuilder + + +class UtilsBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_UTILS" + NAME = "utils" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def sources(self): + return ['csrc/utils/flatten_unflatten.cpp'] diff --git a/requirements/requirements-sparse-attn.txt b/requirements/requirements-sparse_attn.txt similarity index 100% rename from requirements/requirements-sparse-attn.txt rename to requirements/requirements-sparse_attn.txt diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d9881f4bc580..575a30ff5568 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,5 +2,6 @@ torch>=1.2 torchvision>=0.4.0 tqdm psutil -cpufeature tensorboardX==1.8 +ninja +cpufeature diff --git a/setup.py b/setup.py index 1364caab23f9..8560581bac73 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ from setuptools import setup, find_packages from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension -VERSION = "0.3.0" +import op_builder def fetch_requirements(path): @@ -24,88 +24,33 @@ def fetch_requirements(path): return [r.strip() for r in fd.readlines()] -def available_vector_instructions(): - try: - import cpufeature - except ImportError: - warnings.warn( - f'import cpufeature failed - CPU vector optimizations are not available for CPUAdam' - ) - return {} - - cpu_vector_instructions = {} - try: - cpu_vector_instructions = cpufeature.CPUFeature - except _: - warnings.warn( - f'cpufeature.CPUFeature failed - CPU vector optimizations are not available for CPUAdam' - ) - return {} - - return cpu_vector_instructions - - install_requires = fetch_requirements('requirements/requirements.txt') -dev_requires = fetch_requirements('requirements/requirements-dev.txt') -sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt') +extras_require = { + '1bit_adam': fetch_requirements('requirements/requirements-1bit-adam.txt'), + 'readthedocs': fetch_requirements('requirements/requirements-readthedocs.txt'), + 'dev': fetch_requirements('requirements/requirements-dev.txt'), +} # If MPI is available add 1bit-adam requirements if torch.cuda.is_available(): if shutil.which('ompi_info') or shutil.which('mpiname'): - onebit_adam_requires = fetch_requirements( - 'requirements/requirements-1bit-adam.txt') - onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}") - install_requires += onebit_adam_requires - -# Constants for each op -LAMB = "lamb" -TRANSFORMER = "transformer" -SPARSE_ATTN = "sparse-attn" -CPU_ADAM = "cpu-adam" - -cpu_vector_instructions = available_vector_instructions() - -# Build environment variables for custom builds -DS_BUILD_LAMB_MASK = 1 -DS_BUILD_TRANSFORMER_MASK = 10 -DS_BUILD_SPARSE_ATTN_MASK = 100 -DS_BUILD_CPU_ADAM_MASK = 1000 - -# Allow for build_cuda to turn on or off all ops -DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK -DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS - -# Set default of each op based on if build_cuda is set -OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS -DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM', 0)) * DS_BUILD_CPU_ADAM_MASK -DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK -DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER', - OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK -DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN', - OP_DEFAULT)) * DS_BUILD_SPARSE_ATTN_MASK - -# Final effective mask is the bitwise OR of each op -BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN - | DS_BUILD_CPU_ADAM) - -install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, CPU_ADAM], False) -if BUILD_MASK & DS_BUILD_LAMB: - install_ops[LAMB] = True -if BUILD_MASK & DS_BUILD_CPU_ADAM: - install_ops[CPU_ADAM] = True -if BUILD_MASK & DS_BUILD_TRANSFORMER: - install_ops[TRANSFORMER] = True -if BUILD_MASK & DS_BUILD_SPARSE_ATTN: - install_ops[SPARSE_ATTN] = True -if len(install_ops) == 0: - print("Building without any cuda/cpp extensions") -print(f'BUILD_MASK={BUILD_MASK}, install_ops={install_ops}') + cupy = f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}" + extras_require['1bit_adam'].append(cupy) + +# Make an [all] extra that installs all needed dependencies +all_extras = set() +for extra in extras_require.items(): + for req in extra[1]: + all_extras.add(req) +extras_require['all'] = list(all_extras) cmdclass = {} + +# For any pre-installed ops force disable ninja cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False) -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) +TORCH_MAJOR = torch.__version__.split('.')[0] +TORCH_MINOR = torch.__version__.split('.')[1] if not torch.cuda.is_available(): # Fix to allow docker buils, similar to https://github.com/NVIDIA/apex/issues/486 @@ -116,230 +61,118 @@ def available_vector_instructions(): if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" -# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 -version_ge_1_1 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): - version_ge_1_1 = ['-DVERSION_GE_1_1'] -version_ge_1_3 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - version_ge_1_3 = ['-DVERSION_GE_1_3'] -version_ge_1_5 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): - version_ge_1_5 = ['-DVERSION_GE_1_5'] -version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 - -SIMD_WIDTH = '' -if cpu_vector_instructions.get('AVX512f', False): - SIMD_WIDTH = '-D__AVX512__' -elif cpu_vector_instructions.get('AVX2', False): - SIMD_WIDTH = '-D__AVX256__' -print("SIMD_WIDTH = ", SIMD_WIDTH) - ext_modules = [] -## Lamb ## -if BUILD_MASK & DS_BUILD_LAMB: - ext_modules.append( - CUDAExtension(name='deepspeed.ops.lamb.fused_lamb_cuda', - sources=[ - 'csrc/lamb/fused_lamb_cuda.cpp', - 'csrc/lamb/fused_lamb_cuda_kernel.cu' - ], - include_dirs=['csrc/includes'], - extra_compile_args={ - 'cxx': [ - '-O3', - ] + version_dependent_macros, - 'nvcc': ['-O3', - '--use_fast_math'] + version_dependent_macros - })) - -## Adam ## -if BUILD_MASK & DS_BUILD_CPU_ADAM: - ext_modules.append( - CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op', - sources=[ - 'csrc/adam/cpu_adam.cpp', - 'csrc/adam/custom_cuda_kernel.cu', - ], - include_dirs=['csrc/includes', - '/usr/local/cuda/include'], - extra_compile_args={ - 'cxx': [ - '-O3', - '-std=c++14', - '-L/usr/local/cuda/lib64', - '-lcudart', - '-lcublas', - '-g', - '-Wno-reorder', - '-march=native', - '-fopenmp', - SIMD_WIDTH - ], - 'nvcc': [ - '-O3', - '--use_fast_math', - '-gencode', - 'arch=compute_61,code=compute_61', - '-gencode', - 'arch=compute_70,code=compute_70', - '-std=c++14', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__' - ] - })) - -## Transformer ## -if BUILD_MASK & DS_BUILD_TRANSFORMER: - ext_modules.append( - CUDAExtension(name='deepspeed.ops.transformer.transformer_cuda', - sources=[ - 'csrc/transformer/ds_transformer_cuda.cpp', - 'csrc/transformer/cublas_wrappers.cu', - 'csrc/transformer/transform_kernels.cu', - 'csrc/transformer/gelu_kernels.cu', - 'csrc/transformer/dropout_kernels.cu', - 'csrc/transformer/normalize_kernels.cu', - 'csrc/transformer/softmax_kernels.cu', - 'csrc/transformer/general_kernels.cu' - ], - include_dirs=['csrc/includes'], - extra_compile_args={ - 'cxx': ['-O3', - '-std=c++14', - '-g', - '-Wno-reorder'], - 'nvcc': [ - '-O3', - '--use_fast_math', - '-gencode', - 'arch=compute_61,code=compute_61', - '-gencode', - 'arch=compute_60,code=compute_60', - '-gencode', - 'arch=compute_70,code=compute_70', - '-std=c++14', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__' - ] - })) - ext_modules.append( - CUDAExtension(name='deepspeed.ops.transformer.stochastic_transformer_cuda', - sources=[ - 'csrc/transformer/ds_transformer_cuda.cpp', - 'csrc/transformer/cublas_wrappers.cu', - 'csrc/transformer/transform_kernels.cu', - 'csrc/transformer/gelu_kernels.cu', - 'csrc/transformer/dropout_kernels.cu', - 'csrc/transformer/normalize_kernels.cu', - 'csrc/transformer/softmax_kernels.cu', - 'csrc/transformer/general_kernels.cu' - ], - include_dirs=['csrc/includes'], - extra_compile_args={ - 'cxx': ['-O3', - '-std=c++14', - '-g', - '-Wno-reorder'], - 'nvcc': [ - '-O3', - '--use_fast_math', - '-gencode', - 'arch=compute_61,code=compute_61', - '-gencode', - 'arch=compute_60,code=compute_60', - '-gencode', - 'arch=compute_70,code=compute_70', - '-std=c++14', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', - '-D__STOCHASTIC_MODE__' - ] - })) +from op_builder import ALL_OPS + +# Default to pre-install kernels to false so we rely on JIT +BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', 0)) +print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}") def command_exists(cmd): - if '|' in cmd: - cmds = cmd.split("|") - else: - cmds = [cmd] - valid = False - for cmd in cmds: - result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) - valid = valid or result.wait() == 0 - return valid - - -## Sparse transformer ## -if BUILD_MASK & DS_BUILD_SPARSE_ATTN: - # Check to see if llvm and cmake are installed since they are dependencies - required_commands = ['llvm-config|llvm-config-9', 'cmake'] - - command_status = list(map(command_exists, required_commands)) - if not all(command_status): - zipped_status = list(zip(required_commands, command_status)) - warnings.warn( - f'Missing non-python requirements, please install the missing packages: {zipped_status}' - ) - warnings.warn( - 'Skipping sparse attention installation due to missing required packages') - # remove from installed ops list - install_ops[SPARSE_ATTN] = False - elif TORCH_MAJOR == 1 and TORCH_MINOR >= 5: - ext_modules.append( - CppExtension(name='deepspeed.ops.sparse_attention.cpp_utils', - sources=['csrc/sparse_attention/utils.cpp'], - extra_compile_args={'cxx': ['-O2', - '-fopenmp']})) - # Add sparse attention requirements - install_requires += sparse_attn_requires - else: - warnings.warn('Unable to meet requirements to install sparse attention') - # remove from installed ops list - install_ops[SPARSE_ATTN] = False - -# Add development requirements -install_requires += dev_requires + result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + return result.wait() == 0 + + +def op_enabled(op_name): + assert hasattr(ALL_OPS[op_name], 'BUILD_VAR'), \ + f"{op_name} is missing BUILD_VAR field" + env_var = ALL_OPS[op_name].BUILD_VAR + return int(os.environ.get(env_var, BUILD_OP_DEFAULT)) + + +install_ops = dict.fromkeys(ALL_OPS.keys(), False) +for op_name, builder in ALL_OPS.items(): + op_compatible = builder.is_compatible() + + # If op is compatible update install reqs so it can potentially build/run later + if op_compatible: + reqs = builder.python_requirements() + install_requires += builder.python_requirements() + + # If op install enabled, add builder to extensions + if op_enabled(op_name) and op_compatible: + install_ops[op_name] = op_enabled(op_name) + ext_modules.append(builder.builder()) + +compatible_ops = {op_name: op.is_compatible() for (op_name, op) in ALL_OPS.items()} + +print(f'Install Ops={install_ops}') # Write out version/git info git_hash_cmd = "git rev-parse --short HEAD" git_branch_cmd = "git rev-parse --abbrev-ref HEAD" -if command_exists('git'): - result = subprocess.check_output(git_hash_cmd, shell=True) - git_hash = result.decode('utf-8').strip() - result = subprocess.check_output(git_branch_cmd, shell=True) - git_branch = result.decode('utf-8').strip() +if command_exists('git') and 'DS_BUILD_STRING' not in os.environ: + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" else: git_hash = "unknown" git_branch = "unknown" -print(f"version={VERSION}+{git_hash}, git_hash={git_hash}, git_branch={git_branch}") + +# Parse the DeepSpeed version string from version.txt +version_str = open('version.txt', 'r').read().strip() + +# Build specifiers like .devX can be added at install time. Otherwise, add the git hash. +# example: DS_BUILD_STR=".dev20201022" python setup.py sdist bdist_wheel +#version_str += os.environ.get('DS_BUILD_STRING', f'+{git_hash}') + +# Building wheel for distribution, update version file + +if 'DS_BUILD_STRING' in os.environ: + # Build string env specified, probably building for distribution + with open('build.txt', 'w') as fd: + fd.write(os.environ.get('DS_BUILD_STRING')) + version_str += os.environ.get('DS_BUILD_STRING') +elif os.path.isfile('build.txt'): + # build.txt exists, probably installing from distribution + with open('build.txt', 'r') as fd: + version_str += fd.read().strip() +else: + # None of the above, probably installing from source + version_str += f'+{git_hash}' + +torch_version = ".".join([TORCH_MAJOR, TORCH_MINOR]) +cuda_version = ".".join(torch.version.cuda.split('.')[:2]) +torch_info = {"version": torch_version, "cuda_version": cuda_version} + +print(f"version={version_str}, git_hash={git_hash}, git_branch={git_branch}") with open('deepspeed/git_version_info_installed.py', 'w') as fd: - fd.write(f"version='{VERSION}+{git_hash}'\n") + fd.write(f"version='{version_str}'\n") fd.write(f"git_hash='{git_hash}'\n") fd.write(f"git_branch='{git_branch}'\n") fd.write(f"installed_ops={install_ops}\n") + fd.write(f"compatible_ops={compatible_ops}\n") + fd.write(f"torch_info={torch_info}\n") print(f'install_requires={install_requires}') +print(f'compatible_ops={compatible_ops}') +print(f'ext_modules={ext_modules}') setup(name='deepspeed', - version=f"{VERSION}+{git_hash}", + version=version_str, description='DeepSpeed library', author='DeepSpeed Team', author_email='deepspeed@microsoft.com', url='http://deepspeed.ai', install_requires=install_requires, + extras_require=extras_require, packages=find_packages(exclude=["docker", - "third_party", - "csrc"]), - package_data={'deepspeed.ops.sparse_attention.trsrc': ['*.tr']}, - scripts=['bin/deepspeed', - 'bin/deepspeed.pt', - 'bin/ds', - 'bin/ds_ssh'], + "third_party"]), + include_package_data=True, + scripts=[ + 'bin/deepspeed', + 'bin/deepspeed.pt', + 'bin/ds', + 'bin/ds_ssh', + 'bin/ds_report' + ], classifiers=[ 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', diff --git a/tests/unit/modelingpreln.py b/tests/unit/modelingpreln.py index e9e00727353f..8fcae8bcca18 100755 --- a/tests/unit/modelingpreln.py +++ b/tests/unit/modelingpreln.py @@ -363,10 +363,18 @@ def __init__(self, hidden_size, eps=1e-12): self.variance_epsilon = eps def forward(self, x): + pdtype = x.dtype + x = x.float() u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x + self.bias + return self.weight * x.to(pdtype) + self.bias + + #def forward(self, x): + # u = x.mean(-1, keepdim=True) + # s = (x - u).pow(2).mean(-1, keepdim=True) + # x = (x - u) / torch.sqrt(s + self.variance_epsilon) + # return self.weight * x + self.bias class BertEmbeddings(nn.Module): diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index c7dee1fa2c53..645bbf94e89e 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -12,6 +12,8 @@ from deepspeed.runtime.pipe.topology import * PipeTopo = PipeDataParallelTopology +from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder + import argparse import pytest import json @@ -152,8 +154,8 @@ def checkpoint_correctness_verification(args, compare_lr_scheduler_states(trained_model, loaded_model) -@pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'], - reason="lamb is not installed") +@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="lamb is not compatible") def test_checkpoint_unfused_optimizer(tmpdir): config_dict = { "train_batch_size": 2, @@ -264,11 +266,11 @@ def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_sta 'Adam'), (2, True, - 'deepspeed_adam'), + 'Adam'), ]) def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): - if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed") + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") config_dict = { "train_batch_size": 2, @@ -320,14 +322,14 @@ def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_stat "Adam"), (2, True, - 'deepspeed_adam'), + 'Adam'), ]) def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): - if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed") + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") config_dict = { "train_batch_size": 2, @@ -385,11 +387,11 @@ def _test_checkpoint_zero_no_optimizer(args, 'Adam'), (2, True, - 'deepspeed_adam'), + 'Adam'), ]) def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): - if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed") + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") config_dict = { "train_batch_size": 2, @@ -459,11 +461,11 @@ def _test_checkpoint_lr_scheduler(args, 'Adam'), (2, True, - 'deepspeed_adam'), + 'Adam'), ]) def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): - if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed") + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") config_dict = { "train_batch_size": 2, diff --git a/tests/unit/test_cpu_adam.py b/tests/unit/test_cpu_adam.py index 5130dc72fa74..2ae1f8201328 100755 --- a/tests/unit/test_cpu_adam.py +++ b/tests/unit/test_cpu_adam.py @@ -1,16 +1,16 @@ import argparse import torch -import apex import time import numpy as np import pytest import copy import deepspeed -if not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed", allow_module_level=True) -else: - from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.adam import FusedAdam +from deepspeed.ops.op_builder import CPUAdamBuilder + +if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") def check_equal(first, second, atol=1e-2, verbose=False): @@ -32,6 +32,7 @@ def check_equal(first, second, atol=1e-2, verbose=False): (1048576), ]) # yapf: disable def test_cpu_adam_opt(model_size): + from deepspeed.ops.adam import DeepSpeedCPUAdam device = 'cpu' rng_state = torch.get_rng_state() param = torch.nn.Parameter(torch.randn(model_size, device=device)) @@ -42,7 +43,7 @@ def test_cpu_adam_opt(model_size): param2 = torch.nn.Parameter(param2_data) optimizer1 = torch.optim.AdamW([param1]) - optimizer2 = apex.optimizers.FusedAdam([param2]) + optimizer2 = FusedAdam([param2]) optimizer = DeepSpeedCPUAdam([param]) for i in range(10): diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py index 87010f6ea037..d4af962055ef 100755 --- a/tests/unit/test_cuda_backward.py +++ b/tests/unit/test_cuda_backward.py @@ -16,8 +16,8 @@ import sys -if not deepspeed.ops.__installed_ops__['transformer']: - pytest.skip("transformer kernels are not installed", allow_module_level=True) +#if not deepspeed.ops.__installed_ops__['transformer']: +# pytest.skip("transformer kernels are not installed", allow_module_level=True) def check_equal(first, second, atol=1e-2, verbose=False): @@ -254,6 +254,7 @@ def run_backward(ds_config, atol=1e-2, verbose=False): check_equal(base_grads, ds_grads, atol=atol, verbose=verbose) +#test_backward[3-1024-120-16-24-True-True-0.05] @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', [ (3,1024,120,16,24,True,False, 0.05), diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index 99604ecf31f5..6103ea7e12cf 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -16,8 +16,8 @@ import sys -if not deepspeed.ops.__installed_ops__['transformer']: - pytest.skip("transformer kernels are not installed", allow_module_level=True) +#if not deepspeed.ops.__installed_ops__['transformer']: +# pytest.skip("transformer kernels are not installed", allow_module_level=True) def check_equal(first, second, atol=1e-2, verbose=False): diff --git a/tests/unit/test_dynamic_loss_scale.py b/tests/unit/test_dynamic_loss_scale.py index 5bfe3353dcbb..7575d6b49454 100755 --- a/tests/unit/test_dynamic_loss_scale.py +++ b/tests/unit/test_dynamic_loss_scale.py @@ -8,9 +8,6 @@ from common import distributed_test from simple_model import SimpleModel, args_from_dict -lamb_available = pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'], - reason="lamb is not installed") - def run_model_step(model, gradient_list): for value in gradient_list: @@ -168,7 +165,6 @@ def _test_fused_some_overflow(args): _test_fused_some_overflow(args) -@lamb_available def test_unfused_no_overflow(tmpdir): config_dict = { "train_batch_size": 1, @@ -212,7 +208,6 @@ def _test_unfused_no_overflow(args): _test_unfused_no_overflow(args) -@lamb_available def test_unfused_all_overflow(tmpdir): config_dict = { "train_batch_size": 1, @@ -258,7 +253,6 @@ def _test_unfused_all_overflow(args): _test_unfused_all_overflow(args) -@lamb_available def test_unfused_some_overflow(tmpdir): config_dict = { "train_batch_size": 1, diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 47344bfbc9a1..30d53a00251f 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -1,18 +1,21 @@ import torch -import apex import deepspeed import argparse import pytest import json import os +from deepspeed.ops.adam import FusedAdam from common import distributed_test from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict -lamb_available = pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'], - reason="lamb is not installed") +try: + from apex import amp + _amp_available = True +except ImportError: + _amp_available = False +amp_available = pytest.mark.skip(_amp_available, reason="apex/amp is not installed") -@lamb_available def test_lamb_fp32_grad_clip(tmpdir): config_dict = { "train_batch_size": 2, @@ -48,7 +51,6 @@ def _test_lamb_fp32_grad_clip(args, model, hidden_dim): _test_lamb_fp32_grad_clip(args=args, model=model, hidden_dim=hidden_dim) -@lamb_available def test_lamb_fp16_basic(tmpdir): config_dict = { "train_batch_size": 2, @@ -86,7 +88,6 @@ def _test_lamb_fp16_basic(args, model, hidden_dim): _test_lamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) -@lamb_available def test_lamb_fp16_empty_grad(tmpdir): config_dict = { "train_batch_size": 2, @@ -234,8 +235,8 @@ def _test_adamw_fp16_empty_grad(args, model, hidden_dim): True), ]) def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed") + #if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: + # pytest.skip("cpu-adam is not installed") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -302,8 +303,8 @@ def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim): True), ]) def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed") + #if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: + # pytest.skip("cpu-adam is not installed") config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -402,8 +403,8 @@ def _test_zero_static_scale(args): True), ]) def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed") + #if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: + # pytest.skip("cpu-adam is not installed") config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -442,8 +443,8 @@ def _test_zero_allow_untested_optimizer(args): True), ]) def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: - pytest.skip("cpu-adam is not installed") + #if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: + # pytest.skip("cpu-adam is not installed") config_dict = { "train_micro_batch_size_per_gpu": 1, "gradient_accumulation_steps": 1, @@ -489,6 +490,7 @@ def _test_zero_empty_partition(args): _test_zero_empty_partition(args) +@amp_available def test_adam_amp_basic(tmpdir): config_dict = {"train_batch_size": 1, "steps_per_print": 1, "amp": {"enabled": True}} args = args_from_dict(tmpdir, config_dict) @@ -514,7 +516,7 @@ def _test_adam_amp_basic(args, model, hidden_dim): _test_adam_amp_basic(args=args, model=model, hidden_dim=hidden_dim) -@lamb_available +@amp_available def test_lamb_amp_basic(tmpdir): config_dict = { "train_batch_size": 2, @@ -552,6 +554,7 @@ def _test_lamb_amp_basic(args, model, hidden_dim): _test_lamb_amp_basic(args=args, model=model, hidden_dim=hidden_dim) +@amp_available def test_adam_amp_o2(tmpdir): config_dict = { "train_batch_size": 2, @@ -590,6 +593,7 @@ def _test_adam_amp_o2(args, model, hidden_dim): _test_adam_amp_o2(args=args, model=model, hidden_dim=hidden_dim) +@amp_available def test_adam_amp_o2_empty_grad(tmpdir): config_dict = { "train_batch_size": 2, @@ -630,11 +634,11 @@ def _test_adam_amp_o2_empty_grad(args, model, hidden_dim): @pytest.mark.parametrize('zero_stage, optimizer_constructor', [(1, - apex.optimizers.FusedAdam), + FusedAdam), (2, torch.optim.Adam), (2, - apex.optimizers.FusedAdam)]) + FusedAdam)]) def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_constructor): config_dict = { "train_batch_size": 2, diff --git a/tests/unit/test_sparse_attention.py b/tests/unit/test_sparse_attention.py index cb68c20cf731..80eb1b31b596 100644 --- a/tests/unit/test_sparse_attention.py +++ b/tests/unit/test_sparse_attention.py @@ -6,9 +6,11 @@ import pytest import torch import deepspeed +from deepspeed.ops.op_builder import SparseAttnBuilder -if not deepspeed.ops.__installed_ops__['sparse-attn']: - pytest.skip("cpu-adam is not installed", allow_module_level=True) +if not deepspeed.ops.__compatible_ops__[SparseAttnBuilder.NAME]: + pytest.skip("sparse attention op is not compatible on this system", + allow_module_level=True) def test_sparse_attention_module_availability(): @@ -236,7 +238,7 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo def _skip_on_cuda_compatability(): - pytest.skip("Skip these tests for now until we get our docker image fixed.") + #pytest.skip("Skip these tests for now until we get our docker image fixed.") if torch.cuda.get_device_capability()[0] != 7: pytest.skip("needs compute capability 7; v100") cuda_major = int(torch.version.cuda.split('.')[0]) * 10 diff --git a/third_party/apex b/third_party/apex deleted file mode 160000 index 494f8ab3fc1b..000000000000 --- a/third_party/apex +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 494f8ab3fc1b0b26949a3bcbb2bcac78008d48c1 diff --git a/version.txt b/version.txt new file mode 100644 index 000000000000..9e11b32fcaa9 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +0.3.1