Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Support for sample_weight parameter in LogisticRegression #3572

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/include/cuml/linear_model/glm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ void qnFit(const raft::handle_t &cuml_handle, float *X, float *y, int N, int D,
int C, bool fit_intercept, float l1, float l2, int max_iter,
float grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, float *w0, float *f, int *num_iters, bool X_col_major,
int loss_type);
int loss_type, float *sample_weight = nullptr);
void qnFit(const raft::handle_t &cuml_handle, double *X, double *y, int N,
int D, int C, bool fit_intercept, double l1, double l2, int max_iter,
double grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, double *w0, double *f, int *num_iters,
bool X_col_major, int loss_type);
bool X_col_major, int loss_type, double *sample_weight = nullptr);
/** @} */

/**
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/glm/glm.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -71,20 +71,20 @@ void qnFit(const raft::handle_t &cuml_handle, float *X, float *y, int N, int D,
int C, bool fit_intercept, float l1, float l2, int max_iter,
float grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, float *w0, float *f, int *num_iters, bool X_col_major,
int loss_type) {
int loss_type, float *sample_weight) {
qnFit(cuml_handle, X, y, N, D, C, fit_intercept, l1, l2, max_iter, grad_tol,
linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters,
X_col_major, loss_type, cuml_handle.get_stream());
X_col_major, loss_type, cuml_handle.get_stream(), sample_weight);
}

void qnFit(const raft::handle_t &cuml_handle, double *X, double *y, int N,
int D, int C, bool fit_intercept, double l1, double l2, int max_iter,
double grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, double *w0, double *f, int *num_iters,
bool X_col_major, int loss_type) {
bool X_col_major, int loss_type, double *sample_weight) {
qnFit(cuml_handle, X, y, N, D, C, fit_intercept, l1, l2, max_iter, grad_tol,
linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters,
X_col_major, loss_type, cuml_handle.get_stream());
X_col_major, loss_type, cuml_handle.get_stream(), sample_weight);
}

void qnDecisionFunction(const raft::handle_t &cuml_handle, float *X, int N,
Expand Down
56 changes: 42 additions & 14 deletions cpp/src/glm/qn/glm_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/linalg/add.cuh>
#include <raft/linalg/binary_op.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/map_then_reduce.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/stats/mean.cuh>
Expand Down Expand Up @@ -96,9 +97,22 @@ struct GLMBase : GLMDims {
typedef SimpleVec<T> Vec;

const raft::handle_t &handle;
T *sample_weights;
T weights_sum;

GLMBase(const raft::handle_t &handle, int D, int C, bool fit_intercept)
: GLMDims(C, D, fit_intercept), handle(handle) {}
: GLMDims(C, D, fit_intercept),
handle(handle),
sample_weights(nullptr),
weights_sum(0) {}

void add_sample_weights(T *sample_weights, int n_samples,
cudaStream_t stream) {
this->sample_weights = sample_weights;
this->weights_sum =
thrust::reduce(thrust::cuda::par.on(stream), sample_weights,
sample_weights + n_samples, (T)0, thrust::plus<T>());
}

/*
* Computes the following:
Expand All @@ -111,22 +125,36 @@ struct GLMBase : GLMDims {
cudaStream_t stream) {
// Base impl assumes simple case C = 1
Loss *loss = static_cast<Loss *>(this);
T invN = 1.0 / y.len;

auto f_l = [=] __device__(const T y, const T z) {
return loss->lz(y, z) * invN;
};

// TODO would be nice to have a kernel that fuses these two steps
// This would be easy, if mapThenSumReduce allowed outputing the result of
// map (supporting inplace)
raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data,
Z.data);

auto f_dl = [=] __device__(const T y, const T z) {
return loss->dlz(y, z);
};
raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, f_dl, stream);
if (this->sample_weights) { // Sample weights are in use
T normalization = 1.0 / this->weights_sum;
auto f_l = [=] __device__(const T y, const T z, const T weight) {
return loss->lz(y, z) * (weight * normalization);
};
raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data,
Z.data, sample_weights);

auto f_dl = [=] __device__(const T y, const T z, const T weight) {
return weight * loss->dlz(y, z);
};
raft::linalg::map(Z.data, y.len, f_dl, stream, y.data, Z.data,
sample_weights);
} else { // Sample weights are not used
T normalization = 1.0 / y.len;
auto f_l = [=] __device__(const T y, const T z) {
return loss->lz(y, z) * normalization;
};
raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data,
Z.data);

auto f_dl = [=] __device__(const T y, const T z) {
return loss->dlz(y, z);
};
raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, f_dl, stream);
}
}

inline void loss_grad(T *loss_val, Mat &G, const Mat &W,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/glm/qn/qn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C,
bool fit_intercept, T l1, T l2, int max_iter, T grad_tol,
int linesearch_max_iter, int lbfgs_memory, int verbosity, T *w0,
T *f, int *num_iters, bool X_col_major, int loss_type,
cudaStream_t stream) {
cudaStream_t stream, T *sample_weight = nullptr) {
STORAGE_ORDER ord = X_col_major ? COL_MAJOR : ROW_MAJOR;
int C_len = (loss_type == 0) ? (C - 1) : C;

Expand All @@ -75,20 +75,23 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C,
case 0: {
ASSERT(C == 2, "qn.h: logistic loss invalid C");
LogisticLoss<T> loss(handle, D, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
} break;
case 1: {
ASSERT(C == 1, "qn.h: squared loss invalid C");
SquaredLoss<T> loss(handle, D, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
} break;
case 2: {
ASSERT(C > 2, "qn.h: softmax invalid C");
Softmax<T> loss(handle, D, C, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
Expand Down
83 changes: 76 additions & 7 deletions python/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# distutils: language = c++

import numpy as np
import cupy as cp
import pprint

Expand Down Expand Up @@ -126,6 +127,15 @@ class LogisticRegression(Base,
If False, the model expects that you have centered the data.
class_weight: None
Custom class weighs are currently not supported.
class_weight: dict or 'balanced', default=None
By default all classes have a weight one. However, a dictionary
can be provided with weights associated with classes
in the form ``{class_label: weight}``. The "balanced" mode
uses the values of y to automatically adjust weights
inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``. Note that
these weights will be multiplied with sample_weight
(passed through the fit method) if sample_weight is specified.
max_iter: int (default = 1000)
Maximum number of iterations taken for the solvers to converge.
linesearch_max_iter: int (default = 50)
Expand Down Expand Up @@ -178,6 +188,8 @@ class LogisticRegression(Base,
"""

classes_ = CumlArrayDescriptor()
class_weight_ = CumlArrayDescriptor()
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
expl_spec_weights_ = CumlArrayDescriptor()

def __init__(
self,
Expand All @@ -199,9 +211,6 @@ class LogisticRegression(Base,
handle=handle, verbose=verbose, output_type=output_type
)

if class_weight:
raise ValueError("`class_weight` not supported.")

if penalty not in supported_penalties:
raise ValueError("`penalty` " + str(penalty) + "not supported.")

Expand Down Expand Up @@ -246,6 +255,21 @@ class LogisticRegression(Base,

loss = "sigmoid"

if class_weight is not None:
if class_weight == 'balanced':
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
self.class_weight_ = 'balanced'
else:
classes = list(class_weight.keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we save the classes here to check later that the user is indeed passing the classes we expect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that we need to. If not given, all classes are supposed to have weight one. The sample_weight *= class_weight[out] operation should not crash thanks to earlier code that fills ones when necessary :

n_explicit = self.class_weight_.shape[0]
if n_explicit != self._num_classes:
    class_weight = cp.ones(self._num_classes)
    class_weight[:n_explicit] = self.class_weight_
    class_weight = CumlArray(class_weight)
    self.class_weight_ = class_weight

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, but my comment applies more so that the classes that the dictionary of weights has (as keys) have to coincide with the classes of the operation self.classes_ = cp.unique(y_m), no? i.e if the user passes a class in fit that was not in the dict of class_weights (when it was passed as a dict) it would be an error, which would be consistent with what Scikit does if I’m not mistaken https://github.com/scikit-learn/scikit-learn/blob/95119c13af77c76e150b753485c662b7c52a41a2/sklearn/utils/class_weight.py#L67

weights = list(class_weight.values())
max_class = sorted(classes)[-1]
class_weight = cp.ones(max_class + 1)
class_weight[classes] = weights
self.class_weight_, _, _, _ = input_to_cuml_array(class_weight)
self.expl_spec_weights_, _, _, _ = \
input_to_cuml_array(np.array(classes))
else:
self.class_weight_ = None

self.solver_model = QN(
loss=loss,
fit_intercept=self.fit_intercept,
Expand All @@ -267,19 +291,63 @@ class LogisticRegression(Base,

@generate_docstring()
@cuml.internals.api_base_return_any(set_output_dtype=True)
def fit(self, X, y, convert_dtype=True) -> "LogisticRegression":
def fit(self, X, y, sample_weight=None,
convert_dtype=True) -> "LogisticRegression":
"""
Fit the model with X and y.

"""
# Converting y to device array here to use `unique` function
# since calling input_to_cuml_array again in QN has no cost
# Not needed to check dtype since qn class checks it already
y_m, _, _, _ = input_to_cuml_array(y)

y_m, n_rows, _, _ = input_to_cuml_array(y)
self.classes_ = cp.unique(y_m)
self._num_classes = len(self.classes_)

if self._num_classes == 2:
if self.classes_[0] != 0 or self.classes_[1] != 1:
raise ValueError("Only values of 0 and 1 are"
" supported for binary classification.")

if sample_weight is not None or self.class_weight_ is not None:
if sample_weight is None:
sample_weight = cp.ones(n_rows)
viclafargue marked this conversation as resolved.
Show resolved Hide resolved

sample_weight, n_weights, D, _ = input_to_cuml_array(sample_weight)

if n_rows != n_weights or D != 1:
raise ValueError("sample_weight.shape == {}, "
"expected ({},)!".format(sample_weight.shape,
n_rows))

def check_expl_spec_weights():
with cuml.using_output_type("numpy"):
for c in self.expl_spec_weights_:
i = np.searchsorted(self.classes_, c)
if i >= self._num_classes or self.classes_[i] != c:
msg = "Class label {} not present.".format(c)
raise ValueError(msg)

if self.class_weight_ is not None:
if self.class_weight_ == 'balanced':
class_weight = n_rows / \
(self._num_classes *
cp.bincount(y_m.to_output('cupy')))
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
class_weight = CumlArray(class_weight)
else:
check_expl_spec_weights()
n_explicit = self.class_weight_.shape[0]
if n_explicit != self._num_classes:
class_weight = cp.ones(self._num_classes)
class_weight[:n_explicit] = self.class_weight_
class_weight = CumlArray(class_weight)
self.class_weight_ = class_weight
else:
class_weight = self.class_weight_
out = y_m.to_output('cupy')
sample_weight *= class_weight[out].to_output('cupy')
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
sample_weight = CumlArray(sample_weight)

if self._num_classes > 2:
loss = "softmax"
else:
Expand All @@ -293,7 +361,8 @@ class LogisticRegression(Base,
if logger.should_log_for(logger.level_debug):
logger.debug(self.verb_prefix + "Calling QN fit " + str(loss))

self.solver_model.fit(X, y_m, convert_dtype=convert_dtype)
self.solver_model.fit(X, y_m, sample_weight=sample_weight,
convert_dtype=convert_dtype)

# coefficients and intercept are contained in the same array
if logger.should_log_for(logger.level_debug):
Expand Down
29 changes: 22 additions & 7 deletions python/cuml/solvers/qn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
float *f,
int *num_iters,
bool X_col_major,
int loss_type) except +
int loss_type,
float *sample_weight) except +

void qnFit(handle_t& cuml_handle,
double *X,
Expand All @@ -73,7 +74,8 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
double *f,
int *num_iters,
bool X_col_major,
int loss_type) except +
int loss_type,
double *sample_weight) except +

void qnDecisionFunction(handle_t& cuml_handle,
float *X,
Expand Down Expand Up @@ -291,7 +293,7 @@ class QN(Base,
return self._coef_

@generate_docstring()
def fit(self, X, y, convert_dtype=False) -> "QN":
def fit(self, X, y, sample_weight=None, convert_dtype=False) -> "QN":
"""
Fit the model with X and y.

Expand All @@ -310,14 +312,25 @@ class QN(Base,

self._num_classes = len(cp.unique(y_m))

cdef uintptr_t sample_weight_ptr = 0
if sample_weight is not None:
sample_weight, _, _, _ = \
input_to_cuml_array(sample_weight,
check_dtype=self.dtype,
check_rows=n_rows, check_cols=1,
convert_to_dtype=(self.dtype
if convert_dtype
else None))
sample_weight_ptr = sample_weight.ptr

self.loss_type = self._get_loss_int(self.loss)
if self.loss_type != 2 and self._num_classes > 2:
raise ValueError("Only softmax (multinomial) loss supports more"
"than 2 classes.")

if self.loss_type == 2 and self._num_classes <= 2:
raise ValueError("Only softmax (multinomial) loss supports more"
"than 2 classes.")
raise ValueError("Two classes or less cannot be trained"
"with softmax (multinomial).")

if self.loss_type == 0:
self._num_classes_dim = self._num_classes - 1
Expand Down Expand Up @@ -357,7 +370,8 @@ class QN(Base,
<float*> &objective32,
<int*> &num_iters,
<bool> True,
<int> self.loss_type)
<int> self.loss_type,
<float*>sample_weight_ptr)

self.objective = objective32

Expand All @@ -380,7 +394,8 @@ class QN(Base,
<double*> &objective64,
<int*> &num_iters,
<bool> True,
<int> self.loss_type)
<int> self.loss_type,
<double*>sample_weight_ptr)

self.objective = objective64

Expand Down
Loading