Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion CUDA Optimizations Part 4 #14680

Merged
merged 7 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2310,12 +2310,12 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
</dl>

#### Inputs (2 - 6)
#### Inputs (1 - 6)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)</dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size)</dd>
Expand Down
86 changes: 56 additions & 30 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,53 +22,79 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
int max_threads_per_block) {
// key_padding_mask (K/V) : (B) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
// When no packing for q/k/v:
// query (Q) : (B, S, D)
// key (K) : (B, L, D)
// value (V) : (B, L, D_v)
// bias (Q/K/V) : (D + D + D_v)
// key_padding_mask (K/V) : (B) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
// When packed kv is used:
// query (Q) : (B, S, D)
// key (K) : (B, L, N, 2, H)
// value (V) : None
// bias (Q/K/V) : None
// When packed qkv is used:
// query (Q) : (B, L, N, 3, H)
// key (K) : None
// value (V) : None
// bias (Q/K/V) : None

const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
if (query_dims.size() != 3 && query_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ",
query_dims.size());
}

const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

int batch_size = static_cast<int>(query_dims[0]);
int sequence_length = static_cast<int>(query_dims[1]);
int hidden_size = static_cast<int>(query_dims[2]);
int hidden_size = query_dims.size() == 3 ? static_cast<int>(query_dims[2]) : (num_heads * static_cast<int>(query_dims[4]));
int head_size = static_cast<int>(hidden_size) / num_heads;
int kv_sequence_length = static_cast<int>(key_dims[1]);
int kv_sequence_length = sequence_length;

if (key != nullptr) {
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ",
query_dims.size());
}

if (key_dims.size() == 3) {
if (key_dims[2] != query_dims[2]) {
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 2 (hidden_size)");
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}
} else // if (key_dims.size() == 5)
{
if (static_cast<int>(key_dims[2]) != num_heads || static_cast<int>(key_dims[3]) != 2 || static_cast<int>(key_dims[4]) != head_size) {

if (key_dims.size() == 3) {
if (key_dims[2] != query_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 2 (hidden_size)");
}
} else // if (key_dims.size() == 5)
{
if (static_cast<int>(key_dims[2]) != num_heads || static_cast<int>(key_dims[3]) != 2 || static_cast<int>(key_dims[4]) != head_size) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
}
}

kv_sequence_length = static_cast<int>(key_dims[1]);
} else { // packed QKV
if (query_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 5 dimensions when key is empty, got ",
query_dims.size());
}
if (static_cast<int>(query_dims[2]) != num_heads || static_cast<int>(query_dims[3]) != 3) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
"Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv");
}
}

Expand All @@ -82,17 +108,17 @@ Status CheckInputs(const T* query,
// Currently, bias is not allowed for packed KV. This constraint can be removed later.
// Here we assume that fusion tool will not include bias for packed KV.
if (value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. ");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed qkv or kv. ");
}
}

AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
const auto& mask_dims = key_padding_mask->Shape().GetDims();
if (mask_dims.size() == 1 && mask_dims[0] == key_dims[0]) {
if (mask_dims.size() == 1 && mask_dims[0] == static_cast<int64_t>(batch_size)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
} else if (mask_dims.size() == 2 && mask_dims[0] == key_dims[0] && mask_dims[1] == key_dims[1]) {
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
}

Expand All @@ -115,7 +141,7 @@ Status CheckInputs(const T* query,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}

if (key_dims[1] != value_dims[1]) {
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)");
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -190,6 +191,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner,
use_fused_cross_attention,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());

Expand All @@ -204,12 +206,15 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.mask_index_dims = (nullptr == mask_index) ? gsl::span<const int64_t>() : mask_index->Shape().GetDims();
data.past = (nullptr == past) ? nullptr : reinterpret_cast<const CudaT*>(past->Data<T>());
data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast<const CudaT*>(relative_position_bias->Data<T>());
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast<CudaT*>(work_space.get());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = nullptr;
data.use_memory_efficient_attention = use_memory_efficient_attention;
data.cumulated_sequence_length_q_cache = nullptr;
data.cumulated_sequence_length_kv_cache = nullptr;

return QkvToContext<CudaT>(device_prop, cublas, Stream(context), parameters, data);
}
Expand Down
Loading