Skip to content

Commit

Permalink
Disable default installation of CPU Adam (microsoft#450)
Browse files Browse the repository at this point in the history
* Disable default installation of CPU Adam

* Handle cpufeature import/use errors separately
  • Loading branch information
tjruwase authored Sep 29, 2020
1 parent 6f28ea3 commit 7b8be2a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.ops.adam import DeepSpeedCPUAdam

from deepspeed.utils import logger
#Toggle this to true to enable correctness test
Expand Down Expand Up @@ -1416,6 +1415,7 @@ def step(self, closure=None):
#torch.set_num_threads(12)
timers('optimizer_step').start()
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
#self.optimizer.step()
#for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
Expand Down
38 changes: 27 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import shutil
import subprocess
import warnings
import cpufeature
from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension

Expand All @@ -25,6 +24,27 @@ def fetch_requirements(path):
return [r.strip() for r in fd.readlines()]


def available_vector_instructions():
try:
import cpufeature
except ImportError:
warnings.warn(
f'import cpufeature failed - CPU vector optimizations are not available for CPUAdam'
)
return {}

cpu_vector_instructions = {}
try:
cpu_vector_instructions = cpufeature.CPUFeature
except _:
warnings.warn(
f'cpufeature.CPUFeature failed - CPU vector optimizations are not available for CPUAdam'
)
return {}

return cpu_vector_instructions


install_requires = fetch_requirements('requirements/requirements.txt')
dev_requires = fetch_requirements('requirements/requirements-dev.txt')
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
Expand All @@ -43,29 +63,26 @@ def fetch_requirements(path):
SPARSE_ATTN = "sparse-attn"
CPU_ADAM = "cpu-adam"

cpu_vector_instructions = available_vector_instructions()

# Build environment variables for custom builds
DS_BUILD_LAMB_MASK = 1
DS_BUILD_TRANSFORMER_MASK = 10
DS_BUILD_SPARSE_ATTN_MASK = 100
DS_BUILD_CPU_ADAM_MASK = 1000
DS_BUILD_AVX512_MASK = 10000

# Allow for build_cuda to turn on or off all ops
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK | DS_BUILD_AVX512_MASK
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK
DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS

# Set default of each op based on if build_cuda is set
OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS
DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM',
OP_DEFAULT)) * DS_BUILD_CPU_ADAM_MASK
DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM', 0)) * DS_BUILD_CPU_ADAM_MASK
DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK
DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER',
OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK
DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN',
OP_DEFAULT)) * DS_BUILD_SPARSE_ATTN_MASK
DS_BUILD_AVX512 = int(os.environ.get(
'DS_BUILD_AVX512',
cpufeature.CPUFeature['AVX512f'])) * DS_BUILD_AVX512_MASK

# Final effective mask is the bitwise OR of each op
BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN
Expand Down Expand Up @@ -111,11 +128,10 @@ def fetch_requirements(path):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5

cpu_info = cpufeature.CPUFeature
SIMD_WIDTH = ''
if cpu_info['AVX512f'] and DS_BUILD_AVX512:
if cpu_vector_instructions.get('AVX512f', False):
SIMD_WIDTH = '-D__AVX512__'
elif cpu_info['AVX2']:
elif cpu_vector_instructions.get('AVX2', False):
SIMD_WIDTH = '-D__AVX256__'
print("SIMD_WIDTH = ", SIMD_WIDTH)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import copy

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam

if not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed", allow_module_level=True)
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam


def check_equal(first, second, atol=1e-2, verbose=False):
Expand Down

0 comments on commit 7b8be2a

Please sign in to comment.