Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SVM model parameter handling in case n_support=0 #4097

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions cpp/src/svm/results.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,17 @@ class Results {
{
CombineCoefs(alpha, val_tmp.data());
GetDualCoefs(val_tmp.data(), dual_coefs, n_support);
*b = CalcB(alpha, f, *n_support);
if (*n_support > 0) {
*idx = GetSupportVectorIndices(val_tmp.data(), *n_support);
*x_support = CollectSupportVectors(*idx, *n_support);
*b = CalcB(alpha, f);
// Make sure that all pending GPU calculations finished before we return
CUDA_CHECK(cudaStreamSynchronize(stream));
} else {
*dual_coefs = nullptr;
*idx = nullptr;
*x_support = nullptr;
}
// Make sure that all pending GPU calculations finished before we return
CUDA_CHECK(cudaStreamSynchronize(stream));
}

/**
Expand Down Expand Up @@ -192,6 +192,7 @@ class Results {
*n_support = SelectByCoef(val_tmp, n_rows, val_tmp, select_op, val_selected.data());
*dual_coefs = (math_t*)allocator->allocate(*n_support * sizeof(math_t), stream);
raft::copy(*dual_coefs, val_selected.data(), *n_support, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
}

/**
Expand All @@ -218,15 +219,22 @@ class Results {
* @param [in] f optimality indicator vector, size [n_rows]
* @return the value of b
*/
math_t CalcB(const math_t* alpha, const math_t* f)
math_t CalcB(const math_t* alpha, const math_t* f, int n_support)
{
if (n_support == 0) {
math_t f_sum;
cub::DeviceReduce::Sum(
cub_storage.data(), cub_bytes, f, d_val_reduced.data(), n_train, stream);
raft::update_host(&f_sum, d_val_reduced.data(), 1, stream);
return -f_sum / n_train;
}
// We know that for an unbound support vector i, the decision function
// (before taking the sign) has value F(x_i) = y_i, where
// F(x_i) = \sum_j y_j \alpha_j K(x_j, x_i) + b, and j runs through all
// support vectors. The constant b can be expressed from these formulas.
// Note that F and f denote different quantities. The lower case f is the
// optimality indicator vector defined as
// f_i = y_i - \sum_j y_j \alpha_j K(x_j, x_i).
// f_i = - y_i + \sum_j y_j \alpha_j K(x_j, x_i).
// For unbound support vectors f_i = -b.

// Select f for unbound support vectors (0 < alpha < C)
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/svm/smosolver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class SmoSolver {
raft::linalg::unaryOp(
f, yr, n_rows, [epsilon] __device__(math_t y) { return epsilon - y; }, stream);

// f_i = epsilon - y_i, for i \in [n_rows..2*n_rows-1]
// f_i = -epsilon - y_i, for i \in [n_rows..2*n_rows-1]
raft::linalg::unaryOp(
f + n_rows, yr, n_rows, [epsilon] __device__(math_t y) { return -epsilon - y; }, stream);
}
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/svm/svc_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ void svcPredict(const raft::handle_t& handle,
MLCommon::device_buffer<math_t> K(
handle_impl.get_device_allocator(), stream, n_batch * model.n_support);
MLCommon::device_buffer<math_t> y(handle_impl.get_device_allocator(), stream, n_rows);
if (model.n_support == 0) {
CUDA_CHECK(cudaMemsetAsync(y.data(), 0, n_rows * sizeof(math_t), stream));
}
MLCommon::device_buffer<math_t> x_rbf(handle_impl.get_device_allocator(), stream);
MLCommon::device_buffer<int> idx(handle_impl.get_device_allocator(), stream);

Expand All @@ -137,7 +140,7 @@ void svcPredict(const raft::handle_t& handle,
// We process the input data batchwise:
// - calculate the kernel values K[x_batch, x_support]
// - calculate y(x_batch) = K[x_batch, x_support] * dual_coeffs
for (int i = 0; i < n_rows; i += n_batch) {
for (int i = 0; i < n_rows && model.n_support > 0; i += n_batch) {
if (i + n_batch >= n_rows) { n_batch = n_rows - i; }
math_t* x_ptr = nullptr;
int ld1 = 0;
Expand Down
11 changes: 10 additions & 1 deletion cpp/test/sg/svc_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,16 @@ class SvrTest : public ::testing::Test {
{1, 1, 1, 10, 2, 10, 1} // sample weights
},
smoOutput2<math_t>{
6, {}, -15.5, {3.9}, {1.0, 2.0, 3.0, 4.0, 6.0, 7.0}, {0, 1, 2, 3, 5, 6}, {}}}};
6, {}, -15.5, {3.9}, {1.0, 2.0, 3.0, 4.0, 6.0, 7.0}, {0, 1, 2, 3, 5, 6}, {}}},
{SvrInput<math_t>{
svmParameter{1, 0, 100, 10, 1e-6, CUML_LEVEL_INFO, 0.1, EPSILON_SVR},
KernelParams{LINEAR, 3, 1, 0},
7, // n_rows
1, // n_cols
{1, 2, 3, 4, 5, 6, 7}, // x
{2, 2, 2, 2, 2, 2, 2} // y
},
smoOutput2<math_t>{0, {}, 2, {}, {}, {}, {}}}};
for (auto d : data) {
auto p = d.first;
auto exp = d.second;
Expand Down
90 changes: 54 additions & 36 deletions python/cuml/svm/svm_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ class SVMBase(Base,
return self.gamma

def _calc_coef(self):
if (self.n_support_ == 0):
return cupy.zeros((1, self.n_cols), dtype=self.dtype)
with using_output_type("cupy"):
return cupy.dot(self.dual_coef_, self.support_vectors_)

Expand Down Expand Up @@ -429,29 +431,28 @@ class SVMBase(Base,

if self.dtype == np.float32:
model_f = <svmModel[float]*><uintptr_t> self._model
if model_f.n_support == 0:
self._fit_status_ = 1 # incorrect fit
return
self._intercept_ = CumlArray.full(1, model_f.b, np.float32)
self.n_support_ = model_f.n_support

self.dual_coef_ = CumlArray(
data=<uintptr_t>model_f.dual_coefs,
shape=(1, self.n_support_),
dtype=self.dtype,
order='F')
if model_f.n_support > 0:
self.dual_coef_ = CumlArray(
data=<uintptr_t>model_f.dual_coefs,
shape=(1, self.n_support_),
dtype=self.dtype,
order='F')

self.support_ = CumlArray(
data=<uintptr_t>model_f.support_idx,
shape=(self.n_support_,),
dtype=np.int32,
order='F')
self.support_ = CumlArray(
data=<uintptr_t>model_f.support_idx,
shape=(self.n_support_,),
dtype=np.int32,
order='F')

self.support_vectors_ = CumlArray(
data=<uintptr_t>model_f.x_support,
shape=(self.n_support_, self.n_cols),
dtype=self.dtype,
order='F')

self.support_vectors_ = CumlArray(
data=<uintptr_t>model_f.x_support,
shape=(self.n_support_, self.n_cols),
dtype=self.dtype,
order='F')
self.n_classes_ = model_f.n_classes
if self.n_classes_ > 0:
self._unique_labels_ = CumlArray(
Expand All @@ -463,29 +464,28 @@ class SVMBase(Base,
self._unique_labels_ = None
else:
model_d = <svmModel[double]*><uintptr_t> self._model
if model_d.n_support == 0:
self._fit_status_ = 1 # incorrect fit
return
self._intercept_ = CumlArray.full(1, model_d.b, np.float64)
self.n_support_ = model_d.n_support

self.dual_coef_ = CumlArray(
data=<uintptr_t>model_d.dual_coefs,
shape=(1, self.n_support_),
dtype=self.dtype,
order='F')
if model_d.n_support > 0:
self.dual_coef_ = CumlArray(
data=<uintptr_t>model_d.dual_coefs,
shape=(1, self.n_support_),
dtype=self.dtype,
order='F')

self.support_ = CumlArray(
data=<uintptr_t>model_d.support_idx,
shape=(self.n_support_,),
dtype=np.int32,
order='F')
self.support_ = CumlArray(
data=<uintptr_t>model_d.support_idx,
shape=(self.n_support_,),
dtype=np.int32,
order='F')

self.support_vectors_ = CumlArray(
data=<uintptr_t>model_d.x_support,
shape=(self.n_support_, self.n_cols),
dtype=self.dtype,
order='F')

self.support_vectors_ = CumlArray(
data=<uintptr_t>model_d.x_support,
shape=(self.n_support_, self.n_cols),
dtype=self.dtype,
order='F')
self.n_classes_ = model_d.n_classes
if self.n_classes_ > 0:
self._unique_labels_ = CumlArray(
Expand All @@ -496,6 +496,24 @@ class SVMBase(Base,
else:
self._unique_labels_ = None

if self.n_support_ == 0:
self.dual_coef_ = CumlArray.empty(
shape=(1, 0),
dtype=self.dtype,
order='F')

self.support_ = CumlArray.empty(
shape=(0,),
dtype=np.int32,
order='F')

# Setting all dims to zero due to issue
# https://github.com/rapidsai/cuml/issues/4095
self.support_vectors_ = CumlArray.empty(
shape=(0, 0),
dtype=self.dtype,
order='F')

def predict(self, X, predict_class, convert_dtype=True) -> CumlArray:
"""
Predicts the y for X, where y is either the decision function value
Expand Down
22 changes: 22 additions & 0 deletions python/cuml/test/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#

import pytest
import cupy as cp
import numpy as np
from numba import cuda

Expand Down Expand Up @@ -681,3 +682,24 @@ def test_svm_predict_convert_dtype(train_dtype, test_dtype, classifier):
clf = cu_svm.SVR()
clf.fit(X_train, y_train)
clf.predict(X_test.astype(test_dtype))


def test_svm_no_support_vectors():
n_rows = 10
n_cols = 3
X = cp.random.uniform(size=(n_rows, n_cols), dtype=cp.float64)
y = cp.ones((n_rows, 1))
model = cuml.svm.SVR(kernel="linear", C=10)
model.fit(X, y)
pred = model.predict(X)

assert array_equal(pred, y, 0)

assert model.n_support_ == 0
assert abs(model.intercept_ - 1) <= 1e-6
assert array_equal(model.coef_, cp.zeros((1, n_cols)))
assert model.dual_coef_.shape == (1, 0)
assert model.support_.shape == (0,)
assert model.support_vectors_.shape[0] == 0
# Check disabled due to https://github.com/rapidsai/cuml/issues/4095
# assert model.support_vectors_.shape[1] == n_cols