From c4830f61f65b3388c78cb86a6bbe18488ea5576e Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Mon, 26 Jul 2021 23:35:50 +0200 Subject: [PATCH] Fix SVM model parameter handling in case n_support=0 (#4097) Fixes #4033 This PR fixes SVM model parameter handling in case the fitted model has no support vectors, only bias. C++ side changes: - The bias calculation is updated to calculate the bias as the average function value in this case. - The prediction function is modified to avoid kernel function calculation in this case. - Added an SVR unit test to check model fitting and prediction. Python side changes: - It was incorrectly assumed that n_support==0 means the model is not fitted correctly, this is removed. - Model attributes (`dual_coef_`, `support_`, `support_vectors_`) are defined as empty arrays in this case. - `coef_` attribute is an array of zeros if there are no support vectors. - Unit test added to check training prediction and model attributes. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/4097 --- cpp/src/svm/results.cuh | 18 ++++++-- cpp/src/svm/smosolver.cuh | 2 +- cpp/src/svm/svc_impl.cuh | 5 +- cpp/test/sg/svc_test.cu | 11 ++++- python/cuml/svm/svm_base.pyx | 90 +++++++++++++++++++++--------------- python/cuml/test/test_svm.py | 22 +++++++++ 6 files changed, 104 insertions(+), 44 deletions(-) diff --git a/cpp/src/svm/results.cuh b/cpp/src/svm/results.cuh index f931851f9d..17e21d6086 100644 --- a/cpp/src/svm/results.cuh +++ b/cpp/src/svm/results.cuh @@ -117,17 +117,17 @@ class Results { { CombineCoefs(alpha, val_tmp.data()); GetDualCoefs(val_tmp.data(), dual_coefs, n_support); + *b = CalcB(alpha, f, *n_support); if (*n_support > 0) { *idx = GetSupportVectorIndices(val_tmp.data(), *n_support); *x_support = CollectSupportVectors(*idx, *n_support); - *b = CalcB(alpha, f); - // Make sure that all pending GPU calculations finished before we return - CUDA_CHECK(cudaStreamSynchronize(stream)); } else { *dual_coefs = nullptr; *idx = nullptr; *x_support = nullptr; } + // Make sure that all pending GPU calculations finished before we return + CUDA_CHECK(cudaStreamSynchronize(stream)); } /** @@ -192,6 +192,7 @@ class Results { *n_support = SelectByCoef(val_tmp, n_rows, val_tmp, select_op, val_selected.data()); *dual_coefs = (math_t*)allocator->allocate(*n_support * sizeof(math_t), stream); raft::copy(*dual_coefs, val_selected.data(), *n_support, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); } /** @@ -218,15 +219,22 @@ class Results { * @param [in] f optimality indicator vector, size [n_rows] * @return the value of b */ - math_t CalcB(const math_t* alpha, const math_t* f) + math_t CalcB(const math_t* alpha, const math_t* f, int n_support) { + if (n_support == 0) { + math_t f_sum; + cub::DeviceReduce::Sum( + cub_storage.data(), cub_bytes, f, d_val_reduced.data(), n_train, stream); + raft::update_host(&f_sum, d_val_reduced.data(), 1, stream); + return -f_sum / n_train; + } // We know that for an unbound support vector i, the decision function // (before taking the sign) has value F(x_i) = y_i, where // F(x_i) = \sum_j y_j \alpha_j K(x_j, x_i) + b, and j runs through all // support vectors. The constant b can be expressed from these formulas. // Note that F and f denote different quantities. The lower case f is the // optimality indicator vector defined as - // f_i = y_i - \sum_j y_j \alpha_j K(x_j, x_i). + // f_i = - y_i + \sum_j y_j \alpha_j K(x_j, x_i). // For unbound support vectors f_i = -b. // Select f for unbound support vectors (0 < alpha < C) diff --git a/cpp/src/svm/smosolver.cuh b/cpp/src/svm/smosolver.cuh index 03ce45a955..b5231ea050 100644 --- a/cpp/src/svm/smosolver.cuh +++ b/cpp/src/svm/smosolver.cuh @@ -371,7 +371,7 @@ class SmoSolver { raft::linalg::unaryOp( f, yr, n_rows, [epsilon] __device__(math_t y) { return epsilon - y; }, stream); - // f_i = epsilon - y_i, for i \in [n_rows..2*n_rows-1] + // f_i = -epsilon - y_i, for i \in [n_rows..2*n_rows-1] raft::linalg::unaryOp( f + n_rows, yr, n_rows, [epsilon] __device__(math_t y) { return -epsilon - y; }, stream); } diff --git a/cpp/src/svm/svc_impl.cuh b/cpp/src/svm/svc_impl.cuh index c9f2ded154..f9f7ea8c37 100644 --- a/cpp/src/svm/svc_impl.cuh +++ b/cpp/src/svm/svc_impl.cuh @@ -122,6 +122,9 @@ void svcPredict(const raft::handle_t& handle, MLCommon::device_buffer K( handle_impl.get_device_allocator(), stream, n_batch * model.n_support); MLCommon::device_buffer y(handle_impl.get_device_allocator(), stream, n_rows); + if (model.n_support == 0) { + CUDA_CHECK(cudaMemsetAsync(y.data(), 0, n_rows * sizeof(math_t), stream)); + } MLCommon::device_buffer x_rbf(handle_impl.get_device_allocator(), stream); MLCommon::device_buffer idx(handle_impl.get_device_allocator(), stream); @@ -137,7 +140,7 @@ void svcPredict(const raft::handle_t& handle, // We process the input data batchwise: // - calculate the kernel values K[x_batch, x_support] // - calculate y(x_batch) = K[x_batch, x_support] * dual_coeffs - for (int i = 0; i < n_rows; i += n_batch) { + for (int i = 0; i < n_rows && model.n_support > 0; i += n_batch) { if (i + n_batch >= n_rows) { n_batch = n_rows - i; } math_t* x_ptr = nullptr; int ld1 = 0; diff --git a/cpp/test/sg/svc_test.cu b/cpp/test/sg/svc_test.cu index 3d378bc70a..15d4c18019 100644 --- a/cpp/test/sg/svc_test.cu +++ b/cpp/test/sg/svc_test.cu @@ -1448,7 +1448,16 @@ class SvrTest : public ::testing::Test { {1, 1, 1, 10, 2, 10, 1} // sample weights }, smoOutput2{ - 6, {}, -15.5, {3.9}, {1.0, 2.0, 3.0, 4.0, 6.0, 7.0}, {0, 1, 2, 3, 5, 6}, {}}}}; + 6, {}, -15.5, {3.9}, {1.0, 2.0, 3.0, 4.0, 6.0, 7.0}, {0, 1, 2, 3, 5, 6}, {}}}, + {SvrInput{ + svmParameter{1, 0, 100, 10, 1e-6, CUML_LEVEL_INFO, 0.1, EPSILON_SVR}, + KernelParams{LINEAR, 3, 1, 0}, + 7, // n_rows + 1, // n_cols + {1, 2, 3, 4, 5, 6, 7}, // x + {2, 2, 2, 2, 2, 2, 2} // y + }, + smoOutput2{0, {}, 2, {}, {}, {}, {}}}}; for (auto d : data) { auto p = d.first; auto exp = d.second; diff --git a/python/cuml/svm/svm_base.pyx b/python/cuml/svm/svm_base.pyx index 1883627693..355c028c9d 100644 --- a/python/cuml/svm/svm_base.pyx +++ b/python/cuml/svm/svm_base.pyx @@ -311,6 +311,8 @@ class SVMBase(Base, return self.gamma def _calc_coef(self): + if (self.n_support_ == 0): + return cupy.zeros((1, self.n_cols), dtype=self.dtype) with using_output_type("cupy"): return cupy.dot(self.dual_coef_, self.support_vectors_) @@ -429,29 +431,28 @@ class SVMBase(Base, if self.dtype == np.float32: model_f = self._model - if model_f.n_support == 0: - self._fit_status_ = 1 # incorrect fit - return self._intercept_ = CumlArray.full(1, model_f.b, np.float32) self.n_support_ = model_f.n_support - self.dual_coef_ = CumlArray( - data=model_f.dual_coefs, - shape=(1, self.n_support_), - dtype=self.dtype, - order='F') + if model_f.n_support > 0: + self.dual_coef_ = CumlArray( + data=model_f.dual_coefs, + shape=(1, self.n_support_), + dtype=self.dtype, + order='F') - self.support_ = CumlArray( - data=model_f.support_idx, - shape=(self.n_support_,), - dtype=np.int32, - order='F') + self.support_ = CumlArray( + data=model_f.support_idx, + shape=(self.n_support_,), + dtype=np.int32, + order='F') + + self.support_vectors_ = CumlArray( + data=model_f.x_support, + shape=(self.n_support_, self.n_cols), + dtype=self.dtype, + order='F') - self.support_vectors_ = CumlArray( - data=model_f.x_support, - shape=(self.n_support_, self.n_cols), - dtype=self.dtype, - order='F') self.n_classes_ = model_f.n_classes if self.n_classes_ > 0: self._unique_labels_ = CumlArray( @@ -463,29 +464,28 @@ class SVMBase(Base, self._unique_labels_ = None else: model_d = self._model - if model_d.n_support == 0: - self._fit_status_ = 1 # incorrect fit - return self._intercept_ = CumlArray.full(1, model_d.b, np.float64) self.n_support_ = model_d.n_support - self.dual_coef_ = CumlArray( - data=model_d.dual_coefs, - shape=(1, self.n_support_), - dtype=self.dtype, - order='F') + if model_d.n_support > 0: + self.dual_coef_ = CumlArray( + data=model_d.dual_coefs, + shape=(1, self.n_support_), + dtype=self.dtype, + order='F') - self.support_ = CumlArray( - data=model_d.support_idx, - shape=(self.n_support_,), - dtype=np.int32, - order='F') + self.support_ = CumlArray( + data=model_d.support_idx, + shape=(self.n_support_,), + dtype=np.int32, + order='F') + + self.support_vectors_ = CumlArray( + data=model_d.x_support, + shape=(self.n_support_, self.n_cols), + dtype=self.dtype, + order='F') - self.support_vectors_ = CumlArray( - data=model_d.x_support, - shape=(self.n_support_, self.n_cols), - dtype=self.dtype, - order='F') self.n_classes_ = model_d.n_classes if self.n_classes_ > 0: self._unique_labels_ = CumlArray( @@ -496,6 +496,24 @@ class SVMBase(Base, else: self._unique_labels_ = None + if self.n_support_ == 0: + self.dual_coef_ = CumlArray.empty( + shape=(1, 0), + dtype=self.dtype, + order='F') + + self.support_ = CumlArray.empty( + shape=(0,), + dtype=np.int32, + order='F') + + # Setting all dims to zero due to issue + # https://github.com/rapidsai/cuml/issues/4095 + self.support_vectors_ = CumlArray.empty( + shape=(0, 0), + dtype=self.dtype, + order='F') + def predict(self, X, predict_class, convert_dtype=True) -> CumlArray: """ Predicts the y for X, where y is either the decision function value diff --git a/python/cuml/test/test_svm.py b/python/cuml/test/test_svm.py index fa18b37f87..a55b841c6d 100644 --- a/python/cuml/test/test_svm.py +++ b/python/cuml/test/test_svm.py @@ -14,6 +14,7 @@ # import pytest +import cupy as cp import numpy as np from numba import cuda @@ -681,3 +682,24 @@ def test_svm_predict_convert_dtype(train_dtype, test_dtype, classifier): clf = cu_svm.SVR() clf.fit(X_train, y_train) clf.predict(X_test.astype(test_dtype)) + + +def test_svm_no_support_vectors(): + n_rows = 10 + n_cols = 3 + X = cp.random.uniform(size=(n_rows, n_cols), dtype=cp.float64) + y = cp.ones((n_rows, 1)) + model = cuml.svm.SVR(kernel="linear", C=10) + model.fit(X, y) + pred = model.predict(X) + + assert array_equal(pred, y, 0) + + assert model.n_support_ == 0 + assert abs(model.intercept_ - 1) <= 1e-6 + assert array_equal(model.coef_, cp.zeros((1, n_cols))) + assert model.dual_coef_.shape == (1, 0) + assert model.support_.shape == (0,) + assert model.support_vectors_.shape[0] == 0 + # Check disabled due to https://github.com/rapidsai/cuml/issues/4095 + # assert model.support_vectors_.shape[1] == n_cols