-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support dynamic sequence length #320
Changes from all commits
a308637
0d3039b
c896b5c
aad95c4
dc3e064
5963997
00439bd
643c33b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,11 +121,17 @@ class BertTransformerLayer { | |
|
||
void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, | ||
uint8_t* attn_output_dropout_mask_ptr, | ||
uint8_t* layer_output_dropout_mask_ptr); | ||
uint8_t* layer_output_dropout_mask_ptr, | ||
T* layer_norm_var, | ||
T* layer_norm_mean, | ||
T* attn_layer_norm_var, | ||
T* attn_layer_norm_mean); | ||
|
||
inline int GetBatchSize() const { return _batch_size; } | ||
inline int GetNumHeads() const { return _heads; } | ||
inline int GetSeqLength() const { return _seq_length; } | ||
|
||
void SetSeqLength(int seq_len, int bsz); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this used somewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes it is used in ds_transformer_cuda.cpp: https://github.com/microsoft/DeepSpeed/blob/reyazda/support_dynamic_seqlength/csrc/transformer/ds_transformer_cuda.cpp#L708 |
||
inline int GetHiddenSize() const { return _hidden_size; } | ||
void SetTrainingMode(bool training); | ||
|
||
|
@@ -150,8 +156,8 @@ class BertTransformerLayer { | |
// layers | ||
FeedForward<T> _qkv_linear; | ||
FeedForward<T> _attn_out_linear; | ||
Normalize_Layer<T> _norm_layer2; | ||
Normalize_Layer<T> _norm_layer3; | ||
Normalize_Layer<T> _attn_layer_norm; | ||
Normalize_Layer<T> _layer_norm; | ||
Normalize_Layer<T>* _last_normalize; | ||
FeedForward<T> _ff1, _ff2; | ||
Softmax<T> _softmax; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,8 @@ template <typename T> | |
class Gelu { | ||
public: | ||
struct Config { | ||
uint32_t batch_size; | ||
uint32_t seq_length; | ||
uint32_t intermediate_size; | ||
Config(uint32_t batch, uint32_t seq, uint32_t inter_size) | ||
: batch_size(batch), seq_length(seq), intermediate_size(inter_size) | ||
{ | ||
} | ||
Config(uint32_t inter_size) : intermediate_size(inter_size) {} | ||
}; | ||
|
||
Gelu(const Config& config) : _config(config) {} | ||
|
@@ -28,14 +23,12 @@ class Gelu { | |
T* output, | ||
cudaStream_t stream) | ||
{ | ||
launch_bias_gelu<T>( | ||
input_buf, bias, output, _config.intermediate_size, bsz, _config.seq_length, stream); | ||
launch_bias_gelu<T>(input_buf, bias, output, _config.intermediate_size, bsz, stream); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. both "batch" and "seq_length" in config can be removed? |
||
} | ||
|
||
void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream) | ||
{ | ||
launch_d_gelu<T>( | ||
d_output, input_buf, bias, _config.intermediate_size, bsz, _config.seq_length, stream); | ||
launch_d_gelu<T>(d_output, input_buf, bias, _config.intermediate_size, bsz, stream); | ||
} | ||
|
||
private: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks "batch" is useless in config, remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I remove that