From 3b439a09a2a9639da4b877e3ee752e36e4dcd842 Mon Sep 17 00:00:00 2001 From: Micka <9810050+lowener@users.noreply.github.com> Date: Mon, 9 Aug 2021 22:04:21 +0200 Subject: [PATCH] Add Gaussian Naive Bayes (#4079) This is a continuation of PR #1763 and #4053, to add Gaussian Naive Bayes. This is supposed to be merged after #4053 Here is a comparison of cuML and SKLearn performance on Gaussian NB. This is done using a synthetic dataset generated by make_regression. The GPU used is a RTX 8000, and the CPU is i9-10920X @ 3.50GHz ![gaussian](https://user-images.githubusercontent.com/9810050/126572439-8982faa8-5ad1-4bca-91ab-76704050bf33.png) Linking issue #1666 Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/4079 --- docs/source/api.rst | 3 + python/cuml/naive_bayes/__init__.py | 1 + python/cuml/naive_bayes/naive_bayes.py | 379 +++++++++++++++++++++++++ python/cuml/test/test_naive_bayes.py | 149 ++++++++++ 4 files changed, 532 insertions(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 81750597aa..889c307b3c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -267,6 +267,9 @@ Naive Bayes .. autoclass:: cuml.naive_bayes.BernoulliNB :members: +.. autoclass:: cuml.naive_bayes.GaussianNB + :members: + Stochastic Gradient Descent --------------------------- diff --git a/python/cuml/naive_bayes/__init__.py b/python/cuml/naive_bayes/__init__.py index 67202b15bb..73d7b01a46 100644 --- a/python/cuml/naive_bayes/__init__.py +++ b/python/cuml/naive_bayes/__init__.py @@ -16,3 +16,4 @@ from cuml.naive_bayes.naive_bayes import MultinomialNB from cuml.naive_bayes.naive_bayes import BernoulliNB +from cuml.naive_bayes.naive_bayes import GaussianNB diff --git a/python/cuml/naive_bayes/naive_bayes.py b/python/cuml/naive_bayes/naive_bayes.py index c1f04a6d58..4118d41f04 100644 --- a/python/cuml/naive_bayes/naive_bayes.py +++ b/python/cuml/naive_bayes/naive_bayes.py @@ -261,6 +261,385 @@ def predict_proba(self, X) -> CumlArray: return result +class GaussianNB(_BaseNB): + """ + Gaussian Naive Bayes (GaussianNB) + Can perform online updates to model parameters via :meth:`partial_fit`. + For details on algorithm used to update feature means and variance online, + see Stanford CS tech report STAN-CS-79-773 by Chan, Golub, and LeVeque: + + http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf + + Parameters + ---------- + priors : array-like of shape (n_classes,) + Prior probabilities of the classes. If specified the priors are not + adjusted according to the data. + var_smoothing : float, default=1e-9 + Portion of the largest variance of all features that is added to + variances for calculation stability. + output_type : {'input', 'cudf', 'cupy', 'numpy', 'numba'}, default=None + Variable to control output type of the results and attributes of + the estimator. If None, it'll inherit the output type set at the + module level, `cuml.global_settings.output_type`. + See :ref:`output-data-type-configuration` for more info. + handle : cuml.Handle + Specifies the cuml.handle that holds internal CUDA state for + computations in this model. Most importantly, this specifies the + CUDA stream that will be used for the model's computations, so + users can run different models concurrently in different streams + by creating handles in several streams. + If it is None, a new one is created. + verbose : int or boolean, default=False + Sets logging level. It must be one of `cuml.common.logger.level_*`. + See :ref:`verbosity-levels` for more info. + + Examples + -------- + >>> import cupy as cp + >>> X = cp.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], + >>> cp.float32) + >>> Y = cp.array([1, 1, 1, 2, 2, 2], cp.float32) + >>> from cuml.naive_bayes import GaussianNB + >>> clf = GaussianNB() + >>> clf.fit(X, Y) + GaussianNB() + >>> print(clf.predict(cp.array([[-0.8, -1]], cp.float32))) + [1] + >>> clf_pf = GaussianNB() + >>> clf_pf.partial_fit(X, Y, cp.unique(Y)) + GaussianNB() + >>> print(clf_pf.predict(cp.array([[-0.8, -1]], cp.float32))) + [1] + """ + + def __init__(self, *, priors=None, var_smoothing=1e-9, + output_type=None, handle=None, verbose=False): + + super(GaussianNB, self).__init__(handle=handle, + verbose=verbose, + output_type=output_type) + self.priors = priors + self.var_smoothing = var_smoothing + self.fit_called_ = False + self.classes_ = None + + def fit(self, X, y, sample_weight=None) -> "GaussianNB": + """ + Fit Gaussian Naive Bayes classifier according to X, y + + Parameters + ---------- + + X : {array-like, cupy sparse matrix} of shape (n_samples, n_features) + Training vectors, where n_samples is the number of samples and + n_features is the number of features. + y : array-like shape (n_samples) Target values. + sample_weight : array-like of shape (n_samples) + Weights applied to individial samples (1. for unweighted). + Currently sample weight is ignored. + """ + return self._partial_fit(X, y, _classes=cp.unique(y), _refit=True, + sample_weight=sample_weight) + + @nvtx.annotate(message="naive_bayes.GaussianNB._partial_fit", + domain="cuml_python") + def _partial_fit(self, X, y, _classes=None, _refit=False, + sample_weight=None, convert_dtype=True) -> "GaussianNB": + if has_scipy(): + from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix + else: + from cuml.common.import_utils import dummy_function_always_false \ + as scipy_sparse_isspmatrix + + if getattr(self, 'classes_') is None and _classes is None: + raise ValueError("classes must be passed on the first call " + "to partial_fit.") + + if scipy_sparse_isspmatrix(X) or cp.sparse.isspmatrix(X): + X = _convert_x_sparse(X) + else: + X = input_to_cupy_array(X, order='K', + check_dtype=[cp.float32, cp.float64, + cp.int32]).array + + expected_y_dtype = cp.int32 if X.dtype in [cp.float32, + cp.int32] else cp.int64 + y = input_to_cupy_array(y, + convert_to_dtype=(expected_y_dtype + if convert_dtype + else False), + check_dtype=expected_y_dtype).array + + if _classes is not None: + _classes, *_ = input_to_cuml_array(_classes, order='K', + convert_to_dtype=( + expected_y_dtype + if convert_dtype + else False)) + + Y, label_classes = make_monotonic(y, classes=_classes, + copy=True) + if _refit: + self.classes_ = None + + def var_sparse(X, axis=0): + # Compute the variance on dense and sparse matrices + return ((X - X.mean(axis=axis)) ** 2).mean(axis=axis) + + self.epsilon_ = self.var_smoothing * var_sparse(X).max() + + if not self.fit_called_: + self.fit_called_ = True + + # Original labels are stored on the instance + if _classes is not None: + check_labels(Y, _classes.to_output('cupy')) + self.classes_ = _classes + else: + self.classes_ = label_classes + + n_features = X.shape[1] + n_classes = len(self.classes_) + + self.n_classes_ = n_classes + self.n_features_ = n_features + + self.theta_ = cp.zeros((n_classes, n_features)) + self.sigma_ = cp.zeros((n_classes, n_features)) + + self.class_count_ = cp.zeros(n_classes, dtype=X.dtype) + + if self.priors is not None: + if len(self.priors) != n_classes: + raise ValueError("Number of priors must match number of" + " classes.") + if not cp.isclose(self.priors.sum(), 1): + raise ValueError('The sum of the priors should be 1.') + if (self.priors < 0).any(): + raise ValueError('Priors must be non-negative.') + self.class_prior, *_ = input_to_cupy_array( + self.priors, + check_dtype=[cp.float32, cp.float64]) + + else: + self.sigma_[:, :] -= self.epsilon_ + + unique_y = cp.unique(y) + unique_y_in_classes = cp.in1d(unique_y, cp.array(self.classes_)) + + if not cp.all(unique_y_in_classes): + raise ValueError("The target label(s) %s in y do not exist " + "in the initial classes %s" % + (unique_y[~unique_y_in_classes], self.classes_)) + + self.theta_, self.sigma_ = self._update_mean_variance(X, Y) + + self.sigma_[:, :] += self.epsilon_ + + if self.priors is None: + self.class_prior = self.class_count_ / self.class_count_.sum() + + return self + + def partial_fit(self, X, y, classes=None, + sample_weight=None) -> "GaussianNB": + """ + Incremental fit on a batch of samples. + This method is expected to be called several times consecutively on + different chunks of a dataset so as to implement out-of-core or online + learning. + This is especially useful when the whole dataset is too big to fit in + memory at once. + This method has some performance overhead hence it is better to call + partial_fit on chunks of data that are as large as possible (as long + as fitting in the memory budget) to hide the overhead. + + Parameters + ---------- + + X : {array-like, cupy sparse matrix} of shape (n_samples, n_features) + Training vectors, where n_samples is the number of samples and + n_features is the number of features. A sparse matrix in COO + format is preferred, other formats will go through a conversion + to COO. + y : array-like of shape (n_samples) Target values. + classes : array-like of shape (n_classes) + List of all the classes that can possibly appear in the y + vector. Must be provided at the first call to partial_fit, + can be omitted in subsequent calls. + sample_weight : array-like of shape (n_samples) + Weights applied to individual samples (1. for + unweighted). Currently sample weight is ignored. + + Returns + ------- + self : object + """ + return self._partial_fit(X, y, classes, _refit=False, + sample_weight=sample_weight) + + def _update_mean_variance(self, X, Y, sample_weight=None): + + if sample_weight is None: + sample_weight = cp.zeros(0) + + labels_dtype = self.classes_.dtype + + mu = self.theta_ + var = self.sigma_ + + early_return = self.class_count_.sum() == 0 + n_past = cp.expand_dims(self.class_count_, axis=1).copy() + tpb = 32 + n_rows = X.shape[0] + n_cols = X.shape[1] + + if X.shape[0] == 0: + return mu, var + + # Make sure Y iclass_count_s cp array not CumlArray + Y = cp.asarray(Y) + + new_mu = cp.zeros((self.n_classes_, self.n_features_), order="F", + dtype=X.dtype) + new_var = cp.zeros((self.n_classes_, self.n_features_), order="F", + dtype=X.dtype) + class_counts = cp.zeros(self.n_classes_, order="F", dtype=X.dtype) + if cp.sparse.isspmatrix(X): + X = X.tocoo() + + count_features_coo = count_features_coo_kernel(X.dtype, + labels_dtype) + + # Run once for averages + count_features_coo((math.ceil(X.nnz / tpb),), (tpb,), + (new_mu, + X.row, + X.col, + X.data, + X.nnz, + n_rows, + n_cols, + Y, + sample_weight, + sample_weight.shape[0] > 0, + self.n_classes_, False)) + + # Run again for variance + count_features_coo((math.ceil(X.nnz / tpb),), (tpb,), + (new_var, + X.row, + X.col, + X.data, + X.nnz, + n_rows, + n_cols, + Y, + sample_weight, + sample_weight.shape[0] > 0, + self.n_classes_, + True)) + else: + + count_features_dense = count_features_dense_kernel(X.dtype, + labels_dtype) + + # Run once for averages + count_features_dense((math.ceil(n_rows / tpb), + math.ceil(n_cols / tpb), 1), + (tpb, tpb, 1), + (new_mu, + X, + n_rows, + n_cols, + Y, + sample_weight, + sample_weight.shape[0] > 0, + self.n_classes_, + False, + X.flags["C_CONTIGUOUS"])) + + # Run again for variance + count_features_dense((math.ceil(n_rows / tpb), + math.ceil(n_cols / tpb), 1), + (tpb, tpb, 1), + (new_var, + X, + n_rows, + n_cols, + Y, + sample_weight, + sample_weight.shape[0] > 0, + self.n_classes_, + True, + X.flags["C_CONTIGUOUS"])) + + count_classes = count_classes_kernel(X.dtype, labels_dtype) + count_classes((math.ceil(n_rows / tpb),), (tpb,), + (class_counts, n_rows, Y)) + + self.class_count_ += class_counts + # Avoid any division by zero + class_counts = cp.expand_dims(class_counts, axis=1) + class_counts += cp.finfo(X.dtype).eps + + new_mu /= class_counts + + # Construct variance from sum squares + new_var = (new_var / class_counts) - new_mu ** 2 + + if early_return: + return new_mu, new_var + + # Compute (potentially weighted) mean and variance of new datapoints + if sample_weight.shape[0] > 0: + n_new = float(sample_weight.sum()) + else: + n_new = class_counts + + n_total = n_past + n_new + total_mu = (new_mu * n_new + mu * n_past) / n_total + + old_ssd = var * n_past + new_ssd = n_new * new_var + + ssd_sum = old_ssd + new_ssd + combined_feature_counts = n_new * n_past / n_total + mean_adj = (mu - new_mu)**2 + + total_ssd = (ssd_sum + + combined_feature_counts * + mean_adj) + + total_var = total_ssd / n_total + return total_mu, total_var + + def _joint_log_likelihood(self, X): + joint_log_likelihood = [] + + for i in range(len(self.classes_)): + jointi = cp.log(self.class_prior[i]) + + n_ij = -0.5 * cp.sum(cp.log(2. * cp.pi * self.sigma_[i, :])) + + centered = (X - self.theta_[i, :]) ** 2 + zvals = centered / self.sigma_[i, :] + summed = cp.sum(zvals, axis=1) + + n_ij = -(0.5 * summed) + n_ij + joint_log_likelihood.append(jointi + n_ij) + + return cp.array(joint_log_likelihood).T + + def get_param_names(self): + return super().get_param_names() + \ + [ + "priors", + "var_smoothing" + ] + + class _BaseDiscreteNB(_BaseNB): def __init__(self, *, class_prior=None, verbose=False, diff --git a/python/cuml/test/test_naive_bayes.py b/python/cuml/test/test_naive_bayes.py index 2c4f18fde0..f5b00c1fc6 100644 --- a/python/cuml/test/test_naive_bayes.py +++ b/python/cuml/test/test_naive_bayes.py @@ -21,11 +21,14 @@ from sklearn.metrics import accuracy_score from cuml.naive_bayes import MultinomialNB from cuml.naive_bayes import BernoulliNB +from cuml.naive_bayes import GaussianNB from cuml.common.input_utils import sparse_scipy_to_cp from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_array_almost_equal, assert_raises from sklearn.naive_bayes import MultinomialNB as skNB from sklearn.naive_bayes import BernoulliNB as skBNB +from sklearn.naive_bayes import GaussianNB as skGNB import math @@ -312,3 +315,149 @@ def test_bernoulli_partial_fit(x_dtype, y_dtype, nlp_20news): y_sk = modelsk.predict(X.get()) assert_allclose(y_hat, y_sk) + + +def test_gaussian_basic(): + # Data is just 6 separable points in the plane + X = cp.array([[-2, -1, -1], [-1, -1, -1], [-1, -2, -1], + [1, 1, 1], [1, 2, 1], [2, 1, 1]], dtype=cp.float32) + y = cp.array([1, 1, 1, 2, 2, 2]) + + skclf = skGNB() + skclf.fit(X.get(), y.get()) + + clf = GaussianNB() + clf.fit(X, y) + + assert_array_almost_equal(clf.theta_.get(), skclf.theta_, 6) + assert_array_almost_equal(clf.sigma_.get(), skclf.sigma_, 6) + + y_pred = clf.predict(X) + y_pred_proba = clf.predict_proba(X) + y_pred_log_proba = clf.predict_log_proba(X) + y_pred_proba_sk = skclf.predict_proba(X.get()) + y_pred_log_proba_sk = skclf.predict_log_proba(X.get()) + + assert_array_equal(y_pred.get(), y.get()) + assert_array_almost_equal(y_pred_proba.get(), y_pred_proba_sk, 8) + assert_allclose(y_pred_log_proba.get(), y_pred_log_proba_sk, + atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("x_dtype", [cp.float32, cp.float64]) +@pytest.mark.parametrize("y_dtype", [cp.int32, cp.int64, + cp.float32, cp.float64]) +@pytest.mark.parametrize("is_sparse", [True, False]) +def test_gaussian_fit_predict(x_dtype, y_dtype, is_sparse, + nlp_20news): + """ + Cupy Test + """ + + X, y = nlp_20news + model = GaussianNB() + n_rows = 1000 + + X = sparse_scipy_to_cp(X, x_dtype) + X = X.tocsr()[:n_rows] + + if is_sparse: + y = y.astype(y_dtype)[:n_rows] + model.fit(X, y) + else: + X = X.todense() + y = y[:n_rows].astype(y_dtype) + model.fit(np.ascontiguousarray(cp.asnumpy(X).astype(x_dtype)), y) + + y_hat = model.predict(X) + y_hat = cp.asnumpy(y_hat) + y = cp.asnumpy(y) + + assert accuracy_score(y, y_hat) >= 0.99 + + +def test_gaussian_partial_fit(nlp_20news): + chunk_size = 200 + n_rows = 1000 + x_dtype, y_dtype = cp.float32, cp.int32 + + X, y = nlp_20news + + X = sparse_scipy_to_cp(X, x_dtype).tocsr()[:n_rows] + y = y.astype(y_dtype)[:n_rows] + + model = GaussianNB() + modelsk = skGNB() + + classes = np.unique(y) + + total_fit = 0 + + for i in range(math.ceil(X.shape[0] / chunk_size)): + + upper = i*chunk_size+chunk_size + if upper > X.shape[0]: + upper = -1 + + if upper > 0: + x = X[i*chunk_size:upper] + y_c = y[i*chunk_size:upper] + else: + x = X[i*chunk_size:] + y_c = y[i*chunk_size:] + + modelsk.partial_fit(x.get().toarray(), + y_c.get(), + classes=classes.get()) + model.partial_fit(x, y_c, classes=classes) + + total_fit += (upper - (i*chunk_size)) + + if upper == -1: + break + + y_hat = model.predict(X) + y_sk = modelsk.predict(X.get().toarray()) + + y_hat = cp.asnumpy(y_hat) + y = cp.asnumpy(y) + assert_array_equal(y_hat, y_sk) + assert accuracy_score(y, y_hat) >= 0.924 + + # Test whether label mismatch between target y and classes raises an Error + assert_raises(ValueError, + GaussianNB().partial_fit, X, y, classes=cp.array([0, 1])) + # Raise because classes is required on first call of partial_fit + assert_raises(ValueError, GaussianNB().partial_fit, X, y) + + +@pytest.mark.parametrize("priors", [None, 'balanced', 'unbalanced']) +@pytest.mark.parametrize("var_smoothing", [1e-5, 1e-7, 1e-9]) +def test_gaussian_parameters(priors, var_smoothing, nlp_20news): + x_dtype = cp.float32 + y_dtype = cp.int32 + nrows = 150 + + X, y = nlp_20news + + X = sparse_scipy_to_cp(X[:nrows], x_dtype).todense() + y = y.astype(y_dtype)[:nrows] + + if priors == 'balanced': + priors = cp.array([1/20] * 20) + elif priors == 'unbalanced': + priors = cp.linspace(0.01, 0.09, 20) + + model = GaussianNB(priors=priors, var_smoothing=var_smoothing) + model_sk = skGNB(priors=priors.get() if priors is not None else None, + var_smoothing=var_smoothing) + model.fit(X, y) + model_sk.fit(X.get(), y.get()) + + y_hat = model.predict(X) + y_hat_sk = model_sk.predict(X.get()) + y_hat = cp.asnumpy(y_hat) + y = cp.asnumpy(y) + + assert_allclose(model.epsilon_.get(), model_sk.epsilon_, rtol=1e-4) + assert_array_equal(y_hat, y_hat_sk)