diff --git a/.clang-format b/.clang-format
index 9b547c1a18fe..9f90836e1f2a 100755
--- a/.clang-format
+++ b/.clang-format
@@ -20,7 +20,7 @@ AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
-AlwaysBreakTemplateDeclarations: Yes
+AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
new file mode 100644
index 000000000000..173a51cda5de
--- /dev/null
+++ b/.github/workflows/main.yml
@@ -0,0 +1,51 @@
+# This is a basic workflow to help you get started with Actions
+
+name: Build
+
+# Controls when the action will run.
+on:
+ push:
+ paths-ignore:
+ - 'docs/**'
+ pull_request:
+ paths-ignore:
+ - 'docs/**'
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ # This workflow contains a single job called "build"
+ build:
+ # The type of runner that the job will run on
+ runs-on: self-hosted
+
+ # Steps represent a sequence of tasks that will be executed as part of the job
+ steps:
+ # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
+ - uses: actions/checkout@v2
+
+ # Runs a single command using the runners shell
+ - name: environment
+ run: |
+ nvidia-smi
+ which python
+ python --version
+ which nvcc
+ nvcc --version
+ python -c "import torch; print('torch:', torch.__version__, torch)"
+ python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+
+ # Runs a set of commands using the runners shell
+ - name: Install deepspeed
+ run: |
+ pip install .[dev]
+ ds_report
+
+ - name: Formatting checks
+ run: |
+ pre-commit run --all-files
+
+ # Runs a set of commands using the runners shell
+ - name: Unit tests
+ run: |
+ if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
diff --git a/.github/workflows/pre-compile-ops.yml b/.github/workflows/pre-compile-ops.yml
new file mode 100644
index 000000000000..4005d4baf2fc
--- /dev/null
+++ b/.github/workflows/pre-compile-ops.yml
@@ -0,0 +1,47 @@
+# This is a basic workflow to help you get started with Actions
+
+name: Tests-w-precompiled-ops
+
+# Controls when the action will run.
+on:
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ # This workflow contains a single job called "build"
+ build:
+ # The type of runner that the job will run on
+ runs-on: self-hosted
+
+ # Steps represent a sequence of tasks that will be executed as part of the job
+ steps:
+ # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
+ - uses: actions/checkout@v2
+
+ # Runs a single command using the runners shell
+ - name: environment
+ run: |
+ nvidia-smi
+ which python
+ python --version
+ which nvcc
+ nvcc --version
+ python -c "import torch; print('torch:', torch.__version__, torch)"
+ python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+
+ # Runs a set of commands using the runners shell
+ - name: Install deepspeed
+ run: |
+ DS_BUILD_OPS=1 pip install .[dev]
+ ds_report
+
+ - name: Formatting checks
+ run: |
+ pre-commit run --all-files
+
+ # Runs a set of commands using the runners shell
+ - name: Unit tests
+ run: |
+ if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
diff --git a/.gitignore b/.gitignore
index e83ac2d32f53..84340857f802 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,21 +3,28 @@
*~
*.swp
*.log
-deepspeed/git_version_info.py
+deepspeed/git_version_info_installed.py
# Build + installation data
build/
dist/
-fused_lamb_*.so
+*.so
deepspeed.egg-info/
+build.txt
# Website
docs/_site/
docs/build
+docs/code-docs/source/_build
docs/code-docs/_build
docs/code-docs/build
.sass-cache/
.jekyll-cache/
.jekyll-metadata
+# Testing data
tests/unit/saved_checkpoint/
+
+# Dev/IDE data
+.vscode
+.theia
diff --git a/.gitmodules b/.gitmodules
index 1257dc13e0f4..37adb6f39e5c 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,6 +1,3 @@
-[submodule "third_party/apex"]
- path = third_party/apex
- url = https://github.com/NVIDIA/apex.git
[submodule "DeepSpeedExamples"]
path = DeepSpeedExamples
url = https://github.com/microsoft/DeepSpeedExamples
diff --git a/.readthedocs.yml b/.readthedocs.yml
new file mode 100644
index 000000000000..a2da36620152
--- /dev/null
+++ b/.readthedocs.yml
@@ -0,0 +1,18 @@
+
+# Required
+version: 2
+
+# Build documentation in the docs/ directory with Sphinx
+sphinx:
+ configuration: docs/code-docs/source/conf.py
+ fail_on_warning: false
+
+# Optionally build your docs in additional formats such as PDF
+formats:
+ - pdf
+
+# Optionally set the version of Python and requirements required to build your docs
+python:
+ version: 3.7
+ install:
+ - requirements: requirements/requirements-readthedocs.txt
diff --git a/CODEOWNERS b/CODEOWNERS
new file mode 100644
index 000000000000..ec7993c060aa
--- /dev/null
+++ b/CODEOWNERS
@@ -0,0 +1 @@
+* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @arashashari @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @niumanar
diff --git a/DeepSpeedExamples b/DeepSpeedExamples
index fd6fb5148ccf..78d69cb2f89a 160000
--- a/DeepSpeedExamples
+++ b/DeepSpeedExamples
@@ -1 +1 @@
-Subproject commit fd6fb5148ccf5c9ce222432006f1d93806187cd9
+Subproject commit 78d69cb2f89a27b1e9b072df8c3e47d00c024fdc
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 000000000000..53fcc885090e
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,4 @@
+include *.txt README.md
+recursive-include requirements *.txt
+recursive-include deepspeed *.cpp *.h *.cu *.tr *.cuh *.cc
+recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
diff --git a/README.md b/README.md
index edfcb2a98e6c..ee2d3e6bb676 100755
--- a/README.md
+++ b/README.md
@@ -1,6 +1,8 @@
-[![Build Status](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_apis/build/status/microsoft.DeepSpeed?branchName=master)](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_build/latest?definitionId=1&branchName=master)
+[![Build Status](https://github.com/microsoft/deepspeed/workflows/Build/badge.svg)](https://github.com/microsoft/DeepSpeed/actions)
+[![PyPI version](https://badge.fury.io/py/deepspeed.svg)](https://pypi.org/project/deepspeed/)
[![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)
[![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
+[![Docker Pulls](https://img.shields.io/docker/pulls/deepspeed/deepspeed)](https://hub.docker.com/r/deepspeed/deepspeed)
[DeepSpeed](https://www.deepspeed.ai/) is a deep learning optimization
library that makes distributed training easy, efficient, and effective.
@@ -9,9 +11,13 @@ library that makes distributed training easy, efficient, and effective.
10x Faster Training
Minimal Code Change
-DeepSpeed can train deep learning models with over a hundred billion parameters on current
-generation of GPU clusters, while achieving over 10x in system performance
-compared to the state-of-art. Early adopters of DeepSpeed have already produced
+DeepSpeed delivers extreme-scale model training for everyone, from data scientists training on massive supercomputers to those training on low-end clusters or even on a single GPU:
+* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
+* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
+* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
+* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks.
+
+Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
establishing a new SOTA in the LM category.
@@ -25,25 +31,26 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
# News
-
-* [2020/05/19] [ZeRO-2 & DeepSpeed: Shattering Barriers of Deep Learning Speed & Scale](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/)
-**[_NEW_]**
-* [2020/05/19] [An Order-of-Magnitude Larger and Faster Training with ZeRO-2](https://www.deepspeed.ai/news/2020/05/18/zero-stage2.html)
-**[_NEW_]**
-* [2020/05/19] [The Fastest and Most Efficient BERT Training through Optimized Transformer Kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
-**[_NEW_]**
-* [2020/02/13] [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
-* [2020/02/13] [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
+* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
+* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
+* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/)
+ * [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html)
+ * [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html)
+ * [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html)
+ * [10x bigger model training on a single GPU with ZeRO-Offload](https://www.deepspeed.ai/news/2020/09/08/ZeRO-Offload.html)
+* [2020/08/07] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) is now available on-demand
# Table of Contents
| Section | Description |
| --------------------------------------- | ------------------------------------------- |
| [Why DeepSpeed?](#why-deepspeed) | DeepSpeed overview |
-| [Features](#features) | DeepSpeed features |
-| [Further Reading](#further-reading) | DeepSpeed documentation, tutorials, etc. |
-| [Contributing](#contributing) | Instructions for contributing to DeepSpeed |
-| [Publications](#publications) | DeepSpeed publications |
+| [Install](#installation) | Installation details |
+| [Features](#features) | Feature list and overview |
+| [Further Reading](#further-reading) | Documentation, tutorials, etc. |
+| [Contributing](#contributing) | Instructions for contributing |
+| [Publications](#publications) | Publications related to DeepSpeed |
+| [Videos](#videos) | Videos related to DeepSpeed |
# Why DeepSpeed?
Training advanced deep learning models is challenging. Beyond model design,
@@ -55,8 +62,35 @@ a large model easily runs out of memory with pure data parallelism and it is
difficult to use model parallelism. DeepSpeed addresses these challenges to
accelerate model development *and* training.
-# Features
+# Installation
+
+The quickest way to get started with DeepSpeed is via pip, this will install
+the latest release of DeepSpeed which is not tied to specific PyTorch or CUDA
+versions. DeepSpeed includes several C++/CUDA extensions that we commonly refer
+to as our 'ops'. By default, all of these extensions/ops will be built
+just-in-time (JIT) using [torch's JIT C++ extension loader that relies on
+ninja](https://pytorch.org/docs/stable/cpp_extension.html) to build and
+dynamically link them at runtime.
+
+**Note:** [PyTorch](https://pytorch.org/) must be installed _before_ installing
+DeepSpeed.
+```bash
+pip install deepspeed
+```
+
+After installation, you can validate your install and see which extensions/ops
+your machine is compatible with via the DeepSpeed environment report.
+
+```bash
+ds_report
+```
+
+If you would like to pre-install any of the DeepSpeed extensions/ops (instead
+of JIT compiling) or install pre-compiled ops via PyPI please see our [advanced
+installation instructions](https://www.deepspeed.ai/tutorials/advanced-install/).
+
+# Features
Below we provide a brief feature list, see our detailed [feature
overview](https://www.deepspeed.ai/features/) for descriptions and usage.
@@ -66,10 +100,27 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* [Model Parallelism](https://www.deepspeed.ai/features/#model-parallelism)
* Support for Custom Model Parallelism
* Integration with Megatron-LM
-* [Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#memory-and-bandwidth-optimizations)
- * The Zero Redundancy Optimizer (ZeRO)
- * Constant Buffer Optimization (CBO)
+* [Pipeline Parallelism](https://www.deepspeed.ai/tutorials/pipeline/)
+ * 3D Parallelism
+* [The Zero Redundancy Optimizer (ZeRO)](https://www.deepspeed.ai/tutorials/zero/)
+ * Optimizer State and Gradient Partitioning
+ * Activation Partitioning
+ * Constant Buffer Optimization
+ * Contiguous Memory Optimization
+* [ZeRO-Offload](https://www.deepspeed.ai/tutorials/zero-offload/)
+ * Leverage both CPU/GPU memory for model training
+ * Support 10B model training on a single GPU
+* [Ultra-fast dense transformer kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
+* [Sparse attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html)
+ * Memory- and compute-efficient sparse kernels
+ * Support 10x long sequences than dense
+ * Flexible support to different sparse structures
+* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html)
+ * Custom communication collective
+ * Up to 5x communication volume saving
+* [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations)
* Smart Gradient Accumulation
+ * Communication/Computation Overlap
* [Training Features](https://www.deepspeed.ai/features/#training-features)
* Simplified training API
* Gradient Clipping
@@ -79,6 +130,7 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* Memory bandwidth optimized FP16 Optimizer
* Large Batch Training with LAMB Optimizer
* Memory efficient Training with ZeRO Optimizer
+ * CPU-Adam
* [Training Agnostic Checkpointing](https://www.deepspeed.ai/features/#training-agnostic-checkpointing)
* [Advanced Parameter Search](https://www.deepspeed.ai/features/#advanced-parameter-search)
* Learning Rate Range Test
@@ -127,8 +179,23 @@ all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of
Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the
[Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact
-[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or
-comments.
+[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
# Publications
-1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: Memory Optimization Towards Training A Trillion Parameter Models. [ArXiv:1910.02054](https://arxiv.org/abs/1910.02054)
+1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: memory optimizations toward training trillion parameter models. [arXiv:1910.02054](https://arxiv.org/abs/1910.02054) and [In Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis (SC '20)](https://dl.acm.org/doi/10.5555/3433701.3433727).
+2. Jeff Rasley, Samyam Rajbhandari, Olatunji Ruwase, and Yuxiong He. (2020) DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters. [In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD '20, Tutorial)](https://dl.acm.org/doi/10.1145/3394486.3406703).
+3. Minjia Zhang, Yuxiong He. (2020) Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping. [arXiv:2010.13369](https://arxiv.org/abs/2010.13369) and [NeurIPS 2020](https://proceedings.neurips.cc/paper/2020/hash/a1140a3d0df1c81e24ae954d935e8926-Abstract.html).
+4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840).
+
+# Videos
+1. DeepSpeed KDD 2020 Tutorial
+ 1. [Overview](https://www.youtube.com/watch?v=CaseqC45DNc&list=PLa85ZdUjfWS21mgibJ2vCvLziprjpKoW0&index=29)
+ 2. [ZeRO + large model training](https://www.youtube.com/watch?v=y4_bCiAsIAk&list=PLa85ZdUjfWS21mgibJ2vCvLziprjpKoW0&index=28)
+ 3. [17B T-NLG demo](https://www.youtube.com/watch?v=9V-ZbP92drg&list=PLa85ZdUjfWS21mgibJ2vCvLziprjpKoW0&index=27)
+ 4. [Fastest BERT training + RScan tuning](https://www.youtube.com/watch?v=o1K-ZG9F6u0&list=PLa85ZdUjfWS21mgibJ2vCvLziprjpKoW0&index=26)
+ 5. DeepSpeed hands on deep dive: [part 1](https://www.youtube.com/watch?v=_NOk-mBwDYg&list=PLa85ZdUjfWS21mgibJ2vCvLziprjpKoW0&index=92), [part 2](https://www.youtube.com/watch?v=sG6_c4VXLww&list=PLa85ZdUjfWS21mgibJ2vCvLziprjpKoW0&index=94), [part 3](https://www.youtube.com/watch?v=k9yPkBTayos&list=PLa85ZdUjfWS21mgibJ2vCvLziprjpKoW0&index=93)
+ 6. [FAQ](https://www.youtube.com/watch?v=nsHu6vEgPew&list=PLa85ZdUjfWS21mgibJ2vCvLziprjpKoW0&index=24)
+2. Microsoft Research Webinar
+ * Registration is free and all videos are available on-demand.
+ * [ZeRO & Fastest BERT: Increasing the scale and speed of deep learning training in DeepSpeed](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html).
+3. [DeepSpeed on AzureML](https://youtu.be/yBVXR8G8Bg8)
diff --git a/azure-pipelines-docker.yml b/azure-pipelines-docker.yml
deleted file mode 100644
index dc1782f997f3..000000000000
--- a/azure-pipelines-docker.yml
+++ /dev/null
@@ -1,36 +0,0 @@
-# Docker
-# Build a Docker image
-# https://docs.microsoft.com/azure/devops/pipelines/languages/docker
-
-trigger:
-- master
-
-resources:
-- repo: self
-
-variables:
- tag: '$(Build.BuildId)'
-
-stages:
-- stage: Build
- displayName: Build image
- jobs:
- - job: Build
- displayName: Build
- pool:
- vmImage: 'ubuntu-latest'
- steps:
- - task: Docker@2
- displayName: Login to Docker Hub
- inputs:
- command: login
- containerRegistry: DeepSpeedDocker
- - task: Docker@2
- displayName: Build and Push
- inputs:
- command: buildAndPush
- dockerfile: '$(Build.SourcesDirectory)/Dockerfile'
- repository: deepspeed/deepspeed
- tags: |
- $(tag)
- latest
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
deleted file mode 100644
index ba6502606ee8..000000000000
--- a/azure-pipelines.yml
+++ /dev/null
@@ -1,86 +0,0 @@
-
-jobs:
-- job: Default
- timeoutInMinutes: 360
- pool:
- name: 'GPU_testing'
-
- strategy:
- matrix:
- Python36:
- python.version: '3.6'
- #Python35:
- # python.version: '3.5'
- #Python37:
- # python.version: '3.7'
- #Python38:
- # python.version: '3.8'
-
-
- steps:
- - task: UsePythonVersion@0
- inputs:
- versionSpec: '$(python.version)'
- addToPath: true
- architecture: 'x64'
- displayName: 'Use Python $(python.version)'
-
- - script: |
- python -m pip install --upgrade pip
- pip install --user -r requirements.txt
- ./install.sh --pip_sudo
- displayName: 'Install dependencies'
-
- - script: |
- pre-commit run --all-files
- displayName: 'Formatting checks'
-
- - script: |
- pytest --forked --verbose tests/unit/
- displayName: 'Unit tests'
-
- - script: |
- ln -s /data/Megatron-LM/data DeepSpeedExamples/Megatron-LM/
- pip install --user -r DeepSpeedExamples/Megatron-LM/requirements.txt
- cd tests/model/
- pytest -s run_sanity_check.py
- displayName: 'Model tests'
-
- #BingBertSquad logs
- - task: PublishPipelineArtifact@1
- inputs:
- targetPath: '$(Build.SourcesDirectory)/tests/model/BingBertSquad/test/'
- artifactName: BingBertSquad_logs
- displayName: 'BingBertSquad log uploads'
- condition: always()
-
- # Megatron test logs
- #- task: PublishPipelineArtifact@1
- # inputs:
- # targetPath: '$(Build.SourcesDirectory)/tests/model/Megatron_GPT2/test/'
- # artifactName: Megatron_GPT2_logs
- # displayName: 'Megatron GPT2 log uploads'
- # condition: always()
-
- #- task: PublishPipelineArtifact@1
- # inputs:
- # targetPath: '$(Build.SourcesDirectory)/tests/model/Megatron_GPT2/checkpoint_test_logs/'
- # artifactName: Megatron_GPT2_checkpoint_logs
- # displayName: 'Megatron GPT2 checkpoint log uploads'
- # condition: always()
-
-
- #BingBert logs
- #- task: PublishPipelineArtifact@1
- # inputs:
- # targetPath: '$(Build.SourcesDirectory)/tests/model/bing_bert/pretrain_test/'
- # artifactName: BingBert_pretrain_logs
- # displayName: 'BingBert pretrain logs'
- # condition: always()
-
- #- task: PublishPipelineArtifact@1
- # inputs:
- # targetPath: '$(Build.SourcesDirectory)/tests/model/bing_bert/checkpoint_test_logs/'
- # artifactName: BingBert_checkpoint_logs
- # displayName: 'BingBert checkpoint logs'
- # condition: always()
diff --git a/azure/README.md b/azure/README.md
deleted file mode 120000
index fb962e96a1f9..000000000000
--- a/azure/README.md
+++ /dev/null
@@ -1 +0,0 @@
-../docs/_tutorials/azure.md
\ No newline at end of file
diff --git a/azure/README.md b/azure/README.md
new file mode 100644
index 000000000000..1cca695bfa7e
--- /dev/null
+++ b/azure/README.md
@@ -0,0 +1,3 @@
+# Getting Started with DeepSpeed on Azure
+
+Please see our [Azure tutorial](https://www.deepspeed.ai/tutorials/azure/) to get started with DeepSpeed on Azure!
diff --git a/basic_install_test.py b/basic_install_test.py
deleted file mode 100644
index 966b124f5813..000000000000
--- a/basic_install_test.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import torch
-import importlib
-
-try:
- import deepspeed as ds
- print("deepspeed successfully imported")
-except ImportError as err:
- raise err
-
-print(f"torch version: {torch.__version__}")
-
-print(f"deepspeed info: {ds.__version__}, {ds.__git_hash__}, {ds.__git_branch__}")
-
-try:
- apex_C = importlib.import_module('apex_C')
- print("apex successfully installed")
-except Exception as err:
- raise err
-
-try:
- fused_lamb = importlib.import_module('deepspeed_lamb_cuda')
- print('deepspeed fused lamb kernels successfully installed')
-except Exception as err:
- raise err
-
-try:
- from apex.optimizers import FP16_Optimizer
- print("using old-style apex")
-except ImportError:
- print("using new-style apex")
-
-try:
- ds_transformer = importlib.import_module('deepspeed_transformer_cuda')
- print('deepspeed transformer kernels successfully installed')
-except Exception as err:
- raise err
diff --git a/bin/ds b/bin/ds
index 47efea32da34..6bb47da8ce7c 100755
--- a/bin/ds
+++ b/bin/ds
@@ -1,6 +1,6 @@
#!/usr/bin/env python
-from deepspeed.pt.deepspeed_run import main
+from deepspeed.launcher.runner import main
if __name__ == '__main__':
main()
diff --git a/bin/ds_elastic b/bin/ds_elastic
new file mode 100644
index 000000000000..ef92cbdab32d
--- /dev/null
+++ b/bin/ds_elastic
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+
+import argparse
+import json
+
+import deepspeed
+from deepspeed.elasticity import compute_elastic_config
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
+ parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size")
+ args = parser.parse_args()
+ ds_config = json.load(open(args.config, 'r'))
+
+ ds_version = deepspeed.__version__
+
+ elastic_config = ds_config['elasticity']
+ print('------------------------------------------')
+ print("Elasticity config:")
+ print('------------------------------------------')
+ print(json.dumps(elastic_config, indent=4, sort_keys=True))
+
+ if args.world_size > 0:
+ final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size)
+ print('------------------------------------------')
+ print(f"Calculated results for world size {args.world_size}:")
+ print('------------------------------------------')
+ print(f'final_batch_size .... {final_batch_size}')
+ print(f'valid_gpus .......... {valid_gpus}')
+ print(f'micro_batch_size .... {micro_batch_size}')
+ else:
+ final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
+ print('------------------------------------------')
+ print("Calculated results:")
+ print('------------------------------------------')
+ print(f'final_batch_size .... {final_batch_size}')
+ print(f'valid_gpus .......... {valid_gpus}')
diff --git a/bin/ds_report b/bin/ds_report
new file mode 100644
index 000000000000..c03a95645eae
--- /dev/null
+++ b/bin/ds_report
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+
+from deepspeed.env_report import main
+
+if __name__ == '__main__':
+ main()
diff --git a/csrc/adam/compat.h b/csrc/adam/compat.h
new file mode 100644
index 000000000000..86f84a85065c
--- /dev/null
+++ b/csrc/adam/compat.h
@@ -0,0 +1,14 @@
+/* Copyright 2020 The Microsoft DeepSpeed Team
+ Copyright NVIDIA/apex
+ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
+*/
+
+#ifndef TORCH_CHECK
+#define TORCH_CHECK AT_CHECK
+#endif
+
+#ifdef VERSION_GE_1_3
+#define DATA_PTR data_ptr
+#else
+#define DATA_PTR data
+#endif
diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp
new file mode 100755
index 000000000000..e817322630b8
--- /dev/null
+++ b/csrc/adam/cpu_adam.cpp
@@ -0,0 +1,677 @@
+#include "cpu_adam.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "cublas_v2.h"
+#include "cuda.h"
+#include "curand.h"
+#include "custom_cuda_layers.h"
+
+static std::unordered_map> s_optimizers;
+
+#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
+
+// C++ interface
+
+void Adam_Optimizer::Step(float* _params,
+ float* grads,
+ float* _exp_avg,
+ float* _exp_avg_sq,
+ size_t _param_size,
+ __half* dev_params)
+{
+ float betta1_minus1 = 1 - _betta1;
+ float betta2_minus1 = 1 - _betta2;
+
+ float step_size = -1 * _alpha / _bias_correction1;
+ float w_decay = -1 * _alpha * _weight_decay;
+ size_t rounded_size = 0;
+
+#if defined(__AVX512__) or defined(__AVX256__)
+
+ AVX_Data betta1_4;
+ betta1_4.data = SIMD_SET(_betta1);
+ AVX_Data betta2_4;
+ betta2_4.data = SIMD_SET(_betta2);
+
+ AVX_Data betta1_minus1_4;
+ betta1_minus1_4.data = SIMD_SET(betta1_minus1);
+ AVX_Data betta2_minus1_4;
+ betta2_minus1_4.data = SIMD_SET(betta2_minus1);
+
+ AVX_Data bias2_sqrt;
+ bias2_sqrt.data = SIMD_SET(_bias_correction2);
+
+ AVX_Data eps_4;
+ eps_4.data = SIMD_SET(_eps);
+
+ AVX_Data step_size_4;
+ step_size_4.data = SIMD_SET(step_size);
+
+ AVX_Data weight_decay4;
+ if (_weight_decay > 0)
+ weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
+ rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
+
+ for (size_t t = 0; t < rounded_size; t += TILE) {
+ size_t copy_size = TILE;
+ if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
+ size_t offset = copy_size + t;
+#pragma omp parallel for
+ for (size_t i = t; i < offset; i += SIMD_WIDTH) {
+ AVX_Data grad_4;
+ grad_4.data = SIMD_LOAD(grads + i);
+
+ AVX_Data momentum_4;
+ momentum_4.data = SIMD_LOAD(_exp_avg + i);
+ AVX_Data variance_4;
+ variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
+
+ AVX_Data param_4;
+ param_4.data = SIMD_LOAD(_params + i);
+
+ if (_weight_decay > 0 && !_adamw_mode) {
+ grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data);
+ }
+ momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
+ momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
+
+ variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
+ grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
+ variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
+
+ grad_4.data = SIMD_SQRT(variance_4.data);
+ grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
+ grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
+ if (_weight_decay > 0 && _adamw_mode) {
+ param_4.data = SIMD_FMA(param_4.data, weight_decay4.data, param_4.data);
+ }
+ param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
+
+ SIMD_STORE(_params + i, param_4.data);
+
+ if (dev_params) SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4.data);
+
+ SIMD_STORE(_exp_avg + i, momentum_4.data);
+ SIMD_STORE(_exp_avg_sq + i, variance_4.data);
+ }
+ if (dev_params) {
+ launch_param_update(_doubled_buffer[_buf_index],
+ dev_params + t,
+ copy_size,
+ Context::Instance().GetCurrentStream());
+ _buf_index = !_buf_index;
+ }
+ }
+
+#endif
+
+ if (_param_size > rounded_size) {
+#pragma omp parallel for
+ for (size_t k = rounded_size; k < _param_size; k++) {
+ float grad = grads[k];
+ float param = _params[k];
+ float momentum = _exp_avg[k];
+ float variance = _exp_avg_sq[k];
+ if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
+ momentum = momentum * _betta1;
+ momentum = grad * betta1_minus1 + momentum;
+
+ variance = variance * _betta2;
+ grad = grad * grad;
+ variance = grad * betta2_minus1 + variance;
+
+ grad = sqrt(variance);
+ grad = grad * _bias_correction2 + _eps;
+ grad = momentum / grad;
+ if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
+ param = grad * step_size + param;
+ if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param;
+
+ _params[k] = param;
+ _exp_avg[k] = momentum;
+ _exp_avg_sq[k] = variance;
+ }
+ if (dev_params) {
+ launch_param_update(_doubled_buffer[_buf_index],
+ dev_params + rounded_size,
+ (_param_size - rounded_size),
+ Context::Instance().GetCurrentStream());
+ }
+ }
+}
+
+void Adam_Optimizer::Step_4(float* _params,
+ float* grads,
+ float* _exp_avg,
+ float* _exp_avg_sq,
+ size_t _param_size,
+ __half* dev_params)
+{
+ size_t rounded_size = 0;
+
+#if defined(__AVX512__) or defined(__AVX256__)
+
+ AVX_Data betta1_4;
+ betta1_4.data = SIMD_SET(_betta1);
+ AVX_Data betta2_4;
+ betta2_4.data = SIMD_SET(_betta2);
+
+ float betta1_minus1 = 1 - _betta1;
+ float betta2_minus1 = 1 - _betta2;
+ AVX_Data betta1_minus1_4;
+ betta1_minus1_4.data = SIMD_SET(betta1_minus1);
+ AVX_Data betta2_minus1_4;
+ betta2_minus1_4.data = SIMD_SET(betta2_minus1);
+
+ AVX_Data bias2_sqrt;
+ bias2_sqrt.data = SIMD_SET(_bias_correction2);
+
+ AVX_Data eps_4;
+ eps_4.data = SIMD_SET(_eps);
+
+ float step_size = -1 * _alpha / _bias_correction1;
+ AVX_Data step_size_4;
+ step_size_4.data = SIMD_SET(step_size);
+
+ float w_decay = -1 * _alpha * _weight_decay;
+ AVX_Data weight_decay4;
+ if (_weight_decay > 0)
+ weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
+ rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2));
+
+ for (size_t t = 0; t < rounded_size; t += TILE) {
+ size_t copy_size = TILE;
+ if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
+ size_t offset = copy_size + t;
+#pragma omp parallel for
+ for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
+ AVX_Data grad_4[4];
+ grad_4[0].data = SIMD_LOAD(grads + i);
+ grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
+ grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
+ grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
+
+ AVX_Data momentum_4[4];
+ momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
+ momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
+ momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
+ momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
+
+ AVX_Data variance_4[4];
+ variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
+ variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
+ variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
+ variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
+
+ AVX_Data param_4[4];
+ param_4[0].data = SIMD_LOAD(_params + i);
+ param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
+ param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
+ param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
+
+ if (_weight_decay > 0 && !_adamw_mode) {
+ grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
+ grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
+ grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
+ grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
+ }
+
+ momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
+ momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
+ momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
+ momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
+ momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
+ momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
+ momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
+ momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
+
+ variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
+ variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
+ variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
+ variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
+ grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
+ grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
+ grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
+ grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
+ variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
+ variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
+ variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
+ variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
+
+ grad_4[0].data = SIMD_SQRT(variance_4[0].data);
+ grad_4[1].data = SIMD_SQRT(variance_4[1].data);
+ grad_4[2].data = SIMD_SQRT(variance_4[2].data);
+ grad_4[3].data = SIMD_SQRT(variance_4[3].data);
+
+ grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
+ grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
+ grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
+ grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
+ grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
+ grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
+ grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
+ grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
+
+ if (_weight_decay > 0 && _adamw_mode) {
+ param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data);
+ param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data);
+ param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data);
+ param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data);
+ }
+
+ param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
+ param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
+ param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
+ param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
+
+ SIMD_STORE(_params + i, param_4[0].data);
+ SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
+ SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
+ SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
+
+ if (dev_params) {
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
+ param_4[2].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
+ }
+
+ SIMD_STORE(_exp_avg + i, momentum_4[0].data);
+ SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
+ SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
+ SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
+
+ SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
+ SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
+ SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
+ SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
+ }
+
+ if (dev_params) {
+ launch_param_update(_doubled_buffer[_buf_index],
+ dev_params + t,
+ copy_size,
+ Context::Instance().GetCurrentStream());
+ _buf_index = !_buf_index;
+ }
+ }
+#endif
+ if (_param_size > rounded_size)
+ Step((_params + rounded_size),
+ (grads + rounded_size),
+ (_exp_avg + rounded_size),
+ (_exp_avg_sq + rounded_size),
+ (_param_size - rounded_size),
+ (dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
+}
+
+int create_adam_optimizer(int optimizer_id,
+ float alpha = 1e-3,
+ float betta1 = 0.9,
+ float betta2 = 0.999,
+ float eps = 1e-8,
+ float weight_decay = 0,
+ bool adamw_mode = true)
+{
+ auto opt =
+ std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
+
+ s_optimizers[optimizer_id] = opt;
+#if defined(__AVX512__)
+ std::cout << "Adam Optimizer #" << optimizer_id
+ << " is created with AVX512 arithmetic capability." << std::endl;
+ printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
+ alpha,
+ betta1,
+ betta2,
+ weight_decay,
+ (int)adamw_mode);
+#else
+#if defined(__AVX256__)
+ std::cout << "Adam Optimizer #" << optimizer_id
+ << " is created with AVX2 arithmetic capability." << std::endl;
+ printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
+ alpha,
+ betta1,
+ betta2,
+ weight_decay,
+ (int)adamw_mode);
+#else
+ std::cout << "Adam Optimizer #" << optimizer_id
+ << " is created with scalar arithmetic capability." << std::endl;
+ printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
+ alpha,
+ betta1,
+ betta2,
+ weight_decay,
+ (int)adamw_mode);
+#endif
+#endif
+ return 0;
+}
+
+void Adam_Optimizer::Step_8(float* _params,
+ float* grads,
+ float* _exp_avg,
+ float* _exp_avg_sq,
+ size_t _param_size,
+ __half* dev_params)
+{
+ size_t rounded_size = 0;
+
+#if defined(__AVX512__) or defined(__AVX256__)
+
+ AVX_Data betta1_4;
+ betta1_4.data = SIMD_SET(_betta1);
+ AVX_Data betta2_4;
+ betta2_4.data = SIMD_SET(_betta2);
+
+ float betta1_minus1 = 1 - _betta1;
+ float betta2_minus1 = 1 - _betta2;
+ AVX_Data betta1_minus1_4;
+ betta1_minus1_4.data = SIMD_SET(betta1_minus1);
+ AVX_Data betta2_minus1_4;
+ betta2_minus1_4.data = SIMD_SET(betta2_minus1);
+
+ AVX_Data bias2_sqrt;
+ bias2_sqrt.data = SIMD_SET(_bias_correction2);
+
+ AVX_Data eps_4;
+ eps_4.data = SIMD_SET(_eps);
+
+ float step_size = -1 * _alpha / _bias_correction1;
+ AVX_Data step_size_4;
+ step_size_4.data = SIMD_SET(step_size);
+
+ float w_decay = -1 * _alpha * _weight_decay;
+ AVX_Data weight_decay4;
+ if (_weight_decay > 0)
+ weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
+ rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3));
+
+ for (size_t t = 0; t < rounded_size; t += TILE) {
+ size_t copy_size = TILE;
+ if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
+ size_t offset = copy_size + t;
+#pragma omp parallel for
+ for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
+ AVX_Data grad_4[8];
+ grad_4[0].data = SIMD_LOAD(grads + i);
+ grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
+ grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
+ grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
+ grad_4[4].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 2));
+ grad_4[5].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 5);
+ grad_4[6].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 6);
+ grad_4[7].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 7);
+
+ AVX_Data momentum_4[8];
+ momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
+ momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
+ momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
+ momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
+ momentum_4[4].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 2));
+ momentum_4[5].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 5);
+ momentum_4[6].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 6);
+ momentum_4[7].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 7);
+
+ AVX_Data variance_4[8];
+ variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
+ variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
+ variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
+ variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
+ variance_4[4].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 2));
+ variance_4[5].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 5);
+ variance_4[6].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 6);
+ variance_4[7].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 7);
+
+ AVX_Data param_4[8];
+ param_4[0].data = SIMD_LOAD(_params + i);
+ param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
+ param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
+ param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
+ param_4[4].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 2));
+ param_4[5].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 5);
+ param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6);
+ param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);
+
+ if (_weight_decay > 0 && !_adamw_mode) {
+ grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
+ grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
+ grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
+ grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
+ grad_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, grad_4[4].data);
+ grad_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, grad_4[5].data);
+ grad_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, grad_4[6].data);
+ grad_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, grad_4[7].data);
+ }
+
+ momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
+ momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
+ momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
+ momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
+ momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
+ momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
+ momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
+ momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
+ momentum_4[4].data = SIMD_MUL(momentum_4[4].data, betta1_4.data);
+ momentum_4[4].data = SIMD_FMA(grad_4[4].data, betta1_minus1_4.data, momentum_4[4].data);
+ momentum_4[5].data = SIMD_MUL(momentum_4[5].data, betta1_4.data);
+ momentum_4[5].data = SIMD_FMA(grad_4[5].data, betta1_minus1_4.data, momentum_4[5].data);
+ momentum_4[6].data = SIMD_MUL(momentum_4[6].data, betta1_4.data);
+ momentum_4[6].data = SIMD_FMA(grad_4[6].data, betta1_minus1_4.data, momentum_4[6].data);
+ momentum_4[7].data = SIMD_MUL(momentum_4[7].data, betta1_4.data);
+ momentum_4[7].data = SIMD_FMA(grad_4[7].data, betta1_minus1_4.data, momentum_4[7].data);
+
+ variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
+ variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
+ variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
+ variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
+ variance_4[4].data = SIMD_MUL(variance_4[4].data, betta2_4.data);
+ variance_4[5].data = SIMD_MUL(variance_4[5].data, betta2_4.data);
+ variance_4[6].data = SIMD_MUL(variance_4[6].data, betta2_4.data);
+ variance_4[7].data = SIMD_MUL(variance_4[7].data, betta2_4.data);
+ grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
+ grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
+ grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
+ grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
+ grad_4[4].data = SIMD_MUL(grad_4[4].data, grad_4[4].data);
+ grad_4[5].data = SIMD_MUL(grad_4[5].data, grad_4[5].data);
+ grad_4[6].data = SIMD_MUL(grad_4[6].data, grad_4[6].data);
+ grad_4[7].data = SIMD_MUL(grad_4[7].data, grad_4[7].data);
+ variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
+ variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
+ variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
+ variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
+ variance_4[4].data = SIMD_FMA(grad_4[4].data, betta2_minus1_4.data, variance_4[4].data);
+ variance_4[5].data = SIMD_FMA(grad_4[5].data, betta2_minus1_4.data, variance_4[5].data);
+ variance_4[6].data = SIMD_FMA(grad_4[6].data, betta2_minus1_4.data, variance_4[6].data);
+ variance_4[7].data = SIMD_FMA(grad_4[7].data, betta2_minus1_4.data, variance_4[7].data);
+
+ grad_4[0].data = SIMD_SQRT(variance_4[0].data);
+ grad_4[1].data = SIMD_SQRT(variance_4[1].data);
+ grad_4[2].data = SIMD_SQRT(variance_4[2].data);
+ grad_4[3].data = SIMD_SQRT(variance_4[3].data);
+ grad_4[4].data = SIMD_SQRT(variance_4[4].data);
+ grad_4[5].data = SIMD_SQRT(variance_4[5].data);
+ grad_4[6].data = SIMD_SQRT(variance_4[6].data);
+ grad_4[7].data = SIMD_SQRT(variance_4[7].data);
+
+ grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
+ grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
+ grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
+ grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
+ grad_4[4].data = SIMD_FMA(grad_4[4].data, bias2_sqrt.data, eps_4.data);
+ grad_4[5].data = SIMD_FMA(grad_4[5].data, bias2_sqrt.data, eps_4.data);
+ grad_4[6].data = SIMD_FMA(grad_4[6].data, bias2_sqrt.data, eps_4.data);
+ grad_4[7].data = SIMD_FMA(grad_4[7].data, bias2_sqrt.data, eps_4.data);
+ grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
+ grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
+ grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
+ grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
+ grad_4[4].data = SIMD_DIV(momentum_4[4].data, grad_4[4].data);
+ grad_4[5].data = SIMD_DIV(momentum_4[5].data, grad_4[5].data);
+ grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data);
+ grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data);
+
+ if (_weight_decay > 0 && _adamw_mode) {
+ param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data);
+ param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data);
+ param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data);
+ param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data);
+ param_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, param_4[4].data);
+ param_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, param_4[5].data);
+ param_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, param_4[6].data);
+ param_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, param_4[7].data);
+ }
+
+ param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
+ param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
+ param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
+ param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
+ param_4[4].data = SIMD_FMA(grad_4[4].data, step_size_4.data, param_4[4].data);
+ param_4[5].data = SIMD_FMA(grad_4[5].data, step_size_4.data, param_4[5].data);
+ param_4[6].data = SIMD_FMA(grad_4[6].data, step_size_4.data, param_4[6].data);
+ param_4[7].data = SIMD_FMA(grad_4[7].data, step_size_4.data, param_4[7].data);
+
+ SIMD_STORE(_params + i, param_4[0].data);
+ SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
+ SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
+ SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
+ SIMD_STORE(_params + i + (SIMD_WIDTH << 2), param_4[4].data);
+ SIMD_STORE(_params + i + SIMD_WIDTH * 5, param_4[5].data);
+ SIMD_STORE(_params + i + SIMD_WIDTH * 6, param_4[6].data);
+ SIMD_STORE(_params + i + SIMD_WIDTH * 7, param_4[7].data);
+
+ if (dev_params) {
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
+ param_4[2].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2),
+ param_4[4].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6].data);
+ SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7].data);
+ }
+
+ SIMD_STORE(_exp_avg + i, momentum_4[0].data);
+ SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
+ SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
+ SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
+ SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 2), momentum_4[4].data);
+ SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 5, momentum_4[5].data);
+ SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 6, momentum_4[6].data);
+ SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 7, momentum_4[7].data);
+
+ SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
+ SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
+ SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
+ SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
+ SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 2), variance_4[4].data);
+ SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 5, variance_4[5].data);
+ SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 6, variance_4[6].data);
+ SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
+ }
+ if (dev_params) {
+ launch_param_update(_doubled_buffer[_buf_index],
+ dev_params + t,
+ copy_size,
+ Context::Instance().GetCurrentStream());
+ _buf_index = !_buf_index;
+ }
+ }
+#endif
+ if (_param_size > rounded_size)
+ Step_4((_params + rounded_size),
+ (grads + rounded_size),
+ (_exp_avg + rounded_size),
+ (_exp_avg_sq + rounded_size),
+ (_param_size - rounded_size),
+ (dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
+}
+
+int ds_adam_step(int optimizer_id,
+ size_t step,
+ float lr,
+ float beta1,
+ float beta2,
+ float epsilon,
+ float weight_decay,
+ bool bias_correction,
+ torch::Tensor& params,
+ torch::Tensor& grads,
+ torch::Tensor& exp_avg,
+ torch::Tensor& exp_avg_sq)
+{
+ auto params_c = params.contiguous();
+ auto grads_c = grads.contiguous();
+ auto exp_avg_c = exp_avg.contiguous();
+ auto exp_avg_sq_c = exp_avg_sq.contiguous();
+
+ float* params_ptr = (float*)params_c.data_ptr();
+ float* grads_ptr = (float*)grads_c.data_ptr();
+ float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
+ float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
+
+ std::shared_ptr opt =
+ std::static_pointer_cast(s_optimizers[optimizer_id]);
+ opt->IncrementStep(step, beta1, beta2);
+ opt->update_state(lr, epsilon, weight_decay, bias_correction);
+ opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
+
+ return 0;
+}
+
+int ds_adam_step_plus_copy(int optimizer_id,
+ size_t step,
+ float lr,
+ float beta1,
+ float beta2,
+ float epsilon,
+ float weight_decay,
+ bool bias_correction,
+ torch::Tensor& params,
+ torch::Tensor& grads,
+ torch::Tensor& exp_avg,
+ torch::Tensor& exp_avg_sq,
+ torch::Tensor& gpu_params)
+{
+ auto params_c = params.contiguous();
+ auto gpu_params_c = gpu_params.contiguous();
+ auto exp_avg_c = exp_avg.contiguous();
+ auto exp_avg_sq_c = exp_avg_sq.contiguous();
+ auto grads_c = grads.contiguous();
+
+ float* params_ptr = (float*)params_c.data_ptr();
+ float* grads_ptr = (float*)grads_c.data_ptr();
+ __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
+ float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
+ float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
+
+ std::shared_ptr opt =
+ std::static_pointer_cast(s_optimizers[optimizer_id]);
+ opt->IncrementStep(step, beta1, beta2);
+ opt->update_state(lr, epsilon, weight_decay, bias_correction);
+ opt->Step_8(
+ params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
+
+ return 0;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
+ m.def("adam_update_copy",
+ &ds_adam_step_plus_copy,
+ "DeepSpeed CPU Adam update and param copy (C++)");
+ m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
+}
diff --git a/csrc/adam/custom_cuda_kernel.cu b/csrc/adam/custom_cuda_kernel.cu
new file mode 100755
index 000000000000..2f282aff1aca
--- /dev/null
+++ b/csrc/adam/custom_cuda_kernel.cu
@@ -0,0 +1,20 @@
+
+
+#include "custom_cuda_layers.h"
+
+__global__ void param_update_kernel(const float* input, __half* output, int size)
+{
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (id < size) { output[id] = (__half)input[id]; }
+}
+
+void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
+{
+ int threads = 1024;
+
+ dim3 grid_dim((size - 1) / threads + 1);
+ dim3 block_dim(threads);
+
+ param_update_kernel<<>>(input, output, size);
+}
diff --git a/csrc/adam/fused_adam_frontend.cpp b/csrc/adam/fused_adam_frontend.cpp
new file mode 100644
index 000000000000..b06531c53002
--- /dev/null
+++ b/csrc/adam/fused_adam_frontend.cpp
@@ -0,0 +1,20 @@
+#include
+
+void multi_tensor_adam_cuda(int chunk_size,
+ at::Tensor noop_flag,
+ std::vector> tensor_lists,
+ const float lr,
+ const float beta1,
+ const float beta2,
+ const float epsilon,
+ const int step,
+ const int mode,
+ const int bias_correction,
+ const float weight_decay);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("multi_tensor_adam",
+ &multi_tensor_adam_cuda,
+ "Compute and apply gradient update to parameters for Adam optimizer");
+}
diff --git a/csrc/adam/multi_tensor_adam.cu b/csrc/adam/multi_tensor_adam.cu
new file mode 100644
index 000000000000..3cb9763befce
--- /dev/null
+++ b/csrc/adam/multi_tensor_adam.cu
@@ -0,0 +1,163 @@
+/* Copyright 2020 The Microsoft DeepSpeed Team
+ Copyright NVIDIA/apex
+ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
+*/
+
+#include
+#include
+#include
+#include
+// Another possibility:
+// #include
+
+#include
+
+#include "multi_tensor_apply.cuh"
+#include "type_shim.h"
+
+#define BLOCK_SIZE 512
+#define ILP 4
+
+typedef enum {
+ ADAM_MODE_0 = 0, // L2 regularization mode
+ ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
+} adamMode_t;
+
+using MATH_T = float;
+
+template
+struct AdamFunctor {
+ __device__ __forceinline__ void operator()(int chunk_size,
+ volatile int* noop_gmem,
+ TensorListMetadata<4>& tl,
+ const float beta1,
+ const float beta2,
+ const float beta1_correction,
+ const float beta2_correction,
+ const float epsilon,
+ const float lr,
+ adamMode_t mode,
+ const float decay)
+ {
+ // I'd like this kernel to propagate infs/nans.
+ // if(*noop_gmem == 1)
+ // return;
+
+ int tensor_loc = tl.block_to_tensor[blockIdx.x];
+
+ // potentially use to pass in list of scalar
+ // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
+
+ int chunk_idx = tl.block_to_chunk[blockIdx.x];
+ int n = tl.sizes[tensor_loc];
+
+ T* g = (T*)tl.addresses[0][tensor_loc];
+ g += chunk_idx * chunk_size;
+
+ T* p = (T*)tl.addresses[1][tensor_loc];
+ p += chunk_idx * chunk_size;
+
+ T* m = (T*)tl.addresses[2][tensor_loc];
+ m += chunk_idx * chunk_size;
+
+ T* v = (T*)tl.addresses[3][tensor_loc];
+ v += chunk_idx * chunk_size;
+
+ n -= chunk_idx * chunk_size;
+
+ // see note in multi_tensor_scale_kernel.cu
+ for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
+ MATH_T r_g[ILP];
+ MATH_T r_p[ILP];
+ MATH_T r_m[ILP];
+ MATH_T r_v[ILP];
+#pragma unroll
+ for (int ii = 0; ii < ILP; ii++) {
+ int i = i_start + threadIdx.x + ii * blockDim.x;
+ if (i < n && i < chunk_size) {
+ r_g[ii] = g[i];
+ r_p[ii] = p[i];
+ r_m[ii] = m[i];
+ r_v[ii] = v[i];
+ } else {
+ r_g[ii] = MATH_T(0);
+ r_p[ii] = MATH_T(0);
+ r_m[ii] = MATH_T(0);
+ r_v[ii] = MATH_T(0);
+ }
+ }
+#pragma unroll
+ for (int ii = 0; ii < ILP; ii++) {
+ if (mode == ADAM_MODE_0) { // L2
+ r_g[ii] = r_g[ii] + (decay * r_p[ii]);
+ r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
+ r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
+ MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
+ MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
+ MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
+ MATH_T update = next_m_unbiased / denom;
+ r_p[ii] = r_p[ii] - (lr * update);
+ } else { // weight decay
+ r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
+ r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
+ MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
+ MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
+ MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
+ MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
+ r_p[ii] = r_p[ii] - (lr * update);
+ }
+ }
+#pragma unroll
+ for (int ii = 0; ii < ILP; ii++) {
+ int i = i_start + threadIdx.x + ii * blockDim.x;
+ if (i < n && i < chunk_size) {
+ p[i] = r_p[ii];
+ m[i] = r_m[ii];
+ v[i] = r_v[ii];
+ }
+ }
+ }
+ }
+};
+
+void multi_tensor_adam_cuda(int chunk_size,
+ at::Tensor noop_flag,
+ std::vector> tensor_lists,
+ const float lr,
+ const float beta1,
+ const float beta2,
+ const float epsilon,
+ const int step,
+ const int mode,
+ const int bias_correction,
+ const float weight_decay)
+{
+ using namespace at;
+
+ // Handle bias correction mode
+ float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
+ if (bias_correction == 1) {
+ bias_correction1 = 1 - std::pow(beta1, step);
+ bias_correction2 = 1 - std::pow(beta2, step);
+ }
+
+ // Assume single type across p,g,m1,m2 now
+ DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
+ 0,
+ "adam",
+ multi_tensor_apply<4>(BLOCK_SIZE,
+ chunk_size,
+ noop_flag,
+ tensor_lists,
+ AdamFunctor(),
+ beta1,
+ beta2,
+ bias_correction1,
+ bias_correction2,
+ epsilon,
+ lr,
+ (adamMode_t)mode,
+ weight_decay);)
+
+ AT_CUDA_CHECK(cudaGetLastError());
+}
diff --git a/csrc/adam/multi_tensor_apply.cuh b/csrc/adam/multi_tensor_apply.cuh
new file mode 100644
index 000000000000..13af4b7578f6
--- /dev/null
+++ b/csrc/adam/multi_tensor_apply.cuh
@@ -0,0 +1,127 @@
+/* Copyright 2020 The Microsoft DeepSpeed Team
+ Copyright NVIDIA/apex
+ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
+*/
+
+#include
+#include
+#include
+#include
+#include
+#include "compat.h"
+
+#include
+
+// #include
+
+// This header is the one-stop shop for all your multi-tensor apply needs.
+
+// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
+constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
+constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
+
+template
+struct TensorListMetadata {
+ void* addresses[n][depth_to_max_tensors[n - 1]];
+ int sizes[depth_to_max_tensors[n - 1]];
+ unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
+ int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
+ int start_tensor_this_launch;
+};
+
+template
+__global__ void multi_tensor_apply_kernel(int chunk_size,
+ volatile int* noop_flag,
+ T tl,
+ U callable,
+ ArgTypes... args)
+{
+ // Hand the chunk information to the user-supplied functor to process however it likes.
+ callable(chunk_size, noop_flag, tl, args...);
+}
+
+template
+void multi_tensor_apply(int block_size,
+ int chunk_size,
+ const at::Tensor& noop_flag,
+ const std::vector>& tensor_lists,
+ T callable,
+ ArgTypes... args)
+{
+ TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
+ int len0 = tensor_lists[0].size();
+ TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
+ auto ref_device = tensor_lists[0][0].device();
+ TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
+ for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
+ {
+ TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
+ for (int t = 0; t < tensor_lists[l].size(); t++) {
+ // TODO: Print which tensor fails.
+ bool contiguous_memory = tensor_lists[l][t].is_contiguous();
+#ifdef VERSION_GE_1_5
+ contiguous_memory = (contiguous_memory ||
+ tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
+#endif
+ TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
+ TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
+ "A tensor was not on the same device as the first tensor");
+ TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
+ }
+ }
+
+ int ntensors = tensor_lists[0].size();
+
+ TensorListMetadata tl;
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ tl.start_tensor_this_launch = 0;
+ int loc_block_info = 0;
+ int loc_tensor_info = 0;
+ for (int t = 0; t < ntensors; t++) {
+ tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
+ for (int d = 0; d < depth; d++)
+ tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
+ loc_tensor_info++;
+
+ int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
+
+ for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
+ // std::cout << chunks_this_tensor << std::endl;
+ tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
+ tl.block_to_chunk[loc_block_info] = chunk;
+ loc_block_info++;
+
+ bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
+ chunk == chunks_this_tensor - 1);
+ bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
+ bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
+ if (tensors_full || blocks_full || last_chunk) {
+ // using accscalar_t = acc_type;
+ multi_tensor_apply_kernel<<>>(
+ chunk_size, noop_flag.DATA_PTR(), tl, callable, args...);
+
+ AT_CUDA_CHECK(cudaGetLastError());
+
+ // Reset. The control flow possibilities here make my brain hurt.
+ loc_block_info = 0;
+ if (chunk == chunks_this_tensor - 1) {
+ // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
+ // std::endl;
+ loc_tensor_info = 0;
+ tl.start_tensor_this_launch = t + 1;
+ } else {
+ // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
+ // std::endl;
+ tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
+ for (int d = 0; d < depth; d++)
+ tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
+ loc_tensor_info = 1;
+ tl.start_tensor_this_launch = t;
+ }
+ }
+ }
+ }
+}
diff --git a/csrc/includes/context.h b/csrc/includes/context.h
old mode 100644
new mode 100755
index 1e4820177c5d..c2e26cdfa708
--- a/csrc/includes/context.h
+++ b/csrc/includes/context.h
@@ -29,12 +29,12 @@
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
-#define DS_MAXIMUM_NUM_BLOCKS 4096
+#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
- return std::max(
- std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
+ return (std::max)(
+ (std::min)((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since CUDA does not allow empty block
1);
}
@@ -64,17 +64,10 @@ class Context {
return _ctx;
}
- void GenWorkSpace(size_t size)
+ void SetWorkSpace(void* workspace)
{
- if (!_workspace) {
- assert(_workspace == nullptr);
- cudaMalloc(&_workspace, size);
- } else if (_workSpaceSize != size) {
- cudaFree(_workspace);
- cudaMalloc(&_workspace, size);
- }
-
- _workSpaceSize = size;
+ if (!workspace) { throw std::runtime_error("Workspace is null."); }
+ _workspace = workspace;
}
void* GetWorkSpace() { return _workspace; }
@@ -172,6 +165,5 @@ class Context {
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
- size_t _workSpaceSize;
std::vector> _gemm_algos;
};
diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h
new file mode 100755
index 000000000000..0f45409186c1
--- /dev/null
+++ b/csrc/includes/cpu_adam.h
@@ -0,0 +1,155 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include "context.h"
+#include "cublas_v2.h"
+#include "cuda.h"
+#include "curand.h"
+
+#define CUDA_CHECK(callstr) \
+ { \
+ cudaError_t error_code = callstr; \
+ if (error_code != cudaSuccess) { \
+ std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
+ assert(0); \
+ } \
+ }
+
+#define TILE (1024 * 1024 * 1024)
+
+#if defined(__AVX512__)
+#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
+#define SIMD_LOAD(x) _mm512_loadu_ps(x)
+#define SIMD_SET(x) _mm512_set1_ps(x)
+#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
+#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
+#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
+#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
+#define SIMD_WIDTH 16
+#else
+#if defined(__AVX256__)
+#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
+#define SIMD_LOAD(x) _mm256_loadu_ps(x)
+#define SIMD_SET(x) _mm256_set1_ps(x)
+#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
+#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
+#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
+#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
+#define SIMD_WIDTH 8
+#endif
+#endif
+
+class Adam_Optimizer {
+public:
+ Adam_Optimizer(float alpha = 1e-3,
+ float betta1 = 0.9,
+ float betta2 = 0.999,
+ float eps = 1e-8,
+ float weight_decay = 0,
+ bool adamw_mode = true)
+ : _alpha(alpha),
+ _betta1(betta1),
+ _betta2(betta2),
+ _eps(eps),
+ _weight_decay(weight_decay),
+ _betta1_t(1.0),
+ _betta2_t(1.0),
+ _step(0),
+ _buf_index(false),
+ _adamw_mode(adamw_mode)
+ {
+ cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
+ cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
+ }
+ ~Adam_Optimizer()
+ {
+ cudaFreeHost(_doubled_buffer[0]);
+ cudaFreeHost(_doubled_buffer[1]);
+ }
+ void Step(float* _params,
+ float* grads,
+ float* _exp_avg,
+ float* _exp_avg_sq,
+ size_t param_size,
+ __half* dev_param = nullptr);
+ void Step_4(float* _params,
+ float* grads,
+ float* _exp_avg,
+ float* _exp_avg_sa,
+ size_t param_size,
+ __half* dev_param = nullptr);
+ void Step_8(float* _params,
+ float* grads,
+ float* _exp_avg,
+ float* _exp_avg_sq,
+ size_t _param_size,
+ __half* dev_params = nullptr);
+
+ inline void IncrementStep(size_t step, float beta1, float beta2)
+ {
+ if (beta1 != _betta1 || beta2 != _betta2) {
+ _step = step;
+ _betta1 = beta1;
+ _betta2 = beta2;
+ _betta1_t = std::pow(_betta1, step);
+ _betta2_t = std::pow(_betta2, step);
+ } else {
+ _step++;
+ if (_step != step) {
+ _betta1_t = std::pow(_betta1, step);
+ _betta2_t = std::pow(_betta2, step);
+ _step = step;
+ } else {
+ _betta1_t *= _betta1;
+ _betta2_t *= _betta2;
+ }
+ }
+ }
+ inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
+ {
+ _alpha = lr;
+ _eps = epsilon;
+ _weight_decay = weight_decay;
+
+ _bias_correction1 = 1.0f;
+ _bias_correction2 = 1.0f;
+ if (bias_correction == 1) {
+ _bias_correction1 = 1 - _betta1_t;
+ _bias_correction2 = 1 / sqrt(1 - _betta2_t);
+ }
+ }
+
+private:
+#if defined(__AVX512__) or defined(__AVX256__)
+ union AVX_Data {
+#if defined(__AVX512__)
+ __m512 data;
+#else
+ __m256 data;
+#endif
+ // float data_f[16];
+ };
+#endif
+
+ float _alpha;
+ float _betta1;
+ float _betta2;
+ float _eps;
+ float _weight_decay;
+
+ float _betta1_t;
+ float _betta2_t;
+ size_t _step;
+
+ float _bias_correction1;
+ float _bias_correction2;
+
+ float* _doubled_buffer[2];
+ bool _buf_index;
+ bool _adamw_mode;
+};
diff --git a/csrc/includes/custom_cuda_layers.h b/csrc/includes/custom_cuda_layers.h
index d9d0a8a68b38..88b51161d8d6 100755
--- a/csrc/includes/custom_cuda_layers.h
+++ b/csrc/includes/custom_cuda_layers.h
@@ -22,6 +22,8 @@
#define MAX_THREAD_ITERATIONS 8 // Maximum 8K
#define MAX_WARP_NUM 32
+#define MAX_REGISTERS 256
+
// Fused bias add with gelu activation
template
void launch_bias_gelu(const T* input,
@@ -29,7 +31,6 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
- int sequence_length,
cudaStream_t stream);
template
@@ -37,7 +38,6 @@ void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
- int sequence_length,
cudaStream_t stream);
template
@@ -46,7 +46,6 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
- int sequence_length,
cudaStream_t stream);
// Custom fused bias add with layer normalization
@@ -57,14 +56,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
- bool training = false,
- T* vars = nullptr,
- T* means = nullptr,
- T* vals_hat = nullptr);
+ bool training,
+ T* vars,
+ T* means);
template
void launch_bias_residual_layer_norm(T* vals,
@@ -73,14 +70,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
- bool training = false,
- T* vars = nullptr,
- T* vals_hat = nullptr,
- bool save_vals = false);
+ bool training,
+ T* vars);
template
void launch_layerNorm_backward_fused_add(const T* out_grad1,
@@ -93,7 +87,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
template
@@ -106,7 +99,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
@@ -122,7 +114,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
@@ -135,7 +126,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
@@ -153,7 +143,6 @@ void launch_layerNorm_backward_nreversible(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
@@ -264,3 +253,5 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
int rows,
int cols,
cudaStream_t stream);
+
+void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
diff --git a/csrc/includes/dropout.h b/csrc/includes/dropout.h
index 090df3a0abf8..f6e32af5608d 100644
--- a/csrc/includes/dropout.h
+++ b/csrc/includes/dropout.h
@@ -9,15 +9,13 @@ class Dropout {
public:
struct Config {
float ratio;
- uint32_t batch, dim;
+ uint32_t dim;
bool training;
- Config(float r, uint32_t batch, uint32_t dim)
- : ratio(r), batch(batch), dim(dim), training(true)
- {
- }
+ Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {}
float RATIO() const { return training ? ratio : 0.0; }
+ inline void SetDim(uint32_t d) { dim = d; }
};
Dropout(const Config& config) : _config(config), _mask(nullptr) {}
@@ -70,6 +68,8 @@ class Dropout {
Config GetConfig() const { return _config; }
+ inline void SetDimension(uint32_t dim) { _config.SetDim(dim); }
+
private:
uint8_t* _mask;
Config _config;
diff --git a/csrc/includes/ds_transformer_cuda.h b/csrc/includes/ds_transformer_cuda.h
index 896dce8c26db..dbae797a8ecd 100755
--- a/csrc/includes/ds_transformer_cuda.h
+++ b/csrc/includes/ds_transformer_cuda.h
@@ -121,13 +121,22 @@ class BertTransformerLayer {
void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
- uint8_t* layer_output_dropout_mask_ptr);
+ uint8_t* layer_output_dropout_mask_ptr,
+ T* layer_norm_var,
+ T* layer_norm_mean,
+ T* attn_layer_norm_var,
+ T* attn_layer_norm_mean);
inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
+ inline int GetIntermediateSize() const { return _intermediate_size; }
+
+ void SetSeqLength(int seq_len);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
+ inline bool IsTrainingMode() const { return _training; }
+ inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
private:
void Initialize();
@@ -150,8 +159,8 @@ class BertTransformerLayer {
// layers
FeedForward _qkv_linear;
FeedForward _attn_out_linear;
- Normalize_Layer _norm_layer2;
- Normalize_Layer _norm_layer3;
+ Normalize_Layer _attn_layer_norm;
+ Normalize_Layer _layer_norm;
Normalize_Layer* _last_normalize;
FeedForward _ff1, _ff2;
Softmax _softmax;
diff --git a/csrc/includes/gelu.h b/csrc/includes/gelu.h
index 247bfb273de0..41cf6f2a68a7 100644
--- a/csrc/includes/gelu.h
+++ b/csrc/includes/gelu.h
@@ -9,13 +9,8 @@ template
class Gelu {
public:
struct Config {
- uint32_t batch_size;
- uint32_t seq_length;
uint32_t intermediate_size;
- Config(uint32_t batch, uint32_t seq, uint32_t inter_size)
- : batch_size(batch), seq_length(seq), intermediate_size(inter_size)
- {
- }
+ Config(uint32_t inter_size) : intermediate_size(inter_size) {}
};
Gelu(const Config& config) : _config(config) {}
@@ -28,14 +23,12 @@ class Gelu {
T* output,
cudaStream_t stream)
{
- launch_bias_gelu(
- input_buf, bias, output, _config.intermediate_size, bsz, _config.seq_length, stream);
+ launch_bias_gelu(input_buf, bias, output, _config.intermediate_size, bsz, stream);
}
void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream)
{
- launch_d_gelu(
- d_output, input_buf, bias, _config.intermediate_size, bsz, _config.seq_length, stream);
+ launch_d_gelu(d_output, input_buf, bias, _config.intermediate_size, bsz, stream);
}
private:
diff --git a/csrc/includes/gemm_test.h b/csrc/includes/gemm_test.h
index ff06f884351c..b920896b419e 100644
--- a/csrc/includes/gemm_test.h
+++ b/csrc/includes/gemm_test.h
@@ -97,7 +97,7 @@ class GemmTest {
template
int Run(int loops, Func f)
{
- float fast_latency = std::numeric_limits::max();
+ float fast_latency = (std::numeric_limits::max)();
int fast_algo = 0;
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
@@ -252,7 +252,7 @@ class StridedGemmTest {
template
int Run(int loops, Func f)
{
- float fast_latency = std::numeric_limits::max();
+ float fast_latency = (std::numeric_limits::max)();
int fast_algo = 0;
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
diff --git a/csrc/includes/normalize_layer.h b/csrc/includes/normalize_layer.h
index 37ee752c88b5..bfe84636ddb9 100644
--- a/csrc/includes/normalize_layer.h
+++ b/csrc/includes/normalize_layer.h
@@ -16,57 +16,27 @@ class Normalize_Layer {
uint32_t seqLength;
uint32_t hiddenDim;
float epsilon;
- bool training, save_vals;
- bool allocateGrad;
+ bool training;
bool useMean;
- Config(uint32_t batch,
- uint32_t seq,
- uint32_t h,
- bool training,
- bool save_vals = true,
- bool allocateGrad = true,
- bool useMean = true)
+ Config(uint32_t batch, uint32_t seq, uint32_t h, bool training, bool useMean = true)
: batchSize(batch),
seqLength(seq),
hiddenDim(h),
epsilon(1e-12),
training(training),
- save_vals(save_vals),
- allocateGrad(allocateGrad),
useMean(useMean)
{
}
};
- Normalize_Layer(Config config) : config_(config), vars(nullptr), vals_hat(nullptr)
+ Normalize_Layer(Config config)
+ : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
{
- if (config_.training) {
- cudaMalloc((void**)&vars, config_.batchSize * config_.seqLength * sizeof(T));
-
- if (config_.useMean)
- cudaMalloc((void**)&means, config_.batchSize * config_.seqLength * sizeof(T));
-
- if (config_.save_vals)
- cudaMalloc((void**)&vals_hat,
- config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
-
- if (config_.allocateGrad)
- cudaMalloc((void**)&inp_grad,
- config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
- }
}
- ~Normalize_Layer()
- {
- if (config_.training) {
- cudaFree(vars);
- if (config_.useMean) cudaFree(means);
- if (config_.save_vals) cudaFree(vals_hat);
- if (config_.allocateGrad) cudaFree(inp_grad);
- }
- }
+ ~Normalize_Layer() {}
- void ForwardCheckpoint(int bsz,
+ void ForwardCheckpoint(int bsz, // batch * seq
T* vals,
const T* residual,
const T* gamma,
@@ -80,14 +50,12 @@ class Normalize_Layer {
betta,
config_.epsilon,
bsz,
- config_.seqLength,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
- means,
- vals_hat);
+ means);
}
void Forward(int bsz,
@@ -104,14 +72,11 @@ class Normalize_Layer {
betta,
config_.epsilon,
bsz,
- config_.seqLength,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
- vars,
- vals_hat,
- config_.save_vals);
+ vars);
}
void Backward(int bsz,
@@ -120,7 +85,7 @@ class Normalize_Layer {
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
- T* inp_grad_out = nullptr,
+ T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward(out_grad,
@@ -130,9 +95,8 @@ class Normalize_Layer {
gamma,
gamma_grad,
betta_grad,
- (config_.allocateGrad ? inp_grad : inp_grad_out),
+ inp_grad_out,
bsz,
- config_.seqLength,
config_.hiddenDim,
stream);
}
@@ -144,21 +108,20 @@ class Normalize_Layer {
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
- T* inp_grad_out = nullptr,
- const T* norm_out = nullptr)
+ T* inp_grad_out,
+ const T* norm_out)
{
launch_layerNorm_backward(out_grad,
- (config_.save_vals ? vals_hat : norm_out),
+ norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
- (config_.allocateGrad ? inp_grad : inp_grad_out),
+ inp_grad_out,
bsz,
- config_.seqLength,
config_.hiddenDim,
stream,
- config_.save_vals,
+ !config_.useMean,
betta);
}
@@ -169,7 +132,7 @@ class Normalize_Layer {
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
- T* inp_grad_out = nullptr,
+ T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward_fused_add(out_grad1,
@@ -180,9 +143,8 @@ class Normalize_Layer {
gamma,
gamma_grad,
betta_grad,
- (config_.allocateGrad ? inp_grad : inp_grad_out),
+ inp_grad_out,
bsz,
- config_.seqLength,
config_.hiddenDim,
stream);
}
@@ -195,33 +157,41 @@ class Normalize_Layer {
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
- T* inp_grad_out = nullptr,
- const T* norm_out = nullptr)
+ T* inp_grad_out,
+ const T* norm_out)
{
launch_layerNorm_backward_fused_add(out_grad1,
out_grad2,
- (config_.save_vals ? vals_hat : norm_out),
+ norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
- (config_.allocateGrad ? inp_grad : inp_grad_out),
+ inp_grad_out,
bsz,
- config_.seqLength,
config_.hiddenDim,
stream,
- config_.save_vals,
+ !config_.useMean,
betta);
}
- inline T* GetInputGrad() const { return inp_grad; }
-
inline bool UseMean() const { return config_.useMean; }
+ inline void SetVar(T* variance)
+ {
+ if (!variance) { throw std::runtime_error("Normalize variance is null."); }
+ vars = variance;
+ }
+
+ inline void SetMean(T* mean)
+ {
+ if (!mean) { throw std::runtime_error("Normalize mean is null."); }
+ means = mean;
+ }
+
private:
Config config_;
T* vars;
T* means;
T* vals_hat;
- T* inp_grad;
};
diff --git a/csrc/includes/softmax.h b/csrc/includes/softmax.h
old mode 100644
new mode 100755
index 2a18daee0b78..2bc2f67059cf
--- a/csrc/includes/softmax.h
+++ b/csrc/includes/softmax.h
@@ -45,13 +45,15 @@ class Softmax {
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
}
- inline int GetProbDepth() const { return config_.prob_depth; }
+ inline size_t GetProbDepth() const { return config_.prob_depth; }
- inline int GetBatchSize() const { return config_.batchSize; }
+ inline size_t GetBatchSize() const { return config_.batchSize; }
- inline int GetNumHeads() const { return config_.heads; }
+ inline size_t GetNumHeads() const { return config_.heads; }
- inline int GetSeqLength() const { return config_.seq_length; }
+ inline size_t GetSeqLength() const { return config_.seq_length; }
+
+ inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private:
Config config_;
diff --git a/csrc/includes/strided_batch_gemm.h b/csrc/includes/strided_batch_gemm.h
index 8c43608e2ecf..44a1b313b986 100644
--- a/csrc/includes/strided_batch_gemm.h
+++ b/csrc/includes/strided_batch_gemm.h
@@ -3,6 +3,7 @@
#include
#include
#include
+#include "context.h"
template
class StridedBatchGemm {
@@ -38,6 +39,12 @@ class StridedBatchGemm {
gemm_algos(algos)
{
}
+ void SetConfig(int mm, int nn, int kk)
+ {
+ m = mm;
+ n = nn;
+ k = kk;
+ }
};
StridedBatchGemm(const Config& config) : _config(config) {}
@@ -163,6 +170,8 @@ class StridedBatchGemm {
inline const T* GetBufferB() const { return q_buf; }
+ inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
+
private:
Config _config;
const T* q_buf;
diff --git a/csrc/sparse_attention/utils.cpp b/csrc/sparse_attention/utils.cpp
new file mode 100644
index 000000000000..a802025e92ed
--- /dev/null
+++ b/csrc/sparse_attention/utils.cpp
@@ -0,0 +1,120 @@
+// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
+// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
+
+#include
+#include
+#include
+#include
+#ifdef _OPENMP
+#include
+#endif
+
+typedef std::vector> ret_t;
+
+void segment_blocks(torch::Tensor layout,
+ torch::Tensor idx,
+ torch::Tensor scratch,
+ int max_width,
+ ret_t& ret)
+{
+ size_t H = layout.size(0);
+ size_t M = layout.size(1);
+ size_t N = layout.size(2);
+ torch::Tensor tmp = torch::zeros_like(layout);
+
+ auto _tmp = tmp.accessor();
+ auto _layout = layout.accessor();
+ auto _idx = idx.accessor();
+ auto _scratch = scratch.accessor();
+ std::vector current(H, 0);
+
+#ifdef _OPENMP
+#pragma omp parallel for
+#endif
+ for (size_t h = 0; h < H; h++) {
+ // surrounding indices
+ std::vector ii_left(max_width, -1);
+ std::vector> ii_top(max_width, std::vector(N, -1));
+
+ for (size_t m = 0; m < M; m++) {
+ for (size_t n = 0; n < N; n++) {
+ int v = _layout[h][m][n];
+ if (v == 0) continue;
+ int n_left = ii_left[max_width - 1];
+ int m_top = ii_top[max_width - 1][n];
+ int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
+ int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
+ int topleft = (m_top >= 0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
+ int width = std::min(left, std::min(top, topleft)) + 1;
+
+ // reset width if blocks cannot be
+ // packed together (i.e., there's a 1 "in the middle")
+ for (int nn = n_left + 1; nn < n; nn++)
+ if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n]) width = 1;
+ _tmp[h][m][n] = width;
+
+ // update n_left ring buffer
+ for (int k = 0; k < max_width - 1; k++) ii_left[k] = ii_left[k + 1];
+ ii_left[max_width - 1] = n;
+
+ // update ii_top ring buffer
+ for (int k = 0; k < max_width - 1; k++) ii_top[k][n] = ii_top[k + 1][n];
+ ii_top[max_width - 1][n] = m;
+
+ // block is too small -- skip
+ if (width != max_width) continue;
+
+ // retained blocks are set to zeros
+ for (size_t km = 0; km < max_width; km++)
+ for (size_t kn = 0; kn < max_width; kn++) {
+ int mm = ii_top[km][n];
+ int nn = ii_left[kn];
+ if (mm < 0 || nn < 0) continue;
+ _layout[h][mm][nn] = 0;
+ _tmp[h][mm][nn] = 0;
+ _scratch[h][current[h]][0] = (int)h;
+ _scratch[h][current[h]][1] = (int)mm;
+ _scratch[h][current[h]][2] = (int)nn;
+ _scratch[h][current[h]][3] = _idx[h][mm][nn];
+ current[h]++;
+ }
+ }
+ }
+ }
+ std::vector to_cat;
+ for (size_t h = 0; h < H; h++)
+ if (current[h] > 0) to_cat.push_back(scratch[h].slice(0, 0, current[h]));
+ if (!to_cat.empty()) ret.push_back({max_width, torch::cat(to_cat)});
+}
+
+ret_t sdd_segment(torch::Tensor layout, int start_width)
+{
+ ret_t ret;
+
+ // block index
+ torch::Tensor idx = torch::zeros_like(layout);
+ int current = 0;
+ size_t H = layout.size(0);
+ size_t M = layout.size(1);
+ size_t N = layout.size(2);
+ auto _layout = layout.accessor();
+ auto _idx = idx.accessor();
+ for (size_t h = 0; h < H; h++)
+ for (size_t m = 0; m < M; m++)
+ for (size_t n = 0; n < N; n++) {
+ if (_layout[h][m][n] == 0) continue;
+ _idx[h][m][n] = current++;
+ }
+
+ // scratch memory
+ torch::Tensor scratch = torch::empty({H, layout.sum().item(), 4}, layout.dtype());
+
+ for (int max_width = start_width; max_width > 0; max_width /= 2)
+ segment_blocks(layout, idx, scratch, max_width, ret);
+ return ret;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("sdd_segment", &sdd_segment, "SDD segmentation handler");
+}
diff --git a/csrc/transformer/cublas_wrappers.cu b/csrc/transformer/cublas_wrappers.cu
index 7b0016bcae5e..72b62386ea6d 100644
--- a/csrc/transformer/cublas_wrappers.cu
+++ b/csrc/transformer/cublas_wrappers.cu
@@ -34,7 +34,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
- fprintf(stderr, "!!!! kernel execution error.\n");
+ fprintf(stderr,
+ "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
+ m,
+ n,
+ k,
+ (int)status);
return EXIT_FAILURE;
}
return 0;
@@ -74,7 +79,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
- fprintf(stderr, "!!!! kernel execution error.\n");
+ fprintf(stderr,
+ "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
+ m,
+ n,
+ k,
+ (int)status);
return EXIT_FAILURE;
}
return 0;
@@ -122,7 +132,13 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
- fprintf(stderr, "!!!! kernel execution error.\n");
+ fprintf(stderr,
+ "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
+ batch,
+ m,
+ n,
+ k,
+ (int)status);
return EXIT_FAILURE;
}
return 0;
@@ -170,7 +186,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
- fprintf(stderr, "!!!! kernel execution error.\n");
+ fprintf(stderr,
+ "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
+ m,
+ n,
+ k,
+ (int)status);
return EXIT_FAILURE;
}
diff --git a/csrc/transformer/dropout_kernels.cu b/csrc/transformer/dropout_kernels.cu
index faf39b44eb91..6b0655b788eb 100644
--- a/csrc/transformer/dropout_kernels.cu
+++ b/csrc/transformer/dropout_kernels.cu
@@ -1,5 +1,7 @@
#include "custom_cuda_layers.h"
+const int unroll_factor = 4;
+
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
@@ -13,17 +15,17 @@ __global__ void dropout_kernel(const int N,
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
- CUDA_1D_KERNEL_LOOP(j, N / 4)
+ CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float4 rand = curand_uniform4(&state);
- uint8_t m[4];
+ uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
- int i = j * 4;
+ int i = j * unroll_factor;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
@@ -35,6 +37,18 @@ __global__ void dropout_kernel(const int N,
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ float4 rand = curand_uniform4(&state);
+ float* rand_data = &(rand.x);
+ int k = 0;
+ for (int i = high_index; i < N; i++) {
+ uint8_t m = (uint8_t)(rand_data[k++] > ratio);
+ out[i] = Xdata[i] * scale * m;
+ mask[i] = m;
+ }
+ }
}
__global__ void dropout_kernel(const int N,
@@ -66,7 +80,7 @@ __global__ void dropout_kernel(const int N,
__half2 mask_h[2];
float2 mask_f[2];
- CUDA_1D_KERNEL_LOOP(j, N / 4)
+ CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
@@ -80,7 +94,7 @@ __global__ void dropout_kernel(const int N,
float* mask_f_data = &mask_f[0].x;
#pragma unroll
- for (int i = 0; i < 4; i++) mask_f_data[i] = (float)(m[i]);
+ for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
@@ -95,16 +109,16 @@ __global__ void dropout_kernel(const int N,
#else
- CUDA_1D_KERNEL_LOOP(j, N / 4)
+ CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
- int i = j * 4;
+ int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
- uint8_t m[4];
+ uint8_t m[unroll_factor];
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
@@ -123,6 +137,18 @@ __global__ void dropout_kernel(const int N,
}
#endif
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ float4 rand = curand_uniform4(&state);
+ float* rand_data = &(rand.x);
+ int k = 0;
+ for (int i = high_index; i < N; i++) {
+ uint8_t m = (uint8_t)(rand_data[k++] > ratio);
+ out[i] = __float2half((float)Xdata[i] * scale * m);
+ mask[i] = m;
+ }
+ }
}
__global__ void dropout_kernel_bwd(const int N,
@@ -133,15 +159,20 @@ __global__ void dropout_kernel_bwd(const int N,
std::pair seed)
{
const float scale = 1. / (1. - ratio);
- CUDA_1D_KERNEL_LOOP(j, N / 4)
+ CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
- int i = j * 4;
+ int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
+ }
}
__global__ void dropout_kernel_bwd(const int N,
@@ -161,18 +192,20 @@ __global__ void dropout_kernel_bwd(const int N,
float2* out_cast = reinterpret_cast(out);
uint32_t* mask_cast = reinterpret_cast(mask);
- CUDA_1D_KERNEL_LOOP(j, N / 4)
+ CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
- uint8_t* m = reinterpret_cast(mask_cast + j);
+ uint32_t m_32 = mask_cast[j];
+ uint8_t* m = (uint8_t*)&m_32;
+
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
- for (int i = 0; i < 4; i++) mask_f_data[i] = (float)(m[i]);
+ for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
@@ -191,9 +224,9 @@ __global__ void dropout_kernel_bwd(const int N,
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
- CUDA_1D_KERNEL_LOOP(j, N / 4)
+ CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
- int i = j * 4;
+ int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast(Xdata + i);
@@ -211,6 +244,13 @@ __global__ void dropout_kernel_bwd(const int N,
}
#endif
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ for (int i = high_index; i < N; i++) {
+ out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
+ }
+ }
}
template
@@ -223,7 +263,9 @@ void launch_dropout(T* out,
cudaStream_t stream,
bool bwd)
{
- dim3 grid_dim = DS_GET_BLOCKS(total_count / 4);
+ assert(unroll_factor == 4);
+
+ dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
@@ -264,55 +306,70 @@ __global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
-#ifdef __STOCHASTIC_MODE__
-
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast(Xdata);
uint32_t* mask_cast = reinterpret_cast(mask);
- CUDA_1D_KERNEL_LOOP(j, N / 4)
+ CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
- uint8_t* m = reinterpret_cast(mask_cast + j);
+ float2 x_data = x_cast[j];
+ uint32_t m_32 = mask_cast[j];
+ uint8_t* m = (uint8_t*)&m_32;
+
+ float2 result_f;
+ __half2* result_h = reinterpret_cast<__half2*>(&result_f);
+
+#ifdef __STOCHASTIC_MODE__
+
+ __half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
- for (int i = 0; i < 4; i++) *(mask_f_data++) = (float)(m[i]);
+ for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
- float2 x_data = x_cast[j];
- __half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
-
- float2 result_f;
- __half2* result_h = reinterpret_cast<__half2*>(&result_f);
-
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
- x_cast[j] = result_f;
- }
-
#else
- CUDA_1D_KERNEL_LOOP(j, N / 2)
- {
- int i = j * 2;
- Xdata[i] = (__half)((float)Xdata[i] * scale * mask[i]);
- Xdata[i + 1] = (__half)((float)Xdata[i + 1] * scale * mask[i + 1]);
- }
+ __half* x_data_h = reinterpret_cast<__half*>(&x_data);
+ float2 result[2];
+
+ result[0].x = (float)x_data_h[0] * scale * m[0];
+ result[0].y = (float)x_data_h[1] * scale * m[1];
+ result[1].x = (float)x_data_h[2] * scale * m[2];
+ result[1].y = (float)x_data_h[3] * scale * m[3];
+
+ result_h[0] = __float22half2_rn(result[0]);
+ result_h[1] = __float22half2_rn(result[1]);
#endif
+ x_cast[j] = result_f;
+ }
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ for (int i = high_index; i < N; i++) {
+ Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
+ }
+ }
}
template
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream)
{
+ assert(unroll_factor == 4);
+
const float scale = 1. / (1. - ratio);
- dropout_grad_kernel<<>>(
- total_count, scale, vals, mask);
+ dropout_grad_kernel<<>>(total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
@@ -341,11 +398,38 @@ __global__ void dropout_grad_kernel(const int N,
__half* out,
uint8_t* mask)
{
- CUDA_1D_KERNEL_LOOP(j, N / 2)
+ const float2* x_cast = reinterpret_cast(Xdata);
+ float2* out_cast = reinterpret_cast(out);
+ const uint32_t* mask_cast = reinterpret_cast(mask);
+
+ float2 result_f;
+ __half2* result_h = reinterpret_cast<__half2*>(&result_f);
+
+ CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
- int i = j * 2;
- out[i] = (__half)((float)Xdata[i] * scale * mask[i]);
- out[i + 1] = (__half)((float)Xdata[i + 1] * scale * mask[i + 1]);
+ float2 x_data = x_cast[j];
+ uint32_t m_32 = mask_cast[j];
+ uint8_t* m = (uint8_t*)&m_32;
+
+ __half* x_data_h = reinterpret_cast<__half*>(&x_data);
+ float2 result[2];
+
+ result[0].x = (float)x_data_h[0] * scale * m[0];
+ result[0].y = (float)x_data_h[1] * scale * m[1];
+ result[1].x = (float)x_data_h[2] * scale * m[2];
+ result[1].y = (float)x_data_h[3] * scale * m[3];
+
+ result_h[0] = __float22half2_rn(result[0]);
+ result_h[1] = __float22half2_rn(result[1]);
+
+ out_cast[j] = result_f;
+ }
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ for (int i = high_index; i < N; i++) {
+ out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
+ }
}
}
@@ -357,9 +441,13 @@ void launch_dropout_grad(T* vals_out,
float ratio,
cudaStream_t stream)
{
+ assert(unroll_factor == 4);
+
const float scale = 1. / (1. - ratio);
- dropout_grad_kernel<<>>(
- total_count, scale, vals, vals_out, mask);
+ dropout_grad_kernel<<>>(total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
@@ -374,7 +462,8 @@ template void launch_dropout_grad(__half*,
float ratio,
cudaStream_t stream);
-__global__ void dropout_kernel(const int dim,
+__global__ void dropout_kernel(const int N,
+ const int dim,
const float ratio,
const float* bias,
float* Xdata,
@@ -383,26 +472,27 @@ __global__ void dropout_kernel(const int dim,
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
- int tid = threadIdx.x;
+ int tid = threadIdx.x % (dim / unroll_factor);
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast(Xdata);
+ uint32_t* mask_32 = reinterpret_cast(mask);
const float4* bias_cast = reinterpret_cast(bias);
+ CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = curand_uniform4(&state);
- uint8_t m[4];
+ uint32_t m_32;
+ uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
- int i = blockIdx.x * dim + tid * 4;
-
- float4 x_data = Xdata_cast[idx];
+ float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[tid];
x_data.x += b_data.x;
@@ -415,16 +505,26 @@ __global__ void dropout_kernel(const int dim,
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
- mask[i] = (uint8_t)m[0];
- mask[i + 1] = (uint8_t)m[1];
- mask[i + 2] = (uint8_t)m[2];
- mask[i + 3] = (uint8_t)m[3];
-
- Xdata_cast[idx] = x_data;
+ mask_32[j] = m_32;
+ Xdata_cast[j] = x_data;
+ }
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ float4 rand = curand_uniform4(&state);
+ float* rand_data = &(rand.x);
+ int k = 0;
+ for (int i = high_index; i < N; i++) {
+ float x_data = Xdata[i] + bias[threadIdx.x % dim];
+ uint8_t m = (uint8_t)(rand_data[k++] > ratio);
+ Xdata[i] = x_data * scale * m;
+ mask[i] = m;
+ }
}
}
-__global__ void dropout_kernel(const int dim,
+__global__ void dropout_kernel(const int N,
+ const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
@@ -433,17 +533,17 @@ __global__ void dropout_kernel(const int dim,
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
- int tid = threadIdx.x;
+ int tid = threadIdx.x % (dim / unroll_factor);
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast(Xdata);
+ uint32_t* mask_32 = reinterpret_cast(mask);
const float2* bias_cast = reinterpret_cast(bias);
+ CUDA_1D_KERNEL_LOOP(j, N)
{
- int i = blockIdx.x * dim + tid * 4;
-
float4 rand = curand_uniform4(&state);
float2 data_f;
@@ -452,7 +552,7 @@ __global__ void dropout_kernel(const int dim,
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
- data_f = Xdata_cast[idx];
+ data_f = Xdata_cast[j];
bias_f = bias_cast[tid];
float2 data_h_0 = __half22float2(data_h[0]);
@@ -466,7 +566,8 @@ __global__ void dropout_kernel(const int dim,
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
- uint8_t m[4]; // = mask + i;
+ uint32_t m_32;
+ uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
@@ -484,12 +585,21 @@ __global__ void dropout_kernel(const int dim,
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
- Xdata_cast[idx] = result_f;
-
- mask[i] = m[0];
- mask[i + 1] = m[1];
- mask[i + 2] = m[2];
- mask[i + 3] = m[3];
+ Xdata_cast[j] = result_f;
+ mask_32[j] = m_32;
+ }
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ float4 rand = curand_uniform4(&state);
+ float* rand_data = &(rand.x);
+ int k = 0;
+ for (int i = high_index; i < N; i++) {
+ float x_data = (float)Xdata[i] + (float)bias[threadIdx.x % dim];
+ uint8_t m = (uint8_t)(rand_data[k++] > ratio);
+ Xdata[i] = __float2half(x_data * scale * m);
+ mask[i] = m;
+ }
}
}
@@ -502,13 +612,18 @@ void launch_dropout(T* out,
float ratio,
cudaStream_t stream)
{
- dim3 grid_dim(batch); // DS_GET_BLOCKS(total_count/4);
- dim3 block_dim(dim / 4); // DS_CUDA_NUM_THREADS;
+ assert(unroll_factor == 4);
+
+ int total_count = batch * dim / unroll_factor;
+
+ dim3 grid_dim = DS_GET_BLOCKS(total_count);
+ dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair seed = Context::Instance().IncrementOffset(inc);
- dropout_kernel<<>>(dim, ratio, bias, out, mask, seed);
+ dropout_kernel<<>>(
+ total_count, dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
@@ -526,7 +641,8 @@ template void launch_dropout(__half*,
float ratio,
cudaStream_t stream);
-__global__ void dropout_kernel(const int dim,
+__global__ void dropout_kernel(const int N,
+ const int dim,
const float ratio,
const float* input,
const float* residual,
@@ -537,31 +653,34 @@ __global__ void dropout_kernel(const int dim,
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
- int tid = threadIdx.x;
+ int tid = threadIdx.x % (dim / unroll_factor);
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast(out);
+ uint32_t* mask_32 = reinterpret_cast(mask);
+
const float4* bias_cast = reinterpret_cast(bias);
const float4* residual_cast = reinterpret_cast(residual);
const float4* input_cast = reinterpret_cast(input);
+ CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = curand_uniform4(&state);
- uint8_t m[4];
+
+ uint32_t m_32;
+ uint8_t* m = (uint8_t*)&m_32;
+
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
- // int bid = k * blockDim.x + tid;
- int i = blockIdx.x * dim + tid * 4;
-
- float4 out_data = out_cast[idx];
+ float4 out_data;
float4 b_data = bias_cast[tid];
- float4 res_data = residual_cast[idx];
- float4 inp_data = input_cast[idx];
+ float4 res_data = residual_cast[j];
+ float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
@@ -578,16 +697,29 @@ __global__ void dropout_kernel(const int dim,
out_data.z += res_data.z;
out_data.w += res_data.w;
- mask[i] = m[0];
- mask[i + 1] = m[1];
- mask[i + 2] = m[2];
- mask[i + 3] = m[3];
-
- out_cast[idx] = out_data;
+ mask_32[j] = m_32;
+ out_cast[j] = out_data;
+ }
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ float4 rand = curand_uniform4(&state);
+ float* rand_data = &(rand.x);
+ int k = 0;
+ for (int i = high_index; i < N; i++) {
+ float x_data = input[i] + bias[threadIdx.x % dim];
+ uint8_t m = (uint8_t)(rand_data[k++] > ratio);
+ x_data = x_data * scale * m;
+ x_data += residual[i];
+
+ out[i] = x_data;
+ mask[i] = m;
+ }
}
}
-__global__ void dropout_kernel(const int dim,
+__global__ void dropout_kernel(const int N,
+ const int dim,
const float ratio,
const __half* input,
const __half* residual,
@@ -598,19 +730,20 @@ __global__ void dropout_kernel(const int dim,
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
- int tid = threadIdx.x;
+ int tid = threadIdx.x % (dim / unroll_factor);
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast(out);
+ uint32_t* mask_32 = reinterpret_cast(mask);
+
const float2* bias_cast = reinterpret_cast(bias);
const float2* residual_cast = reinterpret_cast(residual);
const float2* input_cast = reinterpret_cast(input);
+ CUDA_1D_KERNEL_LOOP(j, N)
{
- int i = blockIdx.x * dim + tid * 4;
-
float4 rand = curand_uniform4(&state);
float2 data_f;
@@ -625,10 +758,9 @@ __global__ void dropout_kernel(const int dim,
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
- data_f = out_cast[idx];
bias_f = bias_cast[tid];
- residual_f = residual_cast[idx];
- input_f = input_cast[idx];
+ residual_f = residual_cast[j];
+ input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
@@ -647,7 +779,8 @@ __global__ void dropout_kernel(const int dim,
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
- uint8_t m[4]; // = mask + i;
+ uint32_t m_32;
+ uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
@@ -670,12 +803,24 @@ __global__ void dropout_kernel(const int dim,
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
- out_cast[idx] = result_f;
-
- mask[i] = m[0];
- mask[i + 1] = m[1];
- mask[i + 2] = m[2];
- mask[i + 3] = m[3];
+ out_cast[j] = result_f;
+ mask_32[j] = m_32;
+ }
+ int high_index =
+ ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
+ if (N > high_index) {
+ float4 rand = curand_uniform4(&state);
+ float* rand_data = &(rand.x);
+ int k = 0;
+ for (int i = high_index; i < N; i++) {
+ float x_data = (float)input[i] + (float)bias[threadIdx.x % dim];
+ uint8_t m = (uint8_t)(rand_data[k++] > ratio);
+ x_data = x_data * scale * m;
+ x_data += (float)residual[i];
+
+ out[i] = __float2half(x_data);
+ mask[i] = m;
+ }
}
}
@@ -690,14 +835,17 @@ void launch_dropout(T* out,
float ratio,
cudaStream_t stream)
{
- dim3 grid_dim(batch); // DS_GET_BLOCKS(total_count/4);
- dim3 block_dim(dim / 4); // DS_CUDA_NUM_THREADS;
+ assert(unroll_factor == 4);
+
+ int total_count = batch * dim / unroll_factor;
+ dim3 grid_dim = DS_GET_BLOCKS(total_count);
+ dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair seed = Context::Instance().IncrementOffset(inc);
dropout_kernel<<>>(
- dim, ratio, input, residual, bias, out, mask, seed);
+ total_count, dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp
index 269468bdfdb4..f22b8a0743f1 100755
--- a/csrc/transformer/ds_transformer_cuda.cpp
+++ b/csrc/transformer/ds_transformer_cuda.cpp
@@ -14,23 +14,26 @@
static std::unordered_map> s_transformer_layers;
+const int init_seq_length = 128;
+
// C++ interface
template
size_t get_workspace_size(int maxBatchSize,
int seq_len,
int hidden_size,
+ int intermediate_size,
int heads,
bool training,
bool gelu_checkpoint)
{
size_t workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
if (training) {
- workSpacesize += (std::max((4 * size_t(maxBatchSize) * seq_len * hidden_size),
- 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
+ workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
+ 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
}
- return workSpacesize * sizeof(T);
+ return workSpacesize; // * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
@@ -77,39 +80,29 @@ BertTransformerLayer::BertTransformerLayer(int layer_id,
hidden_size,
hidden_size,
gemm_algos[0])),
- _norm_layer2(typename Normalize_Layer::Config(batch_size,
- seq_length,
- hidden_size,
- true,
- false,
- false,
- !normalize_invertible)),
- _norm_layer3(typename Normalize_Layer::Config(batch_size,
- seq_length,
- hidden_size,
- true,
- false,
- false,
- !normalize_invertible)),
+ _attn_layer_norm(typename Normalize_Layer::Config(batch_size,
+ seq_length,
+ hidden_size,
+ true,
+ !normalize_invertible)),
+ _layer_norm(typename Normalize_Layer::Config(batch_size,
+ seq_length,
+ hidden_size,
+ true,
+ !normalize_invertible)),
_ff1(typename FeedForward::Config(batch_size * seq_length,
- 4 * hidden_size,
+ _intermediate_size,
hidden_size,
gemm_algos[1])),
_ff2(typename FeedForward::Config(batch_size * seq_length,
hidden_size,
- 4 * hidden_size,
+ _intermediate_size,
gemm_algos[2])),
_softmax(typename Softmax::Config(batch_size, num_heads, seq_length)),
- _gelu(typename Gelu::Config(_batch_size, _seq_length, _intermediate_size)),
- _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio,
- _batch_size * _heads * _seq_length,
- _seq_length)),
- _attn_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio,
- _batch_size * _seq_length,
- _hidden_size)),
- _layer_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio,
- _batch_size * _seq_length,
- _hidden_size)),
+ _gelu(typename Gelu::Config(_intermediate_size)),
+ _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio, _seq_length)),
+ _attn_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)),
+ _layer_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)),
_attn_scores(typename StridedBatchGemm::Config(_batch_size * _heads,
_seq_length,
_seq_length,
@@ -130,7 +123,6 @@ BertTransformerLayer::BertTransformerLayer(int layer_id,
gemm_algos[4]))
{
assert(_hidden_size % _heads == 0);
- assert(_seq_length <= 1024);
Initialize();
}
@@ -143,9 +135,6 @@ BertTransformerLayer::~BertTransformerLayer()
template
void BertTransformerLayer::Initialize()
{
- Context::Instance().GenWorkSpace(get_workspace_size(
- _batch_size, _seq_length, _hidden_size, _heads, _training, _gelu_checkpoint));
-
if (std::is_same::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}
@@ -190,18 +179,18 @@ void BertTransformerLayer::Forward(int bsz,
if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size;
if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size;
+ int bsz_seq = bsz * _seq_length;
+
if (_pre_or_postLayerNorm) {
- if (_norm_layer3.UseMean())
- _norm_layer3.ForwardCheckpoint(
- bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
+ if (_layer_norm.UseMean())
+ _layer_norm.ForwardCheckpoint(
+ bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
- _norm_layer3.Forward(
- bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
+ _layer_norm.Forward(
+ bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
- int bsz_seq = bsz * _seq_length;
-
if (_pre_or_postLayerNorm)
_qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
else
@@ -241,19 +230,19 @@ void BertTransformerLayer::Forward(int bsz,
bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
if (_pre_or_postLayerNorm) {
- if (_norm_layer2.UseMean())
- _norm_layer2.ForwardCheckpoint(
- bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
+ if (_attn_layer_norm.UseMean())
+ _attn_layer_norm.ForwardCheckpoint(
+ bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
- _norm_layer2.Forward(
- bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
+ _attn_layer_norm.Forward(
+ bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} else {
- if (_norm_layer2.UseMean())
- _norm_layer2.ForwardCheckpoint(
- bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
+ if (_attn_layer_norm.UseMean())
+ _attn_layer_norm.ForwardCheckpoint(
+ bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
- _norm_layer2.Forward(
- bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
+ _attn_layer_norm.Forward(
+ bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
}
_ff1.Forward(bsz_seq,
@@ -262,7 +251,7 @@ void BertTransformerLayer::Forward(int bsz,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
_cublasHandle);
- _gelu.ForwardWithBiasAdd(bsz,
+ _gelu.ForwardWithBiasAdd(bsz_seq,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr,
(_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
@@ -283,11 +272,12 @@ void BertTransformerLayer::Forward(int bsz,
bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
if (!_pre_or_postLayerNorm) {
- if (_norm_layer3.UseMean())
- _norm_layer3.ForwardCheckpoint(
- bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
+ if (_layer_norm.UseMean())
+ _layer_norm.ForwardCheckpoint(
+ bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
- _norm_layer3.Forward(bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
+ _layer_norm.Forward(
+ bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
}
@@ -343,7 +333,8 @@ void BertTransformerLayer::Backward(int bsz,
T* buf_2 = buf_1 + small_buf_size;
T* buf_3 = buf_2 + small_buf_size;
- T* ff2_buf = buf_3 + (_gelu_checkpoint ? 3 : 1) * small_buf_size;
+ T* ff2_buf = (_gelu_checkpoint ? buf_2 + (bsz * _seq_length * _intermediate_size)
+ : buf_3 + small_buf_size);
T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
cudaStream_t streams[2] = {_stream, _stream};
@@ -352,26 +343,26 @@ void BertTransformerLayer::Backward(int bsz,
int bsz_heads = bsz * _heads;
if (!_pre_or_postLayerNorm) {
- if (_norm_layer3.UseMean())
- _norm_layer3.Backward(bsz,
- grad_output_ptr,
- norm_w_ptr,
- grad_norm_w_ptr,
- grad_norm_b_ptr,
- streams,
- buf_1,
- inp_norm_ptr);
+ if (_layer_norm.UseMean())
+ _layer_norm.Backward(bsz_seq,
+ grad_output_ptr,
+ norm_w_ptr,
+ grad_norm_w_ptr,
+ grad_norm_b_ptr,
+ streams,
+ buf_1,
+ inp_norm_ptr);
else
- _norm_layer3.Backward(bsz,
- grad_output_ptr,
- norm_w_ptr,
- norm_b_ptr,
- grad_norm_w_ptr,
- grad_norm_b_ptr,
- streams,
- buf_1,
- output_ptr);
+ _layer_norm.Backward(bsz_seq,
+ grad_output_ptr,
+ norm_w_ptr,
+ norm_b_ptr,
+ grad_norm_w_ptr,
+ grad_norm_b_ptr,
+ streams,
+ buf_1,
+ output_ptr);
}
if (_pre_or_postLayerNorm)
@@ -383,7 +374,8 @@ void BertTransformerLayer::Backward(int bsz,
? buf_0
: (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
- if (_gelu_checkpoint) _gelu.ForwardWithBiasAdd(bsz, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
+ if (_gelu_checkpoint)
+ _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
_ff2.Backward(bsz_seq,
layer_dropout_buf,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
@@ -395,7 +387,7 @@ void BertTransformerLayer::Backward(int bsz,
ff2_buf);
_gelu.Backward(
- bsz, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
+ bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
_ff1.Backward(bsz_seq,
ff2_buf,
@@ -411,49 +403,49 @@ void BertTransformerLayer::Backward(int bsz,
launch_fused_add2(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
if (_pre_or_postLayerNorm) {
- if (_norm_layer2.UseMean())
- _norm_layer2.BackwardFusedAdd(bsz,
- buf_3,
- grad_output_ptr,
- attn_nw_ptr,
- grad_attn_nw_ptr,
- grad_attn_nb_ptr,
- streams,
- buf_0,
- add_res_ptr);
+ if (_attn_layer_norm.UseMean())
+ _attn_layer_norm.BackwardFusedAdd(bsz_seq,
+ buf_3,
+ grad_output_ptr,
+ attn_nw_ptr,
+ grad_attn_nw_ptr,
+ grad_attn_nb_ptr,
+ streams,
+ buf_0,
+ add_res_ptr);
else
- _norm_layer2.BackwardFusedAdd(bsz,
- buf_3,
- grad_output_ptr,
- attn_nw_ptr,
- attn_nb_ptr,
- grad_attn_nw_ptr,
- grad_attn_nb_ptr,
- streams,
- buf_0,
- ff1_inp_ptr);
+ _attn_layer_norm.BackwardFusedAdd(bsz_seq,
+ buf_3,
+ grad_output_ptr,
+ attn_nw_ptr,
+ attn_nb_ptr,
+ grad_attn_nw_ptr,
+ grad_attn_nb_ptr,
+ streams,
+ buf_0,
+ ff1_inp_ptr);
} else {
- if (_norm_layer2.UseMean())
- _norm_layer2.Backward(bsz,
- buf_2,
- attn_nw_ptr,
- grad_attn_nw_ptr,
- grad_attn_nb_ptr,
- streams,
- buf_0,
- add_res_ptr);
+ if (_attn_layer_norm.UseMean())
+ _attn_layer_norm.Backward(bsz_seq,
+ buf_2,
+ attn_nw_ptr,
+ grad_attn_nw_ptr,
+ grad_attn_nb_ptr,
+ streams,
+ buf_0,
+ add_res_ptr);
else
- _norm_layer2.Backward(bsz,
- buf_2,
- attn_nw_ptr,
- attn_nb_ptr,
- grad_attn_nw_ptr,
- grad_attn_nb_ptr,
- streams,
- buf_0,
- ff1_inp_ptr);
+ _attn_layer_norm.Backward(bsz_seq,
+ buf_2,
+ attn_nw_ptr,
+ attn_nb_ptr,
+ grad_attn_nw_ptr,
+ grad_attn_nb_ptr,
+ streams,
+ buf_0,
+ ff1_inp_ptr);
}
_attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
@@ -518,28 +510,28 @@ void BertTransformerLayer::Backward(int bsz,
buf_2);
if (_pre_or_postLayerNorm) {
- if (_norm_layer3.UseMean())
- _norm_layer3.BackwardFusedAdd(bsz,
- buf_2,
- buf_0,
- norm_w_ptr,
- grad_norm_w_ptr,
- grad_norm_b_ptr,
- streams,
- grad_input_ptr,
- input_ptr);
+ if (_layer_norm.UseMean())
+ _layer_norm.BackwardFusedAdd(bsz_seq,
+ buf_2,
+ buf_0,
+ norm_w_ptr,
+ grad_norm_w_ptr,
+ grad_norm_b_ptr,
+ streams,
+ grad_input_ptr,
+ input_ptr);
else
- _norm_layer3.BackwardFusedAdd(bsz,
- buf_2,
- buf_0,
- norm_w_ptr,
- norm_b_ptr,
- grad_norm_w_ptr,
- grad_norm_b_ptr,
- streams,
- grad_input_ptr,
- inp_norm_ptr);
+ _layer_norm.BackwardFusedAdd(bsz_seq,
+ buf_2,
+ buf_0,
+ norm_w_ptr,
+ norm_b_ptr,
+ grad_norm_w_ptr,
+ grad_norm_b_ptr,
+ streams,
+ grad_input_ptr,
+ inp_norm_ptr);
} else
launch_fused_add2(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
}
@@ -556,11 +548,31 @@ void BertTransformerLayer::SetTrainingMode(bool training)
template
void BertTransformerLayer::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
- uint8_t* layer_output_dropout_mask_ptr)
+ uint8_t* layer_output_dropout_mask_ptr,
+ T* attn_layer_norm_var,
+ T* attn_layer_norm_mean,
+ T* layer_norm_var,
+ T* layer_norm_mean)
{
_attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
_attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
_layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
+
+ _attn_layer_norm.SetVar(attn_layer_norm_var);
+ _attn_layer_norm.SetMean(attn_layer_norm_mean);
+ _layer_norm.SetVar(layer_norm_var);
+ _layer_norm.SetMean(layer_norm_mean);
+}
+
+template
+void BertTransformerLayer::SetSeqLength(int seq_len)
+{
+ _seq_length = seq_len;
+
+ _softmax.SetSeqLength(_seq_length);
+ _attn_prob_dropout.SetDimension(_seq_length);
+ _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
+ _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
}
template
@@ -569,7 +581,6 @@ int create_transformer_layer(int layer_id,
int hidden_dim,
int num_heads,
int intermediate_size,
- int seq_length,
float attn_dropout_ratio,
float hidden_dropout_ratio,
int seed,
@@ -582,14 +593,14 @@ int create_transformer_layer(int layer_id,
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
- test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads);
+ test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
- seq_length,
+ init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
pre_or_postLayerNorm,
@@ -681,54 +692,71 @@ std::vector ds_transformer_forward(int layer_id,
std::shared_ptr> layer =
std::static_pointer_cast>(s_transformer_layers[layer_id]);
+ int seq_len = layer->GetSeqLength();
+ if (input.size(1) != seq_len) {
+ seq_len = input.size(1);
+ layer->SetSeqLength(seq_len);
+ }
+
+ auto workspace = torch::empty({get_workspace_size(bsz,
+ seq_len,
+ layer->GetHiddenSize(),
+ layer->GetIntermediateSize(),
+ layer->GetNumHeads(),
+ layer->IsTrainingMode(),
+ layer->GeluCheckpoint())},
+ options);
+ Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
+
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
- auto qkv_tf = torch::empty({(bsz * layer->GetSeqLength()), output_w.size(0) * 3}, options);
+ auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
auto attn_prob_dropout_mask =
- torch::empty({(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
- uint8_options);
+ torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
auto attn_output_dropout_mask =
- torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options);
+ torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto layer_output_dropout_mask =
- torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options);
+ torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
+
+ auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
+ auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
+ auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
+ auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
T* inp_norm_ptr = (T*)inp_norm.data_ptr();
T* add_res_ptr = (T*)add_res.data_ptr();
T* q_tf_ptr = (T*)qkv_tf.data_ptr();
- T* k_tf_ptr =
- q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)k_tf.data_ptr();
- T* v_tf_ptr =
- k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)v_tf.data_ptr();
+ T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
+ T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
- torch::Tensor ff2_inp =
- torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options);
+ torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
torch::Tensor gelu_inp =
- (gelu_checkpoint
- ? ff2_inp
- : torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options));
+ (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
auto ff1_inp = torch::empty_like(input);
T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
- torch::Tensor soft_out = torch::empty(
- {(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, options);
+ torch::Tensor soft_out =
+ torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
torch::Tensor ctx_bufB =
(attn_dropout_checkpoint
? soft_out
- : torch::empty(
- {(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
- options));
+ : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
T* soft_out_ptr = (T*)soft_out.data_ptr();
T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
layer->SetTrainingMode(training_mode);
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
- (uint8_t*)layer_output_dropout_mask.data_ptr());
+ (uint8_t*)layer_output_dropout_mask.data_ptr(),
+ (T*)attn_layer_norm_var.data_ptr(),
+ (T*)attn_layer_norm_mean.data_ptr(),
+ (T*)layer_norm_var.data_ptr(),
+ (T*)layer_norm_mean.data_ptr());
layer->Forward(bsz,
input_ptr,
@@ -770,7 +798,11 @@ std::vector ds_transformer_forward(int layer_id,
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
- layer_output_dropout_mask};
+ layer_output_dropout_mask,
+ attn_layer_norm_var,
+ attn_layer_norm_mean,
+ layer_norm_var,
+ layer_norm_mean};
}
template
@@ -789,6 +821,10 @@ std::vector ds_transformer_backward(int layer_id,
const torch::Tensor& attn_prob_dropout_mask,
const torch::Tensor& attn_output_dropout_mask,
const torch::Tensor& layer_output_dropout_mask,
+ const torch::Tensor& attn_layer_norm_var,
+ const torch::Tensor& attn_layer_norm_mean,
+ const torch::Tensor& layer_norm_var,
+ const torch::Tensor& layer_norm_mean,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
@@ -832,9 +868,26 @@ std::vector ds_transformer_backward(int layer_id,
CHECK_INPUT(norm_b);
int bsz = g_output.size(0);
+
std::shared_ptr> layer =
std::static_pointer_cast>(s_transformer_layers[layer_id]);
+ int seq_len = layer->GetSeqLength();
+ if (g_output.size(1) != seq_len) {
+ seq_len = g_output.size(1);
+ layer->SetSeqLength(seq_len);
+ }
+
+ auto workspace = torch::empty({get_workspace_size(bsz,
+ seq_len,
+ layer->GetHiddenSize(),
+ layer->GetIntermediateSize(),
+ layer->GetNumHeads(),
+ layer->IsTrainingMode(),
+ layer->GeluCheckpoint())},
+ grad_output.options());
+ Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
+
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
@@ -894,7 +947,11 @@ std::vector ds_transformer_backward(int layer_id,
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
- (uint8_t*)layer_output_dropout_mask.data_ptr());
+ (uint8_t*)layer_output_dropout_mask.data_ptr(),
+ (T*)attn_layer_norm_var.data_ptr(),
+ (T*)attn_layer_norm_mean.data_ptr(),
+ (T*)layer_norm_var.data_ptr(),
+ (T*)layer_norm_mean.data_ptr());
layer->Backward(bsz,
grad_output_ptr,
diff --git a/csrc/transformer/gelu_kernels.cu b/csrc/transformer/gelu_kernels.cu
old mode 100755
new mode 100644
index f0e65e3829b5..209b64a90902
--- a/csrc/transformer/gelu_kernels.cu
+++ b/csrc/transformer/gelu_kernels.cu
@@ -279,13 +279,12 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
- int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
- dim3 grid_dims(sequence_length * batch_size);
+ dim3 grid_dims(batch_size);
fused_bias_gelu<<>>(input, bias, output, intermediate_size);
}
@@ -295,24 +294,26 @@ void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
- int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
- dim3 grid_dims(sequence_length * batch_size);
+ dim3 grid_dims(batch_size);
gelu_kernel<<>>(input, output, intermediate_size);
}
-template void
-launch_bias_gelu(const float*, const float*, float*, int, int, int, cudaStream_t);
-template void
-launch_bias_gelu<__half>(const __half*, const __half*, __half*, int, int, int, cudaStream_t);
+template void launch_bias_gelu(const float*, const float*, float*, int, int, cudaStream_t);
+template void launch_bias_gelu<__half>(const __half*,
+ const __half*,
+ __half*,
+ int,
+ int,
+ cudaStream_t);
-template void launch_gelu(const float*, float*, int, int, int, cudaStream_t);
-template void launch_gelu<__half>(const __half*, __half*, int, int, int, cudaStream_t);
+template void launch_gelu(const float*, float*, int, int, cudaStream_t);
+template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t);
template
void launch_d_gelu(T* d_output,
@@ -320,17 +321,15 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
- int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
- dim3 grid_dims(sequence_length * batch_size);
+ dim3 grid_dims(batch_size);
d_gelu_func<<>>(d_output, input, bias, intermediate_size);
}
-template void launch_d_gelu(float*, const float*, const float*, int, int, int, cudaStream_t);
-template void
-launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, int, cudaStream_t);
+template void launch_d_gelu(float*, const float*, const float*, int, int, cudaStream_t);
+template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t);
diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu
index 0ce280a702ab..fbe4d0536789 100644
--- a/csrc/transformer/general_kernels.cu
+++ b/csrc/transformer/general_kernels.cu
@@ -14,15 +14,18 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
cg::thread_block_tile g = cg::tiled_partition(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
+
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- localSum += (float)inp[offset];
- offset += y_stride;
+ if (idx < width) {
+ int offset = threadIdx.y * width + idx;
+ for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
+ localSum += (float)inp[offset];
+ offset += y_stride;
+ }
}
tile[threadIdx.x][threadIdx.y] = localSum;
@@ -40,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- out[pos] = sum;
+ if (pos < (rows * width)) out[pos] = sum;
}
}
@@ -58,10 +61,10 @@ void launch_fuse_transpose_bias_kernel(const float* inp,
int cols,
cudaStream_t stream)
{
- assert(rows % TILE_DIM == 0);
- assert(cols % TILE_DIM == 0);
+ // assert(rows % TILE_DIM == 0);
+ // assert(cols % TILE_DIM == 0);
- dim3 grid_dim(cols / TILE_DIM);
+ dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<<>>(inp, out, rows, cols);
@@ -74,49 +77,38 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
int cols,
cudaStream_t stream)
{
- assert(rows % TILE_DIM == 0);
- assert(cols % TILE_DIM == 0);
+ // assert(rows % TILE_DIM == 0);
+ // assert(cols % TILE_DIM == 0);
- dim3 grid_dim(cols / TILE_DIM);
+ dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<__half><<>>(inp, out, rows, cols);
}
-__global__ void fused_add2_kernel(float* out,
- const float* inp1,
- const float* inp2,
- int size,
- int row_stride)
+__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
- int row = blockIdx.x;
- int id = threadIdx.x;
-
const float4* inp1_4 = reinterpret_cast(inp1);
const float4* inp2_4 = reinterpret_cast(inp2);
float4* out_4 = reinterpret_cast(out);
- float4 val;
- float4 inp1_reg = inp1_4[row * row_stride + id];
- float4 inp2_reg = inp2_4[row * row_stride + id];
+ CUDA_1D_KERNEL_LOOP(j, N)
+ {
+ float4 val;
+ float4 inp1_reg = inp1_4[j];
+ float4 inp2_reg = inp2_4[j];
- val.x = inp1_reg.x + inp2_reg.x;
- val.y = inp1_reg.y + inp2_reg.y;
- val.z = inp1_reg.z + inp2_reg.z;
- val.w = inp1_reg.w + inp2_reg.w;
+ val.x = inp1_reg.x + inp2_reg.x;
+ val.y = inp1_reg.y + inp2_reg.y;
+ val.z = inp1_reg.z + inp2_reg.z;
+ val.w = inp1_reg.w + inp2_reg.w;
- out_4[row * row_stride + id] = val;
+ out_4[j] = val;
+ }
}
-__global__ void fused_add2_kernel(__half* out,
- const __half* inp1,
- const __half* inp2,
- int size,
- int row_stride)
+__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
- int row = blockIdx.x;
- int id = threadIdx.x;
-
float2 inp1_4;
float2 inp2_4;
@@ -126,28 +118,31 @@ __global__ void fused_add2_kernel(__half* out,
const float2* inp1_arr = reinterpret_cast(inp1);
const float2* inp2_arr = reinterpret_cast(inp2);
- inp1_4 = inp1_arr[row * row_stride + id];
- inp2_4 = inp2_arr[row * row_stride + id];
+ CUDA_1D_KERNEL_LOOP(j, N)
+ {
+ inp1_4 = inp1_arr[j];
+ inp2_4 = inp2_arr[j];
- float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
- float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
+ float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
+ float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
- float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
- float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
+ float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
+ float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
- inp1_h_f_0.x += inp2_h_f_0.x;
- inp1_h_f_0.y += inp2_h_f_0.y;
- inp1_h_f_1.x += inp2_h_f_1.x;
- inp1_h_f_1.y += inp2_h_f_1.y;
+ inp1_h_f_0.x += inp2_h_f_0.x;
+ inp1_h_f_0.y += inp2_h_f_0.y;
+ inp1_h_f_1.x += inp2_h_f_1.x;
+ inp1_h_f_1.y += inp2_h_f_1.y;
- float2 val_f;
- __half2* val_h = reinterpret_cast<__half2*>(&val_f);
+ float2 val_f;
+ __half2* val_h = reinterpret_cast<__half2*>(&val_f);
- val_h[0] = __float22half2_rn(inp1_h_f_0);
- val_h[1] = __float22half2_rn(inp1_h_f_1);
+ val_h[0] = __float22half2_rn(inp1_h_f_0);
+ val_h[1] = __float22half2_rn(inp1_h_f_1);
- float2* out_4 = reinterpret_cast(out);
- out_4[row * row_stride + id] = val_f;
+ float2* out_4 = reinterpret_cast(out);
+ out_4[j] = val_f;
+ }
}
template <>
@@ -159,12 +154,12 @@ void launch_fused_add2(float* out,
int hidden_dim,
cudaStream_t& stream)
{
- dim3 grid_dim(batch_size * seq_length);
+ int total_count = batch_size * seq_length * hidden_dim / 4;
+ dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
- dim3 block_dim(hidden_dim / 4);
+ dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
- fused_add2_kernel<<>>(
- out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
+ fused_add2_kernel<<>>(total_count, out, inp1, inp2);
}
template <>
@@ -176,12 +171,12 @@ void launch_fused_add2<__half>(__half* out,
int hidden_dim,
cudaStream_t& stream)
{
- dim3 grid_dim(batch_size * seq_length);
+ int total_count = batch_size * seq_length * hidden_dim / 4;
+ dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
- dim3 block_dim(hidden_dim / 4);
+ dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
- fused_add2_kernel<<>>(
- out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
+ fused_add2_kernel<<>>(total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
diff --git a/csrc/transformer/normalize_kernels.cu b/csrc/transformer/normalize_kernels.cu
old mode 100755
new mode 100644
index 7345175694bf..a0e512c73d44
--- a/csrc/transformer/normalize_kernels.cu
+++ b/csrc/transformer/normalize_kernels.cu
@@ -5,51 +5,53 @@ namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
-Unlike the GELU, which doesn't require template parameters, this layer does since it
-does rely fairly heavily on unrolling loops. Currently, I exclude bounds checks and
-assume that the number of elements is a multiple of a power of 2. Default behavior
-for our purposes uses 256 threads for floats, and 128 threads for __half. This restriction
-is a result of using the shift parameter to perform the minimum number of register file
-shuffles necessary, which requires the number of threads in the secondary reduction to
-be 1, 2, 4, 8, 16, or 32. The number of threads here corresponds to the number of complete
-warps in the threadblock.
-
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
-template
+#define NORM_REG (MAX_REGISTERS / 4)
+
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
- bool training = false,
- float* vars = nullptr,
- float* means = nullptr,
- float* vals_hat = nullptr)
+ bool training,
+ float* vars,
+ float* means,
+ int row_stride)
{
- constexpr int iteration_stride = row_stride / iterations;
+ int iteration_stride = blockDim.x;
+ int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
+ cg::thread_block_tile g = cg::tiled_partition(b);
int row = blockIdx.x;
int id = threadIdx.x;
- int gid = id / 32;
+ int gid = id / WARP_SIZE;
+
+ float vals_arr[NORM_REG];
+ __shared__ float shr[MAX_WARP_NUM];
- float vals_arr[iterations];
- __shared__ float shr[iteration_stride >> 5];
+ residual += (row * row_stride);
+ vals += (row * row_stride);
float sum = 0.f;
+ int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
- vals_arr[i] = residual[row * row_stride + i * iteration_stride + id];
+ vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
+ if (high_index < row_stride) {
+ vals_arr[iterations] = residual[high_index];
+ sum += vals_arr[iterations];
+ iterations++;
+ }
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
@@ -71,7 +73,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
if (g.thread_rank() == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
- variance += (vals_arr[i] - mean) * (vals_arr[i] - mean);
+ vals_arr[i] -= mean;
+ variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
@@ -93,28 +96,34 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
+ iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
- vals_arr[i] = (vals_arr[i] - mean) * rsqrtf(variance);
+ vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
- vals[row * row_stride + i * iteration_stride + id] = vals_arr[i];
+ vals[i * iteration_stride + id] = vals_arr[i];
+ }
+ if ((high_index) < row_stride) {
+ vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
+ vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
+ vals[high_index] = vals_arr[iterations];
}
}
-template
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
- bool training = false,
- __half* vars = nullptr,
- __half* means = nullptr,
- __half* vals_hat = nullptr)
+ bool training,
+ __half* vars,
+ __half* means,
+ int row_stride)
{
#if __CUDA_ARCH__ >= 700
- constexpr int iteration_stride = row_stride / iterations;
+ int iteration_stride = blockDim.x;
+ int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
@@ -123,20 +132,29 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
int id = threadIdx.x;
int gid = id >> 5;
- __half2 vals_arr[iterations];
- float2 vals_f[iterations];
- __shared__ float shr[iteration_stride >> 5];
+ float2 vals_f[NORM_REG];
+ __shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast(residual);
+ residual_cast += (row * row_stride);
+ vals_cast += (row * row_stride);
+
float sum = 0.f;
+ int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
- vals_f[i] = __half22float2(residual_cast[row * row_stride + i * iteration_stride + id]);
+ vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
+ if ((high_index) < row_stride) {
+ vals_f[iterations] = __half22float2(residual_cast[high_index]);
+ sum += vals_f[iterations].x;
+ sum += vals_f[iterations].y;
+ iterations++;
+ }
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
@@ -156,8 +174,10 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
- variance += (vals_f[i].x - mean) * (vals_f[i].x - mean);
- variance += (vals_f[i].y - mean) * (vals_f[i].y - mean);
+ vals_f[i].x -= mean;
+ vals_f[i].y -= mean;
+ variance += vals_f[i].x * vals_f[i].x;
+ variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
@@ -177,7 +197,6 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
variance /= (row_stride * 2);
variance += epsilon;
- __half2 mean_h = __float2half2_rn(mean);
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast(gamma);
const __half2* beta_cast = reinterpret_cast(beta);
@@ -186,13 +205,19 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
-
+ iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
- vals_arr[i] = __float22half2_rn(vals_f[i]);
- vals_arr[i] = (vals_arr[i] - mean_h) * h2rsqrt(variance_h);
- vals_arr[i] = vals_arr[i] * gamma_cast[i * iteration_stride + id] +
- beta_cast[i * iteration_stride + id];
- vals_cast[row * row_stride + i * iteration_stride + id] = vals_arr[i];
+ __half2 vals_arr = __float22half2_rn(vals_f[i]);
+ vals_arr = vals_arr * h2rsqrt(variance_h);
+ vals_arr =
+ vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
+ vals_cast[i * iteration_stride + id] = vals_arr;
+ }
+ if ((high_index) < row_stride) {
+ __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
+ vals_arr = vals_arr * h2rsqrt(variance_h);
+ vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
+ vals_cast[high_index] = vals_arr;
}
#endif
}
@@ -204,14 +229,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
- T* means,
- T* vals_hat);
+ T* means);
template <>
void launch_bias_residual_layer_norm(float* vals,
@@ -220,42 +243,28 @@ void launch_bias_residual_layer_norm(float* vals,
const float* beta,
float epsilon,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
- float* means,
- float* vals_hat)
+ float* means)
{
- constexpr int threads = THREADS;
+ int threads = THREADS;
- dim3 grid_dim(batch_size * sequence_length);
+ dim3 grid_dim(batch_size);
+
+ if (hidden_dim > 16384 && hidden_dim <= 32768)
+ threads <<= 1;
+ else if (hidden_dim > 32768 && hidden_dim <= 65536)
+ threads <<= 2;
+ else if (hidden_dim > 65536)
+ throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
- // There are some limitations to call below functions, now just enumerate the situations.
- if (hidden_dim == 768)
- fused_bias_residual_layer_norm<768, 3><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 512)
- fused_bias_residual_layer_norm<512, 2><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 1024)
- fused_bias_residual_layer_norm<1024, 4><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 1536)
- fused_bias_residual_layer_norm<1536, 6><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 2048)
- fused_bias_residual_layer_norm<2048, 8><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 2560)
- fused_bias_residual_layer_norm<2560, 10><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else
- throw std::runtime_error("Unsupport hidden_dim.");
+ fused_bias_residual_layer_norm<<>>(
+ vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
}
template <>
@@ -265,56 +274,44 @@ void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* beta,
float epsilon,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
- __half* means,
- __half* vals_hat)
+ __half* means)
{
- constexpr int threads = 128;
+ int threads = 128;
- dim3 grid_dim(batch_size * sequence_length);
- dim3 block_dim(threads);
+ dim3 grid_dim(batch_size);
- // There are some limitations to call below functions, now just enumerate the situations.
- if (hidden_dim == 768)
- fused_bias_residual_layer_norm<384, 3><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 512)
- fused_bias_residual_layer_norm<256, 2><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 1024)
- fused_bias_residual_layer_norm<512, 4><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 1536)
- fused_bias_residual_layer_norm<768, 6><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 2048)
- fused_bias_residual_layer_norm<1024, 8><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else if (hidden_dim == 2560)
- fused_bias_residual_layer_norm<1280, 10><<>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
- else
+ if (hidden_dim > 8192 && hidden_dim <= 16384)
+ threads <<= 1;
+ else if (hidden_dim > 16384 && hidden_dim <= 32768)
+ threads <<= 2;
+ else if (hidden_dim > 32768 && hidden_dim <= 65536)
+ threads <<= 3;
+ else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
+
+ dim3 block_dim(threads);
+
+ fused_bias_residual_layer_norm<<>>(
+ vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
}
-template
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
- bool training = false,
- float* vars = nullptr,
- float* vals_hat = nullptr,
- bool save_vals = false)
+ bool training,
+ float* vars,
+ int row_stride)
{
- constexpr int iteration_stride = row_stride / iterations;
+ int iteration_stride = blockDim.x;
+ int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
@@ -323,15 +320,24 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
int id = threadIdx.x;
int gid = id / 32;
- float vals_arr[iterations];
- __shared__ float shr[iteration_stride >> 5];
+ float vals_arr[NORM_REG];
+ __shared__ float shr[MAX_WARP_NUM];
+
+ residual += (row * row_stride);
+ vals += (row * row_stride);
float sum = 0.f;
+ int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
- vals_arr[i] = residual[row * row_stride + i * iteration_stride + id];
+ vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
+ if ((high_index) < row_stride) {
+ vals_arr[iterations] = residual[high_index];
+ sum += vals_arr[iterations];
+ iterations++;
+ }
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
@@ -351,7 +357,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
- variance += (vals_arr[i] - mean) * (vals_arr[i] - mean);
+ vals_arr[i] -= mean;
+ variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
@@ -373,28 +380,34 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
+ iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
- vals_arr[i] = (vals_arr[i] - mean) * rsqrtf(variance);
+ vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
- vals[row * row_stride + i * iteration_stride + id] = vals_arr[i];
+ vals[i * iteration_stride + id] = vals_arr[i];
+ }
+ if ((high_index) < row_stride) {
+ vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
+ vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
+ vals[high_index] = vals_arr[iterations];
}
}
-template
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
- bool training = false,
- __half* vars = nullptr,
- __half* vals_hat = nullptr,
- bool save_vals = false)
+ bool training,
+ __half* vars,
+ int row_stride)
{
#if __CUDA_ARCH__ >= 700
- constexpr int iteration_stride = row_stride / iterations;
+
+ int iteration_stride = blockDim.x;
+ int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
@@ -403,20 +416,29 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
int id = threadIdx.x;
int gid = id >> 5;
- __half2 vals_arr[iterations];
- float2 vals_f[iterations];
- __shared__ float shr[iteration_stride >> 5];
+ float2 vals_f[NORM_REG];
+ __shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast(residual);
+ residual_cast += (row * row_stride);
+ vals_cast += (row * row_stride);
+
float sum = 0.f;
+ int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
- vals_f[i] = __half22float2(residual_cast[row * row_stride + i * iteration_stride + id]);
+ vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
+ if ((high_index) < row_stride) {
+ vals_f[iterations] = __half22float2(residual_cast[high_index]);
+ sum += vals_f[iterations].x;
+ sum += vals_f[iterations].y;
+ iterations++;
+ }
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
@@ -436,8 +458,10 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
- variance += (vals_f[i].x - mean) * (vals_f[i].x - mean);
- variance += (vals_f[i].y - mean) * (vals_f[i].y - mean);
+ vals_f[i].x -= mean;
+ vals_f[i].y -= mean;
+ variance += vals_f[i].x * vals_f[i].x;
+ variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
@@ -457,19 +481,25 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
variance /= (row_stride * 2);
variance += epsilon;
- __half2 mean_h = __float2half2_rn(mean);
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast(gamma);
const __half2* beta_cast = reinterpret_cast(beta);
if (training && g.thread_rank() == 0) vars[row] = __float2half(variance);
+ iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
- vals_arr[i] = __float22half2_rn(vals_f[i]);
- vals_arr[i] = (vals_arr[i] - mean_h) * h2rsqrt(variance_h);
- vals_arr[i] = vals_arr[i] * gamma_cast[i * iteration_stride + id] +
- beta_cast[i * iteration_stride + id];
- vals_cast[row * row_stride + i * iteration_stride + id] = vals_arr[i];
+ __half2 vals_arr = __float22half2_rn(vals_f[i]);
+ vals_arr = vals_arr * h2rsqrt(variance_h);
+ vals_arr =
+ vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
+ vals_cast[i * iteration_stride + id] = vals_arr;
+ }
+ if ((high_index) < row_stride) {
+ __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
+ vals_arr = vals_arr * h2rsqrt(variance_h);
+ vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
+ vals_cast[high_index] = vals_arr;
}
#endif
}
@@ -481,14 +511,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
- T* vars,
- T* vals_hat,
- bool save_vals);
+ T* vars);
/*
To tune this launch the following restrictions must be met:
@@ -512,90 +539,29 @@ void launch_bias_residual_layer_norm(float* vals,
const float* beta,
float epsilon,
int batch_size,
- int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
- float* vars,
- float* vals_hat,
- bool save_vals)
+ float* vars)
{
- constexpr int threads = THREADS;
-
- dim3 grid_dim(batch_size * sequence_length);
+ int threads = THREADS;
- dim3 block_dim(threads);
+ dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
- if (hidden_dim == 768)
- fused_bias_residual_layer_norm<768, 3><<>>(vals,
- residual,
- gamma,
- beta,
- epsilon,
- preLayerNorm,
- training,
- vars,
- vals_hat,
- save_vals);
- else if (hidden_dim == 512)
- fused_bias_residual_layer_norm<512, 2><<>>(vals,
- residual,
- gamma,
- beta,
- epsilon,
- preLayerNorm,
- training,
- vars,
- vals_hat,
- save_vals);
- else if (hidden_dim == 1024)
- fused_bias_residual_layer_norm<1024, 4><<>>(vals,
- residual,
- gamma,
- beta,
- epsilon,
- preLayerNorm,
- training,
- vars,
- vals_hat,
- save_vals);
- else if (hidden_dim == 1536)
- fused_bias_residual_layer_norm<1536, 6><<>>(vals,
- residual,
- gamma,
- beta,
- epsilon,
- preLayerNorm,
- training,
- vars,
- vals_hat,
- save_vals);
- else if (hidden_dim == 2048)
- fused_bias_residual_layer_norm<2048, 8><<>>(vals,
- residual,
- gamma,
- beta,
- epsilon,
- preLayerNorm,
- training,
- vars,
- vals_hat,
- save_vals);
- else if (hidden_dim == 2560)
- fused_bias_residual_layer_norm<2560, 10><<