Skip to content

Commit

Permalink
Do not use threading in LSTM.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Oct 3, 2019
1 parent 7222aea commit 04f3c76
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 85 deletions.
5 changes: 0 additions & 5 deletions onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ class DeepCpuAttnLstmOp final : public OpKernel {

ActivationFuncs activation_funcs_;

// Threadpool for operator. If concurrent Compute calls are possible, it will be shared
// across them. mutable due to this.
// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation
// cost on every call.
mutable onnxruntime::concurrency::ThreadPool ttp_{"DEEPCPU_ATTN_LSTM", (int)std::thread::hardware_concurrency()};
};

} // namespace contrib
Expand Down
17 changes: 5 additions & 12 deletions onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ class UniDirectionalLstm {
const gsl::span<const T>& initial_hidden_state, const gsl::span<const T>& initial_cell_state,
const ActivationFuncs::Entry& activation_func_f, const ActivationFuncs::Entry& activation_func_g,
const ActivationFuncs::Entry& activation_func_h, float clip,
concurrency::ThreadPool& lstm_tp_,
concurrency::ThreadPool* mlas_tp_);

void Compute(const gsl::span<const T>& inputs, const gsl::span<const int>& sequence_lengths, int num_directions,
Expand Down Expand Up @@ -279,7 +278,6 @@ class UniDirectionalLstm {
ActivationInfo<deepcpu::ActivationFuncPtr> activation_g_;
ActivationInfo<deepcpu::LstmMergeGatesFuncPtr> activation_h_;

concurrency::ThreadPool& lstm_tp_;
concurrency::ThreadPool* mlas_tp_;
};

Expand Down Expand Up @@ -460,15 +458,15 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, lstm_tp_, mlas_thread_pool);
clip_, mlas_thread_pool);

detail::UniDirectionalLstm<T> bw(alloc, logger, seq_length, batch_size, input_size,
hidden_size_, Direction::kReverse, input_forget_,
bias_2, peephole_weights_2, initial_hidden_2, initial_cell_2,
activation_funcs_.Entries()[3],
activation_funcs_.Entries()[4],
activation_funcs_.Entries()[5],
clip_, lstm_tp_, mlas_thread_pool);
clip_, mlas_thread_pool);

fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1,
output_1, hidden_output_1, last_cell_1);
Expand All @@ -481,7 +479,7 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, lstm_tp_, mlas_thread_pool);
clip_, mlas_thread_pool);

fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1,
output_1, hidden_output_1, last_cell_1);
Expand Down Expand Up @@ -554,7 +552,6 @@ UniDirectionalLstm<T>::UniDirectionalLstm(AllocatorPtr allocator,
const ActivationFuncs::Entry& activation_func_g,
const ActivationFuncs::Entry& activation_func_h,
const float clip,
concurrency::ThreadPool& lstm_tp,
concurrency::ThreadPool* mlas_tp)
: allocator_(allocator),
logger_(logger),
Expand All @@ -567,7 +564,6 @@ UniDirectionalLstm<T>::UniDirectionalLstm(AllocatorPtr allocator,
clip_(clip),
use_bias_(!bias.empty()),
use_peepholes_(!peephole_weights.empty()),
lstm_tp_(lstm_tp),
mlas_tp_(mlas_tp) {
activation_f_ = {deepcpu::ActivationFuncByName(activation_func_f.name),
activation_func_f.alpha,
Expand Down Expand Up @@ -885,7 +881,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
}
};

ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, lstm_tp_, logger_);
ExecuteLambdaInParallel("Processing batch", hidden_gemm_and_activations, batch_size_, fused_hidden_rows, nullptr, logger_);

} else {
span_T_const_iter previous_state_end = batched_hidden_state_one_step.cend();
Expand Down Expand Up @@ -1124,10 +1120,7 @@ void UniDirectionalLstm<T>::GateComputations(span_T_iter& out, span_T_iter& out_

template <typename T>
void UniDirectionalLstm<T>::SetNumThreads() {
int threads = std::thread::hardware_concurrency() - 1;

if (threads < 1)
threads = 1;
int threads = 1;

hidden_num_threads_ = threads;
batch_parallel_ = false;
Expand Down
7 changes: 0 additions & 7 deletions onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,6 @@ class DeepCpuLstmOp final : public OpKernel {
bool input_forget_ = false;

rnn::detail::ActivationFuncs activation_funcs_;

// Threadpool for operator. If concurrent Compute calls are possible, it will be shared
// across them. mutable due to this.
// The alternative would be to create a threadpool in each call to Compute but that would incur thread creation
// cost on every call.
mutable onnxruntime::concurrency::ThreadPool lstm_tp_{"DEEPCPU_LSTM",
static_cast<int>(std::thread::hardware_concurrency())};
};

} // namespace onnxruntime
116 changes: 55 additions & 61 deletions onnxruntime/core/providers/cpu/rnn/rnn_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,78 +212,72 @@ T* SafeRawPointer(typename gsl::span<T> span, size_t offset, size_t size) {

template <typename TLambda>
void ExecuteLambdaInParallel(const std::string& name, TLambda lambda, int max, int step,
onnxruntime::concurrency::ThreadPool& ttp,
onnxruntime::concurrency::ThreadPool* ttp,
const ::onnxruntime::logging::Logger& logger) {
// #define NOTHREADS to execute the lambdas directly and in order if you need to do that to debug

#ifdef NOTHREADS
ORT_UNUSED_PARAMETER(ttp);
ORT_UNUSED_PARAMETER(logger);

for (int i = 0; i < max; i += step) {
(void)name;
std::bind(lambda, i)();
}
#else


ORT_UNUSED_PARAMETER(name);
ORT_UNUSED_PARAMETER(logger);

// ORT_ENFORCE may and does throw at times from within the tasks that run
// on a thread-pool. Without propagating exceptions the process exits silently
// which will make diagnosing bugs more difficult.

// \! UGLY
// We have a problem here with the current thread-pool is that it takes std::function
// by value and copies it more than once (even though it is movable).
//
// To report status and exceptions properly it's better to use
// futures and promises but they are not copyable, so we can't come up with a functor
// with a promise member and we are downgrading to C++11 where we can't have captures that moved in.
//
// At the same time promises MUST live in the child thread so if we throw from the main thread
// we don't destroy any promises that are on the main thread stack which children threads may still be using.
//
// The only solution with the current Eigen that comes to mind is to have shared_ptr to with std::promise.
//
const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
std::vector<std::future<void> > futures;
futures.reserve(total_tasks);

for (int i = 0, t = 0; i < max; i += step, ++t) {
auto p_ptr = std::make_shared<std::promise<void> >();
futures.push_back(p_ptr->get_future());
ttp.Schedule([p_ptr, lambda, i]() {
if (ttp == nullptr){
for (int i = 0; i < max; i += step) {
std::bind(lambda, i)();
}
} else {
// ORT_ENFORCE may and does throw at times from within the tasks that run
// on a thread-pool. Without propagating exceptions the process exits silently
// which will make diagnosing bugs more difficult.

// \! UGLY
// We have a problem here with the current thread-pool is that it takes std::function
// by value and copies it more than once (even though it is movable).
//
// To report status and exceptions properly it's better to use
// futures and promises but they are not copyable, so we can't come up with a functor
// with a promise member and we are downgrading to C++11 where we can't have captures that moved in.
//
// At the same time promises MUST live in the child thread so if we throw from the main thread
// we don't destroy any promises that are on the main thread stack which children threads may still be using.
//
// The only solution with the current Eigen that comes to mind is to have shared_ptr to with std::promise.
//
const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
std::vector<std::future<void> > futures;
futures.reserve(total_tasks);

for (int i = 0, t = 0; i < max; i += step, ++t) {
auto p_ptr = std::make_shared<std::promise<void> >();
futures.push_back(p_ptr->get_future());
ttp->Schedule([p_ptr, lambda, i]() {
try {
lambda(i);
p_ptr->set_value();
} catch (...) {
p_ptr->set_exception(std::current_exception());
}
});
}

// We'd like to wait until all of the tasks have finished
// even though one or more have already thrown. We will store
// the first exception and then will re-throw at the end.
std::exception_ptr pending_exception;
for (auto& fut : futures) {
try {
lambda(i);
p_ptr->set_value();
// get() will re-throw any exceptions
// the running task may throw
fut.get();
} catch (...) {
p_ptr->set_exception(std::current_exception());
}
});
}

// We'd like to wait until all of the tasks have finished
// even though one or more have already thrown. We will store
// the first exception and then will re-throw at the end.
std::exception_ptr pending_exception;
for (auto& fut : futures) {
try {
// get() will re-throw any exceptions
// the running task may throw
fut.get();
} catch (...) {
if (!pending_exception) {
pending_exception = std::current_exception();
if (!pending_exception) {
pending_exception = std::current_exception();
}
}
}
}

if (pending_exception) {
std::rethrow_exception(pending_exception);
if (pending_exception) {
std::rethrow_exception(pending_exception);
}
}

#endif
}

void DumpMatrixImpl(const std::string& name, const float* src, int row, int col,
Expand Down

0 comments on commit 04f3c76

Please sign in to comment.