Skip to content

Commit

Permalink
ARIMA: pre-allocation of temporary memory to reduce latencies (#3895)
Browse files Browse the repository at this point in the history
This PR can speed up the evaluation of the log-likelihood in ARIMA by 5x for non-seasonal datasets (the impact is smaller for seasonal datasets). It achieves this by pre-allocating all the temporary memory only once instead of every iteration and providing all the pointers with a very low overhead thanks to a dedicated structure. Additionally, I removed some unnecessary copies.

![arima_memory](https://user-images.githubusercontent.com/17441062/119530801-a44ff100-bd83-11eb-9278-3f9071521553.png)

Regarding the unnecessary synchronizations, I'll fix that later in a separate PR. Note that non-seasonal ARIMA performance is now even more limited by the python-side solver bottleneck:

![optimizer_bottleneck](https://user-images.githubusercontent.com/17441062/119531952-b8e0b900-bd84-11eb-88cc-b58497b283fc.png)

One problem is that batched matrix operations are quite memory-hungry so I've duplicated or refactored some bits to avoid allocating extra memory there, but that leads to some duplication that I'm not entirely happy with. Both the ARIMA code and batched matrix prims are due some refactoring.

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3895
  • Loading branch information
Nyrio authored Jun 1, 2021
1 parent 95efa25 commit 94be76f
Show file tree
Hide file tree
Showing 9 changed files with 759 additions and 306 deletions.
15 changes: 12 additions & 3 deletions cpp/bench/sg/arima_loglikelihood.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ class ArimaLoglikelihood : public TsFixtureRandom<DataT> {

// Benchmark loop
this->loopOnState(state, [this]() {
ARIMAMemory<double> arima_mem(order, this->params.batch_size,
this->params.n_obs, temp_mem);

// Evaluate log-likelihood
batched_loglike(*this->handle, this->data.X, this->params.batch_size,
this->params.n_obs, order, param, loglike, residual, true,
false);
batched_loglike(*this->handle, arima_mem, this->data.X,
this->params.batch_size, this->params.n_obs, order, param,
loglike, residual, true, false);
});
}

Expand All @@ -86,6 +89,11 @@ class ArimaLoglikelihood : public TsFixtureRandom<DataT> {
this->params.batch_size * sizeof(DataT), stream);
residual = (DataT*)allocator->allocate(
this->params.batch_size * this->params.n_obs * sizeof(DataT), stream);

// Temporary memory
size_t temp_buf_size = ARIMAMemory<double>::compute_size(
order, this->params.batch_size, this->params.n_obs);
temp_mem = (char*)allocator->allocate(temp_buf_size, stream);
}

void deallocateBuffers(const ::benchmark::State& state) {
Expand All @@ -110,6 +118,7 @@ class ArimaLoglikelihood : public TsFixtureRandom<DataT> {
DataT* param;
DataT* loglike;
DataT* residual;
char* temp_mem;
};

std::vector<ArimaParams> getInputs() {
Expand Down
160 changes: 159 additions & 1 deletion cpp/include/cuml/tsa/arima_common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -207,4 +207,162 @@ struct ARIMAParams {
}
};

/**
* Structure to manage ARIMA temporary memory allocations
* @note The user is expected to give a preallocated buffer to the constructor,
* and ownership is not transferred to this struct! The buffer must be allocated
* as long as the object lives, and deallocated afterwards.
*/
template <typename T, int ALIGN = 256>
struct ARIMAMemory {
T *params_mu, *params_ar, *params_ma, *params_sar, *params_sma,
*params_sigma2, *Tparams_mu, *Tparams_ar, *Tparams_ma, *Tparams_sar,
*Tparams_sma, *Tparams_sigma2, *d_params, *d_Tparams, *Z_dense, *R_dense,
*T_dense, *RQR_dense, *RQ_dense, *P_dense, *alpha_dense, *ImT_dense,
*ImT_inv_dense, *T_values, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense,
*vs, *y_diff, *loglike, *loglike_base, *loglike_pert, *x_pert, *F_buffer,
*sumLogF_buffer, *sigma2_buffer, *I_m_AxA_dense, *I_m_AxA_inv_dense,
*Ts_dense, *RQRs_dense, *Ps_dense;
T **Z_batches, **R_batches, **T_batches, **RQR_batches, **RQ_batches,
**P_batches, **alpha_batches, **ImT_batches, **ImT_inv_batches,
**v_tmp_batches, **m_tmp_batches, **K_batches, **TP_batches,
**I_m_AxA_batches, **I_m_AxA_inv_batches, **Ts_batches, **RQRs_batches,
**Ps_batches;
int *T_col_index, *T_row_index, *ImT_inv_P, *ImT_inv_info, *I_m_AxA_P,
*I_m_AxA_info;

size_t size;

protected:
char* buf;

template <bool assign, typename ValType>
inline void append_buffer(ValType*& ptr, size_t n_elem) {
if (assign) {
ptr = reinterpret_cast<ValType*>(buf + size);
}
size += ((n_elem * sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN;
}

template <bool assign>
inline void buf_offsets(const ARIMAOrder& order, int batch_size, int n_obs,
char* in_buf = nullptr) {
buf = in_buf;
size = 0;

int r = order.r();
int rd = order.rd();
int N = order.complexity();
int n_diff = order.n_diff();

append_buffer<assign>(params_mu, order.k * batch_size);
append_buffer<assign>(params_ar, order.p * batch_size);
append_buffer<assign>(params_ma, order.q * batch_size);
append_buffer<assign>(params_sar, order.P * batch_size);
append_buffer<assign>(params_sma, order.Q * batch_size);
append_buffer<assign>(params_sigma2, batch_size);

append_buffer<assign>(Tparams_mu, order.k * batch_size);
append_buffer<assign>(Tparams_ar, order.p * batch_size);
append_buffer<assign>(Tparams_ma, order.q * batch_size);
append_buffer<assign>(Tparams_sar, order.P * batch_size);
append_buffer<assign>(Tparams_sma, order.Q * batch_size);
append_buffer<assign>(Tparams_sigma2, batch_size);

append_buffer<assign>(d_params, N * batch_size);
append_buffer<assign>(d_Tparams, N * batch_size);
append_buffer<assign>(Z_dense, rd * batch_size);
append_buffer<assign>(Z_batches, batch_size);
append_buffer<assign>(R_dense, rd * batch_size);
append_buffer<assign>(R_batches, batch_size);
append_buffer<assign>(T_dense, rd * rd * batch_size);
append_buffer<assign>(T_batches, batch_size);
append_buffer<assign>(RQ_dense, rd * batch_size);
append_buffer<assign>(RQ_batches, batch_size);
append_buffer<assign>(RQR_dense, rd * rd * batch_size);
append_buffer<assign>(RQR_batches, batch_size);
append_buffer<assign>(P_dense, rd * rd * batch_size);
append_buffer<assign>(P_batches, batch_size);
append_buffer<assign>(alpha_dense, rd * batch_size);
append_buffer<assign>(alpha_batches, batch_size);
append_buffer<assign>(ImT_dense, r * r * batch_size);
append_buffer<assign>(ImT_batches, batch_size);
append_buffer<assign>(ImT_inv_dense, r * r * batch_size);
append_buffer<assign>(ImT_inv_batches, batch_size);
append_buffer<assign>(ImT_inv_P, r * batch_size);
append_buffer<assign>(ImT_inv_info, batch_size);
append_buffer<assign>(T_values, rd * rd * batch_size);
append_buffer<assign>(T_col_index, rd * rd);
append_buffer<assign>(T_row_index, rd + 1);
append_buffer<assign>(v_tmp_dense, rd * batch_size);
append_buffer<assign>(v_tmp_batches, batch_size);
append_buffer<assign>(m_tmp_dense, rd * rd * batch_size);
append_buffer<assign>(m_tmp_batches, batch_size);
append_buffer<assign>(K_dense, rd * batch_size);
append_buffer<assign>(K_batches, batch_size);
append_buffer<assign>(TP_dense, rd * rd * batch_size);
append_buffer<assign>(TP_batches, batch_size);
append_buffer<assign>(F_buffer, n_obs * batch_size);
append_buffer<assign>(sumLogF_buffer, batch_size);
append_buffer<assign>(sigma2_buffer, batch_size);

append_buffer<assign>(vs, n_obs * batch_size);
append_buffer<assign>(y_diff, n_obs * batch_size);
append_buffer<assign>(loglike, batch_size);
append_buffer<assign>(loglike_base, batch_size);
append_buffer<assign>(loglike_pert, batch_size);
append_buffer<assign>(x_pert, N * batch_size);

if (n_diff > 0) {
append_buffer<assign>(Ts_dense, r * r * batch_size);
append_buffer<assign>(Ts_batches, batch_size);
append_buffer<assign>(RQRs_dense, r * r * batch_size);
append_buffer<assign>(RQRs_batches, batch_size);
append_buffer<assign>(Ps_dense, r * r * batch_size);
append_buffer<assign>(Ps_batches, batch_size);
}

if (r <= 5) {
// Note: temp mem for the direct Lyapunov solver grows very quickly!
// This solver is used iff the condition above is satisifed
append_buffer<assign>(I_m_AxA_dense, r * r * r * r * batch_size);
append_buffer<assign>(I_m_AxA_batches, batch_size);
append_buffer<assign>(I_m_AxA_inv_dense, r * r * r * r * batch_size);
append_buffer<assign>(I_m_AxA_inv_batches, batch_size);
append_buffer<assign>(I_m_AxA_P, r * r * batch_size);
append_buffer<assign>(I_m_AxA_info, batch_size);
}
}

/** Protected constructor to estimate max size */
ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs) {
buf_offsets<false>(order, batch_size, n_obs);
}

public:
/** Constructor to create pointers from buffer
* @param[in] order ARIMA order
* @param[in] batch_size Number of series in the batch
* @param[in] n_obs Length of the series
* @param[in] in_buf Pointer to the temporary memory buffer.
* Ownership is retained by the caller
*/
ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs,
char* in_buf) {
buf_offsets<true>(order, batch_size, n_obs, in_buf);
}

/** Static method to get the size of the required buffer allocation
* @param[in] order ARIMA order
* @param[in] batch_size Number of series in the batch
* @param[in] n_obs Length of the series
* @return Buffer size in bytes
*/
static size_t compute_size(const ARIMAOrder& order, int batch_size,
int n_obs) {
ARIMAMemory temp(order, batch_size, n_obs);
return temp.size;
}
};

} // namespace ML
52 changes: 31 additions & 21 deletions cpp/include/cuml/tsa/batched_arima.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void batched_diff(raft::handle_t& handle, double* d_y_diff, const double* d_y,
* in a batched context.
*
* @param[in] handle cuML handle
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] batch_size Number of time series
Expand All @@ -91,13 +92,14 @@ void batched_diff(raft::handle_t& handle, double* d_y_diff, const double* d_y,
* @param[out] d_lower Lower limit of the prediction interval
* @param[out] d_upper Upper limit of the prediction interval
*/
void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size,
int n_obs, const ARIMAOrder& order, const double* d_params,
double* loglike, double* d_vs, bool trans = true,
bool host_loglike = true, LoglikeMethod method = MLE,
int truncate = 0, int fc_steps = 0, double* d_fc = nullptr,
double level = 0, double* d_lower = nullptr,
double* d_upper = nullptr);
void batched_loglike(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem, const double* d_y,
int batch_size, int n_obs, const ARIMAOrder& order,
const double* d_params, double* loglike, double* d_vs,
bool trans = true, bool host_loglike = true,
LoglikeMethod method = MLE, int truncate = 0,
int fc_steps = 0, double* d_fc = nullptr, double level = 0,
double* d_lower = nullptr, double* d_upper = nullptr);

/**
* Compute the loglikelihood of the given parameter on the given time series
Expand All @@ -107,6 +109,7 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size,
* to avoid useless packing / unpacking
*
* @param[in] handle cuML handle
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] batch_size Number of time series
Expand All @@ -129,8 +132,9 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size,
* @param[out] d_lower Lower limit of the prediction interval
* @param[out] d_upper Upper limit of the prediction interval
*/
void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size,
int n_obs, const ARIMAOrder& order,
void batched_loglike(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem, const double* d_y,
int batch_size, int n_obs, const ARIMAOrder& order,
const ARIMAParams<double>& params, double* loglike,
double* d_vs, bool trans = true, bool host_loglike = true,
LoglikeMethod method = MLE, int truncate = 0,
Expand All @@ -141,6 +145,7 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size,
* Compute the gradient of the log-likelihood
*
* @param[in] handle cuML handle
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] batch_size Number of time series
Expand All @@ -154,17 +159,19 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size,
* @param[in] truncate For CSS, start the sum-of-squares after a given
* number of observations
*/
void batched_loglike_grad(raft::handle_t& handle, const double* d_y,
int batch_size, int n_obs, const ARIMAOrder& order,
const double* d_x, double* d_grad, double h,
bool trans = true, LoglikeMethod method = MLE,
int truncate = 0);
void batched_loglike_grad(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const double* d_y, int batch_size, int n_obs,
const ARIMAOrder& order, const double* d_x,
double* d_grad, double h, bool trans = true,
LoglikeMethod method = MLE, int truncate = 0);

/**
* Batched in-sample and out-of-sample prediction of a time-series given all
* the model parameters
*
* @param[in] handle cuML handle
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Batched Time series to predict.
* Shape: (num_samples, batch size) (device)
* @param[in] batch_size Total number of batched time series
Expand All @@ -181,16 +188,17 @@ void batched_loglike_grad(raft::handle_t& handle, const double* d_y,
* @param[out] d_lower Lower limit of the prediction interval
* @param[out] d_upper Upper limit of the prediction interval
*/
void predict(raft::handle_t& handle, const double* d_y, int batch_size,
int n_obs, int start, int end, const ARIMAOrder& order,
const ARIMAParams<double>& params, double* d_y_p,
bool pre_diff = true, double level = 0, double* d_lower = nullptr,
double* d_upper = nullptr);
void predict(raft::handle_t& handle, const ARIMAMemory<double>& arima_mem,
const double* d_y, int batch_size, int n_obs, int start, int end,
const ARIMAOrder& order, const ARIMAParams<double>& params,
double* d_y_p, bool pre_diff = true, double level = 0,
double* d_lower = nullptr, double* d_upper = nullptr);

/**
* Compute an information criterion (AIC, AICc, BIC)
*
* @param[in] handle cuML handle
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] batch_size Total number of batched time series
Expand All @@ -203,8 +211,10 @@ void predict(raft::handle_t& handle, const double* d_y, int batch_size,
* @param[in] ic_type Type of information criterion wanted.
* 0: AIC, 1: AICc, 2: BIC
*/
void information_criterion(raft::handle_t& handle, const double* d_y,
int batch_size, int n_obs, const ARIMAOrder& order,
void information_criterion(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const double* d_y, int batch_size, int n_obs,
const ARIMAOrder& order,
const ARIMAParams<double>& params, double* ic,
int ic_type);

Expand Down
21 changes: 12 additions & 9 deletions cpp/include/cuml/tsa/batched_kalman.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace ML {
* provide the resulting prediction as well as loglikelihood fit.
*
* @param[in] handle cuML handle
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_ys_b Batched time series
* Shape (nobs, batch_size) (col-major, device)
* @param[in] nobs Number of samples per time series
Expand All @@ -46,20 +47,20 @@ namespace ML {
* @param[out] d_lower Lower limit of the prediction interval
* @param[out] d_upper Upper limit of the prediction interval
*/
void batched_kalman_filter(raft::handle_t& handle, const double* d_ys_b,
int nobs, const ARIMAParams<double>& params,
const ARIMAOrder& order, int batch_size,
double* d_loglike, double* d_vs, int fc_steps = 0,
double* d_fc = nullptr, double level = 0,
double* d_lower = nullptr,
double* d_upper = nullptr);
void batched_kalman_filter(
raft::handle_t& handle, const ARIMAMemory<double>& arima_mem,
const double* d_ys_b, int nobs, const ARIMAParams<double>& params,
const ARIMAOrder& order, int batch_size, double* d_loglike, double* d_vs,
int fc_steps = 0, double* d_fc = nullptr, double level = 0,
double* d_lower = nullptr, double* d_upper = nullptr);

/**
* Convenience function for batched "jones transform" used in ARIMA to ensure
* certain properties of the AR and MA parameters (takes host array and
* returns host array)
*
* @param[in] handle cuML handle
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] order ARIMA hyper-parameters
* @param[in] batch_size Number of time series analyzed.
* @param[in] isInv Do the inverse transform?
Expand All @@ -68,7 +69,9 @@ void batched_kalman_filter(raft::handle_t& handle, const double* d_ys_b,
* (expects pre-allocated array of size
* (p+q)*batch_size) (host)
*/
void batched_jones_transform(raft::handle_t& handle, const ARIMAOrder& order,
int batch_size, bool isInv, const double* h_params,
void batched_jones_transform(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const ARIMAOrder& order, int batch_size,
bool isInv, const double* h_params,
double* h_Tparams);
} // namespace ML
Loading

0 comments on commit 94be76f

Please sign in to comment.