Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse attn + ops/runtime refactor + v0.3.0 #343

Merged
merged 1 commit into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ deepspeed/git_version_info.py
# Build + installation data
build/
dist/
fused_lamb_*.so
*.so
deepspeed.egg-info/

# Website
Expand Down
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ RUN apt-get update && \
software-properties-common \
openssh-client openssh-server \
pdsh curl sudo net-tools \
vim iputils-ping wget
vim iputils-ping wget \
llvm-9-dev cmake

##############################################################################
# Installation Latest Git
Expand Down Expand Up @@ -85,7 +86,7 @@ RUN mkdir -p ${STAGE_DIR} && \
dpkg -i ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_TAG}_all.deb

##############################################################################
## Ucomment and set SSH Daemon port
## SSH daemon port inside container cannot conflict with host OS port
###############################################################################
ENV SSH_PORT=2222
RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \
Expand Down
128 changes: 78 additions & 50 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
@@ -1,49 +1,83 @@

jobs:
- job: Default
- job: DeepSpeed_Tests
timeoutInMinutes: 360
pool:
name: 'GPU_testing'
name: 'DS_testing'

strategy:
matrix:
Python36:
PyTorch12-CUDA100:
python.version: '3.6'
#Python35:
# python.version: '3.5'
#Python37:
cuda.version: '10.0'
pytorch.version: '1.2'
torchvision.version: '0.4.0'
runmodeltests: true
#PyTorch15-CUDA101:
# python.version: '3.7'
#Python38:
# python.version: '3.8'
# cuda.version: '10.1'
# pytorch.version: '1.5.0+cu101'
# torchvision.version: '0.6.0+cu101'
# runmodeltests: true
##PyTorch15-CUDA102:
# python.version: '3.7'
# cuda.version: '10.2'
# pytorch.version: '1.5'
# torchvision.version: '0.6.1'
# runmodeltests: true

variables:
conda_env: 'ds_test_py$(python.version)_cuda$(cuda.version)_pytorch$(pytorch.version)'

steps:
- task: UsePythonVersion@0
inputs:
versionSpec: '$(python.version)'
addToPath: true
architecture: 'x64'
displayName: 'Use Python $(python.version)'
# Unfortunately nvidia's nvcc_linux-64=<version> seems to install 10.1 regardless?
# Most of this complexity is a workaround to get the compiler toolchain to match the
# cudatoolkit runtime
- script: |
conda create --force --yes -n $(conda_env) python=$(python.version) cudatoolkit=$(cuda.version)
source activate $(conda_env)
conda install -q --yes conda
conda install -q --yes pip
conda install -q --yes gxx_linux-64
if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi
displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)'

# Manually install torch/torchvision first to enforce versioning.
- script: |
python -m pip install --upgrade pip
pip install --user -r requirements.txt
./install.sh --pip_sudo
displayName: 'Install dependencies'
source activate $(conda_env)
pip install --progress-bar=off torch==$(pytorch.version) torchvision==$(torchvision.version)
#-f https://download.pytorch.org/whl/torch_stable.html
./install.sh --local_only
#python -I basic_install_test.py
displayName: 'Install DeepSpeed'

- script: |
pre-commit run --all-files
displayName: 'Formatting checks'
source activate $(conda_env)
which python
python --version
which nvcc
nvcc --version
which deepspeed
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
python -c "import deepspeed; print('deepspeed:', deepspeed.__version__)"
displayName: 'Show environment'


- script: |
pytest --forked --verbose tests/unit/
source activate $(conda_env)
pytest --durations=0 --forked --verbose -x tests/unit/
displayName: 'Unit tests'

- script: |
source activate $(conda_env)
ln -s /data/Megatron-LM/data DeepSpeedExamples/Megatron-LM/
pip install --user -r DeepSpeedExamples/Megatron-LM/requirements.txt
pip install --progress-bar=off -r DeepSpeedExamples/Megatron-LM/requirements.txt
cd tests/model/
pytest -s run_sanity_check.py
rm -rf BingBertSquad/baseline
rm -rf Megatron_GPT2/baseline
pytest --durations=0 -s run_sanity_check.py
condition: and(succeeded(), eq(variables['runmodeltests'], true))
displayName: 'Model tests'

#BingBertSquad logs
Expand All @@ -52,35 +86,29 @@ jobs:
targetPath: '$(Build.SourcesDirectory)/tests/model/BingBertSquad/test/'
artifactName: BingBertSquad_logs
displayName: 'BingBertSquad log uploads'
condition: always()
condition: eq(variables['runmodeltests'], true)

# Megatron test logs
#- task: PublishPipelineArtifact@1
# inputs:
# targetPath: '$(Build.SourcesDirectory)/tests/model/Megatron_GPT2/test/'
# artifactName: Megatron_GPT2_logs
# displayName: 'Megatron GPT2 log uploads'
# condition: always()

#- task: PublishPipelineArtifact@1
# inputs:
# targetPath: '$(Build.SourcesDirectory)/tests/model/Megatron_GPT2/checkpoint_test_logs/'
# artifactName: Megatron_GPT2_checkpoint_logs
# displayName: 'Megatron GPT2 checkpoint log uploads'
# condition: always()
- job: Code_Quality_Checks
pool:
name: 'DS_testing'
variables:
conda_env: 'ds_codetest'

steps:
- script: |
conda create --force --yes -n $(conda_env) python=3.7
source activate $(conda_env)
displayName: 'Create code test environment'

#BingBert logs
#- task: PublishPipelineArtifact@1
# inputs:
# targetPath: '$(Build.SourcesDirectory)/tests/model/bing_bert/pretrain_test/'
# artifactName: BingBert_pretrain_logs
# displayName: 'BingBert pretrain logs'
# condition: always()
- script: |
source activate $(conda_env)
pip install pre-commit
pre-commit run --all-files
displayName: 'Formatting checks'

#- task: PublishPipelineArtifact@1
# inputs:
# targetPath: '$(Build.SourcesDirectory)/tests/model/bing_bert/checkpoint_test_logs/'
# artifactName: BingBert_checkpoint_logs
# displayName: 'BingBert checkpoint logs'
# condition: always()
- script: |
source activate $(conda_env)
pip install pylint
pylint --exit-zero deepspeed/
displayName: 'Code linter'
5 changes: 3 additions & 2 deletions basic_install_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
raise err

try:
fused_lamb = importlib.import_module('deepspeed_lamb_cuda')
fused_lamb = importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
print('deepspeed fused lamb kernels successfully installed')
except Exception as err:
raise err
Expand All @@ -30,7 +30,8 @@
print("using new-style apex")

try:
ds_transformer = importlib.import_module('deepspeed_transformer_cuda')
ds_transformer = importlib.import_module(
'deepspeed.ops.transformer.transformer_cuda')
print('deepspeed transformer kernels successfully installed')
except Exception as err:
raise err
2 changes: 1 addition & 1 deletion bin/ds
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python

from deepspeed.pt.deepspeed_run import main
from deepspeed.launcher.runner import main

if __name__ == '__main__':
main()
120 changes: 120 additions & 0 deletions csrc/sparse_attention/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp

#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif

typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;

void segment_blocks(torch::Tensor layout,
torch::Tensor idx,
torch::Tensor scratch,
int max_width,
ret_t& ret)
{
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
torch::Tensor tmp = torch::zeros_like(layout);

auto _tmp = tmp.accessor<int, 3>();
auto _layout = layout.accessor<int, 3>();
auto _idx = idx.accessor<int, 3>();
auto _scratch = scratch.accessor<int, 3>();
std::vector<int> current(H, 0);

#ifdef _OPENMP
#pragma omp parallel for
#endif
for (size_t h = 0; h < H; h++) {
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));

for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
int v = _layout[h][m][n];
if (v == 0) continue;
int n_left = ii_left[max_width - 1];
int m_top = ii_top[max_width - 1][n];
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
int topleft = (m_top >= 0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
int width = std::min(left, std::min(top, topleft)) + 1;

// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for (int nn = n_left + 1; nn < n; nn++)
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n]) width = 1;
_tmp[h][m][n] = width;

// update n_left ring buffer
for (int k = 0; k < max_width - 1; k++) ii_left[k] = ii_left[k + 1];
ii_left[max_width - 1] = n;

// update ii_top ring buffer
for (int k = 0; k < max_width - 1; k++) ii_top[k][n] = ii_top[k + 1][n];
ii_top[max_width - 1][n] = m;

// block is too small -- skip
if (width != max_width) continue;

// retained blocks are set to zeros
for (size_t km = 0; km < max_width; km++)
for (size_t kn = 0; kn < max_width; kn++) {
int mm = ii_top[km][n];
int nn = ii_left[kn];
if (mm < 0 || nn < 0) continue;
_layout[h][mm][nn] = 0;
_tmp[h][mm][nn] = 0;
_scratch[h][current[h]][0] = (int)h;
_scratch[h][current[h]][1] = (int)mm;
_scratch[h][current[h]][2] = (int)nn;
_scratch[h][current[h]][3] = _idx[h][mm][nn];
current[h]++;
}
}
}
}
std::vector<torch::Tensor> to_cat;
for (size_t h = 0; h < H; h++)
if (current[h] > 0) to_cat.push_back(scratch[h].slice(0, 0, current[h]));
if (!to_cat.empty()) ret.push_back({max_width, torch::cat(to_cat)});
}

ret_t sdd_segment(torch::Tensor layout, int start_width)
{
ret_t ret;

// block index
torch::Tensor idx = torch::zeros_like(layout);
int current = 0;
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
auto _layout = layout.accessor<int, 3>();
auto _idx = idx.accessor<int, 3>();
for (size_t h = 0; h < H; h++)
for (size_t m = 0; m < M; m++)
for (size_t n = 0; n < N; n++) {
if (_layout[h][m][n] == 0) continue;
_idx[h][m][n] = current++;
}

// scratch memory
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());

for (int max_width = start_width; max_width > 0; max_width /= 2)
segment_blocks(layout, idx, scratch, max_width, ret);
return ret;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("sdd_segment", &sdd_segment, "SDD segmentation handler");
}
Loading