Skip to content

Commit

Permalink
Move workspace memory-allocation to PyTorch (microsoft#661)
Browse files Browse the repository at this point in the history
* move workspace memory-allocation to PyTorch

* refine the code based on the comments

* remove unnecessary options

* remove bsz from set_seq_len function
  • Loading branch information
RezaYazdaniAminabadi authored Jan 13, 2021
1 parent e2fbe4d commit 981bc7d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 28 deletions.
14 changes: 3 additions & 11 deletions csrc/includes/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,10 @@ class Context {
return _ctx;
}

void GenWorkSpace(size_t size)
void SetWorkSpace(void* workspace)
{
if (!_workspace) {
assert(_workspace == nullptr);
cudaMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
cudaFree(_workspace);
cudaMalloc(&_workspace, size);
}

_workSpaceSize = size;
if (!workspace) { throw std::runtime_error("Workspace is null."); }
_workspace = workspace;
}

void* GetWorkSpace() { return _workspace; }
Expand Down Expand Up @@ -172,6 +165,5 @@ class Context {
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
std::vector<std::array<int, 3>> _gemm_algos;
};
5 changes: 4 additions & 1 deletion csrc/includes/ds_transformer_cuda.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,13 @@ class BertTransformerLayer {
inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
inline int GetIntermediateSize() const { return _intermediate_size; }

void SetSeqLength(int seq_len, int bsz);
void SetSeqLength(int seq_len);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
inline bool IsTrainingMode() const { return _training; }
inline bool GeluCheckpoint() const { return _gelu_checkpoint; }

private:
void Initialize();
Expand Down
40 changes: 24 additions & 16 deletions csrc/transformer/ds_transformer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ size_t get_workspace_size(int maxBatchSize,
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
}
return workSpacesize * sizeof(T);
return workSpacesize; // * sizeof(T);
}

// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
Expand Down Expand Up @@ -123,7 +123,6 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
gemm_algos[4]))
{
assert(_hidden_size % _heads == 0);
assert(_seq_length <= 1024);

Initialize();
}
Expand All @@ -136,14 +135,6 @@ BertTransformerLayer<T>::~BertTransformerLayer()
template <typename T>
void BertTransformerLayer<T>::Initialize()
{
Context::Instance().GenWorkSpace(get_workspace_size<T>(_batch_size,
_seq_length,
_hidden_size,
_intermediate_size,
_heads,
_training,
_gelu_checkpoint));

if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}

Expand Down Expand Up @@ -574,17 +565,14 @@ void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_
}

template <typename T>
void BertTransformerLayer<T>::SetSeqLength(int seq_len, int bsz)
void BertTransformerLayer<T>::SetSeqLength(int seq_len)
{
_seq_length = seq_len;

_softmax.SetSeqLength(_seq_length);
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);

Context::Instance().GenWorkSpace(get_workspace_size<T>(
bsz, _seq_length, _hidden_size, _intermediate_size, _heads, _training, _gelu_checkpoint));
}

template <typename T>
Expand Down Expand Up @@ -707,9 +695,19 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
int seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len, bsz);
layer->SetSeqLength(seq_len);
}

auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());

auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
Expand Down Expand Up @@ -877,9 +875,19 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
int seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len, bsz);
layer->SetSeqLength(seq_len);
}

auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
grad_output.options());
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());

auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
Expand Down

0 comments on commit 981bc7d

Please sign in to comment.