-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cherry-pick LLaMA/SDXL to rel-1.16.2 (#18202)
Cherry-pick changes related to LLaMA and StableDiffusion XL to 1.16.2 release branch. ### Motivation and Context --------- Co-authored-by: kunal-vaishnavi <[email protected]> Co-authored-by: Patrice Vignola <[email protected]> Co-authored-by: petermcaughan <[email protected]> Co-authored-by: Peter McAughan <[email protected]> Co-authored-by: Jambay Kinley <[email protected]> Co-authored-by: PeixuanZuo <[email protected]> Co-authored-by: Ye Wang <[email protected]> Co-authored-by: Your Name <[email protected]> Co-authored-by: aciddelgado <[email protected]> Co-authored-by: [email protected] <[email protected]> Co-authored-by: Yufeng Li <[email protected]> Co-authored-by: JiCheng <[email protected]> Co-authored-by: Justin Chu <[email protected]>
- Loading branch information
1 parent
0240274
commit c273f7a
Showing
221 changed files
with
28,481 additions
and
4,055 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "contrib_ops/cpu/bert/rotary_embedding.h" | ||
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" | ||
|
||
#include "core/platform/threadpool.h" | ||
|
||
using onnxruntime::concurrency::ThreadPool; | ||
using namespace onnxruntime::contrib::rotary_embedding_helper; | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
// These ops are internal-only, so register outside of onnx | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
RotaryEmbedding, | ||
kMSDomain, | ||
1, | ||
float, | ||
kCpuExecutionProvider, | ||
KernelDefBuilder() | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) | ||
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), | ||
RotaryEmbedding<float>); | ||
|
||
template <typename T> | ||
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { | ||
scale = info.GetAttrOrDefault<float>("scale", 1.0); | ||
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); | ||
} | ||
|
||
template <typename T> | ||
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const { | ||
const Tensor* input = context->Input<Tensor>(0); | ||
const Tensor* position_ids = context->Input<Tensor>(1); | ||
const Tensor* cos_cache = context->Input<Tensor>(2); | ||
const Tensor* sin_cache = context->Input<Tensor>(3); | ||
|
||
RotaryParameters parameters = {}; | ||
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input, | ||
position_ids, | ||
cos_cache, | ||
sin_cache, | ||
¶meters)); | ||
|
||
Tensor* output = context->Output(0, input->Shape()); | ||
|
||
if (parameters.sequence_length > parameters.max_sequence_length) { | ||
// Launch update_cos_sin_cache kernel with scale | ||
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); | ||
} | ||
|
||
const T* input_src = input->Data<T>(); | ||
const int64_t* pos_ids_data = position_ids->Data<int64_t>(); | ||
const T* cos_cache_data = cos_cache->Data<T>(); | ||
const T* sin_cache_data = sin_cache->Data<T>(); | ||
T* output_dest = output->MutableData<T>(); | ||
|
||
const int batch_size = parameters.batch_size; | ||
const int sequence_length = parameters.sequence_length; | ||
const int num_heads = parameters.num_heads; | ||
const int head_size = parameters.head_size; | ||
const int position_ids_format = parameters.position_ids_format; | ||
const int half_head_size = head_size / 2; | ||
|
||
AllocatorPtr allocator; | ||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); | ||
auto* tp = context->GetOperatorThreadPool(); | ||
|
||
const int loop_len = batch_size * sequence_length * num_heads; | ||
const double cost = static_cast<double>(head_size); | ||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { | ||
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { | ||
const int b = static_cast<int>((ptr / num_heads) / sequence_length); | ||
const int s = static_cast<int>((ptr / num_heads) % sequence_length); | ||
const int n = static_cast<int>(ptr % num_heads); | ||
|
||
const int block_offset = b * sequence_length * num_heads + s * num_heads + n; | ||
const int data_offset = block_offset * head_size; | ||
|
||
const T* input_data = input_src + data_offset; | ||
T* output_data = output_dest + data_offset; | ||
|
||
// Cache is (M, H/2) | ||
const int position_id = (position_ids_format == 0) | ||
? static_cast<int>(pos_ids_data[0]) + s | ||
: static_cast<int>(pos_ids_data[b * sequence_length + s]); | ||
const int cache_offset = position_id * half_head_size; | ||
const T* cos_data = cos_cache_data + cache_offset; | ||
const T* sin_data = sin_cache_data + cache_offset; | ||
|
||
int cache_idx = 0; | ||
T sign = 0; | ||
int j = 0; | ||
for (int i = 0; i < head_size; i++) { | ||
if (interleaved) { | ||
cache_idx = (i / 2) % half_head_size; | ||
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1); | ||
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign | ||
} else { | ||
cache_idx = i % half_head_size; | ||
sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1); | ||
j = (i + half_head_size) % head_size; | ||
} | ||
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; | ||
} | ||
} | ||
}); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
Oops, something went wrong.