From 94be76f60d9e763c81cbb2a9185a02046e81a3cf Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 1 Jun 2021 16:37:27 +0200 Subject: [PATCH] ARIMA: pre-allocation of temporary memory to reduce latencies (#3895) This PR can speed up the evaluation of the log-likelihood in ARIMA by 5x for non-seasonal datasets (the impact is smaller for seasonal datasets). It achieves this by pre-allocating all the temporary memory only once instead of every iteration and providing all the pointers with a very low overhead thanks to a dedicated structure. Additionally, I removed some unnecessary copies. ![arima_memory](https://user-images.githubusercontent.com/17441062/119530801-a44ff100-bd83-11eb-9278-3f9071521553.png) Regarding the unnecessary synchronizations, I'll fix that later in a separate PR. Note that non-seasonal ARIMA performance is now even more limited by the python-side solver bottleneck: ![optimizer_bottleneck](https://user-images.githubusercontent.com/17441062/119531952-b8e0b900-bd84-11eb-88cc-b58497b283fc.png) One problem is that batched matrix operations are quite memory-hungry so I've duplicated or refactored some bits to avoid allocating extra memory there, but that leads to some duplication that I'm not entirely happy with. Both the ARIMA code and batched matrix prims are due some refactoring. Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/3895 --- cpp/bench/sg/arima_loglikelihood.cu | 15 +- cpp/include/cuml/tsa/arima_common.h | 160 ++++++++++++++++- cpp/include/cuml/tsa/batched_arima.hpp | 52 +++--- cpp/include/cuml/tsa/batched_kalman.hpp | 21 ++- cpp/src/arima/batched_arima.cu | 132 +++++++------- cpp/src/arima/batched_kalman.cu | 223 ++++++++++++++--------- cpp/src_prims/linalg/batched/matrix.cuh | 225 ++++++++++++++++++------ cpp/src_prims/sparse/batched/csr.cuh | 107 ++++++++--- python/cuml/tsa/arima.pyx | 130 ++++++++++---- 9 files changed, 759 insertions(+), 306 deletions(-) diff --git a/cpp/bench/sg/arima_loglikelihood.cu b/cpp/bench/sg/arima_loglikelihood.cu index a54b84f845..0c1060440b 100644 --- a/cpp/bench/sg/arima_loglikelihood.cu +++ b/cpp/bench/sg/arima_loglikelihood.cu @@ -63,10 +63,13 @@ class ArimaLoglikelihood : public TsFixtureRandom { // Benchmark loop this->loopOnState(state, [this]() { + ARIMAMemory arima_mem(order, this->params.batch_size, + this->params.n_obs, temp_mem); + // Evaluate log-likelihood - batched_loglike(*this->handle, this->data.X, this->params.batch_size, - this->params.n_obs, order, param, loglike, residual, true, - false); + batched_loglike(*this->handle, arima_mem, this->data.X, + this->params.batch_size, this->params.n_obs, order, param, + loglike, residual, true, false); }); } @@ -86,6 +89,11 @@ class ArimaLoglikelihood : public TsFixtureRandom { this->params.batch_size * sizeof(DataT), stream); residual = (DataT*)allocator->allocate( this->params.batch_size * this->params.n_obs * sizeof(DataT), stream); + + // Temporary memory + size_t temp_buf_size = ARIMAMemory::compute_size( + order, this->params.batch_size, this->params.n_obs); + temp_mem = (char*)allocator->allocate(temp_buf_size, stream); } void deallocateBuffers(const ::benchmark::State& state) { @@ -110,6 +118,7 @@ class ArimaLoglikelihood : public TsFixtureRandom { DataT* param; DataT* loglike; DataT* residual; + char* temp_mem; }; std::vector getInputs() { diff --git a/cpp/include/cuml/tsa/arima_common.h b/cpp/include/cuml/tsa/arima_common.h index e2220727e8..6a39a5f9d0 100644 --- a/cpp/include/cuml/tsa/arima_common.h +++ b/cpp/include/cuml/tsa/arima_common.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -207,4 +207,162 @@ struct ARIMAParams { } }; +/** + * Structure to manage ARIMA temporary memory allocations + * @note The user is expected to give a preallocated buffer to the constructor, + * and ownership is not transferred to this struct! The buffer must be allocated + * as long as the object lives, and deallocated afterwards. + */ +template +struct ARIMAMemory { + T *params_mu, *params_ar, *params_ma, *params_sar, *params_sma, + *params_sigma2, *Tparams_mu, *Tparams_ar, *Tparams_ma, *Tparams_sar, + *Tparams_sma, *Tparams_sigma2, *d_params, *d_Tparams, *Z_dense, *R_dense, + *T_dense, *RQR_dense, *RQ_dense, *P_dense, *alpha_dense, *ImT_dense, + *ImT_inv_dense, *T_values, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, + *vs, *y_diff, *loglike, *loglike_base, *loglike_pert, *x_pert, *F_buffer, + *sumLogF_buffer, *sigma2_buffer, *I_m_AxA_dense, *I_m_AxA_inv_dense, + *Ts_dense, *RQRs_dense, *Ps_dense; + T **Z_batches, **R_batches, **T_batches, **RQR_batches, **RQ_batches, + **P_batches, **alpha_batches, **ImT_batches, **ImT_inv_batches, + **v_tmp_batches, **m_tmp_batches, **K_batches, **TP_batches, + **I_m_AxA_batches, **I_m_AxA_inv_batches, **Ts_batches, **RQRs_batches, + **Ps_batches; + int *T_col_index, *T_row_index, *ImT_inv_P, *ImT_inv_info, *I_m_AxA_P, + *I_m_AxA_info; + + size_t size; + + protected: + char* buf; + + template + inline void append_buffer(ValType*& ptr, size_t n_elem) { + if (assign) { + ptr = reinterpret_cast(buf + size); + } + size += ((n_elem * sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN; + } + + template + inline void buf_offsets(const ARIMAOrder& order, int batch_size, int n_obs, + char* in_buf = nullptr) { + buf = in_buf; + size = 0; + + int r = order.r(); + int rd = order.rd(); + int N = order.complexity(); + int n_diff = order.n_diff(); + + append_buffer(params_mu, order.k * batch_size); + append_buffer(params_ar, order.p * batch_size); + append_buffer(params_ma, order.q * batch_size); + append_buffer(params_sar, order.P * batch_size); + append_buffer(params_sma, order.Q * batch_size); + append_buffer(params_sigma2, batch_size); + + append_buffer(Tparams_mu, order.k * batch_size); + append_buffer(Tparams_ar, order.p * batch_size); + append_buffer(Tparams_ma, order.q * batch_size); + append_buffer(Tparams_sar, order.P * batch_size); + append_buffer(Tparams_sma, order.Q * batch_size); + append_buffer(Tparams_sigma2, batch_size); + + append_buffer(d_params, N * batch_size); + append_buffer(d_Tparams, N * batch_size); + append_buffer(Z_dense, rd * batch_size); + append_buffer(Z_batches, batch_size); + append_buffer(R_dense, rd * batch_size); + append_buffer(R_batches, batch_size); + append_buffer(T_dense, rd * rd * batch_size); + append_buffer(T_batches, batch_size); + append_buffer(RQ_dense, rd * batch_size); + append_buffer(RQ_batches, batch_size); + append_buffer(RQR_dense, rd * rd * batch_size); + append_buffer(RQR_batches, batch_size); + append_buffer(P_dense, rd * rd * batch_size); + append_buffer(P_batches, batch_size); + append_buffer(alpha_dense, rd * batch_size); + append_buffer(alpha_batches, batch_size); + append_buffer(ImT_dense, r * r * batch_size); + append_buffer(ImT_batches, batch_size); + append_buffer(ImT_inv_dense, r * r * batch_size); + append_buffer(ImT_inv_batches, batch_size); + append_buffer(ImT_inv_P, r * batch_size); + append_buffer(ImT_inv_info, batch_size); + append_buffer(T_values, rd * rd * batch_size); + append_buffer(T_col_index, rd * rd); + append_buffer(T_row_index, rd + 1); + append_buffer(v_tmp_dense, rd * batch_size); + append_buffer(v_tmp_batches, batch_size); + append_buffer(m_tmp_dense, rd * rd * batch_size); + append_buffer(m_tmp_batches, batch_size); + append_buffer(K_dense, rd * batch_size); + append_buffer(K_batches, batch_size); + append_buffer(TP_dense, rd * rd * batch_size); + append_buffer(TP_batches, batch_size); + append_buffer(F_buffer, n_obs * batch_size); + append_buffer(sumLogF_buffer, batch_size); + append_buffer(sigma2_buffer, batch_size); + + append_buffer(vs, n_obs * batch_size); + append_buffer(y_diff, n_obs * batch_size); + append_buffer(loglike, batch_size); + append_buffer(loglike_base, batch_size); + append_buffer(loglike_pert, batch_size); + append_buffer(x_pert, N * batch_size); + + if (n_diff > 0) { + append_buffer(Ts_dense, r * r * batch_size); + append_buffer(Ts_batches, batch_size); + append_buffer(RQRs_dense, r * r * batch_size); + append_buffer(RQRs_batches, batch_size); + append_buffer(Ps_dense, r * r * batch_size); + append_buffer(Ps_batches, batch_size); + } + + if (r <= 5) { + // Note: temp mem for the direct Lyapunov solver grows very quickly! + // This solver is used iff the condition above is satisifed + append_buffer(I_m_AxA_dense, r * r * r * r * batch_size); + append_buffer(I_m_AxA_batches, batch_size); + append_buffer(I_m_AxA_inv_dense, r * r * r * r * batch_size); + append_buffer(I_m_AxA_inv_batches, batch_size); + append_buffer(I_m_AxA_P, r * r * batch_size); + append_buffer(I_m_AxA_info, batch_size); + } + } + + /** Protected constructor to estimate max size */ + ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs) { + buf_offsets(order, batch_size, n_obs); + } + + public: + /** Constructor to create pointers from buffer + * @param[in] order ARIMA order + * @param[in] batch_size Number of series in the batch + * @param[in] n_obs Length of the series + * @param[in] in_buf Pointer to the temporary memory buffer. + * Ownership is retained by the caller + */ + ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs, + char* in_buf) { + buf_offsets(order, batch_size, n_obs, in_buf); + } + + /** Static method to get the size of the required buffer allocation + * @param[in] order ARIMA order + * @param[in] batch_size Number of series in the batch + * @param[in] n_obs Length of the series + * @return Buffer size in bytes + */ + static size_t compute_size(const ARIMAOrder& order, int batch_size, + int n_obs) { + ARIMAMemory temp(order, batch_size, n_obs); + return temp.size; + } +}; + } // namespace ML diff --git a/cpp/include/cuml/tsa/batched_arima.hpp b/cpp/include/cuml/tsa/batched_arima.hpp index 8b50560bb3..f11f4ab99d 100644 --- a/cpp/include/cuml/tsa/batched_arima.hpp +++ b/cpp/include/cuml/tsa/batched_arima.hpp @@ -68,6 +68,7 @@ void batched_diff(raft::handle_t& handle, double* d_y_diff, const double* d_y, * in a batched context. * * @param[in] handle cuML handle + * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) * @param[in] batch_size Number of time series @@ -91,13 +92,14 @@ void batched_diff(raft::handle_t& handle, double* d_y_diff, const double* d_y, * @param[out] d_lower Lower limit of the prediction interval * @param[out] d_upper Upper limit of the prediction interval */ -void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, - int n_obs, const ARIMAOrder& order, const double* d_params, - double* loglike, double* d_vs, bool trans = true, - bool host_loglike = true, LoglikeMethod method = MLE, - int truncate = 0, int fc_steps = 0, double* d_fc = nullptr, - double level = 0, double* d_lower = nullptr, - double* d_upper = nullptr); +void batched_loglike(raft::handle_t& handle, + const ARIMAMemory& arima_mem, const double* d_y, + int batch_size, int n_obs, const ARIMAOrder& order, + const double* d_params, double* loglike, double* d_vs, + bool trans = true, bool host_loglike = true, + LoglikeMethod method = MLE, int truncate = 0, + int fc_steps = 0, double* d_fc = nullptr, double level = 0, + double* d_lower = nullptr, double* d_upper = nullptr); /** * Compute the loglikelihood of the given parameter on the given time series @@ -107,6 +109,7 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, * to avoid useless packing / unpacking * * @param[in] handle cuML handle + * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) * @param[in] batch_size Number of time series @@ -129,8 +132,9 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, * @param[out] d_lower Lower limit of the prediction interval * @param[out] d_upper Upper limit of the prediction interval */ -void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, - int n_obs, const ARIMAOrder& order, +void batched_loglike(raft::handle_t& handle, + const ARIMAMemory& arima_mem, const double* d_y, + int batch_size, int n_obs, const ARIMAOrder& order, const ARIMAParams& params, double* loglike, double* d_vs, bool trans = true, bool host_loglike = true, LoglikeMethod method = MLE, int truncate = 0, @@ -141,6 +145,7 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, * Compute the gradient of the log-likelihood * * @param[in] handle cuML handle + * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) * @param[in] batch_size Number of time series @@ -154,17 +159,19 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, * @param[in] truncate For CSS, start the sum-of-squares after a given * number of observations */ -void batched_loglike_grad(raft::handle_t& handle, const double* d_y, - int batch_size, int n_obs, const ARIMAOrder& order, - const double* d_x, double* d_grad, double h, - bool trans = true, LoglikeMethod method = MLE, - int truncate = 0); +void batched_loglike_grad(raft::handle_t& handle, + const ARIMAMemory& arima_mem, + const double* d_y, int batch_size, int n_obs, + const ARIMAOrder& order, const double* d_x, + double* d_grad, double h, bool trans = true, + LoglikeMethod method = MLE, int truncate = 0); /** * Batched in-sample and out-of-sample prediction of a time-series given all * the model parameters * * @param[in] handle cuML handle + * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Batched Time series to predict. * Shape: (num_samples, batch size) (device) * @param[in] batch_size Total number of batched time series @@ -181,16 +188,17 @@ void batched_loglike_grad(raft::handle_t& handle, const double* d_y, * @param[out] d_lower Lower limit of the prediction interval * @param[out] d_upper Upper limit of the prediction interval */ -void predict(raft::handle_t& handle, const double* d_y, int batch_size, - int n_obs, int start, int end, const ARIMAOrder& order, - const ARIMAParams& params, double* d_y_p, - bool pre_diff = true, double level = 0, double* d_lower = nullptr, - double* d_upper = nullptr); +void predict(raft::handle_t& handle, const ARIMAMemory& arima_mem, + const double* d_y, int batch_size, int n_obs, int start, int end, + const ARIMAOrder& order, const ARIMAParams& params, + double* d_y_p, bool pre_diff = true, double level = 0, + double* d_lower = nullptr, double* d_upper = nullptr); /** * Compute an information criterion (AIC, AICc, BIC) * * @param[in] handle cuML handle + * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) * @param[in] batch_size Total number of batched time series @@ -203,8 +211,10 @@ void predict(raft::handle_t& handle, const double* d_y, int batch_size, * @param[in] ic_type Type of information criterion wanted. * 0: AIC, 1: AICc, 2: BIC */ -void information_criterion(raft::handle_t& handle, const double* d_y, - int batch_size, int n_obs, const ARIMAOrder& order, +void information_criterion(raft::handle_t& handle, + const ARIMAMemory& arima_mem, + const double* d_y, int batch_size, int n_obs, + const ARIMAOrder& order, const ARIMAParams& params, double* ic, int ic_type); diff --git a/cpp/include/cuml/tsa/batched_kalman.hpp b/cpp/include/cuml/tsa/batched_kalman.hpp index 06b24cc7e7..0f48b6f1fc 100644 --- a/cpp/include/cuml/tsa/batched_kalman.hpp +++ b/cpp/include/cuml/tsa/batched_kalman.hpp @@ -29,6 +29,7 @@ namespace ML { * provide the resulting prediction as well as loglikelihood fit. * * @param[in] handle cuML handle + * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_ys_b Batched time series * Shape (nobs, batch_size) (col-major, device) * @param[in] nobs Number of samples per time series @@ -46,13 +47,12 @@ namespace ML { * @param[out] d_lower Lower limit of the prediction interval * @param[out] d_upper Upper limit of the prediction interval */ -void batched_kalman_filter(raft::handle_t& handle, const double* d_ys_b, - int nobs, const ARIMAParams& params, - const ARIMAOrder& order, int batch_size, - double* d_loglike, double* d_vs, int fc_steps = 0, - double* d_fc = nullptr, double level = 0, - double* d_lower = nullptr, - double* d_upper = nullptr); +void batched_kalman_filter( + raft::handle_t& handle, const ARIMAMemory& arima_mem, + const double* d_ys_b, int nobs, const ARIMAParams& params, + const ARIMAOrder& order, int batch_size, double* d_loglike, double* d_vs, + int fc_steps = 0, double* d_fc = nullptr, double level = 0, + double* d_lower = nullptr, double* d_upper = nullptr); /** * Convenience function for batched "jones transform" used in ARIMA to ensure @@ -60,6 +60,7 @@ void batched_kalman_filter(raft::handle_t& handle, const double* d_ys_b, * returns host array) * * @param[in] handle cuML handle + * @param[in] arima_mem Pre-allocated temporary memory * @param[in] order ARIMA hyper-parameters * @param[in] batch_size Number of time series analyzed. * @param[in] isInv Do the inverse transform? @@ -68,7 +69,9 @@ void batched_kalman_filter(raft::handle_t& handle, const double* d_ys_b, * (expects pre-allocated array of size * (p+q)*batch_size) (host) */ -void batched_jones_transform(raft::handle_t& handle, const ARIMAOrder& order, - int batch_size, bool isInv, const double* h_params, +void batched_jones_transform(raft::handle_t& handle, + const ARIMAMemory& arima_mem, + const ARIMAOrder& order, int batch_size, + bool isInv, const double* h_params, double* h_Tparams); } // namespace ML diff --git a/cpp/src/arima/batched_arima.cu b/cpp/src/arima/batched_arima.cu index b52821a102..3a24b2deb3 100644 --- a/cpp/src/arima/batched_arima.cu +++ b/cpp/src/arima/batched_arima.cu @@ -58,10 +58,11 @@ void batched_diff(raft::handle_t& handle, double* d_y_diff, const double* d_y, order.D, order.s, stream); } -void predict(raft::handle_t& handle, const double* d_y, int batch_size, - int n_obs, int start, int end, const ARIMAOrder& order, - const ARIMAParams& params, double* d_y_p, bool pre_diff, - double level, double* d_lower, double* d_upper) { +void predict(raft::handle_t& handle, const ARIMAMemory& arima_mem, + const double* d_y, int batch_size, int n_obs, int start, int end, + const ARIMAOrder& order, const ARIMAParams& params, + double* d_y_p, bool pre_diff, double level, double* d_lower, + double* d_upper) { ML::PUSH_RANGE(__func__); auto allocator = handle.get_device_allocator(); const auto stream = handle.get_stream(); @@ -71,26 +72,21 @@ void predict(raft::handle_t& handle, const double* d_y, int batch_size, // Prepare data int n_obs_kf; const double* d_y_kf; - MLCommon::device_buffer diff_buffer(allocator, stream); ARIMAOrder order_after_prep = order; if (diff) { n_obs_kf = n_obs - order.n_diff(); - diff_buffer.resize(n_obs_kf * batch_size, stream); - MLCommon::TimeSeries::prepare_data(diff_buffer.data(), d_y, batch_size, - n_obs, order.d, order.D, order.s, - stream); - d_y_kf = diff_buffer.data(); + MLCommon::TimeSeries::prepare_data(arima_mem.y_diff, d_y, batch_size, n_obs, + order.d, order.D, order.s, stream); order_after_prep.d = 0; order_after_prep.D = 0; + + d_y_kf = arima_mem.y_diff; } else { n_obs_kf = n_obs; d_y_kf = d_y; } - // Create temporary array for the residuals - MLCommon::device_buffer v_buffer(allocator, stream, - n_obs_kf * batch_size); - double* d_vs = v_buffer.data(); + double* d_vs = arima_mem.vs; // Create temporary array for the forecasts int num_steps = std::max(end - n_obs, 0); @@ -101,9 +97,9 @@ void predict(raft::handle_t& handle, const double* d_y, int batch_size, // Compute the residual and forecast std::vector loglike = std::vector(batch_size); /// TODO: use device loglike to avoid useless copy ; part of #2233 - batched_loglike(handle, d_y_kf, batch_size, n_obs_kf, order_after_prep, - params, loglike.data(), d_vs, false, true, MLE, 0, num_steps, - d_y_fc, level, d_lower, d_upper); + batched_loglike(handle, arima_mem, d_y_kf, batch_size, n_obs_kf, + order_after_prep, params, loglike.data(), d_vs, false, true, + MLE, 0, num_steps, d_y_fc, level, d_lower, d_upper); auto counting = thrust::make_counting_iterator(0); int predict_ld = end - start; @@ -278,8 +274,9 @@ void conditional_sum_of_squares(raft::handle_t& handle, const double* d_y, ML::POP_RANGE(); } -void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, - int n_obs, const ARIMAOrder& order, +void batched_loglike(raft::handle_t& handle, + const ARIMAMemory& arima_mem, const double* d_y, + int batch_size, int n_obs, const ARIMAOrder& order, const ARIMAParams& params, double* loglike, double* d_vs, bool trans, bool host_loglike, LoglikeMethod method, int truncate, int fc_steps, @@ -289,24 +286,18 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, auto allocator = handle.get_device_allocator(); auto stream = handle.get_stream(); - ARIMAParams Tparams; + + ARIMAParams Tparams = { + arima_mem.Tparams_mu, arima_mem.Tparams_ar, arima_mem.Tparams_ma, + arima_mem.Tparams_sar, arima_mem.Tparams_sma, arima_mem.Tparams_sigma2}; ASSERT(method == MLE || fc_steps == 0, "Only MLE method is valid for forecasting"); /* Create log-likelihood device array if host pointer is provided */ - double* d_loglike; - MLCommon::device_buffer loglike_buffer(allocator, stream); - if (host_loglike) { - loglike_buffer.resize(batch_size, stream); - d_loglike = loglike_buffer.data(); - } else { - d_loglike = loglike; - } + double* d_loglike = host_loglike ? arima_mem.loglike : loglike; if (trans) { - Tparams.allocate(order, batch_size, allocator, stream, true); - MLCommon::TimeSeries::batched_jones_transform( order, batch_size, false, params, Tparams, allocator, stream); @@ -320,50 +311,50 @@ void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, conditional_sum_of_squares(handle, d_y, batch_size, n_obs, order, Tparams, d_loglike, truncate); } else { - batched_kalman_filter(handle, d_y, n_obs, Tparams, order, batch_size, - d_loglike, d_vs, fc_steps, d_fc, level, d_lower, - d_upper); + batched_kalman_filter(handle, arima_mem, d_y, n_obs, Tparams, order, + batch_size, d_loglike, d_vs, fc_steps, d_fc, level, + d_lower, d_upper); } if (host_loglike) { /* Tranfer log-likelihood device -> host */ raft::update_host(loglike, d_loglike, batch_size, stream); } - - if (trans) { - Tparams.deallocate(order, batch_size, allocator, stream, true); - } ML::POP_RANGE(); } -void batched_loglike(raft::handle_t& handle, const double* d_y, int batch_size, - int n_obs, const ARIMAOrder& order, const double* d_params, - double* loglike, double* d_vs, bool trans, - bool host_loglike, LoglikeMethod method, int truncate, - int fc_steps, double* d_fc, double level, double* d_lower, - double* d_upper) { +void batched_loglike(raft::handle_t& handle, + const ARIMAMemory& arima_mem, const double* d_y, + int batch_size, int n_obs, const ARIMAOrder& order, + const double* d_params, double* loglike, double* d_vs, + bool trans, bool host_loglike, LoglikeMethod method, + int truncate, int fc_steps, double* d_fc, double level, + double* d_lower, double* d_upper) { ML::PUSH_RANGE(__func__); // unpack parameters auto allocator = handle.get_device_allocator(); auto stream = handle.get_stream(); - ARIMAParams params; - params.allocate(order, batch_size, allocator, stream, false); - params.unpack(order, batch_size, d_params, stream); - batched_loglike(handle, d_y, batch_size, n_obs, order, params, loglike, d_vs, - trans, host_loglike, method, truncate, fc_steps, d_fc, level, - d_lower, d_upper); + ARIMAParams params = {arima_mem.params_mu, arima_mem.params_ar, + arima_mem.params_ma, arima_mem.params_sar, + arima_mem.params_sma, arima_mem.params_sigma2}; + + params.unpack(order, batch_size, d_params, stream); - params.deallocate(order, batch_size, allocator, stream, false); + batched_loglike(handle, arima_mem, d_y, batch_size, n_obs, order, params, + loglike, d_vs, trans, host_loglike, method, truncate, + fc_steps, d_fc, level, d_lower, d_upper); ML::POP_RANGE(); } -void batched_loglike_grad(raft::handle_t& handle, const double* d_y, - int batch_size, int n_obs, const ARIMAOrder& order, - const double* d_x, double* d_grad, double h, - bool trans, LoglikeMethod method, int truncate) { +void batched_loglike_grad(raft::handle_t& handle, + const ARIMAMemory& arima_mem, + const double* d_y, int batch_size, int n_obs, + const ARIMAOrder& order, const double* d_x, + double* d_grad, double h, bool trans, + LoglikeMethod method, int truncate) { ML::PUSH_RANGE(__func__); auto allocator = handle.get_device_allocator(); auto stream = handle.get_stream(); @@ -371,20 +362,16 @@ void batched_loglike_grad(raft::handle_t& handle, const double* d_y, int N = order.complexity(); // Initialize the perturbed x vector - MLCommon::device_buffer x_pert(allocator, stream, N * batch_size); - double* d_x_pert = x_pert.data(); + double* d_x_pert = arima_mem.x_pert; raft::copy(d_x_pert, d_x, N * batch_size, stream); - // Create buffers for the log-likelihood and residuals - MLCommon::device_buffer ll_base(allocator, stream, batch_size); - MLCommon::device_buffer ll_pert(allocator, stream, batch_size); - MLCommon::device_buffer res(allocator, stream, n_obs * batch_size); - double* d_ll_base = ll_base.data(); - double* d_ll_pert = ll_pert.data(); + double* d_vs = arima_mem.vs; + double* d_ll_base = arima_mem.loglike_base; + double* d_ll_pert = arima_mem.loglike_pert; // Evaluate the log-likelihood with the given parameter vector - batched_loglike(handle, d_y, batch_size, n_obs, order, d_x, d_ll_base, - res.data(), trans, false, method, truncate); + batched_loglike(handle, arima_mem, d_y, batch_size, n_obs, order, d_x, + d_ll_base, d_vs, trans, false, method, truncate); for (int i = 0; i < N; i++) { // Add the perturbation to the i-th parameter @@ -394,8 +381,8 @@ void batched_loglike_grad(raft::handle_t& handle, const double* d_y, }); // Evaluate the log-likelihood with the positive perturbation - batched_loglike(handle, d_y, batch_size, n_obs, order, d_x_pert, d_ll_pert, - res.data(), trans, false, method, truncate); + batched_loglike(handle, arima_mem, d_y, batch_size, n_obs, order, d_x_pert, + d_ll_pert, d_vs, trans, false, method, truncate); // First derivative with a first-order accuracy thrust::for_each(thrust::cuda::par.on(stream), counting, @@ -413,20 +400,21 @@ void batched_loglike_grad(raft::handle_t& handle, const double* d_y, ML::POP_RANGE(); } -void information_criterion(raft::handle_t& handle, const double* d_y, - int batch_size, int n_obs, const ARIMAOrder& order, +void information_criterion(raft::handle_t& handle, + const ARIMAMemory& arima_mem, + const double* d_y, int batch_size, int n_obs, + const ARIMAOrder& order, const ARIMAParams& params, double* d_ic, int ic_type) { ML::PUSH_RANGE(__func__); auto allocator = handle.get_device_allocator(); auto stream = handle.get_stream(); - MLCommon::device_buffer v_buffer(allocator, stream, - n_obs * batch_size); + double* d_vs = arima_mem.vs; /* Compute log-likelihood in d_ic */ - batched_loglike(handle, d_y, batch_size, n_obs, order, params, d_ic, - v_buffer.data(), false, false, MLE); + batched_loglike(handle, arima_mem, d_y, batch_size, n_obs, order, params, + d_ic, d_vs, false, false, MLE); /* Compute information criterion from log-likelihood and base term */ MLCommon::Metrics::Batched::information_criterion( diff --git a/cpp/src/arima/batched_kalman.cu b/cpp/src/arima/batched_kalman.cu index ba729c1d04..419cfcdc24 100644 --- a/cpp/src/arima/batched_kalman.cu +++ b/cpp/src/arima/batched_kalman.cu @@ -21,7 +21,6 @@ #include #include -#include #include #include @@ -270,6 +269,7 @@ __global__ void batched_kalman_loop_kernel( /** * Kalman loop for large matrices (r > 8). * + * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_ys Batched time series * @param[in] nobs Number of observation per series * @param[in] T Batched transition matrix. (r x r) @@ -291,7 +291,7 @@ __global__ void batched_kalman_loop_kernel( * @param[out] d_F_fc Batched variance of forecast errors (fc_steps) */ void _batched_kalman_loop_large( - const double* d_ys, int nobs, + const ARIMAMemory& arima_mem, const double* d_ys, int nobs, const MLCommon::LinAlg::Batched::Matrix& T, const MLCommon::Sparse::Batched::CSR& T_sparse, const MLCommon::LinAlg::Batched::Matrix& Z, @@ -309,14 +309,18 @@ void _batched_kalman_loop_large( auto counting = thrust::make_counting_iterator(0); // Temporary matrices and vectors - MLCommon::LinAlg::Batched::Matrix v_tmp(rd, 1, nb, cublasHandle, - allocator, stream, false); - MLCommon::LinAlg::Batched::Matrix m_tmp(rd, rd, nb, cublasHandle, - allocator, stream, false); - MLCommon::LinAlg::Batched::Matrix K(rd, 1, nb, cublasHandle, - allocator, stream, false); - MLCommon::LinAlg::Batched::Matrix TP(rd, rd, nb, cublasHandle, - allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix v_tmp( + rd, 1, nb, cublasHandle, arima_mem.v_tmp_batches, arima_mem.v_tmp_dense, + allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix m_tmp( + rd, rd, nb, cublasHandle, arima_mem.m_tmp_batches, arima_mem.m_tmp_dense, + allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix K( + rd, 1, nb, cublasHandle, arima_mem.K_batches, arima_mem.K_dense, allocator, + stream, false); + MLCommon::LinAlg::Batched::Matrix TP( + rd, rd, nb, cublasHandle, arima_mem.TP_batches, arima_mem.TP_dense, + allocator, stream, false); // Shortcuts const double* d_Z = Z.raw_data(); @@ -516,7 +520,9 @@ void _batched_kalman_loop_large( } /// Wrapper around functions that execute the Kalman loop (for performance) -void batched_kalman_loop(raft::handle_t& handle, const double* ys, int nobs, +void batched_kalman_loop(raft::handle_t& handle, + const ARIMAMemory& arima_mem, const double* ys, + int nobs, const MLCommon::LinAlg::Batched::Matrix& T, const MLCommon::LinAlg::Batched::Matrix& Z, const MLCommon::LinAlg::Batched::Matrix& RQR, @@ -597,10 +603,11 @@ void batched_kalman_loop(raft::handle_t& handle, const double* ys, int nobs, // Note: not always used MLCommon::Sparse::Batched::CSR T_sparse = MLCommon::Sparse::Batched::CSR::from_dense( - T, T_mask, handle.get_cusolver_sp_handle()); - _batched_kalman_loop_large(ys, nobs, T, T_sparse, Z, RQR, P0, alpha, - intercept, d_mu, rd, vs, Fs, sum_logFs, n_diff, - fc_steps, d_fc, conf_int, d_F_fc); + T, T_mask, handle.get_cusolver_sp_handle(), arima_mem.T_values, + arima_mem.T_col_index, arima_mem.T_row_index); + _batched_kalman_loop_large(arima_mem, ys, nobs, T, T_sparse, Z, RQR, P0, + alpha, intercept, d_mu, rd, vs, Fs, sum_logFs, + n_diff, fc_steps, d_fc, conf_int, d_F_fc); } } @@ -659,17 +666,49 @@ __global__ void confidence_intervals(const double* d_fc, const double* d_sigma2, d_upper[idx] = fc + margin; } +void _lyapunov_wrapper(raft::handle_t& handle, + const ARIMAMemory& arima_mem, + const MLCommon::LinAlg::Batched::Matrix& A, + MLCommon::LinAlg::Batched::Matrix& Q, + MLCommon::LinAlg::Batched::Matrix& X, int r) { + if (r <= 5) { + auto stream = handle.get_stream(); + auto cublasHandle = handle.get_cublas_handle(); + auto allocator = handle.get_device_allocator(); + int batch_size = A.batches(); + int r2 = r * r; + + // + // Use direct solution with Kronecker product + // + + MLCommon::LinAlg::Batched::Matrix I_m_AxA( + r2, r2, batch_size, cublasHandle, arima_mem.I_m_AxA_batches, + arima_mem.I_m_AxA_dense, allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix I_m_AxA_inv( + r2, r2, batch_size, cublasHandle, arima_mem.I_m_AxA_inv_batches, + arima_mem.I_m_AxA_inv_dense, allocator, stream, false); + + MLCommon::LinAlg::Batched::_direct_lyapunov_helper( + A, Q, X, I_m_AxA, I_m_AxA_inv, arima_mem.I_m_AxA_P, + arima_mem.I_m_AxA_info, r); + } else { + // Note: the other Lyapunov solver is doing temporary mem allocations, + // but when r > 5, allocation overhead shouldn't be a bottleneck + X = MLCommon::LinAlg::Batched::b_lyapunov(A, Q); + } +} + /// Internal Kalman filter implementation that assumes data exists on GPU. -void _batched_kalman_filter(raft::handle_t& handle, const double* d_ys, - int nobs, const ARIMAOrder& order, - const MLCommon::LinAlg::Batched::Matrix& Zb, - const MLCommon::LinAlg::Batched::Matrix& Tb, - const MLCommon::LinAlg::Batched::Matrix& Rb, - std::vector& T_mask, double* d_vs, - double* d_Fs, double* d_loglike, - const double* d_sigma2, bool intercept, - const double* d_mu, int fc_steps, double* d_fc, - double level, double* d_lower, double* d_upper) { +void _batched_kalman_filter( + raft::handle_t& handle, const ARIMAMemory& arima_mem, + const double* d_ys, int nobs, const ARIMAOrder& order, + const MLCommon::LinAlg::Batched::Matrix& Zb, + const MLCommon::LinAlg::Batched::Matrix& Tb, + const MLCommon::LinAlg::Batched::Matrix& Rb, + std::vector& T_mask, double* d_vs, double* d_Fs, double* d_loglike, + const double* d_sigma2, bool intercept, const double* d_mu, int fc_steps, + double* d_fc, double level, double* d_lower, double* d_upper) { const size_t batch_size = Zb.batches(); auto stream = handle.get_stream(); auto cublasHandle = handle.get_cublas_handle(); @@ -681,8 +720,9 @@ void _batched_kalman_filter(raft::handle_t& handle, const double* d_ys, int rd = order.rd(); int r = order.r(); - MLCommon::LinAlg::Batched::Matrix RQb(rd, 1, batch_size, cublasHandle, - allocator, stream, true); + MLCommon::LinAlg::Batched::Matrix RQb( + rd, 1, batch_size, cublasHandle, arima_mem.RQ_batches, arima_mem.RQ_dense, + allocator, stream, true); double* d_RQ = RQb.raw_data(); const double* d_R = Rb.raw_data(); thrust::for_each(thrust::cuda::par.on(stream), counting, @@ -692,13 +732,17 @@ void _batched_kalman_filter(raft::handle_t& handle, const double* d_ys, d_RQ[bid * rd + i] = d_R[bid * rd + i] * sigma2; } }); - MLCommon::LinAlg::Batched::Matrix RQR = - MLCommon::LinAlg::Batched::b_gemm(RQb, Rb, false, true); + MLCommon::LinAlg::Batched::Matrix RQR( + rd, rd, batch_size, cublasHandle, arima_mem.RQR_batches, + arima_mem.RQR_dense, allocator, stream, false); + MLCommon::LinAlg::Batched::b_gemm(false, true, rd, rd, 1, 1.0, RQb, Rb, 0.0, + RQR); // Durbin Koopman "Time Series Analysis" pg 138 ML::PUSH_RANGE("Init P"); - MLCommon::LinAlg::Batched::Matrix P(rd, rd, batch_size, cublasHandle, - allocator, stream, true); + MLCommon::LinAlg::Batched::Matrix P( + rd, rd, batch_size, cublasHandle, arima_mem.P_batches, arima_mem.P_dense, + allocator, stream, true); { double* d_P = P.raw_data(); @@ -715,18 +759,25 @@ void _batched_kalman_filter(raft::handle_t& handle, const double* d_ys, }); // Initialize the stationary part by solving a Lyapunov equation - /// TODO: reduce amount of memory copies - MLCommon::LinAlg::Batched::Matrix Ts = - MLCommon::LinAlg::Batched::b_2dcopy(Tb, n_diff, n_diff, r, r); - MLCommon::LinAlg::Batched::Matrix RQRs = - MLCommon::LinAlg::Batched::b_2dcopy(RQR, n_diff, n_diff, r, r); - MLCommon::LinAlg::Batched::Matrix Ps = - MLCommon::LinAlg::Batched::b_lyapunov(Ts, RQRs); + MLCommon::LinAlg::Batched::Matrix Ts( + r, r, batch_size, cublasHandle, arima_mem.Ts_batches, + arima_mem.Ts_dense, allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix RQRs( + r, r, batch_size, cublasHandle, arima_mem.RQRs_batches, + arima_mem.RQRs_dense, allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix Ps( + r, r, batch_size, cublasHandle, arima_mem.Ps_batches, + arima_mem.Ps_dense, allocator, stream, false); + + MLCommon::LinAlg::Batched::b_2dcopy(Tb, Ts, n_diff, n_diff, r, r); + MLCommon::LinAlg::Batched::b_2dcopy(RQR, RQRs, n_diff, n_diff, r, r); + // Ps = MLCommon::LinAlg::Batched::b_lyapunov(Ts, RQRs); + _lyapunov_wrapper(handle, arima_mem, Ts, RQRs, Ps, r); MLCommon::LinAlg::Batched::b_2dcopy(Ps, P, 0, 0, r, r, n_diff, n_diff); } else { // Initialize by solving a Lyapunov equation - /// TODO: avoid copy - P = MLCommon::LinAlg::Batched::b_lyapunov(Tb, RQR); + // P = MLCommon::LinAlg::Batched::b_lyapunov(Tb, RQR); + _lyapunov_wrapper(handle, arima_mem, Tb, RQR, P, rd); } } ML::POP_RANGE(); @@ -739,12 +790,13 @@ void _batched_kalman_filter(raft::handle_t& handle, const double* d_ys, // T* = T[d+s*D:, d+s*D:] // x* = alpha_0[d+s*D:] MLCommon::LinAlg::Batched::Matrix alpha( - rd, 1, batch_size, handle.get_cublas_handle(), - handle.get_device_allocator(), stream, false); + rd, 1, batch_size, handle.get_cublas_handle(), arima_mem.alpha_batches, + arima_mem.alpha_dense, handle.get_device_allocator(), stream, false); if (intercept) { // Compute I-T* MLCommon::LinAlg::Batched::Matrix ImT( - r, r, batch_size, cublasHandle, allocator, stream, false); + r, r, batch_size, cublasHandle, arima_mem.ImT_batches, + arima_mem.ImT_dense, allocator, stream, false); const double* d_T = Tb.raw_data(); double* d_ImT = ImT.raw_data(); thrust::for_each(thrust::cuda::par.on(stream), counting, @@ -770,7 +822,11 @@ void _batched_kalman_filter(raft::handle_t& handle, const double* d_ys, } // Compute (I-T*)^-1 - MLCommon::LinAlg::Batched::Matrix ImT_inv = ImT.inv(); + MLCommon::LinAlg::Batched::Matrix ImT_inv( + r, r, batch_size, cublasHandle, arima_mem.ImT_inv_batches, + arima_mem.ImT_inv_dense, allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix::inv( + ImT, ImT_inv, arima_mem.ImT_inv_P, arima_mem.ImT_inv_info); // Compute (I-T*)^-1 * c -> multiply 1st column by mu const double* d_ImT_inv = ImT_inv.raw_data(); @@ -793,23 +849,21 @@ void _batched_kalman_filter(raft::handle_t& handle, const double* d_ys, sizeof(double) * rd * batch_size, stream)); } - MLCommon::device_buffer sumLogF_buffer(allocator, stream, batch_size); - - batched_kalman_loop(handle, d_ys, nobs, Tb, Zb, RQR, P, alpha, T_mask, - intercept, d_mu, order, d_vs, d_Fs, sumLogF_buffer.data(), - fc_steps, d_fc, level > 0, d_lower); + batched_kalman_loop(handle, arima_mem, d_ys, nobs, Tb, Zb, RQR, P, alpha, + T_mask, intercept, d_mu, order, d_vs, d_Fs, + arima_mem.sumLogF_buffer, fc_steps, d_fc, level > 0, + d_lower); // Finalize loglikelihood and prediction intervals - MLCommon::device_buffer sigma2_buffer(allocator, stream, batch_size); constexpr int NUM_THREADS = 128; batched_kalman_loglike_kernel <<>>( - d_vs, d_Fs, sumLogF_buffer.data(), nobs, batch_size, d_loglike, - sigma2_buffer.data(), n_diff, level); + d_vs, d_Fs, arima_mem.sumLogF_buffer, nobs, batch_size, d_loglike, + arima_mem.sigma2_buffer, n_diff, level); CUDA_CHECK(cudaPeekAtLastError()); if (level > 0) { confidence_intervals<<>>( - d_fc, sigma2_buffer.data(), d_lower, d_upper, fc_steps, + d_fc, arima_mem.sigma2_buffer, d_lower, d_upper, fc_steps, sqrt(2.0) * erfinv(level)); CUDA_CHECK(cudaPeekAtLastError()); } @@ -962,12 +1016,11 @@ void init_batched_kalman_matrices(raft::handle_t& handle, const double* d_ar, ML::POP_RANGE(); } -void batched_kalman_filter(raft::handle_t& handle, const double* d_ys, int nobs, - const ARIMAParams& params, - const ARIMAOrder& order, int batch_size, - double* d_loglike, double* d_vs, int fc_steps, - double* d_fc, double level, double* d_lower, - double* d_upper) { +void batched_kalman_filter( + raft::handle_t& handle, const ARIMAMemory& arima_mem, + const double* d_ys, int nobs, const ARIMAParams& params, + const ARIMAOrder& order, int batch_size, double* d_loglike, double* d_vs, + int fc_steps, double* d_fc, double level, double* d_lower, double* d_upper) { ML::PUSH_RANGE(__func__); auto cublasHandle = handle.get_cublas_handle(); @@ -977,12 +1030,15 @@ void batched_kalman_filter(raft::handle_t& handle, const double* d_ys, int nobs, // see (3.18) in TSA by D&K int rd = order.rd(); - MLCommon::LinAlg::Batched::Matrix Zb(1, rd, batch_size, cublasHandle, - allocator, stream, false); - MLCommon::LinAlg::Batched::Matrix Tb(rd, rd, batch_size, cublasHandle, - allocator, stream, false); - MLCommon::LinAlg::Batched::Matrix Rb(rd, 1, batch_size, cublasHandle, - allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix Zb( + 1, rd, batch_size, cublasHandle, arima_mem.Z_batches, arima_mem.Z_dense, + allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix Tb( + rd, rd, batch_size, cublasHandle, arima_mem.T_batches, arima_mem.T_dense, + allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix Rb( + rd, 1, batch_size, cublasHandle, arima_mem.R_batches, arima_mem.R_dense, + allocator, stream, false); std::vector T_mask; init_batched_kalman_matrices(handle, params.ar, params.ma, params.sar, @@ -992,30 +1048,30 @@ void batched_kalman_filter(raft::handle_t& handle, const double* d_ys, int nobs, //////////////////////////////////////////////////////////// // Computation - MLCommon::device_buffer F_buffer(allocator, stream, - nobs * batch_size); - - _batched_kalman_filter(handle, d_ys, nobs, order, Zb, Tb, Rb, T_mask, d_vs, - F_buffer.data(), d_loglike, params.sigma2, - static_cast(order.k), params.mu, fc_steps, d_fc, - level, d_lower, d_upper); + _batched_kalman_filter(handle, arima_mem, d_ys, nobs, order, Zb, Tb, Rb, + T_mask, d_vs, arima_mem.F_buffer, d_loglike, + params.sigma2, static_cast(order.k), params.mu, + fc_steps, d_fc, level, d_lower, d_upper); ML::POP_RANGE(); } -void batched_jones_transform(raft::handle_t& handle, const ARIMAOrder& order, - int batch_size, bool isInv, const double* h_params, +void batched_jones_transform(raft::handle_t& handle, + const ARIMAMemory& arima_mem, + const ARIMAOrder& order, int batch_size, + bool isInv, const double* h_params, double* h_Tparams) { int N = order.complexity(); auto allocator = handle.get_device_allocator(); auto stream = handle.get_stream(); - double* d_params = - (double*)allocator->allocate(N * batch_size * sizeof(double), stream); - double* d_Tparams = - (double*)allocator->allocate(N * batch_size * sizeof(double), stream); - ARIMAParams params, Tparams; - params.allocate(order, batch_size, allocator, stream, false); - Tparams.allocate(order, batch_size, allocator, stream, true); + double* d_params = arima_mem.d_params; + double* d_Tparams = arima_mem.d_Tparams; + ARIMAParams params = {arima_mem.params_mu, arima_mem.params_ar, + arima_mem.params_ma, arima_mem.params_sar, + arima_mem.params_sma, arima_mem.params_sigma2}; + ARIMAParams Tparams = { + arima_mem.Tparams_mu, arima_mem.Tparams_ar, arima_mem.Tparams_ma, + arima_mem.Tparams_sar, arima_mem.Tparams_sma, arima_mem.Tparams_sigma2}; raft::update_device(d_params, h_params, N * batch_size, stream); @@ -1028,11 +1084,6 @@ void batched_jones_transform(raft::handle_t& handle, const ARIMAOrder& order, Tparams.pack(order, batch_size, d_Tparams, stream); raft::update_host(h_Tparams, d_Tparams, N * batch_size, stream); - - allocator->deallocate(d_params, N * batch_size * sizeof(double), stream); - allocator->deallocate(d_Tparams, N * batch_size * sizeof(double), stream); - params.deallocate(order, batch_size, allocator, stream, false); - Tparams.deallocate(order, batch_size, allocator, stream, true); } } // namespace ML diff --git a/cpp/src_prims/linalg/batched/matrix.cuh b/cpp/src_prims/linalg/batched/matrix.cuh index 9bce5af745..08c8c6a9df 100644 --- a/cpp/src_prims/linalg/batched/matrix.cuh +++ b/cpp/src_prims/linalg/batched/matrix.cuh @@ -144,15 +144,14 @@ class Matrix { // Fill with zeros if requested if (setZero) CUDA_CHECK(cudaMemsetAsync( - m_dense.data(), 0, + raw_data(), 0, sizeof(T) * m_shape.first * m_shape.second * m_batch_size, m_stream)); // Fill array of pointers to each batch matrix. constexpr int TPB = 256; fill_strided_pointers_kernel<<(m_batch_size, TPB), TPB, 0, m_stream>>>( - m_dense.data(), m_batches.data(), m_batch_size, m_shape.first, - m_shape.second); + raw_data(), data(), m_batch_size, m_shape.first, m_shape.second); CUDA_CHECK(cudaPeekAtLastError()); } @@ -177,7 +176,41 @@ class Matrix { m_stream(stream), m_shape(m, n), m_batches(allocator, stream, batch_size), - m_dense(allocator, stream, m * n * batch_size) { + m_dense(allocator, stream, m * n * batch_size), + d_batches(m_batches.data()), + d_dense(m_dense.data()) { + initialize(setZero); + } + + /** + * @brief Constructor that uses pre-allocated memory. + * @note The given arrays don't need to be initialized prior to constructing this object. + * Memory ownership is retained by the caller, not this object! + * Some methods might still allocate temporary memory with the provided allocator. + * + * @param[in] m Number of rows + * @param[in] n Number of columns + * @param[in] batch_size Number of matrices in the batch + * @param[in] cublasHandle cuBLAS handle + * @param[in] allocator Device memory allocator + * @param[in] d_batches Pre-allocated pointers array: batch_size * sizeof(T*) + * @param[in] d_dense Pre-allocated data array: m * n * batch_size * sizeof(T) + * @param[in] stream CUDA stream + * @param[in] setZero Should matrix be zeroed on allocation? + */ + Matrix(int m, int n, int batch_size, cublasHandle_t cublasHandle, + T** d_batches, T* d_dense, + std::shared_ptr allocator, + cudaStream_t stream, bool setZero = true) + : m_batch_size(batch_size), + m_allocator(allocator), + m_cublasHandle(cublasHandle), + m_stream(stream), + m_shape(m, n), + m_batches(allocator, stream, 0), + m_dense(allocator, stream, 0), + d_batches(d_batches), + d_dense(d_dense) { initialize(setZero); } @@ -193,11 +226,13 @@ class Matrix { m_shape(other.m_shape), m_batches(other.m_allocator, other.m_stream, other.m_batch_size), m_dense(other.m_allocator, other.m_stream, - other.m_shape.first * other.m_shape.second * other.m_batch_size) { + other.m_shape.first * other.m_shape.second * other.m_batch_size), + d_batches(m_batches.data()), + d_dense(m_dense.data()) { initialize(false); // Copy the raw data - raft::copy(m_dense.data(), other.m_dense.data(), + raft::copy(raw_data(), other.raw_data(), m_batch_size * m_shape.first * m_shape.second, m_stream); } @@ -208,10 +243,12 @@ class Matrix { m_batches.resize(m_batch_size, m_stream); m_dense.resize(m_batch_size * m_shape.first * m_shape.second, m_stream); + d_batches = m_batches.data(); + d_dense = m_dense.data(); initialize(false); // Copy the raw data - raft::copy(m_dense.data(), other.m_dense.data(), + raft::copy(raw_data(), other.raw_data(), m_batch_size * m_shape.first * m_shape.second, m_stream); return *this; @@ -234,13 +271,13 @@ class Matrix { //! Return shape const std::pair& shape() const { return m_shape; } - //! Return pointer array - T** data() { return m_batches.data(); } - const T** data() const { return m_batches.data(); } + //! Return array of pointers to the offsets in the data buffer + const T** data() const { return d_batches; } + T** data() { return d_batches; } //! Return pointer to the underlying memory - T* raw_data() { return m_dense.data(); } - const T* raw_data() const { return m_dense.data(); } + const T* raw_data() const { return d_dense; } + T* raw_data() { return d_dense; } /** * @brief Return pointer to the data of a specific matrix @@ -249,7 +286,7 @@ class Matrix { * @return A pointer to the raw data of the matrix */ T* operator[](int id) const { - return &(m_dense.data()[id * m_shape.first * m_shape.second]); + return &(raw_data()[id * m_shape.first * m_shape.second]); } /** @@ -273,7 +310,7 @@ class Matrix { int r = m * n; Matrix toVec(r, 1, m_batch_size, m_cublasHandle, m_allocator, m_stream, false); - raft::copy(toVec[0], m_dense.data(), m_batch_size * r, m_stream); + raft::copy(toVec[0], raw_data(), m_batch_size * r, m_stream); return toVec; } @@ -290,7 +327,7 @@ class Matrix { "ERROR: Size mismatch - Cannot reshape array into desired size"); Matrix toMat(m, n, m_batch_size, m_cublasHandle, m_allocator, m_stream, false); - raft::copy(toMat[0], m_dense.data(), m_batch_size * r, m_stream); + raft::copy(toMat[0], raw_data(), m_batch_size * r, m_stream); return toMat; } @@ -299,7 +336,7 @@ class Matrix { void print(std::string name) const { size_t len = m_shape.first * m_shape.second * m_batch_size; std::vector A(len); - raft::update_host(A.data(), m_dense.data(), len, m_stream); + raft::update_host(A.data(), raw_data(), len, m_stream); std::cout << name << "=\n"; for (int i = 0; i < m_shape.first; i++) { for (int j = 0; j < m_shape.second; j++) { @@ -339,6 +376,25 @@ class Matrix { return out; } + /** + * @brief Compute the inverse of a batched matrix and write it to another matrix + * + * @param[inout] A Matrix to inverse. Overwritten by its LU factorization! + * @param[out] Ainv Inversed matrix + * @param[out] d_P Pre-allocated array of size n * batch_size * sizeof(int) + * @param[out] d_info Pre-allocated array of size batch_size * sizeof(int) + */ + static void inv(Matrix& A, Matrix& Ainv, int* d_P, int* d_info) { + int n = A.m_shape.first; + + CUBLAS_CHECK(raft::linalg::cublasgetrfBatched(A.m_cublasHandle, n, A.data(), + n, d_P, d_info, + A.m_batch_size, A.m_stream)); + CUBLAS_CHECK(raft::linalg::cublasgetriBatched( + A.m_cublasHandle, n, A.data(), n, d_P, Ainv.data(), n, d_info, + A.m_batch_size, A.m_stream)); + } + /** * @brief Compute the inverse of the batched matrix * @@ -358,11 +414,7 @@ class Matrix { Matrix Ainv(n, n, m_batch_size, m_cublasHandle, m_allocator, m_stream, false); - CUBLAS_CHECK(raft::linalg::cublasgetrfBatched( - m_cublasHandle, n, Acopy.data(), n, P, info, m_batch_size, m_stream)); - CUBLAS_CHECK(raft::linalg::cublasgetriBatched( - m_cublasHandle, n, Acopy.data(), n, P, Ainv.data(), n, info, m_batch_size, - m_stream)); + Matrix::inv(Acopy, Ainv, P, info); m_allocator->deallocate(P, sizeof(int) * n * m_batch_size, m_stream); m_allocator->deallocate(info, sizeof(int) * m_batch_size, m_stream); @@ -381,8 +433,8 @@ class Matrix { Matrix At(n, m, m_batch_size, m_cublasHandle, m_allocator, m_stream); - const T* d_A = m_dense.data(); - T* d_At = At.m_dense.data(); + const T* d_A = raw_data(); + T* d_At = At.raw_data(); // Naive batched transpose ; TODO: improve auto counting = thrust::make_counting_iterator(0); @@ -426,11 +478,13 @@ class Matrix { //! Shape (rows, cols) of matrices. We assume all matrices in batch have same shape. std::pair m_shape; - //! Array(pointer) to each matrix. + //! Pointers to each matrix in the contiguous data buffer (strided offsets) device_buffer m_batches; + T** d_batches; // When pre-allocated //! Data pointer to first element of dense matrix data. device_buffer m_dense; + T* d_dense; // When pre-allocated //! Number of matrices in batch int m_batch_size; @@ -446,27 +500,28 @@ class Matrix { * @note The block x is the batch id, the thread x is the starting row * in B and the thread y is the starting column in B * - * @param[in] A Pointer to the raw data of matrix `A` - * @param[in] m Number of rows (A) - * @param[in] n Number of columns (A) - * @param[in] B Pointer to the raw data of matrix `B` - * @param[in] p Number of rows (B) - * @param[in] q Number of columns (B) - * @param[out] AkB Pointer to raw data of the result kronecker product - * @param[in] k_m Number of rows of the result (m * p) - * @param[in] k_n Number of columns of the result (n * q) + * @param[in] A Pointer to the raw data of matrix `A` + * @param[in] m Number of rows (A) + * @param[in] n Number of columns (A) + * @param[in] B Pointer to the raw data of matrix `B` + * @param[in] p Number of rows (B) + * @param[in] q Number of columns (B) + * @param[out] AkB Pointer to raw data of the result kronecker product + * @param[in] k_m Number of rows of the result (m * p) + * @param[in] k_n Number of columns of the result (n * q) + * @param[in] alpha Multiplying coefficient */ template __global__ void kronecker_product_kernel(const T* A, int m, int n, const T* B, - int p, int q, T* AkB, int k_m, - int k_n) { + int p, int q, T* AkB, int k_m, int k_n, + T alpha) { const T* A_b = A + blockIdx.x * m * n; const T* B_b = B + blockIdx.x * p * q; T* AkB_b = AkB + blockIdx.x * k_m * k_n; for (int ia = 0; ia < m; ia++) { for (int ja = 0; ja < n; ja++) { - T A_ia_ja = A_b[ia + ja * m]; + T A_ia_ja = alpha * A_b[ia + ja * m]; for (int ib = threadIdx.x; ib < p; ib += blockDim.x) { for (int jb = threadIdx.y; jb < q; jb += blockDim.y) { @@ -699,6 +754,40 @@ Matrix b_solve(const Matrix& A, const Matrix& b) { return x; } +/** + * @brief The batched kroneker product for batched matrices A and B + * + * Calculates AkB = alpha * A (x) B + * + * @param[in] A Matrix A + * @param[in] B Matrix B + * @param[out] AkB A (x) B + * @param[in] alpha Multiplying coefficient + */ +template +void b_kron(const Matrix& A, const Matrix& B, Matrix& AkB, + T alpha = (T)1) { + int m = A.shape().first; + int n = A.shape().second; + + int p = B.shape().first; + int q = B.shape().second; + + // Resulting shape + int k_m = m * p; + int k_n = n * q; + ASSERT(AkB.shape().first == k_m, + "Kronecker product output dimensions mismatch"); + ASSERT(AkB.shape().second == k_n, + "Kronecker product output dimensions mismatch"); + + // Run kronecker + dim3 threads(std::min(p, 32), std::min(q, 32)); + kronecker_product_kernel<<>>( + A.raw_data(), m, n, B.raw_data(), p, q, AkB.raw_data(), k_m, k_n, alpha); + CUDA_CHECK(cudaPeekAtLastError()); +} + /** * @brief The batched kroneker product A (x) B for given batched matrix A * and batched matrix B @@ -722,11 +811,8 @@ Matrix b_kron(const Matrix& A, const Matrix& B) { Matrix AkB(k_m, k_n, A.batches(), A.cublasHandle(), A.allocator(), A.stream()); - // Run kronecker - dim3 threads(std::min(p, 32), std::min(q, 32)); - kronecker_product_kernel<<>>( - A.raw_data(), m, n, B.raw_data(), p, q, AkB.raw_data(), k_m, k_n); - CUDA_CHECK(cudaPeekAtLastError()); + b_kron(A, B, AkB); + return AkB; } @@ -1690,6 +1776,36 @@ Matrix b_trsyl_uplo(const Matrix& R, const Matrix& S, return Y; } +/// Auxiliary function for the direct Lyapunov solver +template +void _direct_lyapunov_helper(const Matrix& A, Matrix& Q, Matrix& X, + Matrix& I_m_AxA, Matrix& I_m_AxA_inv, int* P, + int* info, int r) { + auto stream = A.stream(); + int batch_size = A.batches(); + int r2 = r * r; + auto counting = thrust::make_counting_iterator(0); + + b_kron(A, A, I_m_AxA, (T)-1); + + T* d_I_m_AxA = I_m_AxA.raw_data(); + thrust::for_each(thrust::cuda::par.on(stream), counting, + counting + batch_size, [=] __device__(int ib) { + T* b_I_m_AxA = d_I_m_AxA + ib * r2 * r2; + for (int i = 0; i < r2; i++) { + b_I_m_AxA[(r2 + 1) * i] += 1.0; + } + }); + + Matrix::inv(I_m_AxA, I_m_AxA_inv, P, info); + + Q.reshape(r2, 1); + X.reshape(r2, 1); + b_gemm(false, false, r2, 1, r2, (T)1, I_m_AxA_inv, Q, (T)0, X); + Q.reshape(r, r); + X.reshape(r, r); +} + /** * @brief Solve discrete Lyapunov equation A*X*A' - X + Q = 0 * @@ -1711,23 +1827,22 @@ Matrix b_lyapunov(const Matrix& A, Matrix& Q) { int n2 = n * n; auto counting = thrust::make_counting_iterator(0); - if (n <= 4 || (n <= 6 && batch_size <= 96)) { + if (n <= 5) { // // Use direct solution with Kronecker product // - Matrix I_m_AxA = b_kron(-A, A); - T* d_I_m_AxA = I_m_AxA.raw_data(); - thrust::for_each(thrust::cuda::par.on(stream), counting, - counting + batch_size, [=] __device__(int ib) { - T* b_I_m_AxA = d_I_m_AxA + ib * n2 * n2; - for (int i = 0; i < n * n; i++) { - b_I_m_AxA[(n2 + 1) * i] += (T)1; - } - }); - Q.reshape(n2, 1); - Matrix X = b_solve(I_m_AxA, Q); - Q.reshape(n, n); - X.reshape(n, n); + MLCommon::LinAlg::Batched::Matrix I_m_AxA( + n2, n2, batch_size, A.cublasHandle(), allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix I_m_AxA_inv( + n2, n2, batch_size, A.cublasHandle(), allocator, stream, false); + MLCommon::LinAlg::Batched::Matrix X(n, n, batch_size, A.cublasHandle(), + allocator, stream, false); + int* P = (int*)allocator->allocate(sizeof(int) * n * batch_size, stream); + int* info = (int*)allocator->allocate(sizeof(int) * batch_size, stream); + MLCommon::LinAlg::Batched::_direct_lyapunov_helper(A, Q, X, I_m_AxA, + I_m_AxA_inv, P, info, n); + allocator->deallocate(P, sizeof(int) * n * batch_size, stream); + allocator->deallocate(info, sizeof(int) * batch_size, stream); return X; } else { // diff --git a/cpp/src_prims/sparse/batched/csr.cuh b/cpp/src_prims/sparse/batched/csr.cuh index 329b1f2290..6af5eaadbe 100644 --- a/cpp/src_prims/sparse/batched/csr.cuh +++ b/cpp/src_prims/sparse/batched/csr.cuh @@ -152,7 +152,43 @@ class CSR { m_nnz(nnz), m_values(allocator, stream, nnz * batch_size), m_col_index(allocator, stream, nnz), - m_row_index(allocator, stream, m + 1) {} + m_row_index(allocator, stream, m + 1), + d_values(m_values.data()), + d_row_index(m_row_index.data()), + d_col_index(m_col_index.data()) {} + + /** + * @brief Constructor from pre-allocated memory; leaves the matrix uninitialized + * + * @param[in] m Number of rows per matrix + * @param[in] n Number of columns per matrix + * @param[in] nnz Number of non-zero elements per matrix + * @param[in] batch_size Number of matrices in the batch + * @param[in] cublasHandle cuBLAS handle + * @param[in] cusolverSpHandle cuSOLVER sparse handle + * @param[in] d_values Pre-allocated values array + * @param[in] d_col_index Pre-allocated column index array + * @param[in] d_row_index Pre-allocated row index array + * @param[in] allocator Device memory allocator + * @param[in] stream CUDA stream + */ + CSR(int m, int n, int nnz, int batch_size, cublasHandle_t cublasHandle, + cusolverSpHandle_t cusolverSpHandle, T* d_values, int* d_col_index, + int* d_row_index, std::shared_ptr allocator, + cudaStream_t stream) + : m_batch_size(batch_size), + m_allocator(allocator), + m_cublasHandle(cublasHandle), + m_cusolverSpHandle(cusolverSpHandle), + m_stream(stream), + m_shape(m, n), + m_nnz(nnz), + m_values(allocator, stream, nnz * batch_size), + m_col_index(allocator, stream, nnz), + m_row_index(allocator, stream, m + 1), + d_values(d_values), + d_col_index(d_col_index), + d_row_index(d_row_index) {} //! Destructor: nothing to destroy explicitely ~CSR() {} @@ -169,12 +205,15 @@ class CSR { m_values(other.m_allocator, other.m_stream, other.m_nnz * other.m_batch_size), m_col_index(other.m_allocator, other.m_stream, other.m_nnz), - m_row_index(other.m_allocator, other.m_stream, other.m_shape.first + 1) { + m_row_index(other.m_allocator, other.m_stream, other.m_shape.first + 1), + d_values(m_values.data()), + d_row_index(m_row_index.data()), + d_col_index(m_col_index.data()) { // Copy the raw data - raft::copy(m_values.data(), other.m_values.data(), m_nnz * m_batch_size, + raft::copy(get_values(), other.get_values(), m_nnz * m_batch_size, m_stream); - raft::copy(m_col_index.data(), other.m_col_index.data(), m_nnz, m_stream); - raft::copy(m_row_index.data(), other.m_row_index.data(), m_shape.first + 1, + raft::copy(get_col_index(), other.get_col_index(), m_nnz, m_stream); + raft::copy(get_row_index(), other.get_row_index(), m_shape.first + 1, m_stream); } @@ -187,12 +226,15 @@ class CSR { m_values.resize(m_nnz * m_batch_size, m_stream); m_col_index.resize(m_nnz, m_stream); m_row_index.resize(m_shape.first + 1, m_stream); + d_values = m_values.data(); + d_col_index = m_col_index.data(); + d_row_index = m_row_index.data(); // Copy the raw data - raft::copy(m_values.data(), other.m_values.data(), m_nnz * m_batch_size, + raft::copy(get_values(), other.get_values(), m_nnz * m_batch_size, m_stream); - raft::copy(m_col_index.data(), other.m_col_index.data(), m_nnz, m_stream); - raft::copy(m_row_index.data(), other.m_row_index.data(), m_shape.first + 1, + raft::copy(get_col_index(), other.get_col_index(), m_nnz, m_stream); + raft::copy(get_row_index(), other.get_row_index(), m_shape.first + 1, m_stream); return *this; @@ -201,19 +243,24 @@ class CSR { /** * @brief Construct from a dense batched matrix and its mask * - * @param[in] dense Dense batched matrix - * @param[in] mask Col-major host device matrix containing a mask of the - * non-zero values common to all matrices in the batch. - * Note: the point of using a mask is that some values - * might be zero in a few matrices but not generally in - * the batch so we shouldn't rely on a single matrix to - * get the mask + * @param[in] dense Dense batched matrix + * @param[in] mask Col-major host device matrix containing a mask of the + * non-zero values common to all matrices in the batch. + * Note: the point of using a mask is that some values + * might be zero in a few matrices but not generally in + * the batch so we shouldn't rely on a single matrix to + * get the mask * @param[in] cusolverSpHandle cusolver sparse handle + * @param[in] d_values Optional pre-allocated values array + * @param[in] d_col_index Optional pre-allocated column index array + * @param[in] d_row_index Optional pre-allocated row index array * @return Batched CSR matrix */ static CSR from_dense(const LinAlg::Batched::Matrix& dense, const std::vector& mask, - cusolverSpHandle_t cusolverSpHandle) { + cusolverSpHandle_t cusolverSpHandle, + T* d_values = nullptr, int* d_col_index = nullptr, + int* d_row_index = nullptr) { std::pair shape = dense.shape(); // Create the index arrays from the mask @@ -231,9 +278,14 @@ class CSR { } h_row_index[shape.first] = nnz; - CSR out = CSR(shape.first, shape.second, nnz, dense.batches(), - dense.cublasHandle(), cusolverSpHandle, - dense.allocator(), dense.stream()); + CSR out = + (d_values == nullptr) + ? CSR(shape.first, shape.second, nnz, dense.batches(), + dense.cublasHandle(), cusolverSpHandle, dense.allocator(), + dense.stream()) + : CSR(shape.first, shape.second, nnz, dense.batches(), + dense.cublasHandle(), cusolverSpHandle, d_values, d_col_index, + d_row_index, dense.allocator(), dense.stream()); // Copy the host index arrays to the device raft::copy(out.get_col_index(), h_col_index.data(), nnz, out.stream()); @@ -265,7 +317,7 @@ class CSR { constexpr int TPB = 256; csr_to_dense_kernel<<(m_batch_size, TPB), TPB, 0, m_stream>>>( - dense.raw_data(), m_col_index.data(), m_row_index.data(), m_values.data(), + dense.raw_data(), get_col_index(), get_row_index(), get_values(), m_batch_size, m_shape.first, m_shape.second, m_nnz); CUDA_CHECK(cudaPeekAtLastError()); @@ -296,16 +348,16 @@ class CSR { const std::pair& shape() const { return m_shape; } //! Return values array - T* get_values() { return m_values.data(); } - const T* get_values() const { return m_values.data(); } + T* get_values() { return d_values; } + const T* get_values() const { return d_values; } //! Return columns index array - int* get_col_index() { return m_col_index.data(); } - const int* get_col_index() const { return m_col_index.data(); } + int* get_col_index() { return d_col_index; } + const int* get_col_index() const { return d_col_index; } //! Return rows index array - int* get_row_index() { return m_row_index.data(); } - const int* get_row_index() const { return m_row_index.data(); } + int* get_row_index() { return d_row_index; } + const int* get_row_index() const { return d_row_index; } protected: //! Shape (rows, cols) of matrices. @@ -316,12 +368,15 @@ class CSR { //! Array(pointer) to the values in all the batched matrices. device_buffer m_values; + T* d_values; //! Array(pointer) to the column index of the CSR. device_buffer m_col_index; + int* d_col_index; //! Array(pointer) to the row index of the CSR. device_buffer m_row_index; + int* d_row_index; //! Number of matrices in batch size_t m_batch_size; diff --git a/python/cuml/tsa/arima.pyx b/python/cuml/tsa/arima.pyx index 060ae2e32c..7c04053e77 100644 --- a/python/cuml/tsa/arima.pyx +++ b/python/cuml/tsa/arima.pyx @@ -48,6 +48,13 @@ cdef extern from "cuml/tsa/arima_common.h" namespace "ML": DataT* sma DataT* sigma2 + cdef cppclass ARIMAMemory[DataT]: + ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs, + char* in_buf) + + @staticmethod + size_t compute_size(const ARIMAOrder& order, int batch_size, int n_obs) + cdef extern from "cuml/tsa/batched_arima.hpp" namespace "ML": ctypedef enum LoglikeMethod: CSS, MLE @@ -65,32 +72,34 @@ cdef extern from "cuml/tsa/batched_arima.hpp" namespace "ML": int n_obs, const ARIMAOrder& order) void batched_loglike( - handle_t& handle, const double* y, int batch_size, int nobs, - const ARIMAOrder& order, const double* params, double* loglike, - double* d_vs, bool trans, bool host_loglike, LoglikeMethod method, - int truncate) + handle_t& handle, const ARIMAMemory[double]& arima_mem, + const double* y, int batch_size, int nobs, const ARIMAOrder& order, + const double* params, double* loglike, double* d_vs, bool trans, + bool host_loglike, LoglikeMethod method, int truncate) void batched_loglike( - handle_t& handle, const double* y, int batch_size, int n_obs, - const ARIMAOrder& order, const ARIMAParams[double]& params, - double* loglike, double* d_vs, bool trans, bool host_loglike, - LoglikeMethod method, int truncate) + handle_t& handle, const ARIMAMemory[double]& arima_mem, + const double* y, int batch_size, int n_obs, const ARIMAOrder& order, + const ARIMAParams[double]& params, double* loglike, double* d_vs, + bool trans, bool host_loglike, LoglikeMethod method, int truncate) void batched_loglike_grad( - handle_t& handle, const double* d_y, int batch_size, int nobs, - const ARIMAOrder& order, const double* d_x, double* d_grad, double h, - bool trans, LoglikeMethod method, int truncate) + handle_t& handle, const ARIMAMemory[double]& arima_mem, + const double* d_y, int batch_size, int nobs, const ARIMAOrder& order, + const double* d_x, double* d_grad, double h, bool trans, + LoglikeMethod method, int truncate) void cpp_predict "predict" ( - handle_t& handle, const double* d_y, int batch_size, int nobs, - int start, int end, const ARIMAOrder& order, - const ARIMAParams[double]& params, double* d_y_p, bool pre_diff, - double level, double* d_lower, double* d_upper) + handle_t& handle, const ARIMAMemory[double]& arima_mem, + const double* d_y, int batch_size, int nobs, int start, int end, + const ARIMAOrder& order, const ARIMAParams[double]& params, + double* d_y_p, bool pre_diff, double level, double* d_lower, + double* d_upper) void information_criterion( - handle_t& handle, const double* d_y, int batch_size, int nobs, - const ARIMAOrder& order, const ARIMAParams[double]& params, - double* ic, int ic_type) + handle_t& handle, const ARIMAMemory[double]& arima_mem, + const double* d_y, int batch_size, int nobs, const ARIMAOrder& order, + const ARIMAParams[double]& params, double* ic, int ic_type) void estimate_x0( handle_t& handle, ARIMAParams[double]& params, const double* d_y, @@ -100,8 +109,9 @@ cdef extern from "cuml/tsa/batched_arima.hpp" namespace "ML": cdef extern from "cuml/tsa/batched_kalman.hpp" namespace "ML": void batched_jones_transform( - handle_t& handle, const ARIMAOrder& order, int batchSize, - bool isInv, const double* h_params, double* h_Tparams) + handle_t& handle, ARIMAMemory[double]& arima_mem, + const ARIMAOrder& order, int batchSize, bool isInv, + const double* h_params, double* h_Tparams) cdef class ARIMAParamsWrapper: @@ -266,6 +276,7 @@ class ARIMA(Base): d_y = CumlArrayDescriptor() # TODO: (MDD) Should this be public? Its not listed in the attributes doc _d_y_diff = CumlArrayDescriptor() + _temp_mem = CumlArrayDescriptor() mu_ = CumlArrayDescriptor() ar_ = CumlArrayDescriptor() @@ -339,6 +350,11 @@ class ARIMA(Base): self.n_obs_diff = self.n_obs - d - D * s + # Allocate temporary storage + temp_mem_size = ARIMAMemory[double].compute_size( + cpp_order, self.batch_size, self.n_obs) + self._temp_mem = CumlArray.empty(temp_mem_size, np.byte) + self._initial_calc() @cuml.internals.api_base_return_any_skipall @@ -379,6 +395,7 @@ class ARIMA(Base): """ cdef handle_t* handle_ = self.handle.getHandle() + cdef ARIMAOrder order = self.order cdef ARIMAOrder order_kf = \ self.order_diff if self.simple_differencing else self.order cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params @@ -398,10 +415,17 @@ class ARIMA(Base): except KeyError as e: raise NotImplementedError("IC type '{}' unknown".format(ic_type)) - information_criterion(handle_[0], d_y_kf_ptr, - self.batch_size, n_obs_kf, - order_kf, cpp_params, d_ic_ptr, - ic_type_id) + cdef uintptr_t d_temp_mem = self._temp_mem.ptr + arima_mem_ptr = new ARIMAMemory[double]( + order, self.batch_size, self.n_obs, + d_temp_mem) + + information_criterion(handle_[0], arima_mem_ptr[0], + d_y_kf_ptr, self.batch_size, + n_obs_kf, order_kf, cpp_params, + d_ic_ptr, ic_type_id) + + del arima_mem_ptr return ic @@ -586,13 +610,20 @@ class ARIMA(Base): cdef uintptr_t d_y_ptr = self.d_y.ptr - cpp_predict(handle_[0], d_y_ptr, self.batch_size, - self.n_obs, start, end, order, - cpp_params, d_y_p_ptr, + cdef uintptr_t d_temp_mem = self._temp_mem.ptr + arima_mem_ptr = new ARIMAMemory[double]( + order, self.batch_size, self.n_obs, + d_temp_mem) + + cpp_predict(handle_[0], arima_mem_ptr[0], d_y_ptr, + self.batch_size, self.n_obs, start, + end, order, cpp_params, d_y_p_ptr, self.simple_differencing, (0 if level is None else level), d_lower_ptr, d_upper_ptr) + del arima_mem_ptr + if level is None: return d_y_p else: @@ -806,6 +837,7 @@ class ARIMA(Base): cdef LoglikeMethod ll_method = CSS if method == "css" else MLE diff = ll_method != MLE or self.simple_differencing + cdef ARIMAOrder order = self.order cdef ARIMAOrder order_kf = self.order_diff if diff else self.order d_x_array, *_ = \ @@ -817,17 +849,25 @@ class ARIMA(Base): cdef handle_t* handle_ = self.handle.getHandle() + # TODO: don't create vs array every time! n_obs_kf = (self.n_obs_diff if diff else self.n_obs) d_vs = CumlArray.empty((n_obs_kf, self.batch_size), dtype=np.float64, order="F") cdef uintptr_t d_vs_ptr = d_vs.ptr - batched_loglike(handle_[0], d_y_kf_ptr, + cdef uintptr_t d_temp_mem = self._temp_mem.ptr + arima_mem_ptr = new ARIMAMemory[double]( + order, self.batch_size, self.n_obs, + d_temp_mem) + + batched_loglike(handle_[0], arima_mem_ptr[0], d_y_kf_ptr, self.batch_size, n_obs_kf, order_kf, d_x_ptr, vec_loglike.data(), d_vs_ptr, trans, True, ll_method, truncate) + del arima_mem_ptr + return np.array(vec_loglike, dtype=np.float64) @nvtx.annotate(message="tsa.arima.ARIMA._loglike_grad", @@ -869,6 +909,7 @@ class ARIMA(Base): grad = CumlArray.empty(N * self.batch_size, np.float64) cdef uintptr_t d_grad = grad.ptr + cdef ARIMAOrder order = self.order cdef ARIMAOrder order_kf = self.order_diff if diff else self.order d_x_array, *_ = \ @@ -880,13 +921,20 @@ class ARIMA(Base): cdef handle_t* handle_ = self.handle.getHandle() - batched_loglike_grad(handle_[0], d_y_kf_ptr, - self.batch_size, + cdef uintptr_t d_temp_mem = self._temp_mem.ptr + arima_mem_ptr = new ARIMAMemory[double]( + order, self.batch_size, self.n_obs, + d_temp_mem) + + batched_loglike_grad(handle_[0], arima_mem_ptr[0], + d_y_kf_ptr, self.batch_size, (self.n_obs_diff if diff else self.n_obs), order_kf, d_x_ptr, d_grad, h, trans, ll_method, truncate) + del arima_mem_ptr + return grad.to_output("numpy") @property @@ -902,6 +950,7 @@ class ARIMA(Base): cdef vector[double] vec_loglike vec_loglike.resize(self.batch_size) + cdef ARIMAOrder order = self.order cdef ARIMAOrder order_kf = \ self.order_diff if self.simple_differencing else self.order cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params @@ -919,12 +968,19 @@ class ARIMA(Base): order="F") cdef uintptr_t d_vs_ptr = d_vs.ptr - batched_loglike(handle_[0], d_y_kf_ptr, + cdef uintptr_t d_temp_mem = self._temp_mem.ptr + arima_mem_ptr = new ARIMAMemory[double]( + order, self.batch_size, self.n_obs, + d_temp_mem) + + batched_loglike(handle_[0], arima_mem_ptr[0], d_y_kf_ptr, self.batch_size, n_obs_kf, order_kf, cpp_params, vec_loglike.data(), d_vs_ptr, False, True, ll_method, 0) + del arima_mem_ptr + return np.array(vec_loglike, dtype=np.float64) @nvtx.annotate(message="tsa.arima.ARIMA.unpack", domain="cuml_python") @@ -1000,9 +1056,17 @@ class ARIMA(Base): cdef handle_t* handle_ = self.handle.getHandle() Tx = np.zeros(self.batch_size * N) + cdef uintptr_t d_temp_mem = self._temp_mem.ptr + arima_mem_ptr = new ARIMAMemory[double]( + order, self.batch_size, self.n_obs, + d_temp_mem) + cdef uintptr_t x_ptr = x.ctypes.data cdef uintptr_t Tx_ptr = Tx.ctypes.data - batched_jones_transform(handle_[0], order, self.batch_size, - isInv, x_ptr, Tx_ptr) + batched_jones_transform( + handle_[0], arima_mem_ptr[0], order, self.batch_size, + isInv, x_ptr, Tx_ptr) + + del arima_mem_ptr return (Tx)