diff --git a/python/cuml/test/dask/test_naive_bayes.py b/python/cuml/test/dask/test_naive_bayes.py index d00b9b9d7e..62cf819ce8 100644 --- a/python/cuml/test/dask/test_naive_bayes.py +++ b/python/cuml/test/dask/test_naive_bayes.py @@ -14,13 +14,15 @@ # limitations under the License. # - -from cuml.test.dask.utils import load_text_corpus - -from sklearn.metrics import accuracy_score +import cupy as cp +import dask.array +import numpy as np from cuml.dask.naive_bayes import MultinomialNB from cuml.naive_bayes.naive_bayes import MultinomialNB as SGNB +from cuml.test.dask.utils import load_text_corpus +from cupy.sparse import csr_matrix as cp_csr_matrix +from sklearn.metrics import accuracy_score def test_basic_fit_predict(client): @@ -74,3 +76,20 @@ def test_score(client): y_local = y.compute() assert(accuracy_score(y_hat_local.get(), y_local) == score) + + +def test_model_multiple_chunks(client): + # tests naive_bayes with n_chunks being greater than one, related to issue + # https://github.com/rapidsai/cuml/issues/3150 + X = cp.array([[0, 0, 0, 1], [1, 0, 0, 1], [1, 0, 0, 0]]) + + X = dask.array.from_array(X, chunks=((1, 1, 1), -1)).astype(cp.int32) + y = dask.array.from_array([1, 0, 0], asarray=False, + fancy=False, chunks=(1)).astype(cp.int32) + + model = MultinomialNB() + model.fit(X, y) + + # this test is a code coverage test, it is too small to be a numeric test, + # but we call score here to exercise the whole model. + assert(0 <= model.score(X, y) <= 1)