Skip to content

Commit

Permalink
Add support for exogenous variables to ARIMA (rapidsai#4221)
Browse files Browse the repository at this point in the history
Closes rapidsai#3846 

Adds support for exogenous variables to ARIMA.
All series in the batch must have the same number of exogenous variables, and exogenous variables are not shared across the batch (`exog` therefore has `n_exog * batch_size` columns).

Example:
```python
model = ARIMA(endog=df_endog, exog=df_exog_past, order=(1,0,1),
              seasonal_order=(1,1,1,12), fit_intercept=True,
              simple_differencing=False)
model.fit()
fc, lower, upper = model.forecast(40, exog=df_exog_future, level=0.95)
```

![2021-09-22_exog_fc](https://user-images.githubusercontent.com/17441062/134339807-f815a7a3-98dc-49e5-8599-9607e660597a.png)

Authors:
  - Louis Sugy (https://github.com/Nyrio)
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4221
  • Loading branch information
Nyrio authored Nov 18, 2021
1 parent 1f029bb commit 9ef49be
Show file tree
Hide file tree
Showing 20 changed files with 1,878 additions and 297 deletions.
11 changes: 6 additions & 5 deletions cpp/bench/sg/arima_loglikelihood.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ArimaLoglikelihood : public TsFixtureRandom<DataT> {
batched_loglike(*this->handle,
arima_mem,
this->data.X.data(),
nullptr,
this->params.batch_size,
this->params.n_obs,
order,
Expand Down Expand Up @@ -122,11 +123,11 @@ std::vector<ArimaParams> getInputs()
struct std::vector<ArimaParams> out;
ArimaParams p;
p.data.seed = 12345ULL;
std::vector<ARIMAOrder> list_order = {{1, 1, 1, 0, 0, 0, 0, 0},
{1, 1, 1, 1, 1, 1, 4, 0},
{1, 1, 1, 1, 1, 1, 12, 0},
{1, 1, 1, 1, 1, 1, 24, 0},
{1, 1, 1, 1, 1, 1, 52, 0}};
std::vector<ARIMAOrder> list_order = {{1, 1, 1, 0, 0, 0, 0, 0, 0},
{1, 1, 1, 1, 1, 1, 4, 0, 0},
{1, 1, 1, 1, 1, 1, 12, 0, 0},
{1, 1, 1, 1, 1, 1, 24, 0, 0},
{1, 1, 1, 1, 1, 1, 52, 0, 0}};
std::vector<int> list_batch_size = {10, 100, 1000, 10000};
std::vector<int> list_n_obs = {200, 500, 1000};
for (auto& order : list_order) {
Expand Down
35 changes: 26 additions & 9 deletions cpp/include/cuml/tsa/arima_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,16 @@ struct ARIMAOrder {
int P; // Seasonal order
int D;
int Q;
int s; // Seasonal period
int k; // Fit intercept?
int s; // Seasonal period
int k; // Fit intercept?
int n_exog; // Number of exogenous regressors

inline int n_diff() const { return d + s * D; }
inline int n_phi() const { return p + s * P; }
inline int n_theta() const { return q + s * Q; }
inline int r() const { return std::max(n_phi(), n_theta() + 1); }
inline int rd() const { return n_diff() + r(); }
inline int complexity() const { return p + P + q + Q + k + 1; }
inline int complexity() const { return p + P + q + Q + k + n_exog + 1; }
inline bool need_diff() const { return static_cast<bool>(d + D); }
};

Expand All @@ -58,6 +59,7 @@ struct ARIMAOrder {
template <typename DataT>
struct ARIMAParams {
DataT* mu = nullptr;
DataT* beta = nullptr;
DataT* ar = nullptr;
DataT* ma = nullptr;
DataT* sar = nullptr;
Expand All @@ -77,6 +79,8 @@ struct ARIMAParams {
{
rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource();
if (order.k && !tr) mu = (DataT*)rmm_alloc->allocate(batch_size * sizeof(DataT), stream);
if (order.n_exog && !tr)
beta = (DataT*)rmm_alloc->allocate(order.n_exog * batch_size * sizeof(DataT), stream);
if (order.p) ar = (DataT*)rmm_alloc->allocate(order.p * batch_size * sizeof(DataT), stream);
if (order.q) ma = (DataT*)rmm_alloc->allocate(order.q * batch_size * sizeof(DataT), stream);
if (order.P) sar = (DataT*)rmm_alloc->allocate(order.P * batch_size * sizeof(DataT), stream);
Expand All @@ -97,6 +101,8 @@ struct ARIMAParams {
{
rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource();
if (order.k && !tr) rmm_alloc->deallocate(mu, batch_size * sizeof(DataT), stream);
if (order.n_exog && !tr)
rmm_alloc->deallocate(beta, order.n_exog * batch_size * sizeof(DataT), stream);
if (order.p) rmm_alloc->deallocate(ar, order.p * batch_size * sizeof(DataT), stream);
if (order.q) rmm_alloc->deallocate(ma, order.q * batch_size * sizeof(DataT), stream);
if (order.P) rmm_alloc->deallocate(sar, order.P * batch_size * sizeof(DataT), stream);
Expand All @@ -118,14 +124,19 @@ struct ARIMAParams {
int N = order.complexity();
auto counting = thrust::make_counting_iterator(0);
// The device lambda can't capture structure members...
const DataT *_mu = mu, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma, *_sigma2 = sigma2;
const DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma,
*_sigma2 = sigma2;
thrust::for_each(
thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) {
DataT* param = param_vec + bid * N;
if (order.k) {
*param = _mu[bid];
param++;
}
for (int i = 0; i < order.n_exog; i++) {
param[i] = _beta[order.n_exog * bid + i];
}
param += order.n_exog;
for (int ip = 0; ip < order.p; ip++) {
param[ip] = _ar[order.p * bid + ip];
}
Expand Down Expand Up @@ -160,14 +171,19 @@ struct ARIMAParams {
int N = order.complexity();
auto counting = thrust::make_counting_iterator(0);
// The device lambda can't capture structure members...
DataT *_mu = mu, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma, *_sigma2 = sigma2;
DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma,
*_sigma2 = sigma2;
thrust::for_each(
thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) {
const DataT* param = param_vec + bid * N;
if (order.k) {
_mu[bid] = *param;
param++;
}
for (int i = 0; i < order.n_exog; i++) {
_beta[order.n_exog * bid + i] = param[i];
}
param += order.n_exog;
for (int ip = 0; ip < order.p; ip++) {
_ar[order.p * bid + ip] = param[ip];
}
Expand Down Expand Up @@ -197,11 +213,11 @@ struct ARIMAParams {
*/
template <typename T, int ALIGN = 256>
struct ARIMAMemory {
T *params_mu, *params_ar, *params_ma, *params_sar, *params_sma, *params_sigma2, *Tparams_mu,
T *params_mu, *params_beta, *params_ar, *params_ma, *params_sar, *params_sma, *params_sigma2,
*Tparams_ar, *Tparams_ma, *Tparams_sar, *Tparams_sma, *Tparams_sigma2, *d_params, *d_Tparams,
*Z_dense, *R_dense, *T_dense, *RQR_dense, *RQ_dense, *P_dense, *alpha_dense, *ImT_dense,
*ImT_inv_dense, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *pred, *y_diff, *loglike,
*loglike_base, *loglike_pert, *x_pert, *I_m_AxA_dense, *I_m_AxA_inv_dense, *Ts_dense,
*ImT_inv_dense, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *pred, *y_diff, *exog_diff,
*loglike, *loglike_base, *loglike_pert, *x_pert, *I_m_AxA_dense, *I_m_AxA_inv_dense, *Ts_dense,
*RQRs_dense, *Ps_dense;
T **Z_batches, **R_batches, **T_batches, **RQR_batches, **RQ_batches, **P_batches,
**alpha_batches, **ImT_batches, **ImT_inv_batches, **v_tmp_batches, **m_tmp_batches,
Expand Down Expand Up @@ -236,13 +252,13 @@ struct ARIMAMemory {
int n_diff = order.n_diff();

append_buffer<assign>(params_mu, order.k * batch_size);
append_buffer<assign>(params_beta, order.n_exog * batch_size);
append_buffer<assign>(params_ar, order.p * batch_size);
append_buffer<assign>(params_ma, order.q * batch_size);
append_buffer<assign>(params_sar, order.P * batch_size);
append_buffer<assign>(params_sma, order.Q * batch_size);
append_buffer<assign>(params_sigma2, batch_size);

append_buffer<assign>(Tparams_mu, order.k * batch_size);
append_buffer<assign>(Tparams_ar, order.p * batch_size);
append_buffer<assign>(Tparams_ma, order.q * batch_size);
append_buffer<assign>(Tparams_sar, order.P * batch_size);
Expand Down Expand Up @@ -282,6 +298,7 @@ struct ARIMAMemory {

append_buffer<assign>(pred, n_obs * batch_size);
append_buffer<assign>(y_diff, n_obs * batch_size);
append_buffer<assign>(exog_diff, n_obs * order.n_exog * batch_size);
append_buffer<assign>(loglike, batch_size);
append_buffer<assign>(loglike_base, batch_size);
append_buffer<assign>(loglike_pert, batch_size);
Expand Down
55 changes: 34 additions & 21 deletions cpp/include/cuml/tsa/batched_arima.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ void batched_diff(raft::handle_t& handle,
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] d_exog Exogenous variables: shape = (n_obs, n_exog * batch_size) and
* expects column major data layout. (device)
* @param[in] batch_size Number of time series
* @param[in] n_obs Number of observations in a time series
* @param[in] order ARIMA hyper-parameters
Expand All @@ -101,16 +103,11 @@ void batched_diff(raft::handle_t& handle,
* @param[in] method Whether to use sum-of-squares or Kalman filter
* @param[in] truncate For CSS, start the sum-of-squares after a given
* number of observations
* @param[in] fc_steps Number of steps to forecast
* @param[in] d_fc Array to store the forecast
* @param[in] level Confidence level for prediction intervals. 0 to
* skip the computation. Else 0 < level < 1
* @param[out] d_lower Lower limit of the prediction interval
* @param[out] d_upper Upper limit of the prediction interval
*/
void batched_loglike(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const double* d_y,
const double* d_exog,
int batch_size,
int n_obs,
const ARIMAOrder& order,
Expand All @@ -119,12 +116,7 @@ void batched_loglike(raft::handle_t& handle,
bool trans = true,
bool host_loglike = true,
LoglikeMethod method = MLE,
int truncate = 0,
int fc_steps = 0,
double* d_fc = nullptr,
double level = 0,
double* d_lower = nullptr,
double* d_upper = nullptr);
int truncate = 0);

/**
* Compute the loglikelihood of the given parameter on the given time series
Expand All @@ -137,6 +129,8 @@ void batched_loglike(raft::handle_t& handle,
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] d_exog Exogenous variables: shape = (n_obs, n_exog * batch_size) and
* expects column major data layout. (device)
* @param[in] batch_size Number of time series
* @param[in] n_obs Number of observations in a time series
* @param[in] order ARIMA hyper-parameters
Expand All @@ -149,6 +143,8 @@ void batched_loglike(raft::handle_t& handle,
* number of observations
* @param[in] fc_steps Number of steps to forecast
* @param[in] d_fc Array to store the forecast
* @param[in] d_exog_fut Future values of exogenous variables
* Shape (fc_steps, n_exog * batch_size) (col-major, device)
* @param[in] level Confidence level for prediction intervals. 0 to
* skip the computation. Else 0 < level < 1
* @param[out] d_lower Lower limit of the prediction interval
Expand All @@ -157,20 +153,22 @@ void batched_loglike(raft::handle_t& handle,
void batched_loglike(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const double* d_y,
const double* d_exog,
int batch_size,
int n_obs,
const ARIMAOrder& order,
const ARIMAParams<double>& params,
double* loglike,
bool trans = true,
bool host_loglike = true,
LoglikeMethod method = MLE,
int truncate = 0,
int fc_steps = 0,
double* d_fc = nullptr,
double level = 0,
double* d_lower = nullptr,
double* d_upper = nullptr);
bool trans = true,
bool host_loglike = true,
LoglikeMethod method = MLE,
int truncate = 0,
int fc_steps = 0,
double* d_fc = nullptr,
const double* d_exog_fut = nullptr,
double level = 0,
double* d_lower = nullptr,
double* d_upper = nullptr);

/**
* Compute the gradient of the log-likelihood
Expand All @@ -179,6 +177,8 @@ void batched_loglike(raft::handle_t& handle,
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] d_exog Exogenous variables: shape = (n_obs, n_exog * batch_size) and
* expects column major data layout. (device)
* @param[in] batch_size Number of time series
* @param[in] n_obs Number of observations in a time series
* @param[in] order ARIMA hyper-parameters
Expand All @@ -193,6 +193,7 @@ void batched_loglike(raft::handle_t& handle,
void batched_loglike_grad(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const double* d_y,
const double* d_exog,
int batch_size,
int n_obs,
const ARIMAOrder& order,
Expand All @@ -211,6 +212,10 @@ void batched_loglike_grad(raft::handle_t& handle,
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Batched Time series to predict.
* Shape: (num_samples, batch size) (device)
* @param[in] d_exog Exogenous variables.
* Shape = (n_obs, n_exog * batch_size) (device)
* @param[in] d_exog_fut Future values of exogenous variables
* Shape: (end - n_obs, batch_size) (device)
* @param[in] batch_size Total number of batched time series
* @param[in] n_obs Number of samples per time series
* (all series must be identical)
Expand All @@ -228,6 +233,8 @@ void batched_loglike_grad(raft::handle_t& handle,
void predict(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const double* d_y,
const double* d_exog,
const double* d_exog_fut,
int batch_size,
int n_obs,
int start,
Expand All @@ -247,6 +254,8 @@ void predict(raft::handle_t& handle,
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] d_exog Exogenous variables.
* Shape = (n_obs, n_exog * batch_size) (device)
* @param[in] batch_size Total number of batched time series
* @param[in] n_obs Number of samples per time series
* (all series must be identical)
Expand All @@ -260,6 +269,7 @@ void predict(raft::handle_t& handle,
void information_criterion(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const double* d_y,
const double* d_exog,
int batch_size,
int n_obs,
const ARIMAOrder& order,
Expand All @@ -274,6 +284,8 @@ void information_criterion(raft::handle_t& handle,
* @param[in] params ARIMA parameters (device)
* @param[in] d_y Series to fit: shape = (n_obs, batch_size) and
* expects column major data layout. (device)
* @param[in] d_exog Exogenous variables.
* Shape = (n_obs, n_exog * batch_size) (device)
* @param[in] batch_size Total number of batched time series
* @param[in] n_obs Number of samples per time series
* (all series must be identical)
Expand All @@ -283,6 +295,7 @@ void information_criterion(raft::handle_t& handle,
void estimate_x0(raft::handle_t& handle,
ARIMAParams<double>& params,
const double* d_y,
const double* d_exog,
int batch_size,
int n_obs,
const ARIMAOrder& order,
Expand Down
20 changes: 13 additions & 7 deletions cpp/include/cuml/tsa/batched_kalman.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ namespace ML {
*
* @param[in] handle cuML handle
* @param[in] arima_mem Pre-allocated temporary memory
* @param[in] d_ys_b Batched time series
* @param[in] d_ys Batched time series
* Shape (nobs, batch_size) (col-major, device)
* @param[in] d_exog Batched exogenous variables
* Shape (nobs, n_exog * batch_size) (col-major, device)
* @param[in] nobs Number of samples per time series
* @param[in] params ARIMA parameters (device)
* @param[in] order ARIMA hyper-parameters
Expand All @@ -41,25 +43,29 @@ namespace ML {
* shape=(nobs-d-s*D, batch_size) (device)
* @param[in] fc_steps Number of steps to forecast
* @param[in] d_fc Array to store the forecast
* @param[in] d_exog_fut Future values of exogenous variables
* Shape (fc_steps, n_exog * batch_size) (col-major, device)
* @param[in] level Confidence level for prediction intervals. 0 to
* skip the computation. Else 0 < level < 1
* @param[out] d_lower Lower limit of the prediction interval
* @param[out] d_upper Upper limit of the prediction interval
*/
void batched_kalman_filter(raft::handle_t& handle,
const ARIMAMemory<double>& arima_mem,
const double* d_ys_b,
const double* d_ys,
const double* d_exog,
int nobs,
const ARIMAParams<double>& params,
const ARIMAOrder& order,
int batch_size,
double* d_loglike,
double* d_pred,
int fc_steps = 0,
double* d_fc = nullptr,
double level = 0,
double* d_lower = nullptr,
double* d_upper = nullptr);
int fc_steps = 0,
double* d_fc = nullptr,
const double* d_exog_fut = nullptr,
double level = 0,
double* d_lower = nullptr,
double* d_upper = nullptr);

/**
* Convenience function for batched "jones transform" used in ARIMA to ensure
Expand Down
Loading

0 comments on commit 9ef49be

Please sign in to comment.