Skip to content

Commit

Permalink
Support for sample_weight parameter in LogisticRegression (#3572)
Browse files Browse the repository at this point in the history
Closes #3559

This PR adds the `sample_weight` and `class_weight` parameters to the `LogisticRegression` estimator.

Authors:
  - Victor Lafargue (@viclafargue)

Approvers:
  - Dante Gama Dessavre (@dantegd)

URL: #3572
  • Loading branch information
viclafargue authored Mar 23, 2021
1 parent af0863d commit 5bb564e
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 36 deletions.
4 changes: 2 additions & 2 deletions cpp/include/cuml/linear_model/glm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ void qnFit(const raft::handle_t &cuml_handle, float *X, float *y, int N, int D,
int C, bool fit_intercept, float l1, float l2, int max_iter,
float grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, float *w0, float *f, int *num_iters, bool X_col_major,
int loss_type);
int loss_type, float *sample_weight = nullptr);
void qnFit(const raft::handle_t &cuml_handle, double *X, double *y, int N,
int D, int C, bool fit_intercept, double l1, double l2, int max_iter,
double grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, double *w0, double *f, int *num_iters,
bool X_col_major, int loss_type);
bool X_col_major, int loss_type, double *sample_weight = nullptr);
/** @} */

/**
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/glm/glm.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -71,20 +71,20 @@ void qnFit(const raft::handle_t &cuml_handle, float *X, float *y, int N, int D,
int C, bool fit_intercept, float l1, float l2, int max_iter,
float grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, float *w0, float *f, int *num_iters, bool X_col_major,
int loss_type) {
int loss_type, float *sample_weight) {
qnFit(cuml_handle, X, y, N, D, C, fit_intercept, l1, l2, max_iter, grad_tol,
linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters,
X_col_major, loss_type, cuml_handle.get_stream());
X_col_major, loss_type, cuml_handle.get_stream(), sample_weight);
}

void qnFit(const raft::handle_t &cuml_handle, double *X, double *y, int N,
int D, int C, bool fit_intercept, double l1, double l2, int max_iter,
double grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, double *w0, double *f, int *num_iters,
bool X_col_major, int loss_type) {
bool X_col_major, int loss_type, double *sample_weight) {
qnFit(cuml_handle, X, y, N, D, C, fit_intercept, l1, l2, max_iter, grad_tol,
linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters,
X_col_major, loss_type, cuml_handle.get_stream());
X_col_major, loss_type, cuml_handle.get_stream(), sample_weight);
}

void qnDecisionFunction(const raft::handle_t &cuml_handle, float *X, int N,
Expand Down
56 changes: 42 additions & 14 deletions cpp/src/glm/qn/glm_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/linalg/add.cuh>
#include <raft/linalg/binary_op.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/map_then_reduce.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/stats/mean.cuh>
Expand Down Expand Up @@ -96,9 +97,22 @@ struct GLMBase : GLMDims {
typedef SimpleVec<T> Vec;

const raft::handle_t &handle;
T *sample_weights;
T weights_sum;

GLMBase(const raft::handle_t &handle, int D, int C, bool fit_intercept)
: GLMDims(C, D, fit_intercept), handle(handle) {}
: GLMDims(C, D, fit_intercept),
handle(handle),
sample_weights(nullptr),
weights_sum(0) {}

void add_sample_weights(T *sample_weights, int n_samples,
cudaStream_t stream) {
this->sample_weights = sample_weights;
this->weights_sum =
thrust::reduce(thrust::cuda::par.on(stream), sample_weights,
sample_weights + n_samples, (T)0, thrust::plus<T>());
}

/*
* Computes the following:
Expand All @@ -111,22 +125,36 @@ struct GLMBase : GLMDims {
cudaStream_t stream) {
// Base impl assumes simple case C = 1
Loss *loss = static_cast<Loss *>(this);
T invN = 1.0 / y.len;

auto f_l = [=] __device__(const T y, const T z) {
return loss->lz(y, z) * invN;
};

// TODO would be nice to have a kernel that fuses these two steps
// This would be easy, if mapThenSumReduce allowed outputing the result of
// map (supporting inplace)
raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data,
Z.data);

auto f_dl = [=] __device__(const T y, const T z) {
return loss->dlz(y, z);
};
raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, f_dl, stream);
if (this->sample_weights) { // Sample weights are in use
T normalization = 1.0 / this->weights_sum;
auto f_l = [=] __device__(const T y, const T z, const T weight) {
return loss->lz(y, z) * (weight * normalization);
};
raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data,
Z.data, sample_weights);

auto f_dl = [=] __device__(const T y, const T z, const T weight) {
return weight * loss->dlz(y, z);
};
raft::linalg::map(Z.data, y.len, f_dl, stream, y.data, Z.data,
sample_weights);
} else { // Sample weights are not used
T normalization = 1.0 / y.len;
auto f_l = [=] __device__(const T y, const T z) {
return loss->lz(y, z) * normalization;
};
raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data,
Z.data);

auto f_dl = [=] __device__(const T y, const T z) {
return loss->dlz(y, z);
};
raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, f_dl, stream);
}
}

inline void loss_grad(T *loss_val, Mat &G, const Mat &W,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/glm/qn/qn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C,
bool fit_intercept, T l1, T l2, int max_iter, T grad_tol,
int linesearch_max_iter, int lbfgs_memory, int verbosity, T *w0,
T *f, int *num_iters, bool X_col_major, int loss_type,
cudaStream_t stream) {
cudaStream_t stream, T *sample_weight = nullptr) {
STORAGE_ORDER ord = X_col_major ? COL_MAJOR : ROW_MAJOR;
int C_len = (loss_type == 0) ? (C - 1) : C;

Expand All @@ -75,20 +75,23 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C,
case 0: {
ASSERT(C == 2, "qn.h: logistic loss invalid C");
LogisticLoss<T> loss(handle, D, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
} break;
case 1: {
ASSERT(C == 1, "qn.h: squared loss invalid C");
SquaredLoss<T> loss(handle, D, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
} break;
case 2: {
ASSERT(C > 2, "qn.h: softmax invalid C");
Softmax<T> loss(handle, D, C, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
Expand Down
83 changes: 76 additions & 7 deletions python/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# distutils: language = c++

import numpy as np
import cupy as cp
import pprint

Expand Down Expand Up @@ -126,6 +127,15 @@ class LogisticRegression(Base,
If False, the model expects that you have centered the data.
class_weight: None
Custom class weighs are currently not supported.
class_weight: dict or 'balanced', default=None
By default all classes have a weight one. However, a dictionary
can be provided with weights associated with classes
in the form ``{class_label: weight}``. The "balanced" mode
uses the values of y to automatically adjust weights
inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``. Note that
these weights will be multiplied with sample_weight
(passed through the fit method) if sample_weight is specified.
max_iter: int (default = 1000)
Maximum number of iterations taken for the solvers to converge.
linesearch_max_iter: int (default = 50)
Expand Down Expand Up @@ -178,6 +188,8 @@ class LogisticRegression(Base,
"""

classes_ = CumlArrayDescriptor()
class_weight_ = CumlArrayDescriptor()
expl_spec_weights_ = CumlArrayDescriptor()

def __init__(
self,
Expand All @@ -199,9 +211,6 @@ class LogisticRegression(Base,
handle=handle, verbose=verbose, output_type=output_type
)

if class_weight:
raise ValueError("`class_weight` not supported.")

if penalty not in supported_penalties:
raise ValueError("`penalty` " + str(penalty) + "not supported.")

Expand Down Expand Up @@ -246,6 +255,21 @@ class LogisticRegression(Base,

loss = "sigmoid"

if class_weight is not None:
if class_weight == 'balanced':
self.class_weight_ = 'balanced'
else:
classes = list(class_weight.keys())
weights = list(class_weight.values())
max_class = sorted(classes)[-1]
class_weight = cp.ones(max_class + 1)
class_weight[classes] = weights
self.class_weight_, _, _, _ = input_to_cuml_array(class_weight)
self.expl_spec_weights_, _, _, _ = \
input_to_cuml_array(np.array(classes))
else:
self.class_weight_ = None

self.solver_model = QN(
loss=loss,
fit_intercept=self.fit_intercept,
Expand All @@ -267,19 +291,63 @@ class LogisticRegression(Base,

@generate_docstring()
@cuml.internals.api_base_return_any(set_output_dtype=True)
def fit(self, X, y, convert_dtype=True) -> "LogisticRegression":
def fit(self, X, y, sample_weight=None,
convert_dtype=True) -> "LogisticRegression":
"""
Fit the model with X and y.

"""
# Converting y to device array here to use `unique` function
# since calling input_to_cuml_array again in QN has no cost
# Not needed to check dtype since qn class checks it already
y_m, _, _, _ = input_to_cuml_array(y)

y_m, n_rows, _, _ = input_to_cuml_array(y)
self.classes_ = cp.unique(y_m)
self._num_classes = len(self.classes_)

if self._num_classes == 2:
if self.classes_[0] != 0 or self.classes_[1] != 1:
raise ValueError("Only values of 0 and 1 are"
" supported for binary classification.")

if sample_weight is not None or self.class_weight_ is not None:
if sample_weight is None:
sample_weight = cp.ones(n_rows)

sample_weight, n_weights, D, _ = input_to_cuml_array(sample_weight)

if n_rows != n_weights or D != 1:
raise ValueError("sample_weight.shape == {}, "
"expected ({},)!".format(sample_weight.shape,
n_rows))

def check_expl_spec_weights():
with cuml.using_output_type("numpy"):
for c in self.expl_spec_weights_:
i = np.searchsorted(self.classes_, c)
if i >= self._num_classes or self.classes_[i] != c:
msg = "Class label {} not present.".format(c)
raise ValueError(msg)

if self.class_weight_ is not None:
if self.class_weight_ == 'balanced':
class_weight = n_rows / \
(self._num_classes *
cp.bincount(y_m.to_output('cupy')))
class_weight = CumlArray(class_weight)
else:
check_expl_spec_weights()
n_explicit = self.class_weight_.shape[0]
if n_explicit != self._num_classes:
class_weight = cp.ones(self._num_classes)
class_weight[:n_explicit] = self.class_weight_
class_weight = CumlArray(class_weight)
self.class_weight_ = class_weight
else:
class_weight = self.class_weight_
out = y_m.to_output('cupy')
sample_weight *= class_weight[out].to_output('cupy')
sample_weight = CumlArray(sample_weight)

if self._num_classes > 2:
loss = "softmax"
else:
Expand All @@ -293,7 +361,8 @@ class LogisticRegression(Base,
if logger.should_log_for(logger.level_debug):
logger.debug(self.verb_prefix + "Calling QN fit " + str(loss))

self.solver_model.fit(X, y_m, convert_dtype=convert_dtype)
self.solver_model.fit(X, y_m, sample_weight=sample_weight,
convert_dtype=convert_dtype)

# coefficients and intercept are contained in the same array
if logger.should_log_for(logger.level_debug):
Expand Down
29 changes: 22 additions & 7 deletions python/cuml/solvers/qn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
float *f,
int *num_iters,
bool X_col_major,
int loss_type) except +
int loss_type,
float *sample_weight) except +

void qnFit(handle_t& cuml_handle,
double *X,
Expand All @@ -73,7 +74,8 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
double *f,
int *num_iters,
bool X_col_major,
int loss_type) except +
int loss_type,
double *sample_weight) except +

void qnDecisionFunction(handle_t& cuml_handle,
float *X,
Expand Down Expand Up @@ -291,7 +293,7 @@ class QN(Base,
return self._coef_

@generate_docstring()
def fit(self, X, y, convert_dtype=False) -> "QN":
def fit(self, X, y, sample_weight=None, convert_dtype=False) -> "QN":
"""
Fit the model with X and y.

Expand All @@ -310,14 +312,25 @@ class QN(Base,

self._num_classes = len(cp.unique(y_m))

cdef uintptr_t sample_weight_ptr = 0
if sample_weight is not None:
sample_weight, _, _, _ = \
input_to_cuml_array(sample_weight,
check_dtype=self.dtype,
check_rows=n_rows, check_cols=1,
convert_to_dtype=(self.dtype
if convert_dtype
else None))
sample_weight_ptr = sample_weight.ptr

self.loss_type = self._get_loss_int(self.loss)
if self.loss_type != 2 and self._num_classes > 2:
raise ValueError("Only softmax (multinomial) loss supports more"
"than 2 classes.")

if self.loss_type == 2 and self._num_classes <= 2:
raise ValueError("Only softmax (multinomial) loss supports more"
"than 2 classes.")
raise ValueError("Two classes or less cannot be trained"
"with softmax (multinomial).")

if self.loss_type == 0:
self._num_classes_dim = self._num_classes - 1
Expand Down Expand Up @@ -357,7 +370,8 @@ class QN(Base,
<float*> &objective32,
<int*> &num_iters,
<bool> True,
<int> self.loss_type)
<int> self.loss_type,
<float*>sample_weight_ptr)

self.objective = objective32

Expand All @@ -380,7 +394,8 @@ class QN(Base,
<double*> &objective64,
<int*> &num_iters,
<bool> True,
<int> self.loss_type)
<int> self.loss_type,
<double*>sample_weight_ptr)

self.objective = objective64

Expand Down
Loading

0 comments on commit 5bb564e

Please sign in to comment.