Skip to content

Commit

Permalink
refactor cpu kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Dec 8, 2022
1 parent a63ebf1 commit 44866c4
Show file tree
Hide file tree
Showing 12 changed files with 463 additions and 439 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Status BeamSearchBase<T>::Initialize() {
if (!IsCuda()) {
// Logits processor is used in CPU only. In CUDA, cuda kernels are used instead.
// Initialize processors after CheckInputs so that parameters_->vocab_mask is ready.
logits_processors_.Init(*parameters_, thread_pool_);
logits_processors_.Init(*parameters_);
}

return Status::OK();
Expand Down
18 changes: 6 additions & 12 deletions onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,13 @@ class GenerateBase {

if (attention_mask != nullptr) {
const auto& dims_attn = attention_mask->Shape().GetDims();
if (dims_attn.size() == 2) {
if (!SpanEq(dims_attn, dims)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have same shape as input_ids");
}
} else if (dims_attn.size() == 4) {
if (dims_attn[0] != dims[0] || dims_attn[1] != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to shape [batch_size, 1, max_sequence_length, max_sequence_length]");
}
} else {
if (dims_attn.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size());
}
if (!SpanEq(dims_attn, dims)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have 2 or 4 dimensions, got ", dims_attn.size());
"Input 'attention_mask' is expected to have same shape as input_ids");
}
}

Expand Down
168 changes: 62 additions & 106 deletions onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/generation_device_helper.h"
#include "contrib_ops/cpu/transformers/sampling_cpu_helper.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"

Expand Down Expand Up @@ -138,8 +139,7 @@ Status CreateGptInputs(
OrtValue attention_mask;
if (attn_mask_value != nullptr) {
const Tensor& attn_mask = attn_mask_value->Get<Tensor>();
const TensorShape& attn_mask_shape = attn_mask.Shape(); // 2d or 4d
Tensor::InitOrtValue(element_type, attn_mask_shape, const_cast<Tensor*>(&attn_mask)->MutableData<int32_t>(),
Tensor::InitOrtValue(element_type, input_ids_shape, const_cast<Tensor*>(&attn_mask)->MutableData<int32_t>(),
allocator->Info(), attention_mask);
} else {
auto mask_type = DataTypeImpl::GetType<int32_t>();
Expand Down Expand Up @@ -181,13 +181,13 @@ Status CreateGptInputs(
expanded_input_ids = std::move(input_ids);
expanded_position_ids = std::move(position_ids);
expanded_attention_mask = std::move(attention_mask);
} else {
// bugbug: 4d not supported here
ExpandInputs<int32_t>(input_ids, num_beams, allocator, expanded_input_ids);
ExpandInputs<int32_t>(position_ids, num_beams, allocator, expanded_position_ids);
ExpandInputs<int32_t>(attention_mask, num_beams, allocator, expanded_attention_mask);
return Status::OK();
}

ExpandInputs<int32_t>(input_ids, num_beams, allocator, expanded_input_ids);
ExpandInputs<int32_t>(position_ids, num_beams, allocator, expanded_position_ids);
ExpandInputs<int32_t>(attention_mask, num_beams, allocator, expanded_attention_mask);

return Status::OK();
}

Expand Down Expand Up @@ -408,7 +408,7 @@ template <typename T>
Status GreedySearchProcessLogits(
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingCudaState<T>* sampling_state, // sampling_state
transformers::ISamplingState<T>* sampling_state, // sampling_state
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
Expand All @@ -418,7 +418,6 @@ Status GreedySearchProcessLogits(
int step, // iteration counter
void* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
ORT_UNUSED_PARAMETER(sampling_state);
#ifndef DEBUG_GENERATION
ORT_UNUSED_PARAMETER(dumper);
#endif
Expand Down Expand Up @@ -462,94 +461,51 @@ Status GreedySearchProcessLogits(
constexpr unsigned top_k = 1;

if (do_sampling) {
// bugbug: is this Softmax really needed?
gsl::span<T>& next_token_probs = greedy_state->next_token_probs;
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(batch_size,
vocab_size,
next_token_scores.data(),
next_token_probs.data(),
false,
thread_pool));

// torch.multinomial()
int64_t next_token_probs_dims[] = {static_cast<int64_t>(batch_size), vocab_size};
TensorShape next_token_probs_shape(&next_token_probs_dims[0], 2);
auto element_type = DataTypeImpl::GetType<T>();
OrtValue next_token_probs_value;
Tensor::InitOrtValue(element_type,
next_token_probs_shape,
next_token_probs.data(),
allocator->Info(),
next_token_probs_value);
const Tensor& input = next_token_probs_value.Get<Tensor>();

std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast<uint32_t>(parameters->seed)};

int64_t sampled_idx_dims[] = {static_cast<int64_t>(batch_size), 1};
TensorShape sampled_idx_shape(&sampled_idx_dims[0], 2);

gsl::span<int64_t>& next_token_idx = greedy_state->next_tokens_cpu;

OrtValue sampled_idx_ov;
Tensor::InitOrtValue(DataTypeImpl::GetType<int64_t>(),
sampled_idx_shape,
next_token_idx.data(),
allocator->Info(),
sampled_idx_ov);
Tensor* sampled_idx = sampled_idx_ov.GetMutable<Tensor>();

AllocatorPtr allocator_temp = allocator;
ORT_RETURN_IF_ERROR(MultinomialComputeShared<int64_t>(allocator_temp,
input,
batch_size,
vocab_size,
1,
generator,
*sampled_idx));
// TODO: update presense_mask
SamplingCpuHelper::TopPSamplingCpu<T> top_p_sampler(allocator,
thread_pool,
sampling_state,
greedy_state,
parameters);
ORT_RETURN_IF_ERROR(top_p_sampler.Sample(next_token_scores));

#ifdef DEBUG_GENERATION
dumper->Print("sampled_idx", *sampled_idx);
#endif
} else {
// next_tokens = torch.argmax(scores, dim=-1)
int64_t next_token_scores_dims[] = {static_cast<int64_t>(batch_size), vocab_size};
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
auto element_type = DataTypeImpl::GetType<T>();
OrtValue next_token_scores_value;
Tensor::InitOrtValue(element_type,
next_token_scores_shape,
next_token_scores.data(),
allocator->Info(),
next_token_scores_value);
const Tensor& input = next_token_scores_value.Get<Tensor>();

constexpr int axis = 1;
constexpr bool largest = true;
constexpr bool sorted = false;

Tensor topk_scores;
Tensor topk_indices;
ORT_RETURN_IF_ERROR(TopK(&input,
axis,
top_k,
largest,
sorted,
allocator,
stream,
thread_pool,
topk_scores,
topk_indices));
return Status::OK();
}
// next_tokens = torch.argmax(scores, dim=-1)
int64_t next_token_scores_dims[] = {static_cast<int64_t>(batch_size), vocab_size};
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
auto element_type = DataTypeImpl::GetType<T>();
OrtValue next_token_scores_value;
Tensor::InitOrtValue(element_type,
next_token_scores_shape,
next_token_scores.data(),
allocator->Info(),
next_token_scores_value);
const Tensor& input = next_token_scores_value.Get<Tensor>();

constexpr int axis = 1;
constexpr bool largest = true;
constexpr bool sorted = false;

Tensor topk_scores;
Tensor topk_indices;
ORT_RETURN_IF_ERROR(TopK(&input,
axis,
top_k,
largest,
sorted,
allocator,
stream,
thread_pool,
topk_scores,
topk_indices));

#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", topk_scores);
dumper->Print("topk_indices", topk_indices);
#endif

gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
gsl::copy(next_token_indices, greedy_state->next_tokens_cpu);
}

gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
gsl::copy(next_token_indices, greedy_state->next_tokens_cpu);

#ifdef DEBUG_GENERATION
gsl::span<const int64_t> next_tokens(greedy_state->next_tokens_cpu.data(),
Expand Down Expand Up @@ -624,6 +580,7 @@ Status UpdateGptFeeds(
// last_outputs: logits, present_0, present_1, ...
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
ORT_UNUSED_PARAMETER(stream);

// The following updates inputs for subgraph

// Update input_ids with next tokens.
Expand All @@ -639,6 +596,7 @@ Status UpdateGptFeeds(
input_ids_data[i] = beam_next_tokens[i];
}
next_inputs[0] = input_ids;

if (increase_position) {
// Update position IDs
int32_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int32_t>();
Expand All @@ -647,24 +605,22 @@ Status UpdateGptFeeds(
}
}
next_inputs[1] = position_ids;

// Update attention mask
const OrtValue& old_mask = next_inputs[2];
const auto& mask_dims = old_mask.Get<Tensor>().Shape().GetDims();
if (mask_dims.size() == 2) {
const int32_t* old_mask_data = old_mask.Get<Tensor>().Data<int32_t>();
int64_t mask_dims[] = {batch_beam_size, current_length};
TensorShape mask_shape(&mask_dims[0], 2);
OrtValue attention_mask;
Tensor::InitOrtValue(int32_type, mask_shape, allocator, attention_mask);
int32_t* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < current_length - 1; j++) {
mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j];
}
mask_data[i * current_length + current_length - 1] = 1;
const int32_t* old_mask_data = old_mask.Get<Tensor>().Data<int32_t>();
int64_t mask_dims[] = {batch_beam_size, current_length};
TensorShape mask_shape(&mask_dims[0], 2);
OrtValue attention_mask;
Tensor::InitOrtValue(int32_type, mask_shape, allocator, attention_mask);
int32_t* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < current_length - 1; j++) {
mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j];
}
next_inputs[2] = attention_mask;
} // if mask_dims.size() == 4 do nothing
mask_data[i * current_length + current_length - 1] = 1;
}
next_inputs[2] = attention_mask;

// Update past state
if (num_beams == 1) {
Expand Down Expand Up @@ -894,7 +850,7 @@ template Status ProcessLogits<float>(
template Status GreedySearchProcessLogits<float>(
const OrtValue& logits,
transformers::IGreedySearchState<float>* greedy_state,
transformers::ISamplingCudaState<float>* sampling_state,
transformers::ISamplingState<float>* sampling_state,
transformers::ISequences* sequences,
AllocatorPtr& allocator,
onnxruntime::concurrency::ThreadPool* thread_pool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ template <typename T>
using GreedySearchProcessLogitsFunc = std::function<Status(
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingCudaState<T>* sampling_state, // sampling buffers
transformers::ISamplingState<T>* sampling_state, // sampling buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
Expand Down Expand Up @@ -209,7 +209,7 @@ Status ProcessLogits(const OrtValue& logits, //
template <typename T>
Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingCudaState<T>* sampling_state, // sampling buffers
transformers::ISamplingState<T>* sampling_state, // sampling buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <utility>
#include <random>
#include "core/common/gsl.h"
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
Expand Down Expand Up @@ -33,7 +34,7 @@ struct IBeamSearchState {
gsl::span<float> scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
gsl::span<float> remaining_scores; // portion of scores that is available for appending next token scores.
gsl::span<float> topk_buffer; // temp buffer for topk computation, including:
// 1st stage needs:
// 1st stage needs:
// temp score: (batch_size * num_beams * parts_vocab, 2 * num_beams)
// temp token: (batch_size * num_beams * parts_vocab, 2 * num_beams)
// 2nd stage needs:
Expand Down Expand Up @@ -62,12 +63,11 @@ struct IGreedySearchState {
gsl::span<int32_t> next_positions; // shape (batch_size, num_beams). Next position value for position_ids.
gsl::span<bool> eos_meet; // shape (batch_size)
gsl::span<T> next_token_scores; // shape (batch_size, vocab_size)
gsl::span<T> next_token_probs; // shape (batch_size, vocab_size)
gsl::span<int32_t> next_tokens; // shape (batch_size)
};

template <typename T>
struct ISamplingCudaState {
struct ISamplingState {
gsl::span<int> d_index_in;
gsl::span<int> d_index_out;
gsl::span<int> d_offset;
Expand All @@ -82,6 +82,7 @@ struct ISamplingCudaState {

BufferUniquePtr storage_buffer;
size_t temp_storage_bytes;
std::default_random_engine generator;
};

class ISequences {
Expand Down
Loading

0 comments on commit 44866c4

Please sign in to comment.