Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Dec 8, 2022
1 parent c54445f commit 08cb1e8
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,6 @@ Status GreedySearchProcessLogits(
int step, // iteration counter
void* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
#ifndef DEBUG_GENERATION
ORT_UNUSED_PARAMETER(dumper);
#endif

int batch_size = parameters->batch_size;
int vocab_size = parameters->vocab_size;
Expand Down Expand Up @@ -458,18 +455,18 @@ Status GreedySearchProcessLogits(
dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size);
#endif

constexpr unsigned top_k = 1;

if (do_sampling) {
SamplingCpuHelper::TopPSamplingCpu<T> top_p_sampler(allocator,
thread_pool,
sampling_state,
greedy_state,
parameters);
parameters,
dumper);
ORT_RETURN_IF_ERROR(top_p_sampler.Sample(next_token_scores));

return Status::OK();
}

// next_tokens = torch.argmax(scores, dim=-1)
int64_t next_token_scores_dims[] = {static_cast<int64_t>(batch_size), vocab_size};
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
Expand All @@ -482,6 +479,7 @@ Status GreedySearchProcessLogits(
next_token_scores_value);
const Tensor& input = next_token_scores_value.Get<Tensor>();

constexpr unsigned top_k = 1;
constexpr int axis = 1;
constexpr bool largest = true;
constexpr bool sorted = false;
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ struct ISamplingState {
BufferUniquePtr storage_buffer;
size_t temp_storage_bytes;
std::default_random_engine generator;

std::vector<T> sorted_scores;
std::vector<size_t> sorted_indices;
std::vector<T> cumulative_probs;
};

class ISequences {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@ struct SamplingState : public ISamplingState<T> {
this->h_sampled_all = AllocateBuffer<float>(cpu_allocator, h_sampled_all_buffer_, SafeInt<size_t>(batch_size * max_iter));
this->d_indices = AllocateBuffer<int64_t>(allocator, d_indices_buffer_, SafeInt<size_t>(batch_size));
this->temp_storage_bytes = 0;
// TODO(wy): Do not allocate this buffer if there's no presence_mask
// TODO: Do not allocate this buffer if there's no presence_mask
this->d_presence_mask = AllocateBuffer<int>(allocator, d_presence_mask_buffer_, SafeInt<size_t>(total_count));

std::uniform_real_distribution<float> distribution(0.0, 1.0);
distribution(this->generator);
for (size_t i = 0; i < this->h_sampled_all.size(); ++i) {
this->h_sampled_all[i] = distribution(this->generator);
}
} else {
this->sorted_scores.reserve(total_count);
this->sorted_indices.reserve(total_count);
this->cumulative_probs.reserve(total_count);
}
}

Expand Down
70 changes: 38 additions & 32 deletions onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ class TopPSamplingCpu{
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ISamplingState<T>* sampling_state,
transformers::IGreedySearchState<T>* greedy_state,
const transformers::IGenerationParameters* parameters):
const transformers::IGenerationParameters* parameters,
const transformers::IConsoleDumper* dumper):
allocator_(allocator),
thread_pool_(thread_pool),
sampling_state_(sampling_state),
greedy_state_(greedy_state),
parameters_(parameters) {}
parameters_(parameters),
dumper_(dumper) {}

Status Sample(gsl::span<T>& next_token_scores);

Expand All @@ -30,6 +32,7 @@ class TopPSamplingCpu{
transformers::ISamplingState<T>* sampling_state_;
transformers::IGreedySearchState<T>* greedy_state_;
const transformers::IGenerationParameters* parameters_;
const transformers::IConsoleDumper* dumper_;
};

template <typename T>
Expand All @@ -46,40 +49,43 @@ Status TopPSamplingCpu<T>::Sample(gsl::span<T>& next_token_scores) {
ORT_THROW("top_p shall be greater than 0");
}

for (int i = 0; i < parameters_->batch_size; i++) {
gsl::span<T> next_token_score = next_token_scores.subspan(i * parameters_->vocab_size,
parameters_->vocab_size);

// Copy the vector
std::vector<T> sorted_score(next_token_score.begin(), next_token_score.end());

// Decending sort
std::vector<size_t> sorted_indice(parameters_->vocab_size);
std::iota(sorted_indice.begin(), sorted_indice.end(), 0);
std::sort(sorted_indice.begin(),
sorted_indice.end(),
[&sorted_score](size_t i1, size_t i2) {
return sorted_score[i1] > sorted_score[i2];
std::vector<T>& sorted_scores = sampling_state_->sorted_scores;
sorted_scores.assign(next_token_scores.begin(), next_token_scores.end());
// Decending sort
std::vector<size_t>& sorted_indices = sampling_state_->sorted_indices;

for (size_t i = 0; i < static_cast<size_t>(parameters_->batch_size); i++) {
auto indices_begin = sorted_indices.begin() + i * parameters_->vocab_size;
auto indices_end = sorted_indices.begin() + (i + 1) * parameters_->vocab_size;
std::iota(indices_begin, indices_end, 0);
std::sort(indices_begin, indices_end,
[&next_token_scores](size_t i1, size_t i2) {
return next_token_scores[i1] > next_token_scores[i2];
});

std::sort(sorted_score.begin(), sorted_score.end(), std::greater<T>());
std::vector<T> cumulative_prob(parameters_->vocab_size);
std::sort(sorted_scores.begin() + i * parameters_->vocab_size,
sorted_scores.end() + (i + 1) * parameters_->vocab_size,
std::greater<T>());
}

std::vector<T>& cumulative_probs = sampling_state_->cumulative_probs;

// TODO: batch
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(1,
parameters_->vocab_size,
sorted_score.data(),
cumulative_prob.data(),
false,
thread_pool_));
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters_->batch_size,
parameters_->vocab_size,
sorted_scores.data(),
cumulative_probs.data(),
false,
thread_pool_));

if (cumulative_prob[0] > parameters_->top_p) {
filter_scores(sorted_indice, next_token_score, 1);
for (size_t i = 0; i < static_cast<size_t>(parameters_->batch_size); i++) {
size_t offset = i * parameters_->vocab_size;
if (cumulative_probs[offset] > parameters_->top_p) {
filter_scores(sorted_indices, next_token_scores, 1 + offset);
}
for (size_t i = 1; i < static_cast<size_t>(parameters_->vocab_size) - 1; i++) {
cumulative_prob[i] += cumulative_prob[i - 1];
if (cumulative_prob[i] > parameters_->top_p) {
filter_scores(sorted_indice, next_token_score, i + 1);
for (size_t j = 1; j < static_cast<size_t>(parameters_->vocab_size) - 1; j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] > parameters_->top_p) {
filter_scores(sorted_indices, next_token_scores, j + offset + 1);
}
}
}
Expand Down Expand Up @@ -132,7 +138,7 @@ Status TopPSamplingCpu<T>::Sample(gsl::span<T>& next_token_scores) {
// TODO: update presense_mask()

#ifdef DEBUG_GENERATION
dumper->Print("sampled_idx", *sampled_idx);
dumper_->Print("sampled_idx", *sampled_idx);
#endif

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,8 @@ Status GreedySearchProcessLogits(
cuda_stream,
sampling_state,
greedy_state,
parameters);
parameters,
dumper);
ORT_RETURN_IF_ERROR(top_p_sampler.Sample(step, next_token_scores));

return Status::OK();
Expand Down
39 changes: 27 additions & 12 deletions onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ class TopPSamplingCuda{
cudaStream_t cuda_stream,
transformers::ISamplingState<T>* sampling_state,
transformers::IGreedySearchState<T>* greedy_state,
const transformers::IGenerationParameters* parameters):
const transformers::IGenerationParameters* parameters,
const transformers::IConsoleDumper* dumper):
allocator_(allocator),
cuda_stream_(cuda_stream),
sampling_state_(sampling_state),
greedy_state_(greedy_state),
parameters_(parameters) {}
parameters_(parameters),
dumper_(dumper) {}

Status Sample(int step, gsl::span<T>& next_token_scores);

Expand All @@ -37,6 +39,7 @@ class TopPSamplingCuda{
transformers::ISamplingState<T>* sampling_state_;
transformers::IGreedySearchState<T>* greedy_state_;
const transformers::IGenerationParameters* parameters_;
const transformers::IConsoleDumper* dumper_;
};

template <typename T>
Expand All @@ -63,7 +66,7 @@ Status TopPSamplingCuda<T>::Sample(int step, gsl::span<T>& next_token_scores) {
cuda_stream_);

#ifdef DEBUG_GENERATION
dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1);
dumper_->Print("d_offset_buffer", d_offset.data(), parameters_->batch_size + 1, 1);
#endif

void* temp_storage = allocator_->Alloc(sampling_state_->temp_storage_bytes);
Expand All @@ -75,7 +78,7 @@ Status TopPSamplingCuda<T>::Sample(int step, gsl::span<T>& next_token_scores) {
gsl::span<int>& d_index_out = sampling_state_->d_index_out;

#ifdef DEBUG_GENERATION
dumper->Print("temp_storage_bytes", temp_storage_bytes, true);
dumper_->Print("temp_storage_bytes", sampling_state_->temp_storage_bytes, true);
#endif

cuda::LaunchSortPairsDescending<CudaT>(storage_buffer.get(),
Expand All @@ -90,9 +93,12 @@ Status TopPSamplingCuda<T>::Sample(int step, gsl::span<T>& next_token_scores) {
cuda_stream_);

#ifdef DEBUG_GENERATION
dumper->Print("d_sorted_score_buffer", reinterpret_cast<T*>(d_sorted_score.data()), batch_size, vocab_size);
dumper->Print("d_index_buffer_in", d_index_in.data(), batch_size, vocab_size);
dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size);
dumper_->Print("d_sorted_score_buffer",
reinterpret_cast<T*>(d_sorted_score.data()),
parameters_->batch_size,
parameters_->vocab_size);
dumper_->Print("d_index_buffer_in", d_index_in.data(), parameters_->batch_size, parameters_->vocab_size);
dumper_->Print("d_index_buffer_out", d_index_out.data(), parameters_->batch_size, parameters_->vocab_size);
#endif

gsl::span<float>& d_sorted_softmaxed_score = sampling_state_->d_sorted_softmaxed_score;
Expand All @@ -105,7 +111,10 @@ Status TopPSamplingCuda<T>::Sample(int step, gsl::span<T>& next_token_scores) {
parameters_->batch_size);

#ifdef DEBUG_GENERATION
dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), batch_size, vocab_size);
dumper_->Print("d_sorted_softmaxed_score_buffer",
d_sorted_softmaxed_score.data(),
parameters_->batch_size,
parameters_->vocab_size);
#endif

cuda::LaunchFilterLogitsKernel<CudaT>(d_sorted_softmaxed_score.data(),
Expand All @@ -118,7 +127,10 @@ Status TopPSamplingCuda<T>::Sample(int step, gsl::span<T>& next_token_scores) {
cuda_stream_);

#ifdef DEBUG_GENERATION
dumper->Print("next_token_scores after filtering logits", reinterpret_cast<T*>(next_token_scores.data()), batch_size, vocab_size);
dumper_->Print("next_token_scores after filtering logits",
reinterpret_cast<T*>(next_token_scores.data()),
parameters_->batch_size,
parameters_->vocab_size);
#endif

// TODO(wy): Can we only do softmax at the very beginning and sort the softmaxed scores.
Expand All @@ -132,7 +144,10 @@ Status TopPSamplingCuda<T>::Sample(int step, gsl::span<T>& next_token_scores) {
parameters_->batch_size);

#ifdef DEBUG_GENERATION
dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), batch_size, vocab_size);
dumper_->Print("d_softmaxed_score_buffer",
d_softmaxed_score.data(),
parameters_->batch_size,
parameters_->vocab_size);
#endif

// Multinomial sampling
Expand All @@ -145,7 +160,7 @@ Status TopPSamplingCuda<T>::Sample(int step, gsl::span<T>& next_token_scores) {
cuda_stream_));

#ifdef DEBUG_GENERATION
dumper->Print("d_sampled", d_sampled.data(), batch_size, 1);
dumper_->Print("d_sampled", d_sampled.data(), parameters_->batch_size, 1);
#endif

gsl::span<int64_t>& d_indices = sampling_state_->d_indices;
Expand All @@ -159,7 +174,7 @@ Status TopPSamplingCuda<T>::Sample(int step, gsl::span<T>& next_token_scores) {
cuda_stream_);

#ifdef DEBUG_GENERATION
dumper->Print("d_indices", d_indices.data(), batch_size, 1);
dumper_->Print("d_indices", d_indices.data(), parameters_->batch_size, 1);
#endif

CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state_->next_tokens_cpu.data(),
Expand Down

0 comments on commit 08cb1e8

Please sign in to comment.