Skip to content

Commit

Permalink
Change calculation of ARIMA confidence intervals (#4248)
Browse files Browse the repository at this point in the history
The formula that we have been using since I added support for confidence intervals in ARIMA is slightly different than the one used in statsmodels. The difference is in particular quite pronounced when datasets have missing observations, which pushed me to raise tolerance for the intervals unit tests when I added test cases in the recent PR #4058.

In this PR, I change our calculation to match statsmodels, and decrease the corresponding test tolerance, as we now have a strict match with statsmodels.

Previous formula:
```python
lower_t = fc_t - sqrt(2) * erfinv(level) * sqrt(F_t * mean(v_i**2 / F_i))
upper_t = fc_t + sqrt(2) * erfinv(level) * sqrt(F_t * mean(v_i**2 / F_i))
```
New formula:
```python
lower_t = fc_t - sqrt(2) * erfinv(level) * sqrt(F_t)
upper_t = fc_t + sqrt(2) * erfinv(level) * sqrt(F_t)
```

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #4248
  • Loading branch information
Nyrio authored Oct 11, 2021
1 parent 5408c6b commit 0c13f44
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 45 deletions.
5 changes: 2 additions & 3 deletions cpp/include/cuml/tsa/arima_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ struct ARIMAMemory {
*Tparams_ar, *Tparams_ma, *Tparams_sar, *Tparams_sma, *Tparams_sigma2, *d_params, *d_Tparams,
*Z_dense, *R_dense, *T_dense, *RQR_dense, *RQ_dense, *P_dense, *alpha_dense, *ImT_dense,
*ImT_inv_dense, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *pred, *y_diff, *loglike,
*loglike_base, *loglike_pert, *x_pert, *sigma2_buffer, *I_m_AxA_dense, *I_m_AxA_inv_dense,
*Ts_dense, *RQRs_dense, *Ps_dense;
*loglike_base, *loglike_pert, *x_pert, *I_m_AxA_dense, *I_m_AxA_inv_dense, *Ts_dense,
*RQRs_dense, *Ps_dense;
T **Z_batches, **R_batches, **T_batches, **RQR_batches, **RQ_batches, **P_batches,
**alpha_batches, **ImT_batches, **ImT_inv_batches, **v_tmp_batches, **m_tmp_batches,
**K_batches, **TP_batches, **I_m_AxA_batches, **I_m_AxA_inv_batches, **Ts_batches,
Expand Down Expand Up @@ -279,7 +279,6 @@ struct ARIMAMemory {
append_buffer<assign>(K_batches, batch_size);
append_buffer<assign>(TP_dense, rd * rd * batch_size);
append_buffer<assign>(TP_batches, batch_size);
append_buffer<assign>(sigma2_buffer, batch_size);

append_buffer<assign>(pred, n_obs * batch_size);
append_buffer<assign>(y_diff, n_obs * batch_size);
Expand Down
53 changes: 13 additions & 40 deletions cpp/src/arima/batched_kalman.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ DI void MM_l(const double* A, const double* B, double* out)
* @param[in] batch_size Batch size
* @param[out] d_pred Predictions (nobs)
* @param[out] d_loglike Log-likelihood (1)
* @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1)
* @param[in] n_diff d + s*D
* @param[in] fc_steps Number of steps to forecast
* @param[out] d_fc Array to store the forecast
Expand All @@ -116,7 +115,6 @@ __global__ void batched_kalman_loop_kernel(const double* ys,
int batch_size,
double* d_pred,
double* d_loglike,
double* d_ll_sigma2,
int n_diff,
int fc_steps = 0,
double* d_fc = nullptr,
Expand Down Expand Up @@ -257,7 +255,6 @@ __global__ void batched_kalman_loop_kernel(const double* ys,
{
double n_obs_ll_f = static_cast<double>(n_obs_ll);
b_ll_s2 /= n_obs_ll_f;
if (conf_int) d_ll_sigma2[bid] = b_ll_s2;
d_loglike[bid] = -.5 * (b_sum_logFs + n_obs_ll_f * (b_ll_s2 + log(2 * M_PI)));
}

Expand Down Expand Up @@ -342,7 +339,6 @@ union KalmanLoopSharedMemory {
* @param[in] rd State vector dimension
* @param[out] d_pred Predictions (nobs)
* @param[out] d_loglike Log-likelihood (1)
* @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1)
* @param[in] n_diff d + s*D
* @param[in] fc_steps Number of steps to forecast
* @param[out] d_fc Array to store the forecast
Expand All @@ -365,7 +361,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys,
int rd,
double* d_pred,
double* d_loglike,
double* d_ll_sigma2,
int n_diff,
int fc_steps,
double* d_fc,
Expand Down Expand Up @@ -603,7 +598,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys,
if (threadIdx.x == 0) {
double n_obs_ll_f = static_cast<double>(n_obs_ll);
ll_s2 /= n_obs_ll_f;
if (conf_int) d_ll_sigma2[bid] = ll_s2;
d_loglike[bid] = -.5 * (sum_logFs + n_obs_ll_f * (ll_s2 + log(2 * M_PI)));
}
}
Expand All @@ -625,7 +619,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys,
* @param[in] rd Dimension of the state vector
* @param[out] d_pred Predictions (nobs)
* @param[out] d_loglike Log-likelihood (1)
* @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1)
* @param[in] n_diff d + s*D
* @param[in] fc_steps Number of steps to forecast
* @param[out] d_fc Array to store the forecast
Expand All @@ -646,7 +639,6 @@ void _batched_kalman_device_loop_large(const ARIMAMemory<double>& arima_mem,
int rd,
double* d_pred,
double* d_loglike,
double* d_ll_sigma2,
int n_diff,
int fc_steps = 0,
double* d_fc = nullptr,
Expand Down Expand Up @@ -690,7 +682,6 @@ void _batched_kalman_device_loop_large(const ARIMAMemory<double>& arima_mem,
rd,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -713,7 +704,6 @@ void batched_kalman_loop(raft::handle_t& handle,
const ARIMAOrder& order,
double* d_pred,
double* d_loglike,
double* d_ll_sigma2,
int fc_steps = 0,
double* d_fc = nullptr,
bool conf_int = false,
Expand Down Expand Up @@ -741,7 +731,6 @@ void batched_kalman_loop(raft::handle_t& handle,
batch_size,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -762,7 +751,6 @@ void batched_kalman_loop(raft::handle_t& handle,
batch_size,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -783,7 +771,6 @@ void batched_kalman_loop(raft::handle_t& handle,
batch_size,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -804,7 +791,6 @@ void batched_kalman_loop(raft::handle_t& handle,
batch_size,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -825,7 +811,6 @@ void batched_kalman_loop(raft::handle_t& handle,
batch_size,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -846,7 +831,6 @@ void batched_kalman_loop(raft::handle_t& handle,
batch_size,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -867,7 +851,6 @@ void batched_kalman_loop(raft::handle_t& handle,
batch_size,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -888,7 +871,6 @@ void batched_kalman_loop(raft::handle_t& handle,
batch_size,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand Down Expand Up @@ -917,7 +899,6 @@ void batched_kalman_loop(raft::handle_t& handle,
rd,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -939,7 +920,6 @@ void batched_kalman_loop(raft::handle_t& handle,
rd,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -963,7 +943,6 @@ void batched_kalman_loop(raft::handle_t& handle,
rd,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -985,7 +964,6 @@ void batched_kalman_loop(raft::handle_t& handle,
rd,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -1008,7 +986,6 @@ void batched_kalman_loop(raft::handle_t& handle,
rd,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -1030,7 +1007,6 @@ void batched_kalman_loop(raft::handle_t& handle,
rd,
d_pred,
d_loglike,
d_ll_sigma2,
n_diff,
fc_steps,
d_fc,
Expand All @@ -1046,25 +1022,21 @@ void batched_kalman_loop(raft::handle_t& handle,
* @note: One block per batch member, one thread per forecast time step
*
* @param[in] d_fc Mean forecasts
* @param[in] d_sigma2 sum(v_t * v_t / F_t) / n_obs_diff
* @param[inout] d_lower Input: F_{n+t}
* Output: lower bound of the confidence intervals
* @param[out] d_upper Upper bound of the confidence intervals
* @param[in] fc_steps Number of forecast steps
* @param[in] n_elem Total number of elements (fc_steps * batch_size)
* @param[in] multiplier Coefficient associated with the confidence level
*/
__global__ void confidence_intervals(const double* d_fc,
const double* d_sigma2,
double* d_lower,
double* d_upper,
int fc_steps,
double multiplier)
__global__ void confidence_intervals(
const double* d_fc, double* d_lower, double* d_upper, int n_elem, double multiplier)
{
int idx = blockIdx.x * fc_steps + threadIdx.x;
double fc = d_fc[idx];
double margin = multiplier * sqrt(d_lower[idx] * d_sigma2[blockIdx.x]);
d_lower[idx] = fc - margin;
d_upper[idx] = fc + margin;
for (int idx = threadIdx.x; idx < n_elem; idx += blockDim.x * gridDim.x) {
double fc = d_fc[idx];
double margin = multiplier * sqrt(d_lower[idx]);
d_lower[idx] = fc - margin;
d_upper[idx] = fc + margin;
}
}

void _lyapunov_wrapper(raft::handle_t& handle,
Expand Down Expand Up @@ -1287,15 +1259,16 @@ void _batched_kalman_filter(raft::handle_t& handle,
order,
d_pred,
d_loglike,
arima_mem.sigma2_buffer,
fc_steps,
d_fc,
level > 0,
d_lower);

if (level > 0) {
confidence_intervals<<<batch_size, fc_steps, 0, stream>>>(
d_fc, arima_mem.sigma2_buffer, d_lower, d_upper, fc_steps, sqrt(2.0) * erfinv(level));
constexpr int TPB_conf = 256;
int n_blocks = raft::ceildiv<int>(fc_steps * batch_size, TPB_conf);
confidence_intervals<<<n_blocks, TPB_conf, 0, stream>>>(
d_fc, d_lower, d_upper, fc_steps * batch_size, sqrt(2.0) * erfinv(level));
CUDA_CHECK(cudaPeekAtLastError());
}
}
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/test/test_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@ def _predict_common(key, data, dtype, start, end, num_steps=None, level=None,
np.testing.assert_allclose(cuml_pred, ref_preds, rtol=0.001, atol=0.01)
if level is not None:
np.testing.assert_allclose(
cuml_lower, ref_lower, rtol=0.03, atol=0.01)
cuml_lower, ref_lower, rtol=0.005, atol=0.01)
np.testing.assert_allclose(
cuml_upper, ref_upper, rtol=0.03, atol=0.01)
cuml_upper, ref_upper, rtol=0.005, atol=0.01)


@pytest.mark.parametrize('key, data', test_data)
Expand Down

0 comments on commit 0c13f44

Please sign in to comment.