From 388c1396d0ff0b18705dce7be07b49e0be3eab69 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 14 Feb 2022 13:41:49 +0100 Subject: [PATCH 1/2] Add QN solver to ElasticNet and Lasso models --- python/cuml/linear_model/elastic_net.pyx | 59 +++++++-- .../cuml/linear_model/{lasso.pyx => lasso.py} | 115 +++++------------- python/cuml/test/test_linear_model.py | 37 +++++- 3 files changed, 118 insertions(+), 93 deletions(-) rename python/cuml/linear_model/{lasso.pyx => lasso.py} (61%) diff --git a/python/cuml/linear_model/elastic_net.pyx b/python/cuml/linear_model/elastic_net.pyx index ef187c5cf7..d6bfa7c3cd 100644 --- a/python/cuml/linear_model/elastic_net.pyx +++ b/python/cuml/linear_model/elastic_net.pyx @@ -16,11 +16,15 @@ # distutils: language = c++ -from cuml.solvers import CD +from inspect import signature + +from cuml.solvers import CD, QN from cuml.common.base import Base from cuml.common.mixins import RegressorMixin from cuml.common.doc_utils import generate_docstring +from cuml.common.array import CumlArray from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common.logger import warn from cuml.common.mixins import FMajorInputTagMixin from cuml.linear_model.base import LinearPredictMixin @@ -117,6 +121,14 @@ class ElasticNet(Base, The tolerance for the optimization: if the updates are smaller than tol, the optimization code checks the dual gap for optimality and continues until it is smaller than tol. + solver : {'cd', 'qn'} (default='cd') + Choose an algorithm: + + * 'cd' - coordinate descent + * 'qn' - quasi-newton + + You may find the alternative 'qn' algorithm is faster when the number + of features is sufficiently large, but the sample size is small. selection : {'cyclic', 'random'} (default='cyclic') If set to ‘random’, a random coefficient is updated every iteration rather than looping over features sequentially by default. @@ -154,7 +166,8 @@ class ElasticNet(Base, coef_ = CumlArrayDescriptor() def __init__(self, *, alpha=1.0, l1_ratio=0.5, fit_intercept=True, - normalize=False, max_iter=1000, tol=1e-3, selection='cyclic', + normalize=False, max_iter=1000, tol=1e-3, + solver='cd', selection='cyclic', handle=None, output_type=None, verbose=False): """ Initializes the elastic-net regression class. @@ -167,6 +180,7 @@ class ElasticNet(Base, normalize: boolean. max_iter: int tol: float or double. + solver: str, 'cd' or 'qn' selection : str, ‘cyclic’, or 'random' For additional docs, see `scikitlearn's ElasticNet @@ -184,6 +198,7 @@ class ElasticNet(Base, self.alpha = alpha self.l1_ratio = l1_ratio self.fit_intercept = fit_intercept + self.solver = solver self.normalize = normalize self.max_iter = max_iter self.tol = tol @@ -200,11 +215,31 @@ class ElasticNet(Base, if self.selection == 'random': shuffle = True - self.solver_model = CD(fit_intercept=self.fit_intercept, - normalize=self.normalize, alpha=self.alpha, - l1_ratio=self.l1_ratio, shuffle=shuffle, - max_iter=self.max_iter, handle=self.handle, - tol=self.tol) + if solver == 'qn': + pams = signature(self.__init__).parameters + if (pams['selection'].default != selection): + warn("Parameter 'selection' has no effect " + "when 'qn' solver is used.") + if (pams['normalize'].default != normalize): + warn("Parameter 'normalize' has no effect " + "when 'qn' solver is used.") + + self.solver_model = QN( + fit_intercept=self.fit_intercept, + l1_strength=self.alpha * self.l1_ratio, + l2_strength=self.alpha * (1.0 - self.l1_ratio), + max_iter=self.max_iter, handle=self.handle, + loss='l2', tol=self.tol, penalty_normalized=False, + verbose=self.verbose) + elif solver == 'cd': + self.solver_model = CD( + fit_intercept=self.fit_intercept, + normalize=self.normalize, alpha=self.alpha, + l1_ratio=self.l1_ratio, shuffle=shuffle, + max_iter=self.max_iter, handle=self.handle, + tol=self.tol) + else: + raise TypeError(f"solver {solver} is not supported") def _check_alpha(self, alpha): if alpha <= 0.0: @@ -223,6 +258,15 @@ class ElasticNet(Base, """ self.solver_model.fit(X, y, convert_dtype=convert_dtype) + if isinstance(self.solver_model, QN): + self.coef_ = CumlArray( + data=self.solver_model.coef_, + index=self.solver_model.coef_._index, + dtype=self.solver_model.coef_.dtype, + order=self.solver_model.coef_.order, + shape=(self.solver_model.coef_.shape[0],) + ) + self.intercept_ = self.solver_model.intercept_.item() return self @@ -242,5 +286,6 @@ class ElasticNet(Base, "normalize", "max_iter", "tol", + "solver", "selection", ] diff --git a/python/cuml/linear_model/lasso.pyx b/python/cuml/linear_model/lasso.py similarity index 61% rename from python/cuml/linear_model/lasso.pyx rename to python/cuml/linear_model/lasso.py index ba418be80c..632395270f 100644 --- a/python/cuml/linear_model/lasso.pyx +++ b/python/cuml/linear_model/lasso.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,20 +14,10 @@ # limitations under the License. # -# distutils: language = c++ +from cuml.linear_model.elastic_net import ElasticNet -from cuml.solvers import CD -from cuml.common.base import Base -from cuml.common.mixins import RegressorMixin -from cuml.common.doc_utils import generate_docstring -from cuml.common.mixins import FMajorInputTagMixin -from cuml.linear_model.base import LinearPredictMixin - -class Lasso(Base, - LinearPredictMixin, - RegressorMixin, - FMajorInputTagMixin): +class Lasso(ElasticNet): """ Lasso extends LinearRegression by providing L1 regularization on the @@ -92,23 +82,34 @@ class Lasso(Base, alpha : float (default = 1.0) Constant that multiplies the L1 term. alpha = 0 is equivalent to an ordinary least square, solved by the - LinearRegression class. - For numerical reasons, using alpha = 0 with the Lasso class is not + LinearRegression object. + For numerical reasons, using alpha = 0 with the Lasso object is not advised. - Given this, you should use the LinearRegression class. + Given this, you should use the LinearRegression object. fit_intercept : boolean (default = True) If True, Lasso tries to correct for the global mean of y. If False, the model expects that you have centered the data. normalize : boolean (default = False) - If True, the predictors in X will be normalized by dividing by it's L2 - norm. + If True, the predictors in X will be normalized by dividing by the + column-wise standard deviation. If False, no scaling will be done. - max_iter : int + Note: this is in contrast to sklearn's deprecated `normalize` flag, + which divides by the column-wise L2 norm; but this is the same as if + using sklearn's StandardScaler. + max_iter : int (default = 1000) The maximum number of iterations tol : float (default = 1e-3) The tolerance for the optimization: if the updates are smaller than tol, the optimization code checks the dual gap for optimality and continues until it is smaller than tol. + solver : {'cd', 'qn'} (default='cd') + Choose an algorithm: + + * 'cd' - coordinate descent + * 'qn' - quasi-newton + + You may find the alternative 'qn' algorithm is faster when the number + of features is sufficiently large, but the sample size is small. selection : {'cyclic', 'random'} (default='cyclic') If set to ‘random’, a random coefficient is updated every iteration rather than looping over features sequentially by default. @@ -143,69 +144,13 @@ class Lasso(Base, `_. """ - def __init__(self, *, alpha=1.0, fit_intercept=True, normalize=False, - max_iter=1000, tol=1e-3, selection='cyclic', handle=None, - output_type=None, verbose=False): - - # Hard-code verbosity as CoordinateDescent does not have verbosity - super().__init__(handle=handle, - verbose=verbose, - output_type=output_type) - - self._check_alpha(alpha) - self.alpha = alpha - self.fit_intercept = fit_intercept - self.normalize = normalize - self.max_iter = max_iter - self.tol = tol - self.solver_model = None - if selection in ['cyclic', 'random']: - self.selection = selection - else: - msg = "selection {!r} is not supported" - raise TypeError(msg.format(selection)) - - self.intercept_value = 0.0 - - shuffle = False - if self.selection == 'random': - shuffle = True - - self.solver_model = CD(fit_intercept=self.fit_intercept, - normalize=self.normalize, alpha=self.alpha, - l1_ratio=1.0, shuffle=shuffle, - max_iter=self.max_iter, handle=self.handle, - tol=self.tol) - - def _check_alpha(self, alpha): - if alpha <= 0.0: - msg = "alpha value has to be positive" - raise ValueError(msg.format(alpha)) - - def set_params(self, **params): - super().set_params(**params) - if 'selection' in params: - params.pop('selection') - params['shuffle'] = self.selection == 'random' - self.solver_model.set_params(**params) - return self - - @generate_docstring() - def fit(self, X, y, convert_dtype=True) -> "Lasso": - """ - Fit the model with X and y. - - """ - self.solver_model.fit(X, y, convert_dtype=convert_dtype) - - return self - - def get_param_names(self): - return super().get_param_names() + [ - "alpha", - "fit_intercept", - "normalize", - "max_iter", - "tol", - "selection", - ] + def __init__(self, *, alpha=1.0, fit_intercept=True, + normalize=False, max_iter=1000, tol=1e-3, + solver='cd', selection='cyclic', + handle=None, output_type=None, verbose=False): + # Lasso is just a special case of ElasticNet + super().__init__( + l1_ratio=1.0, alpha=alpha, fit_intercept=fit_intercept, + normalize=normalize, max_iter=max_iter, tol=tol, + solver=solver, selection=selection, + handle=handle, output_type=output_type, verbose=verbose) diff --git a/python/cuml/test/test_linear_model.py b/python/cuml/test/test_linear_model.py index 8f06fb7331..4099eeb1dc 100644 --- a/python/cuml/test/test_linear_model.py +++ b/python/cuml/test/test_linear_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import pytest from distutils.version import LooseVersion import cudf +from cuml import ElasticNet as cuElasticNet from cuml import LinearRegression as cuLinearRegression from cuml import LogisticRegression as cuLog from cuml import Ridge as cuRidge @@ -671,3 +672,37 @@ def test_linear_models_set_params(algo): assert not array_equal(coef_before, coef_after) assert array_equal(coef_after, coef_test) + + +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) +@pytest.mark.parametrize("alpha", [0.1, 1.0, 10.0]) +@pytest.mark.parametrize("l1_ratio", [0.1, 0.5, 0.9]) +@pytest.mark.parametrize( + "nrows", [unit_param(1000), quality_param(5000), stress_param(500000)] +) +@pytest.mark.parametrize( + "column_info", + [ + unit_param([20, 10]), + quality_param([100, 50]), + stress_param([1000, 500]) + ], +) +def test_elasticnet_solvers_eq(datatype, alpha, l1_ratio, nrows, column_info): + + ncols, n_info = column_info + X_train, X_test, y_train, y_test = make_regression_dataset( + datatype, nrows, ncols, n_info + ) + + kwargs = {'alpha': alpha, 'l1_ratio': l1_ratio} + cd = cuElasticNet(solver='cd', **kwargs) + cd.fit(X_train, y_train) + cd_res = cd.predict(X_test) + + qn = cuElasticNet(solver='qn', **kwargs) + qn.fit(X_train, y_train) + # the results of the two models should be close (even if both are bad) + assert qn.score(X_test, cd_res) > 0.95 + # coefficients of the two models should be close + assert np.corrcoef(cd.coef_, qn.coef_)[0, 1] > 0.98 From d4c40d57ef7edcd1c6e20d09433572941169e0ba Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 14 Feb 2022 18:23:37 +0100 Subject: [PATCH 2/2] Fix get_param_names --- python/cuml/linear_model/lasso.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/cuml/linear_model/lasso.py b/python/cuml/linear_model/lasso.py index 632395270f..55048aeaf8 100644 --- a/python/cuml/linear_model/lasso.py +++ b/python/cuml/linear_model/lasso.py @@ -154,3 +154,6 @@ def __init__(self, *, alpha=1.0, fit_intercept=True, normalize=normalize, max_iter=max_iter, tol=tol, solver=solver, selection=selection, handle=handle, output_type=output_type, verbose=verbose) + + def get_param_names(self): + return list(set(super().get_param_names()) - {'l1_ratio'})