From e23167ced25e4f665ed2577467b74eabaf02f723 Mon Sep 17 00:00:00 2001 From: Jinfeng Li Date: Mon, 24 Jul 2023 13:42:11 -0700 Subject: [PATCH] Add multi-node-multi-gpu Logistic Regression in C++ (#5477) This PR enables multi-node-multi-gpu Logistic Regression and it mostly reuses existing codes (i.e. GLMWithData and min_lbfgs) of single-GPU Logistic Regression. No change to any existing codes. Added Pytest code for Spark cluster and the tests run successfully with 2 GPUs on a random dataset. The coef_ and intercept_ are the same as single-GPU cuml.LogisticRegression.fit. Pytest code can be found here: https://github.com/lijinf2/spark-rapids-ml/blob/lr/python/tests/test_logistic_regression.py Authors: - Jinfeng Li (https://github.com/lijinf2) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/5477 --- cpp/CMakeLists.txt | 1 + cpp/include/cuml/linear_model/qn_mg.hpp | 56 +++++ cpp/src/glm/qn/glm_base_mg.cuh | 166 +++++++++++++++ cpp/src/glm/qn_mg.cu | 157 ++++++++++++++ python/cuml/dask/common/base.py | 9 +- python/cuml/dask/linear_model/__init__.py | 1 + .../dask/linear_model/logistic_regression.py | 75 +++++++ python/cuml/linear_model/CMakeLists.txt | 1 + python/cuml/linear_model/base_mg.pyx | 4 +- .../linear_model/logistic_regression_mg.pyx | 199 ++++++++++++++++++ .../dask/test_dask_logistic_regression.py | 69 ++++++ 11 files changed, 732 insertions(+), 6 deletions(-) create mode 100644 cpp/include/cuml/linear_model/qn_mg.hpp create mode 100644 cpp/src/glm/qn/glm_base_mg.cuh create mode 100644 cpp/src/glm/qn_mg.cu create mode 100644 python/cuml/dask/linear_model/logistic_regression.py create mode 100644 python/cuml/linear_model/logistic_regression_mg.pyx diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 354d66d59c..8ce0451602 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -524,6 +524,7 @@ if(BUILD_CUML_CPP_LIBRARY) src/glm/ols_mg.cu src/glm/preprocess_mg.cu src/glm/ridge_mg.cu + src/glm/qn_mg.cu src/kmeans/kmeans_mg.cu src/knn/knn_mg.cu src/knn/knn_classify_mg.cu diff --git a/cpp/include/cuml/linear_model/qn_mg.hpp b/cpp/include/cuml/linear_model/qn_mg.hpp new file mode 100644 index 0000000000..89a79f0677 --- /dev/null +++ b/cpp/include/cuml/linear_model/qn_mg.hpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +using namespace MLCommon; + +namespace ML { +namespace GLM { +namespace opg { + +/** + * @brief performs MNMG fit operation for the logistic regression using quasi newton methods + * @param[in] handle: the internal cuml handle object + * @param[in] input_data: vector holding all partitions for that rank + * @param[in] input_desc: PartDescriptor object for the input + * @param[in] labels: labels data + * @param[out] coef: learned coefficients + * @param[in] pams: model parameters + * @param[in] X_col_major: true if X is stored column-major + * @param[in] n_classes: number of outputs (number of classes or `1` for regression) + * @param[out] f: host pointer holding the final objective value + * @param[out] num_iters: host pointer holding the actual number of iterations taken + */ +void qnFit(raft::handle_t& handle, + std::vector*>& input_data, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + float* coef, + const qn_params& pams, + bool X_col_major, + int n_classes, + float* f, + int* num_iters); + +}; // namespace opg +}; // namespace GLM +}; // namespace ML diff --git a/cpp/src/glm/qn/glm_base_mg.cuh b/cpp/src/glm/qn/glm_base_mg.cuh new file mode 100644 index 0000000000..1304ddaf60 --- /dev/null +++ b/cpp/src/glm/qn/glm_base_mg.cuh @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace ML { +namespace GLM { +namespace opg { +template +// multi-gpu version of linearBwd +inline void linearBwdMG(const raft::handle_t& handle, + SimpleDenseMat& G, + const SimpleMat& X, + const SimpleDenseMat& dZ, + bool setZero, + const int64_t n_samples, + const int n_ranks) +{ + cudaStream_t stream = handle.get_stream(); + // Backward pass: + // - compute G <- dZ * X.T + // - for bias: Gb = mean(dZ, 1) + + const bool has_bias = X.n != G.n; + const int D = X.n; + const T beta = setZero ? T(0) : T(1); + + if (has_bias) { + SimpleVec Gbias; + SimpleDenseMat Gweights; + + col_ref(G, Gbias, D); + + col_slice(G, Gweights, 0, D); + + // TODO can this be fused somehow? + Gweights.assign_gemm(handle, 1.0 / n_samples, dZ, false, X, false, beta / n_ranks, stream); + + raft::stats::mean(Gbias.data, dZ.data, dZ.m, dZ.n, false, true, stream); + T bias_factor = 1.0 * dZ.n / n_samples; + raft::linalg::multiplyScalar(Gbias.data, Gbias.data, bias_factor, dZ.m, stream); + + } else { + CUML_LOG_DEBUG("has bias not enabled"); + G.assign_gemm(handle, 1.0 / n_samples, dZ, false, X, false, beta / n_ranks, stream); + } +} + +/** + * @brief Aggregates local gradient vectors and loss values from local training data. This + * class is the multi-node-multi-gpu version of GLMWithData. + * + * The implementation overrides existing GLMWithData::() function. The purpose is to + * aggregate local gradient vectors and loss values from distributed X, y, where X represents the + * input vectors and y represents labels. + * + * GLMWithData::() currently invokes three functions: linearFwd, getLossAndDz and linearBwd. + * linearFwd multiplies local input vectors with the coefficient vector (i.e. coef_), so does not + * require communication. getLossAndDz calculates local loss so requires allreduce to obtain a + * global loss. linearBwd calculates local gradient vector so requires allreduce to obtain a + * global gradient vector. The global loss and the global gradient vector will be used in + * min_lbfgs to update coefficient. The update runs individually on every GPU and when finished, + * all GPUs have the same value of coefficient. + */ +template +struct GLMWithDataMG : ML::GLM::detail::GLMWithData { + const raft::handle_t* handle_p; + int rank; + int64_t n_samples; + int n_ranks; + + GLMWithDataMG(raft::handle_t const& handle, + int rank, + int n_ranks, + int64_t n_samples, + GLMObjective* obj, + const SimpleMat& X, + const SimpleVec& y, + SimpleDenseMat& Z) + : ML::GLM::detail::GLMWithData(obj, X, y, Z) + { + this->handle_p = &handle; + this->rank = rank; + this->n_ranks = n_ranks; + this->n_samples = n_samples; + } + + inline T operator()(const SimpleVec& wFlat, + SimpleVec& gradFlat, + T* dev_scalar, + cudaStream_t stream) + { + SimpleDenseMat W(wFlat.data, this->C, this->dims); + SimpleDenseMat G(gradFlat.data, this->C, this->dims); + SimpleVec lossVal(dev_scalar, 1); + + // apply regularization + auto regularizer_obj = this->objective; + auto lossFunc = regularizer_obj->loss; + auto reg = regularizer_obj->reg; + G.fill(0, stream); + reg->reg_grad(dev_scalar, G, W, lossFunc->fit_intercept, stream); + float reg_host; + raft::update_host(®_host, dev_scalar, 1, stream); + // note: avoid syncing here because there's a sync before reg_host is used. + + // apply linearFwd, getLossAndDz, linearBwd + ML::GLM::detail::linearFwd( + lossFunc->handle, *(this->Z), *(this->X), W); // linear part: forward pass + + raft::comms::comms_t const& communicator = raft::resource::get_comms(*(this->handle_p)); + + lossFunc->getLossAndDZ(dev_scalar, *(this->Z), *(this->y), stream); // loss specific part + + // normalize local loss before allreduce sum + T factor = 1.0 * (*this->y).len / this->n_samples; + raft::linalg::multiplyScalar(dev_scalar, dev_scalar, factor, 1, stream); + + communicator.allreduce(dev_scalar, dev_scalar, 1, raft::comms::op_t::SUM, stream); + communicator.sync_stream(stream); + + linearBwdMG(lossFunc->handle, + G, + *(this->X), + *(this->Z), + false, + n_samples, + n_ranks); // linear part: backward pass + + communicator.allreduce(G.data, G.data, this->C * this->dims, raft::comms::op_t::SUM, stream); + communicator.sync_stream(stream); + + float loss_host; + raft::update_host(&loss_host, dev_scalar, 1, stream); + raft::resource::sync_stream(*(this->handle_p)); + loss_host += reg_host; + lossVal.fill(loss_host + reg_host, stream); + + return loss_host; + } +}; +}; // namespace opg +}; // namespace GLM +}; // namespace ML \ No newline at end of file diff --git a/cpp/src/glm/qn_mg.cu b/cpp/src/glm/qn_mg.cu new file mode 100644 index 0000000000..2a20be37ae --- /dev/null +++ b/cpp/src/glm/qn_mg.cu @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "qn/glm_logistic.cuh" +#include "qn/glm_regularizer.cuh" +#include "qn/qn_util.cuh" +#include "qn/simple_mat/dense.hpp" +#include +#include +#include +#include +#include +#include +#include +using namespace MLCommon; + +#include "qn/glm_base_mg.cuh" + +#include + +namespace ML { +namespace GLM { +namespace opg { + +template +void qnFit_impl(const raft::handle_t& handle, + const qn_params& pams, + T* X, + bool X_col_major, + T* y, + size_t N, + size_t D, + size_t C, + T* w0, + T* f, + int* num_iters, + size_t n_samples, + int rank, + int n_ranks) +{ + switch (pams.loss) { + case QN_LOSS_LOGISTIC: { + RAFT_EXPECTS( + C == 2, + "qn_mg.cu: only the LOGISTIC loss is supported currently. The number of classes must be 2"); + } break; + default: { + RAFT_EXPECTS(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss); + } + } + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto X_simple = SimpleDenseMat(X, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR); + auto y_simple = SimpleVec(y, N); + SimpleVec coef_simple(w0, D + pams.fit_intercept); + + ML::GLM::detail::LBFGSParam opt_param(pams); + + // prepare regularizer regularizer_obj + ML::GLM::detail::LogisticLoss loss_func(handle, D, pams.fit_intercept); + T l2 = pams.penalty_l2; + if (pams.penalty_normalized) { + l2 /= n_samples; // l2 /= 1/X.m + } + ML::GLM::detail::Tikhonov reg(l2); + ML::GLM::detail::RegularizedGLM, decltype(reg)> + regularizer_obj(&loss_func, ®); + + // prepare GLMWithDataMG + int n_targets = C == 2 ? 1 : C; + rmm::device_uvector tmp(n_targets * N, stream); + SimpleDenseMat Z(tmp.data(), n_targets, N); + auto obj_function = + GLMWithDataMG(handle, rank, n_ranks, n_samples, ®ularizer_obj, X_simple, y_simple, Z); + + // prepare temporary variables fx, k, workspace + float fx = -1; + int k = -1; + rmm::device_uvector tmp_workspace(lbfgs_workspace_size(opt_param, coef_simple.len), + stream); + SimpleVec workspace(tmp_workspace.data(), tmp_workspace.size()); + + // call min_lbfgs + min_lbfgs(opt_param, obj_function, coef_simple, fx, &k, workspace, stream, 5); +} + +template +void qnFit_impl(raft::handle_t& handle, + std::vector*>& input_data, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + T* coef, + const qn_params& pams, + bool X_col_major, + int n_classes, + T* f, + int* num_iters) +{ + RAFT_EXPECTS(input_data.size() == 1, + "qn_mg.cu currently does not accept more than one input matrix"); + RAFT_EXPECTS(labels.size() == input_data.size(), "labels size does not equal to input_data size"); + + auto data_X = input_data[0]; + auto data_y = labels[0]; + + size_t n_samples = 0; + for (auto p : input_desc.partsToRanks) { + n_samples += p->size; + } + + qnFit_impl(handle, + pams, + data_X->ptr, + X_col_major, + data_y->ptr, + input_desc.totalElementsOwnedBy(input_desc.rank), + input_desc.N, + n_classes, + coef, + f, + num_iters, + input_desc.M, + input_desc.rank, + input_desc.uniqueRanks().size()); +} + +void qnFit(raft::handle_t& handle, + std::vector*>& input_data, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + float* coef, + const qn_params& pams, + bool X_col_major, + int n_classes, + float* f, + int* num_iters) +{ + qnFit_impl( + handle, input_data, input_desc, labels, coef, pams, X_col_major, n_classes, f, num_iters); +} + +}; // namespace opg +}; // namespace GLM +}; // namespace ML diff --git a/python/cuml/dask/common/base.py b/python/cuml/dask/common/base.py index 82b5025d51..bbf15098d5 100644 --- a/python/cuml/dask/common/base.py +++ b/python/cuml/dask/common/base.py @@ -407,12 +407,13 @@ def _fit(self, model_func, data): ] ) + fit_func = self._func_fit lin_fit = dict( [ ( worker_data[0], self.client.submit( - _func_fit, + fit_func, lin_models[data.worker_info[worker_data[0]]["rank"]], worker_data[1], data.total_rows, @@ -434,9 +435,9 @@ def _fit(self, model_func, data): comms.destroy() return lin_models - -def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank): - return f.fit(data, n_rows, n_cols, partsToSizes, rank) + @staticmethod + def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank): + return f.fit(data, n_rows, n_cols, partsToSizes, rank) def mnmg_import(func): diff --git a/python/cuml/dask/linear_model/__init__.py b/python/cuml/dask/linear_model/__init__.py index 4f8594a665..b8f9471d73 100644 --- a/python/cuml/dask/linear_model/__init__.py +++ b/python/cuml/dask/linear_model/__init__.py @@ -22,6 +22,7 @@ from cuml.dask.linear_model.ridge import Ridge from cuml.dask.linear_model.lasso import Lasso from cuml.dask.linear_model.elastic_net import ElasticNet + from cuml.dask.linear_model.logistic_regression import LogisticRegression else: warnings.warn( "Dask not found. All Dask-based multi-GPU operation is disabled." diff --git a/python/cuml/dask/linear_model/logistic_regression.py b/python/cuml/dask/linear_model/logistic_regression.py new file mode 100644 index 0000000000..bbd66e6a58 --- /dev/null +++ b/python/cuml/dask/linear_model/logistic_regression.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cuml.dask.common.base import BaseEstimator +from cuml.dask.common.base import DelayedPredictionMixin +from cuml.dask.common.base import mnmg_import +from cuml.dask.common.base import SyncFitMixinLinearModel +from raft_dask.common.comms import get_raft_comm_state +from dask.distributed import get_worker + +from cuml.dask.common import parts_to_ranks +from cuml.dask.common.input_utils import DistributedDataHandler, concatenate +from raft_dask.common.comms import Comms +from cuml.dask.common.utils import wait_and_raise_from_futures +from cuml.internals.safe_imports import cpu_only_import +from cuml.internals.safe_imports import gpu_only_import + +cp = gpu_only_import("cupy") +np = cpu_only_import("numpy") + + +class LogisticRegression(BaseEstimator, SyncFitMixinLinearModel): + def __init__(self, *, client=None, verbose=False, **kwargs): + super().__init__(client=client, verbose=verbose, **kwargs) + + def fit(self, X, y): + """ + Fit the model with X and y. + + Parameters + ---------- + X : Dask cuDF dataframe or CuPy backed Dask Array (n_rows, n_features) + Features for regression + y : Dask cuDF dataframe or CuPy backed Dask Array (n_rows, 1) + Labels (outcome values) + """ + + models = self._fit( + model_func=LogisticRegression._create_model, data=(X, y) + ) + + self._set_internal_model(models[0]) + + return self + + def get_param_names(self): + return list(self.kwargs.keys()) + + @staticmethod + @mnmg_import + def _create_model(sessionId, datatype, **kwargs): + from cuml.linear_model.logistic_regression_mg import ( + LogisticRegressionMG, + ) + + handle = get_raft_comm_state(sessionId, get_worker())["handle"] + return LogisticRegressionMG(handle=handle) + + @staticmethod + def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank): + inp_X = concatenate([X for X, _ in data]) + inp_y = concatenate([y for _, y in data]) + return f.fit([(inp_X, inp_y)], n_rows, n_cols, partsToSizes, rank) diff --git a/python/cuml/linear_model/CMakeLists.txt b/python/cuml/linear_model/CMakeLists.txt index 49fdc68fb4..cd5bc60da0 100644 --- a/python/cuml/linear_model/CMakeLists.txt +++ b/python/cuml/linear_model/CMakeLists.txt @@ -27,6 +27,7 @@ if(NOT SINGLEGPU) list(APPEND cython_sources base_mg.pyx linear_regression_mg.pyx + logistic_regression_mg.pyx ridge_mg.pyx ) diff --git a/python/cuml/linear_model/base_mg.pyx b/python/cuml/linear_model/base_mg.pyx index 660a6f2551..c13d0d2de1 100644 --- a/python/cuml/linear_model/base_mg.pyx +++ b/python/cuml/linear_model/base_mg.pyx @@ -34,7 +34,7 @@ from cuml.decomposition.utils cimport * class MGFitMixin(object): @cuml.internals.api_base_return_any_skipall - def fit(self, input_data, n_rows, n_cols, partsToSizes, rank): + def fit(self, input_data, n_rows, n_cols, partsToSizes, rank, order='F'): """ Fit function for MNMG linear regression classes This not meant to be used as @@ -58,7 +58,7 @@ class MGFitMixin(object): check_dtype = self.dtype X_m, _, self.n_cols, _ = \ - input_to_cuml_array(input_data[i][0], check_dtype=check_dtype) + input_to_cuml_array(input_data[i][0], check_dtype=check_dtype, order=order) X_arys.append(X_m) if i == 0: diff --git a/python/cuml/linear_model/logistic_regression_mg.pyx b/python/cuml/linear_model/logistic_regression_mg.pyx new file mode 100644 index 0000000000..83eb915186 --- /dev/null +++ b/python/cuml/linear_model/logistic_regression_mg.pyx @@ -0,0 +1,199 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# distutils: language = c++ + +from cuml.internals.safe_imports import gpu_only_import +cp = gpu_only_import('cupy') +from cuml.internals.safe_imports import cpu_only_import +np = cpu_only_import('numpy') + +from libcpp cimport bool +from libc.stdint cimport uintptr_t + +from cuml.common import input_to_cuml_array +import numpy as np + +import cuml.internals +from cuml.internals.array import CumlArray +from cuml.linear_model.base_mg import MGFitMixin +from cuml.linear_model import LogisticRegression +from cuml.solvers.qn import QNParams +from cython.operator cimport dereference as deref + +from pylibraft.common.handle cimport handle_t +from cuml.common.opg_data_utils_mg cimport * + +# the cdef was copied from cuml.linear_model.qn +cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM" nogil: + + # TODO: Use single-GPU version qn_loss_type and qn_params https://github.com/rapidsai/cuml/issues/5502 + cdef enum qn_loss_type "ML::GLM::qn_loss_type": + QN_LOSS_LOGISTIC "ML::GLM::QN_LOSS_LOGISTIC" + QN_LOSS_SQUARED "ML::GLM::QN_LOSS_SQUARED" + QN_LOSS_SOFTMAX "ML::GLM::QN_LOSS_SOFTMAX" + QN_LOSS_SVC_L1 "ML::GLM::QN_LOSS_SVC_L1" + QN_LOSS_SVC_L2 "ML::GLM::QN_LOSS_SVC_L2" + QN_LOSS_SVR_L1 "ML::GLM::QN_LOSS_SVR_L1" + QN_LOSS_SVR_L2 "ML::GLM::QN_LOSS_SVR_L2" + QN_LOSS_ABS "ML::GLM::QN_LOSS_ABS" + QN_LOSS_UNKNOWN "ML::GLM::QN_LOSS_UNKNOWN" + + cdef struct qn_params: + qn_loss_type loss + double penalty_l1 + double penalty_l2 + double grad_tol + double change_tol + int max_iter + int linesearch_max_iter + int lbfgs_memory + int verbose + bool fit_intercept + bool penalty_normalized + + +cdef extern from "cuml/linear_model/qn_mg.hpp" namespace "ML::GLM::opg" nogil: + + cdef void qnFit( + handle_t& handle, + vector[floatData_t *] input_data, + PartDescriptor &input_desc, + vector[floatData_t *] labels, + float *coef, + const qn_params& pams, + bool X_col_major, + int n_classes, + float *f, + int *num_iters) except + + + +class LogisticRegressionMG(MGFitMixin, LogisticRegression): + + def __init__(self, *, handle=None): + super().__init__(handle=handle) + + @property + @cuml.internals.api_base_return_array_skipall + def coef_(self): + return self.solver_model.coef_ + + @coef_.setter + def coef_(self, value): + # convert 1-D value to 2-D (to inherit MGFitMixin which sets self.coef_ to a 1-D array of length self.n_cols) + if len(value.shape) == 1: + new_shape=(1, value.shape[0]) + cp_array = value.to_output('array').reshape(new_shape) + value, _, _, _ = input_to_cuml_array(cp_array, order='K') + if (self.fit_intercept) and (self.solver_model.intercept_ is None): + self.solver_model.intercept_ = CumlArray.zeros(shape=(1, 1), dtype = value.dtype) + + self.solver_model.coef_ = value + + def prepare_for_fit(self, n_classes): + self.qnparams = QNParams( + loss=self.loss, + penalty_l1=self.l1_strength, + penalty_l2=self.l2_strength, + grad_tol=self.tol, + change_tol=self.delta + if self.delta is not None else (self.tol * 0.01), + max_iter=self.max_iter, + linesearch_max_iter=self.linesearch_max_iter, + lbfgs_memory=self.lbfgs_memory, + verbose=self.verbose, + fit_intercept=self.fit_intercept, + penalty_normalized=self.penalty_normalized + ) + + # modified + qnpams = self.qnparams.params + + # modified qnp + solves_classification = qnpams['loss'] in { + qn_loss_type.QN_LOSS_LOGISTIC, + qn_loss_type.QN_LOSS_SOFTMAX, + qn_loss_type.QN_LOSS_SVC_L1, + qn_loss_type.QN_LOSS_SVC_L2 + } + solves_multiclass = qnpams['loss'] in { + qn_loss_type.QN_LOSS_SOFTMAX + } + + if solves_classification: + self._num_classes = n_classes + else: + self._num_classes = 1 + + if not solves_multiclass and self._num_classes > 2: + raise ValueError( + f"The selected solver ({self.loss}) does not support" + f" more than 2 classes ({self._num_classes} discovered).") + + if qnpams['loss'] == qn_loss_type.QN_LOSS_SOFTMAX \ + and self._num_classes <= 2: + raise ValueError("Two classes or less cannot be trained" + "with softmax (multinomial).") + + if solves_classification and not solves_multiclass: + self._num_classes_dim = self._num_classes - 1 + else: + self._num_classes_dim = self._num_classes + + if self.fit_intercept: + coef_size = (self.n_cols + 1, self._num_classes_dim) + else: + coef_size = (self.n_cols, self._num_classes_dim) + + if self.coef_ is None or not self.warm_start: + self.solver_model._coef_ = CumlArray.zeros( + coef_size, dtype=self.dtype, order='C') + + def fit(self, input_data, n_rows, n_cols, parts_rank_size, rank, convert_dtype=False): + + assert len(input_data) == 1, f"Currently support only one (X, y) pair in the list. Received {len(input_data)} pairs." + self.is_col_major = False + order = 'F' if self.is_col_major else 'C' + super().fit(input_data, n_rows, n_cols, parts_rank_size, rank, order=order) + + @cuml.internals.api_base_return_any_skipall + def _fit(self, X, y, coef_ptr, input_desc): + cdef handle_t* handle_ = self.handle.getHandle() + cdef float objective32 + cdef int num_iters + + # TODO: calculate _num_classes at runtime + self._num_classes = 2 + self.prepare_for_fit(self._num_classes) + cdef uintptr_t mat_coef_ptr = self.coef_.ptr + + cdef qn_params qnpams = self.qnparams.params + + if self.dtype == np.float32: + qnFit( + handle_[0], + deref(X), + deref(input_desc), + deref(y), + mat_coef_ptr, + qnpams, + self.is_col_major, + self._num_classes, + &objective32, + &num_iters) + + self.solver_model._calc_intercept() + + self.handle.sync() diff --git a/python/cuml/tests/dask/test_dask_logistic_regression.py b/python/cuml/tests/dask/test_dask_logistic_regression.py index 77a5243f78..4df34d09a0 100644 --- a/python/cuml/tests/dask/test_dask_logistic_regression.py +++ b/python/cuml/tests/dask/test_dask_logistic_regression.py @@ -141,3 +141,72 @@ def imp(): probs_cuml = cuml_model.predict_proba(gX).compute() probs_sk = sk_model.predict_proba(X)[:, 1] assert np.abs(probs_sk - probs_cuml.get()).max() <= 0.05 + + +@pytest.mark.mg +@pytest.mark.parametrize("n_parts", [2]) +@pytest.mark.parametrize("datatype", [np.float32]) +def test_lbfgs_toy(n_parts, datatype, client): + def imp(): + import cuml.comm.serialize # NOQA + + client.run(imp) + + X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], datatype) + y = np.array([1.0, 1.0, 0.0, 0.0], datatype) + + from cuml.dask.linear_model.logistic_regression import ( + LogisticRegression as cumlLBFGS_dask, + ) + + X_df, y_df = _prep_training_data(client, X, y, n_parts) + + lr = cumlLBFGS_dask() + lr.fit(X_df, y_df) + lr_coef = lr.coef_.to_numpy() + lr_intercept = lr.intercept_.to_numpy() + + assert len(lr_coef) == 1 + assert lr_coef[0] == pytest.approx([-0.71483153, 0.7148315], abs=1e-6) + assert lr_intercept == pytest.approx([-2.2614916e-08], abs=1e-6) + + +@pytest.mark.mg +@pytest.mark.parametrize("nrows", [1e5]) +@pytest.mark.parametrize("ncols", [20]) +@pytest.mark.parametrize("n_parts", [2, 23]) +@pytest.mark.parametrize("datatype", [np.float32]) +def test_lbfgs(nrows, ncols, n_parts, datatype, client): + tolerance = 0.005 + + def imp(): + import cuml.comm.serialize # NOQA + + client.run(imp) + + from cuml.dask.linear_model.logistic_regression import ( + LogisticRegression as cumlLBFGS_dask, + ) + + # set n_informative variable for calling sklearn.datasets.make_classification + n_info = 5 + nrows = int(nrows) + ncols = int(ncols) + X, y = make_classification_dataset(datatype, nrows, ncols, n_info) + + X_df, y_df = _prep_training_data(client, X, y, n_parts) + + lr = cumlLBFGS_dask() + lr.fit(X_df, y_df) + lr_coef = lr.coef_.to_numpy() + lr_intercept = lr.intercept_.to_numpy() + + sk_model = skLR() + sk_model.fit(X, y) + sk_coef = sk_model.coef_ + sk_intercept = sk_model.intercept_ + + assert len(lr_coef) == len(sk_coef) + for i in range(len(lr_coef)): + assert lr_coef[i] == pytest.approx(sk_coef[i], abs=tolerance) + assert lr_intercept == pytest.approx(sk_intercept, abs=tolerance)