Skip to content

Commit

Permalink
Fix set_params for linear models (rapidsai#4096)
Browse files Browse the repository at this point in the history
Closes rapidsai#4089

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4096
  • Loading branch information
lowener authored Jul 28, 2021
1 parent 3fc607b commit dcec9b1
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 30 deletions.
11 changes: 10 additions & 1 deletion python/cuml/linear_model/elastic_net.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ class ElasticNet(Base,
self.solver_model = CD(fit_intercept=self.fit_intercept,
normalize=self.normalize, alpha=self.alpha,
l1_ratio=self.l1_ratio, shuffle=shuffle,
max_iter=self.max_iter, handle=self.handle)
max_iter=self.max_iter, handle=self.handle,
tol=self.tol)

def _check_alpha(self, alpha):
if alpha <= 0.0:
Expand All @@ -222,6 +223,14 @@ class ElasticNet(Base,

return self

def set_params(self, **params):
super().set_params(**params)
if 'selection' in params:
params.pop('selection')
params['shuffle'] = self.selection == 'random'
self.solver_model.set_params(**params)
return self

def get_param_names(self):
return super().get_param_names() + [
"alpha",
Expand Down
11 changes: 10 additions & 1 deletion python/cuml/linear_model/lasso.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,22 @@ class Lasso(Base,
self.solver_model = CD(fit_intercept=self.fit_intercept,
normalize=self.normalize, alpha=self.alpha,
l1_ratio=1.0, shuffle=shuffle,
max_iter=self.max_iter, handle=self.handle)
max_iter=self.max_iter, handle=self.handle,
tol=self.tol)

def _check_alpha(self, alpha):
if alpha <= 0.0:
msg = "alpha value has to be positive"
raise ValueError(msg.format(alpha))

def set_params(self, **params):
super().set_params(**params)
if 'selection' in params:
params.pop('selection')
params['shuffle'] = self.selection == 'random'
self.solver_model.set_params(**params)
return self

@generate_docstring()
def fit(self, X, y, convert_dtype=True) -> "Lasso":
"""
Expand Down
83 changes: 56 additions & 27 deletions python/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -240,37 +240,12 @@ class LogisticRegression(Base,
raise ValueError(msg.format(l1_ratio))
self.l1_ratio = l1_ratio

if self.penalty == "none":
l1_strength = 0.0
l2_strength = 0.0

elif self.penalty == "l1":
l1_strength = 1.0 / self.C
l2_strength = 0.0

elif self.penalty == "l2":
l1_strength = 0.0
l2_strength = 1.0 / self.C

else:
strength = 1.0 / self.C
l1_strength = self.l1_ratio * strength
l2_strength = (1.0 - self.l1_ratio) * strength
l1_strength, l2_strength = self._get_qn_params()

loss = "sigmoid"

if class_weight is not None:
if class_weight == 'balanced':
self.class_weight_ = 'balanced'
else:
classes = list(class_weight.keys())
weights = list(class_weight.values())
max_class = sorted(classes)[-1]
class_weight = cp.ones(max_class + 1)
class_weight[classes] = weights
self.class_weight_, _, _, _ = input_to_cuml_array(class_weight)
self.expl_spec_weights_, _, _, _ = \
input_to_cuml_array(np.array(classes))
self._build_class_weights(class_weight)
else:
self.class_weight_ = None

Expand Down Expand Up @@ -472,6 +447,60 @@ class LogisticRegression(Base,

return proba

def _get_qn_params(self):
if self.penalty == "none":
l1_strength = 0.0
l2_strength = 0.0

elif self.penalty == "l1":
l1_strength = 1.0 / self.C
l2_strength = 0.0

elif self.penalty == "l2":
l1_strength = 0.0
l2_strength = 1.0 / self.C

else:
strength = 1.0 / self.C
l1_strength = self.l1_ratio * strength
l2_strength = (1.0 - self.l1_ratio) * strength
return l1_strength, l2_strength

def _build_class_weights(self, class_weight):
if class_weight == 'balanced':
self.class_weight_ = 'balanced'
else:
classes = list(class_weight.keys())
weights = list(class_weight.values())
max_class = sorted(classes)[-1]
class_weight = cp.ones(max_class + 1)
class_weight[classes] = weights
self.class_weight_, _, _, _ = input_to_cuml_array(class_weight)
self.expl_spec_weights_, _, _, _ = \
input_to_cuml_array(np.array(classes))

def set_params(self, **params):
super().set_params(**params)
rebuild_params = False
# Remove class-specific parameters
for param_name in ['penalty', 'l1_ratio', 'C']:
if param_name in params:
params.pop(param_name)
rebuild_params = True
if rebuild_params:
# re-build QN solver parameters
l1_strength, l2_strength = self._get_qn_params()
params.update({'l1_strength': l1_strength,
'l2_strength': l2_strength})
if 'class_weight' in params:
# re-build class weight
class_weight = params.pop('class_weight')
self._build_class_weights(class_weight)

# Update solver
self.solver_model.set_params(**params)
return self

def get_param_names(self):
return super().get_param_names() + [
"penalty",
Expand Down
5 changes: 5 additions & 0 deletions python/cuml/linear_model/mbsgd_classifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ class MBSGDClassifier(Base,

return preds

def set_params(self, **params):
super().set_params(**params)
self.solver_model.set_params(**params)
return self

def get_param_names(self):
return super().get_param_names() + [
"loss",
Expand Down
5 changes: 5 additions & 0 deletions python/cuml/linear_model/mbsgd_regressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ class MBSGDRegressor(Base,
convert_dtype=convert_dtype)
return preds

def set_params(self, **params):
super().set_params(**params)
self.solver_model.set_params(**params)
return self

def get_param_names(self):
return super().get_param_names() + [
"loss",
Expand Down
10 changes: 10 additions & 0 deletions python/cuml/linear_model/ridge.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,16 @@ class Ridge(Base,

return self

def set_params(self, **params):
super().set_params(**params)
if 'solver' in params:
if params['solver'] in ['svd', 'eig', 'cd']:
self.algo = self._get_algorithm_int(params['solver'])
else:
msg = "solver {!r} is not supported"
raise TypeError(msg.format(params['solver']))
return self

def get_param_names(self):
return super().get_param_names() + \
['solver', 'fit_intercept', 'normalize', 'alpha']
24 changes: 23 additions & 1 deletion python/cuml/test/test_coordinate_descent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -189,3 +189,25 @@ def test_lasso_predict_convert_dtype(train_dtype, test_dtype):
clf = cuLasso()
clf.fit(X_train, y_train)
clf.predict(X_test.astype(test_dtype))


@pytest.mark.parametrize('algo', [cuElasticNet, cuLasso])
def test_set_params(algo):
x = np.linspace(0, 1, 50)
y = 2 * x

model = algo(alpha=0.01)
model.fit(x, y)
coef_before = model.coef_

model = algo(selection="random", alpha=0.1)
model.fit(x, y)
coef_after = model.coef_

model = algo(alpha=0.01)
model.set_params(**{'selection': "random", 'alpha': 0.1})
model.fit(x, y)
coef_test = model.coef_

assert coef_before != coef_after
assert coef_after == coef_test
27 changes: 27 additions & 0 deletions python/cuml/test/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,30 @@ def test_logistic_regression_weighting(regression_dataset,
skOut = sklog.predict(data)
assert array_equal(skOut, cuOut, unit_tol=unit_tol,
total_tol=total_tol)


@pytest.mark.parametrize('algo', [cuLog, cuRidge])
def test_linear_models_set_params(algo):
x = np.linspace(0, 1, 50)
y = 2 * x

model = algo()
model.fit(x, y)
coef_before = model.coef_

if algo == cuLog:
params = {'penalty': "none", 'C': 1, 'max_iter': 30}
model = algo(penalty='none', C=1, max_iter=30)
else:
model = algo(solver='svd', alpha=0.1)
params = {'solver': "svd", 'alpha': 0.1}
model.fit(x, y)
coef_after = model.coef_

model = algo()
model.set_params(**params)
model.fit(x, y)
coef_test = model.coef_

assert not array_equal(coef_before, coef_after)
assert array_equal(coef_after, coef_test)
21 changes: 21 additions & 0 deletions python/cuml/test/test_mbsgd_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,24 @@ def test_mbsgd_classifier_default(make_dataset):
cu_acc = accuracy_score(cp.asnumpy(cu_pred), cp.asnumpy(y_test))

assert cu_acc >= 0.69


def test_mbsgd_classifier_set_params():
x = np.linspace(0, 1, 50)
y = (x > 0.5).astype(cp.int32)

model = cumlMBSGClassifier()
model.fit(x, y)
coef_before = model.coef_

model = cumlMBSGClassifier(epochs=20, loss='hinge')
model.fit(x, y)
coef_after = model.coef_

model = cumlMBSGClassifier()
model.set_params(**{'epochs': 20, 'loss': 'hinge'})
model.fit(x, y)
coef_test = model.coef_

assert coef_before != coef_after
assert coef_after == coef_test
21 changes: 21 additions & 0 deletions python/cuml/test/test_mbsgd_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,24 @@ def test_mbsgd_regressor_default(make_dataset):
convert_dtype=datatype)

assert cu_r2 > 0.9


def test_mbsgd_regressor_set_params():
x = np.linspace(0, 1, 50)
y = x * 2

model = cumlMBSGRegressor()
model.fit(x, y)
coef_before = model.coef_

model = cumlMBSGRegressor(eta0=0.1, fit_intercept=False)
model.fit(x, y)
coef_after = model.coef_

model = cumlMBSGRegressor()
model.set_params(**{'eta0': 0.1, 'fit_intercept': False})
model.fit(x, y)
coef_test = model.coef_

assert coef_before != coef_after
assert coef_after == coef_test

0 comments on commit dcec9b1

Please sign in to comment.