Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gaussian Naive Bayes #4079

Merged
merged 30 commits into from
Aug 9, 2021
Merged
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d0ecf57
Initial gaussian naive bayes variant
cjnolet Feb 19, 2020
053f8f3
Adding tests for gaussian naive bayes
cjnolet Feb 19, 2020
3d4ca96
Merge branch '013-fea-kmeans-da' into fea-ext-naive_bayes_variants
cjnolet Feb 27, 2020
7a0bf65
A couple changes. Starting to test end to end to at least verify the …
cjnolet Feb 27, 2020
ddda055
Merge branch 'branch-0.15' into fea-ext-naive_bayes_variants
cjnolet Jun 11, 2020
caea9d2
A little progress. There are still some different things to consider …
cjnolet Jun 12, 2020
7477a7b
A few updates to the test. We are producing the right results, just a…
cjnolet Jun 12, 2020
a1b95ad
Initial port of Bernoulli & Categorical naive bayes. Also adding `bin…
cjnolet Jun 24, 2020
2ca8938
Merge branch 'branch-21.08' into fea-ext-naive_bayes_variants
lowener Jun 10, 2021
29d8173
Merge branch 'branch-21.08' into fea-ext-naive_bayes_variants
lowener Jun 15, 2021
e144d28
Update naive bayes refactor
lowener Jun 16, 2021
a89f7cd
Update Gaussian Naive Bayes code
lowener Jun 29, 2021
4a71a32
Adding working version of GNB and Bernoulli and their test
lowener Jul 12, 2021
42dc80c
Update init and binarize primitive
lowener Jul 12, 2021
130d4ae
Adding working CategoricalNB
lowener Jul 13, 2021
f9480d9
Separating code to keep Multinomial and Bernoulli only
lowener Jul 13, 2021
f87cafd
Merge branch 'branch-21.08' into 21.08-multinomial-nb
lowener Jul 13, 2021
9945d93
Fix style
lowener Jul 13, 2021
8f4eee9
Update copyright
lowener Jul 14, 2021
e6cb185
Fix docstring
lowener Jul 14, 2021
3063586
Update tests to compare with sklearn and factorize naive bayes init
lowener Jul 19, 2021
63133c4
Fix style
lowener Jul 19, 2021
defbf41
Fix count_features_dense write order
lowener Jul 20, 2021
c63dab4
Fix class_prior parameter
lowener Jul 21, 2021
5a03348
Add Gaussian NB and tests
lowener Jul 21, 2021
07dcfa8
update api.rst
lowener Jul 22, 2021
e6b70af
fix style
lowener Jul 22, 2021
9ed4984
Merge branch 'branch-21.08' into 21.08-gaussian-nb
lowener Jul 26, 2021
2053118
Fix style
lowener Jul 26, 2021
750779e
Update GaussianNB doc for preferred sparse format
lowener Aug 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add Gaussian NB and tests
lowener committed Jul 21, 2021
commit 5a03348b545f3774ad492d17437fe61f2fc383f5
1 change: 1 addition & 0 deletions python/cuml/naive_bayes/__init__.py
Original file line number Diff line number Diff line change
@@ -16,3 +16,4 @@

from cuml.naive_bayes.naive_bayes import MultinomialNB
from cuml.naive_bayes.naive_bayes import BernoulliNB
from cuml.naive_bayes.naive_bayes import GaussianNB
383 changes: 383 additions & 0 deletions python/cuml/naive_bayes/naive_bayes.py
Original file line number Diff line number Diff line change
@@ -261,6 +261,389 @@ def predict_proba(self, X) -> CumlArray:
return result


class GaussianNB(_BaseNB):
"""
Gaussian Naive Bayes (GaussianNB)
Can perform online updates to model parameters via :meth:`partial_fit`.
For details on algorithm used to update feature means and variance online,
see Stanford CS tech report STAN-CS-79-773 by Chan, Golub, and LeVeque:
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
Parameters
----------
priors : array-like of shape (n_classes,)
Prior probabilities of the classes. If specified the priors are not
adjusted according to the data.
var_smoothing : float, default=1e-9
Portion of the largest variance of all features that is added to
variances for calculation stability.
output_type : {'input', 'cudf', 'cupy', 'numpy', 'numba'}, default=None
Variable to control output type of the results and attributes of
the estimator. If None, it'll inherit the output type set at the
module level, `cuml.global_settings.output_type`.
See :ref:`output-data-type-configuration` for more info.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the
CUDA stream that will be used for the model's computations, so
users can run different models concurrently in different streams
by creating handles in several streams.
If it is None, a new one is created.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Examples
--------
>>> import cupy as cp
>>> X = cp.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]],
>>> cp.float32)
>>> Y = cp.array([1, 1, 1, 2, 2, 2], cp.float32)
>>> from cuml.naive_bayes import GaussianNB
>>> clf = GaussianNB()
>>> clf.fit(X, Y)
GaussianNB()
>>> print(clf.predict(cp.array([[-0.8, -1]], cp.float32)))
[1]
>>> clf_pf = GaussianNB()
>>> clf_pf.partial_fit(X, Y, cp.unique(Y))
GaussianNB()
>>> print(clf_pf.predict(cp.array([[-0.8, -1]], cp.float32)))
[1]
"""

def __init__(self, *, priors=None, var_smoothing=1e-9,
output_type=None, handle=None, verbose=False):

super(GaussianNB, self).__init__(handle=handle,
verbose=verbose,
output_type=output_type)
self.priors = priors
self.var_smoothing = var_smoothing
self.fit_called_ = False
self.classes_ = None

def fit(self, X, y, sample_weight=None) -> "GaussianNB":
"""
Fit Gaussian Naive Bayes classifier according to X, y
Parameters
----------
X : {array-like, cupy sparse matrix} of shape (n_samples, n_features)
Training vectors, where n_samples is the number of samples and
n_features is the number of features.
y : array-like shape (n_samples) Target values.
sample_weight : array-like of shape (n_samples)
Weights applied to individial samples (1. for unweighted).
Currently sample weight is ignored.
"""
return self._partial_fit(X, y, _classes=cp.unique(y), _refit=True,
sample_weight=sample_weight)

@nvtx.annotate(message="naive_bayes.GaussianNB._partial_fit",
domain="cuml_python")
def _partial_fit(self, X, y, _classes=None, _refit=False,
sample_weight=None, convert_dtype=True) -> "GaussianNB":
if has_scipy():
from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix
else:
from cuml.common.import_utils import dummy_function_always_false \
as scipy_sparse_isspmatrix

if getattr(self, 'classes_') is None and _classes is None:
raise ValueError("classes must be passed on the first call "
"to partial_fit.")

if scipy_sparse_isspmatrix(X) or cp.sparse.isspmatrix(X):
X = _convert_x_sparse(X)
else:
X = input_to_cupy_array(X, order='K',
check_dtype=[cp.float32, cp.float64,
cp.int32]).array

expected_y_dtype = cp.int32 if X.dtype in [cp.float32,
cp.int32] else cp.int64
y = input_to_cupy_array(y,
convert_to_dtype=(expected_y_dtype
if convert_dtype
else False),
check_dtype=expected_y_dtype).array

if _classes is not None:
_classes, *_ = input_to_cuml_array(_classes, order='K',
convert_to_dtype=(
expected_y_dtype
if convert_dtype
else False))

Y, label_classes = make_monotonic(y, classes=_classes,
copy=True)
if _refit:
self.classes_ = None

def var_sparse(X, axis=0):
# Compute the variance on dense and sparse matrices
return ((X - X.mean(axis=axis)) ** 2).mean(axis=axis)

self.epsilon_ = self.var_smoothing * var_sparse(X).max()

if not self.fit_called_:
self.fit_called_ = True

# Original labels are stored on the instance
if _classes is not None:
check_labels(Y, _classes.to_output('cupy'))
self.classes_ = _classes
else:
self.classes_ = label_classes

n_features = X.shape[1]
n_classes = len(self.classes_)

self.n_classes_ = n_classes
self.n_features_ = n_features

self.theta_ = cp.zeros((n_classes, n_features))
self.sigma_ = cp.zeros((n_classes, n_features))

self.class_count_ = cp.zeros(n_classes, dtype=X.dtype)

if self.priors is not None:
if len(self.priors) != n_classes:
raise ValueError("Number of priors must match number of"
" classes.")
if not cp.isclose(self.priors.sum(), 1):
raise ValueError('The sum of the priors should be 1.')
if (self.priors < 0).any():
raise ValueError('Priors must be non-negative.')
self.class_prior, *_ = input_to_cupy_array(
self.priors,
check_dtype=[cp.float32, cp.float64])

else:
self.sigma_[:, :] -= self.epsilon_

unique_y = cp.unique(y)
unique_y_in_classes = cp.in1d(unique_y, cp.array(self.classes_))

if not cp.all(unique_y_in_classes):
raise ValueError("The target label(s) %s in y do not exist "
"in the initial classes %s" %
(unique_y[~unique_y_in_classes], self.classes_))

self.theta_, self.sigma_ = self._update_mean_variance(X, Y)

self.sigma_[:, :] += self.epsilon_

if self.priors is None:
self.class_prior = self.class_count_ / self.class_count_.sum()

return self

def partial_fit(self, X, y, classes=None,
sample_weight=None) -> "GaussianNB":
"""
Incremental fit on a batch of samples.
This method is expected to be called several times consecutively on
different chunks of a dataset so as to implement out-of-core or online
learning.
This is especially useful when the whole dataset is too big to fit in
memory at once.
This method has some performance overhead hence it is better to call
partial_fit on chunks of data that are as large as possible (as long
as fitting in the memory budget) to hide the overhead.
Parameters
----------
X : {array-like, cupy sparse matrix} of shape (n_samples, n_features)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to make a mention here that the "optimal" input for the sparse matrix is COOrdinate format, otherwise it will be copied in order to convert to that format.

Training vectors, where n_samples is the number of samples and
n_features is the number of features
y : array-like of shape (n_samples) Target values.
classes : array-like of shape (n_classes)
List of all the classes that can possibly appear in the y
vector. Must be provided at the first call to partial_fit,
can be omitted in subsequent calls.
sample_weight : array-like of shape (n_samples)
Weights applied to individual samples (1. for
unweighted). Currently sample weight is ignored.
Returns
-------
self : object
"""
return self._partial_fit(X, y, classes, _refit=False,
sample_weight=sample_weight)

def _update_mean_variance(self, X, Y, sample_weight=None):

if sample_weight is None:
sample_weight = cp.zeros(0)

labels_dtype = self.classes_.dtype

mu = self.theta_
var = self.sigma_

early_return = self.class_count_.sum() == 0
n_past = cp.expand_dims(self.class_count_, axis=1).copy()
tpb = 32
n_rows = X.shape[0]
n_cols = X.shape[1]

if X.shape[0] == 0:
return mu, var

# Make sure Y iclass_count_s cp array not CumlArray
Y = cp.asarray(Y)

new_mu = cp.zeros((self.n_classes_, self.n_features_), order="F",
dtype=X.dtype)
new_var = cp.zeros((self.n_classes_, self.n_features_), order="F",
dtype=X.dtype)
class_counts = cp.zeros(self.n_classes_, order="F", dtype=X.dtype)
if cp.sparse.isspmatrix(X):
X = X.tocoo()

count_features_coo = count_features_coo_kernel(X.dtype,
labels_dtype)

# Run once for averages
count_features_coo((math.ceil(X.nnz / tpb),), (tpb,),
(new_mu,
X.row,
X.col,
X.data,
X.nnz,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
self.n_classes_, False))

# Run again for variance
count_features_coo((math.ceil(X.nnz / tpb),), (tpb,),
(new_var,
X.row,
X.col,
X.data,
X.nnz,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
self.n_classes_,
True))
else:

count_features_dense = count_features_dense_kernel(X.dtype,
labels_dtype)

# Run once for averages
count_features_dense((math.ceil(n_rows / tpb),
math.ceil(n_cols / tpb), 1),
(tpb, tpb, 1),
(new_mu,
X,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
self.n_classes_,
False,
X.flags["C_CONTIGUOUS"]))

# Run again for variance
count_features_dense((math.ceil(n_rows / tpb),
math.ceil(n_cols / tpb), 1),
(tpb, tpb, 1),
(new_var,
X,
n_rows,
n_cols,
Y,
sample_weight,
sample_weight.shape[0] > 0,
self.n_classes_,
True,
X.flags["C_CONTIGUOUS"]))

count_classes = count_classes_kernel(X.dtype, labels_dtype)
count_classes((math.ceil(n_rows / tpb),), (tpb,),
(class_counts, n_rows, Y))

self.class_count_ += class_counts
# Avoid any division by zero
class_counts = cp.expand_dims(class_counts, axis=1)
class_counts += cp.finfo(X.dtype).eps

new_mu /= class_counts

# Construct variance from sum squares
new_var = (new_var / class_counts) - new_mu ** 2

if early_return:
return new_mu, new_var

# Compute (potentially weighted) mean and variance of new datapoints
if sample_weight.shape[0] > 0:
n_new = float(sample_weight.sum())
else:
n_new = class_counts

n_total = n_past + n_new
total_mu = (new_mu * n_new + mu * n_past) / n_total

old_ssd = var * n_past
new_ssd = n_new * new_var

ssd_sum = old_ssd + new_ssd
combined_feature_counts = n_new * n_past / n_total
mean_adj = (mu - new_mu)**2

total_ssd = (ssd_sum +
combined_feature_counts *
mean_adj)

total_var = total_ssd / n_total
return total_mu, total_var

def _joint_log_likelihood(self, X):
joint_log_likelihood = []

for i in range(len(self.classes_)):
jointi = cp.log(self.class_prior[i])

n_ij = -0.5 * cp.sum(cp.log(2. * cp.pi * self.sigma_[i, :]))

centered = (X - self.theta_[i, :]) ** 2
zvals = centered / self.sigma_[i, :]
summed = cp.sum(zvals, axis=1)

n_ij = -(0.5 * summed) + n_ij
joint_log_likelihood.append(jointi + n_ij)

return cp.array(joint_log_likelihood).T

def get_param_names(self):
return super().get_param_names() + \
[
"priors",
"var_smoothing"
]


class _BaseDiscreteNB(_BaseNB):

def __init__(self, *, class_prior=None, verbose=False,
149 changes: 149 additions & 0 deletions python/cuml/test/test_naive_bayes.py
Original file line number Diff line number Diff line change
@@ -21,11 +21,14 @@
from sklearn.metrics import accuracy_score
from cuml.naive_bayes import MultinomialNB
from cuml.naive_bayes import BernoulliNB
from cuml.naive_bayes import GaussianNB
from cuml.common.input_utils import sparse_scipy_to_cp

from numpy.testing import assert_allclose, assert_array_equal
from numpy.testing import assert_array_almost_equal, assert_raises
from sklearn.naive_bayes import MultinomialNB as skNB
from sklearn.naive_bayes import BernoulliNB as skBNB
from sklearn.naive_bayes import GaussianNB as skGNB

import math

@@ -312,3 +315,149 @@ def test_bernoulli_partial_fit(x_dtype, y_dtype, nlp_20news):
y_sk = modelsk.predict(X.get())

assert_allclose(y_hat, y_sk)


def test_gaussian_basic():
# Data is just 6 separable points in the plane
X = cp.array([[-2, -1, -1], [-1, -1, -1], [-1, -2, -1],
[1, 1, 1], [1, 2, 1], [2, 1, 1]], dtype=cp.float32)
y = cp.array([1, 1, 1, 2, 2, 2])

skclf = skGNB()
skclf.fit(X.get(), y.get())

clf = GaussianNB()
clf.fit(X, y)

assert_array_almost_equal(clf.theta_.get(), skclf.theta_, 6)
assert_array_almost_equal(clf.sigma_.get(), skclf.sigma_, 6)

y_pred = clf.predict(X)
y_pred_proba = clf.predict_proba(X)
y_pred_log_proba = clf.predict_log_proba(X)
y_pred_proba_sk = skclf.predict_proba(X.get())
y_pred_log_proba_sk = skclf.predict_log_proba(X.get())

assert_array_equal(y_pred.get(), y.get())
assert_array_almost_equal(y_pred_proba.get(), y_pred_proba_sk, 8)
assert_allclose(y_pred_log_proba.get(), y_pred_log_proba_sk,
atol=1e-2, rtol=1e-2)


@pytest.mark.parametrize("x_dtype", [cp.float32, cp.float64])
@pytest.mark.parametrize("y_dtype", [cp.int32, cp.int64,
cp.float32, cp.float64])
@pytest.mark.parametrize("is_sparse", [True, False])
def test_gaussian_fit_predict(x_dtype, y_dtype, is_sparse,
nlp_20news):
"""
Cupy Test
"""

X, y = nlp_20news
model = GaussianNB()
n_rows = 1000

X = sparse_scipy_to_cp(X, x_dtype)
X = X.tocsr()[:n_rows]

if is_sparse:
y = y.astype(y_dtype)[:n_rows]
model.fit(X, y)
else:
X = X.todense()
y = y[:n_rows].astype(y_dtype)
model.fit(np.ascontiguousarray(cp.asnumpy(X).astype(x_dtype)), y)

y_hat = model.predict(X)
y_hat = cp.asnumpy(y_hat)
y = cp.asnumpy(y)

assert accuracy_score(y, y_hat) >= 0.99


def test_gaussian_partial_fit(nlp_20news):
chunk_size = 200
n_rows = 1000
x_dtype, y_dtype = cp.float32, cp.int32

X, y = nlp_20news

X = sparse_scipy_to_cp(X, x_dtype).tocsr()[:n_rows]
y = y.astype(y_dtype)[:n_rows]

model = GaussianNB()
modelsk = skGNB()

classes = np.unique(y)

total_fit = 0

for i in range(math.ceil(X.shape[0] / chunk_size)):

upper = i*chunk_size+chunk_size
if upper > X.shape[0]:
upper = -1

if upper > 0:
x = X[i*chunk_size:upper]
y_c = y[i*chunk_size:upper]
else:
x = X[i*chunk_size:]
y_c = y[i*chunk_size:]

modelsk.partial_fit(x.get().toarray(),
y_c.get(),
classes=classes.get())
model.partial_fit(x, y_c, classes=classes)

total_fit += (upper - (i*chunk_size))

if upper == -1:
break

y_hat = model.predict(X)
y_sk = modelsk.predict(X.get().toarray())

y_hat = cp.asnumpy(y_hat)
y = cp.asnumpy(y)
assert_array_equal(y_hat, y_sk)
assert accuracy_score(y, y_hat) >= 0.924

# Test whether label mismatch between target y and classes raises an Error
assert_raises(ValueError,
GaussianNB().partial_fit, X, y, classes=cp.array([0, 1]))
# Raise because classes is required on first call of partial_fit
assert_raises(ValueError, GaussianNB().partial_fit, X, y)


@pytest.mark.parametrize("priors", [None, 'balanced', 'unbalanced'])
@pytest.mark.parametrize("var_smoothing", [1e-5, 1e-7, 1e-9])
def test_gaussian_parameters(priors, var_smoothing, nlp_20news):
x_dtype = cp.float32
y_dtype = cp.int32
nrows = 150

X, y = nlp_20news

X = sparse_scipy_to_cp(X[:nrows], x_dtype).todense()
y = y.astype(y_dtype)[:nrows]

if priors == 'balanced':
priors = cp.array([1/20] * 20)
elif priors == 'unbalanced':
priors = cp.linspace(0.01, 0.09, 20)

model = GaussianNB(priors=priors, var_smoothing=var_smoothing)
model_sk = skGNB(priors=priors.get() if priors is not None else None,
var_smoothing=var_smoothing)
model.fit(X, y)
model_sk.fit(X.get(), y.get())

y_hat = model.predict(X)
y_hat_sk = model_sk.predict(X.get())
y_hat = cp.asnumpy(y_hat)
y = cp.asnumpy(y)

assert_allclose(model.epsilon_.get(), model_sk.epsilon_, rtol=1e-4)
assert_array_equal(y_hat, y_hat_sk)