From 0c13f4488ec98e2d155b657c1442d4e8b8f57e97 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Mon, 11 Oct 2021 20:25:36 +0200 Subject: [PATCH] Change calculation of ARIMA confidence intervals (#4248) The formula that we have been using since I added support for confidence intervals in ARIMA is slightly different than the one used in statsmodels. The difference is in particular quite pronounced when datasets have missing observations, which pushed me to raise tolerance for the intervals unit tests when I added test cases in the recent PR #4058. In this PR, I change our calculation to match statsmodels, and decrease the corresponding test tolerance, as we now have a strict match with statsmodels. Previous formula: ```python lower_t = fc_t - sqrt(2) * erfinv(level) * sqrt(F_t * mean(v_i**2 / F_i)) upper_t = fc_t + sqrt(2) * erfinv(level) * sqrt(F_t * mean(v_i**2 / F_i)) ``` New formula: ```python lower_t = fc_t - sqrt(2) * erfinv(level) * sqrt(F_t) upper_t = fc_t + sqrt(2) * erfinv(level) * sqrt(F_t) ``` Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/4248 --- cpp/include/cuml/tsa/arima_common.h | 5 ++- cpp/src/arima/batched_kalman.cu | 53 +++++++---------------------- python/cuml/test/test_arima.py | 4 +-- 3 files changed, 17 insertions(+), 45 deletions(-) diff --git a/cpp/include/cuml/tsa/arima_common.h b/cpp/include/cuml/tsa/arima_common.h index 1f0358f7a6..2ed9da31e2 100644 --- a/cpp/include/cuml/tsa/arima_common.h +++ b/cpp/include/cuml/tsa/arima_common.h @@ -201,8 +201,8 @@ struct ARIMAMemory { *Tparams_ar, *Tparams_ma, *Tparams_sar, *Tparams_sma, *Tparams_sigma2, *d_params, *d_Tparams, *Z_dense, *R_dense, *T_dense, *RQR_dense, *RQ_dense, *P_dense, *alpha_dense, *ImT_dense, *ImT_inv_dense, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *pred, *y_diff, *loglike, - *loglike_base, *loglike_pert, *x_pert, *sigma2_buffer, *I_m_AxA_dense, *I_m_AxA_inv_dense, - *Ts_dense, *RQRs_dense, *Ps_dense; + *loglike_base, *loglike_pert, *x_pert, *I_m_AxA_dense, *I_m_AxA_inv_dense, *Ts_dense, + *RQRs_dense, *Ps_dense; T **Z_batches, **R_batches, **T_batches, **RQR_batches, **RQ_batches, **P_batches, **alpha_batches, **ImT_batches, **ImT_inv_batches, **v_tmp_batches, **m_tmp_batches, **K_batches, **TP_batches, **I_m_AxA_batches, **I_m_AxA_inv_batches, **Ts_batches, @@ -279,7 +279,6 @@ struct ARIMAMemory { append_buffer(K_batches, batch_size); append_buffer(TP_dense, rd * rd * batch_size); append_buffer(TP_batches, batch_size); - append_buffer(sigma2_buffer, batch_size); append_buffer(pred, n_obs * batch_size); append_buffer(y_diff, n_obs * batch_size); diff --git a/cpp/src/arima/batched_kalman.cu b/cpp/src/arima/batched_kalman.cu index 2b043dc3a4..7eb0cd2efe 100644 --- a/cpp/src/arima/batched_kalman.cu +++ b/cpp/src/arima/batched_kalman.cu @@ -96,7 +96,6 @@ DI void MM_l(const double* A, const double* B, double* out) * @param[in] batch_size Batch size * @param[out] d_pred Predictions (nobs) * @param[out] d_loglike Log-likelihood (1) - * @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1) * @param[in] n_diff d + s*D * @param[in] fc_steps Number of steps to forecast * @param[out] d_fc Array to store the forecast @@ -116,7 +115,6 @@ __global__ void batched_kalman_loop_kernel(const double* ys, int batch_size, double* d_pred, double* d_loglike, - double* d_ll_sigma2, int n_diff, int fc_steps = 0, double* d_fc = nullptr, @@ -257,7 +255,6 @@ __global__ void batched_kalman_loop_kernel(const double* ys, { double n_obs_ll_f = static_cast(n_obs_ll); b_ll_s2 /= n_obs_ll_f; - if (conf_int) d_ll_sigma2[bid] = b_ll_s2; d_loglike[bid] = -.5 * (b_sum_logFs + n_obs_ll_f * (b_ll_s2 + log(2 * M_PI))); } @@ -342,7 +339,6 @@ union KalmanLoopSharedMemory { * @param[in] rd State vector dimension * @param[out] d_pred Predictions (nobs) * @param[out] d_loglike Log-likelihood (1) - * @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1) * @param[in] n_diff d + s*D * @param[in] fc_steps Number of steps to forecast * @param[out] d_fc Array to store the forecast @@ -365,7 +361,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, int rd, double* d_pred, double* d_loglike, - double* d_ll_sigma2, int n_diff, int fc_steps, double* d_fc, @@ -603,7 +598,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, if (threadIdx.x == 0) { double n_obs_ll_f = static_cast(n_obs_ll); ll_s2 /= n_obs_ll_f; - if (conf_int) d_ll_sigma2[bid] = ll_s2; d_loglike[bid] = -.5 * (sum_logFs + n_obs_ll_f * (ll_s2 + log(2 * M_PI))); } } @@ -625,7 +619,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, * @param[in] rd Dimension of the state vector * @param[out] d_pred Predictions (nobs) * @param[out] d_loglike Log-likelihood (1) - * @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1) * @param[in] n_diff d + s*D * @param[in] fc_steps Number of steps to forecast * @param[out] d_fc Array to store the forecast @@ -646,7 +639,6 @@ void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, int rd, double* d_pred, double* d_loglike, - double* d_ll_sigma2, int n_diff, int fc_steps = 0, double* d_fc = nullptr, @@ -690,7 +682,6 @@ void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -713,7 +704,6 @@ void batched_kalman_loop(raft::handle_t& handle, const ARIMAOrder& order, double* d_pred, double* d_loglike, - double* d_ll_sigma2, int fc_steps = 0, double* d_fc = nullptr, bool conf_int = false, @@ -741,7 +731,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -762,7 +751,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -783,7 +771,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -804,7 +791,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -825,7 +811,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -846,7 +831,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -867,7 +851,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -888,7 +871,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -917,7 +899,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -939,7 +920,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -963,7 +943,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -985,7 +964,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -1008,7 +986,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -1030,7 +1007,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -1046,25 +1022,21 @@ void batched_kalman_loop(raft::handle_t& handle, * @note: One block per batch member, one thread per forecast time step * * @param[in] d_fc Mean forecasts - * @param[in] d_sigma2 sum(v_t * v_t / F_t) / n_obs_diff * @param[inout] d_lower Input: F_{n+t} * Output: lower bound of the confidence intervals * @param[out] d_upper Upper bound of the confidence intervals - * @param[in] fc_steps Number of forecast steps + * @param[in] n_elem Total number of elements (fc_steps * batch_size) * @param[in] multiplier Coefficient associated with the confidence level */ -__global__ void confidence_intervals(const double* d_fc, - const double* d_sigma2, - double* d_lower, - double* d_upper, - int fc_steps, - double multiplier) +__global__ void confidence_intervals( + const double* d_fc, double* d_lower, double* d_upper, int n_elem, double multiplier) { - int idx = blockIdx.x * fc_steps + threadIdx.x; - double fc = d_fc[idx]; - double margin = multiplier * sqrt(d_lower[idx] * d_sigma2[blockIdx.x]); - d_lower[idx] = fc - margin; - d_upper[idx] = fc + margin; + for (int idx = threadIdx.x; idx < n_elem; idx += blockDim.x * gridDim.x) { + double fc = d_fc[idx]; + double margin = multiplier * sqrt(d_lower[idx]); + d_lower[idx] = fc - margin; + d_upper[idx] = fc + margin; + } } void _lyapunov_wrapper(raft::handle_t& handle, @@ -1287,15 +1259,16 @@ void _batched_kalman_filter(raft::handle_t& handle, order, d_pred, d_loglike, - arima_mem.sigma2_buffer, fc_steps, d_fc, level > 0, d_lower); if (level > 0) { - confidence_intervals<<>>( - d_fc, arima_mem.sigma2_buffer, d_lower, d_upper, fc_steps, sqrt(2.0) * erfinv(level)); + constexpr int TPB_conf = 256; + int n_blocks = raft::ceildiv(fc_steps * batch_size, TPB_conf); + confidence_intervals<<>>( + d_fc, d_lower, d_upper, fc_steps * batch_size, sqrt(2.0) * erfinv(level)); CUDA_CHECK(cudaPeekAtLastError()); } } diff --git a/python/cuml/test/test_arima.py b/python/cuml/test/test_arima.py index d4823624a1..1a704abbc9 100644 --- a/python/cuml/test/test_arima.py +++ b/python/cuml/test/test_arima.py @@ -420,9 +420,9 @@ def _predict_common(key, data, dtype, start, end, num_steps=None, level=None, np.testing.assert_allclose(cuml_pred, ref_preds, rtol=0.001, atol=0.01) if level is not None: np.testing.assert_allclose( - cuml_lower, ref_lower, rtol=0.03, atol=0.01) + cuml_lower, ref_lower, rtol=0.005, atol=0.01) np.testing.assert_allclose( - cuml_upper, ref_upper, rtol=0.03, atol=0.01) + cuml_upper, ref_upper, rtol=0.005, atol=0.01) @pytest.mark.parametrize('key, data', test_data)