From 2a5185dc567bff6607e892361ae261e23aa7059d Mon Sep 17 00:00:00 2001
From: Dante Gama Dessavre <danteg@nvidia.com>
Date: Thu, 23 May 2024 10:14:32 -0500
Subject: [PATCH 1/4] Reduce and rename cudf.pandas integrations jobs (#5890)

cc @vyasr

Authors:
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Ray Douglass (https://github.com/raydouglass)

URL: https://github.com/rapidsai/cuml/pull/5890
---
 .github/workflows/pr.yaml | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml
index 933db0304d..018387ec92 100644
--- a/.github/workflows/pr.yaml
+++ b/.github/workflows/pr.yaml
@@ -33,7 +33,7 @@ jobs:
     with:
       enable_check_generated_files: false
       ignored_pr_jobs: >-
-        conda-python-tests-cudf-pandas-integration
+        optional-job-conda-python-tests-cudf-pandas-integration
   clang-tidy:
     needs: checks
     secrets: inherit
@@ -77,11 +77,12 @@ jobs:
     with:
       build_type: pull-request
       script: "ci/test_python_singlegpu.sh"
-  conda-python-tests-cudf-pandas-integration:
+  optional-job-conda-python-tests-cudf-pandas-integration:
     needs: conda-python-build
     secrets: inherit
     uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.06
     with:
+      matrix_filter: map(select(.ARCH == "amd64"))
       build_type: pull-request
       script: "ci/test_python_integration.sh"
   conda-python-tests-dask:

From 47416d7f417382e17b3b4c45d098fbeaff640bc9 Mon Sep 17 00:00:00 2001
From: Jinsol Park <soleeep99@gmail.com>
Date: Thu, 23 May 2024 08:25:27 -0700
Subject: [PATCH 2/4] Fix RandomForestClassifier return type (#5896)

Closes #5637

```
import cuml
from cuml.datasets import make_classification

X, y = make_classification()

clf = cuml.ensemble.RandomForestClassifier().fit(X,y)
print(clf.predict(X[:5]).dtype)
```

Result is

```
int64
```

Authors:
  - Jinsol Park (https://github.com/jinsolp)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: https://github.com/rapidsai/cuml/pull/5896
---
 python/cuml/ensemble/randomforestclassifier.pyx | 3 ++-
 python/cuml/tests/test_random_forest.py         | 8 ++++++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx
index ba16335dad..23a1bae940 100644
--- a/python/cuml/ensemble/randomforestclassifier.pyx
+++ b/python/cuml/ensemble/randomforestclassifier.pyx
@@ -1,6 +1,6 @@
 
 #
-# Copyright (c) 2019-2023, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -550,6 +550,7 @@ class RandomForestClassifier(BaseRandomForestModel,
         domain="cuml_python")
     @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
                            return_values=[('dense', '(n_samples, 1)')])
+    @cuml.internals.api_base_return_array(get_output_dtype=True)
     def predict(self, X, predict_model="GPU", threshold=0.5,
                 algo='auto', convert_dtype=True,
                 fil_sparse_format='auto') -> CumlArray:
diff --git a/python/cuml/tests/test_random_forest.py b/python/cuml/tests/test_random_forest.py
index d7f6ff6705..b18d6ec8ab 100644
--- a/python/cuml/tests/test_random_forest.py
+++ b/python/cuml/tests/test_random_forest.py
@@ -1382,3 +1382,11 @@ def test_rf_min_samples_split_with_small_float(estimator, make_data):
 
     # Does not error
     clf.fit(X, y)
+
+
+def test_rf_predict_returns_int():
+
+    X, y = make_classification()
+    clf = cuml.ensemble.RandomForestClassifier().fit(X, y)
+    pred = clf.predict(X)
+    assert pred.dtype == np.int64

From 326b049ea9025785cca61fe967c51621b9a652dc Mon Sep 17 00:00:00 2001
From: Tim Head <betatim@gmail.com>
Date: Wed, 29 May 2024 17:48:59 +0200
Subject: [PATCH 3/4] Update scikit-learn to 1.4 (#5851)

This is an attempt to update the scikit-learn dependency from 1.2 to 1.4. Most changes are related to constructor arguments that were deprecated in 1.2 and in 1.4 have changed/been removed.

A question I have is what cuml's deprecation policy is? I've gone with "two releases" for parameters where we can easily do so (deprecated in 24.06 and then remove them in 24.10). However that is only about 4 months of deprecation which could be a bit short.

Some of the changes would be hard to do as a deprecation (with 1.4 there is no way to provide the "old way"), we'd have to stick with 1.3 for now. I think this is a bit of a bummer but maybe the price to pay for not keeping on top of deprecations. And it seems like there is no deprecation policy in the docs/towards users? So maybe we can play this card once now, to catch up and at the same time introduce a deprecation policy.

The SHAP test needed its reference updating. I am not sure why, at least I couldn't quickly find a reason for why you'd have to do this.

I am not sure how possible it would be to support a range of scikit-learn versions (say 1.2 - 1.4). Would be cool but maybe not worth the added complexity?

Todo:
* [x] add deprecation warning in AgglomerativeClustering
* [ ] add tests for deprecations
	* [x] RF regressor
	* [x] RF classifier
	* [ ] ~~LARS~~ - LARS is experimental, so no need for deprecation
	* [x] LogisticRegression
	* [x] OneHotEncoder
	* [x] AgglomerativeClustering
* [ ] think about how to combine this with #5799
* [x] decide deprecation cycle length - copy cudf, so 24.06 -> 24.08
* [x] update "expiry" version in the warnings
* [x] update doc strings

xref #5799

Authors:
  - Tim Head (https://github.com/betatim)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Jake Awe (https://github.com/AyodeAwe)

URL: https://github.com/rapidsai/cuml/pull/5851
---
 .../all_cuda-118_arch-x86_64.yaml             |  2 +-
 .../all_cuda-122_arch-x86_64.yaml             |  2 +-
 dependencies.yaml                             |  2 +-
 .../sklearn/preprocessing/_data.py            | 31 +++++++---
 .../sklearn/preprocessing/_discretization.py  | 16 ++++-
 python/cuml/cluster/agglomerative.pyx         | 53 +++++++++++++---
 python/cuml/ensemble/randomforest_common.pyx  | 18 +++++-
 .../cuml/ensemble/randomforestclassifier.pyx  |  7 ++-
 .../cuml/ensemble/randomforestregressor.pyx   | 10 ++-
 .../cuml/experimental/linear_model/lars.pyx   |  8 ++-
 .../cuml/linear_model/logistic_regression.pyx | 19 ++++--
 python/cuml/preprocessing/encoders.py         | 40 ++++++++++--
 .../dask/test_dask_logistic_regression.py     |  4 +-
 .../tests/dask/test_dask_one_hot_encoder.py   | 28 +++++----
 .../tests/dask/test_dask_random_forest.py     | 10 +--
 .../explainer/test_explainer_kernel_shap.py   | 34 +++++------
 python/cuml/tests/test_agglomerative.py       | 34 ++++++++---
 python/cuml/tests/test_device_selection.py    |  2 +-
 python/cuml/tests/test_kmeans.py              |  8 ++-
 python/cuml/tests/test_lars.py                |  9 ++-
 python/cuml/tests/test_linear_model.py        | 23 ++++---
 python/cuml/tests/test_metrics.py             | 11 ++--
 python/cuml/tests/test_one_hot_encoder.py     | 61 ++++++++++++-------
 python/cuml/tests/test_random_forest.py       | 28 +++++++--
 python/cuml/tests/test_thirdparty.py          |  6 +-
 python/pyproject.toml                         |  2 +-
 26 files changed, 333 insertions(+), 135 deletions(-)

diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml
index cb442367bc..c292f5598b 100644
--- a/conda/environments/all_cuda-118_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-118_arch-x86_64.yaml
@@ -63,7 +63,7 @@ dependencies:
 - recommonmark
 - rmm==24.6.*
 - scikit-build-core>=0.7.0
-- scikit-learn==1.2
+- scikit-learn==1.5
 - scipy>=1.8.0
 - seaborn
 - sphinx-copybutton
diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml
index 29492b1f20..43bf3069b3 100644
--- a/conda/environments/all_cuda-122_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-122_arch-x86_64.yaml
@@ -59,7 +59,7 @@ dependencies:
 - recommonmark
 - rmm==24.6.*
 - scikit-build-core>=0.7.0
-- scikit-learn==1.2
+- scikit-learn==1.5
 - scipy>=1.8.0
 - seaborn
 - sphinx-copybutton
diff --git a/dependencies.yaml b/dependencies.yaml
index 5d8ac0a94e..95514dc299 100644
--- a/dependencies.yaml
+++ b/dependencies.yaml
@@ -356,7 +356,7 @@ dependencies:
           # https://github.com/pydata/pydata-sphinx-theme/issues/1539
           - pydata-sphinx-theme!=0.14.2
           - recommonmark
-          - &scikit_learn scikit-learn==1.2
+          - &scikit_learn scikit-learn==1.5
           - sphinx<6
           - sphinx-copybutton
           - sphinx-markdown-tables
diff --git a/python/cuml/_thirdparty/sklearn/preprocessing/_data.py b/python/cuml/_thirdparty/sklearn/preprocessing/_data.py
index f1e9eac615..04164604c9 100644
--- a/python/cuml/_thirdparty/sklearn/preprocessing/_data.py
+++ b/python/cuml/_thirdparty/sklearn/preprocessing/_data.py
@@ -14,6 +14,19 @@
 # This code is under BSD 3 clause license.
 # Authors mentioned above do not endorse or promote this production.
 
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from ....internals.memory_utils import using_output_type
 from ....internals import _deprecate_pos_args
@@ -32,6 +45,7 @@
 from ..utils.extmath import _incremental_mean_and_var
 from ..utils.extmath import row_norms
 from ....thirdparty_adapters import check_array
+from sklearn.utils._indexing import resample
 from cuml.internals.mixins import AllowNaNTagMixin, SparseInputTagMixin, \
     StatelessTagMixin
 from ..utils.skl_dependencies import BaseEstimator, TransformerMixin
@@ -2284,17 +2298,14 @@ def _dense_fit(self, X, random_state):
         n_samples, n_features = X.shape
         references = np.asnumpy(self.references_ * 100)
 
-        self.quantiles_ = []
-        for col in X.T:
-            if self.subsample < n_samples:
-                subsample_idx = random_state.choice(n_samples,
-                                                    size=self.subsample,
-                                                    replace=False)
-                col = col.take(subsample_idx)
-            self.quantiles_.append(
-                cpu_np.nanpercentile(np.asnumpy(col), references)
+        X = np.asnumpy(X)
+        if self.subsample is not None and self.subsample < n_samples:
+            # Take a subsample of `X`
+            X = resample(
+                X, replace=False, n_samples=self.subsample, random_state=random_state
             )
-        self.quantiles_ = cpu_np.transpose(self.quantiles_)
+
+        self.quantiles_ = cpu_np.nanpercentile(X, references, axis=0)
         # Due to floating-point precision error in `np.nanpercentile`,
         # make sure that quantiles are monotonically increasing.
         # Upstream issue in numpy:
diff --git a/python/cuml/_thirdparty/sklearn/preprocessing/_discretization.py b/python/cuml/_thirdparty/sklearn/preprocessing/_discretization.py
index ed85c5262f..02762f6585 100644
--- a/python/cuml/_thirdparty/sklearn/preprocessing/_discretization.py
+++ b/python/cuml/_thirdparty/sklearn/preprocessing/_discretization.py
@@ -10,6 +10,20 @@
 # This code is under BSD 3 clause license.
 # Authors mentioned above do not endorse or promote this production.
 
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
 
 from ....internals import _deprecate_pos_args
 from ....internals.memory_utils import using_output_type
@@ -240,7 +254,7 @@ def fit(self, X, y=None) -> "KBinsDiscretizer":
         if 'onehot' in self.encode:
             self._encoder = OneHotEncoder(
                 categories=np.array([np.arange(i) for i in self.n_bins_]),
-                sparse=self.encode == 'onehot', output_type='cupy')
+                sparse_output=self.encode == 'onehot', output_type='cupy')
             # Fit the OneHotEncoder with toy datasets
             # so that it's ready for use after the KBinsDiscretizer is fitted
             self._encoder.fit(np.zeros((1, len(self.n_bins_)), dtype=int))
diff --git a/python/cuml/cluster/agglomerative.pyx b/python/cuml/cluster/agglomerative.pyx
index 84cc579201..34150d3f6b 100644
--- a/python/cuml/cluster/agglomerative.pyx
+++ b/python/cuml/cluster/agglomerative.pyx
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2019-2022, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
 
 # distutils: language = c++
 
+import warnings
+
 from libc.stdint cimport uintptr_t
 
 from cuml.internals.safe_imports import cpu_only_import
@@ -103,6 +105,17 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
         Metric used to compute the linkage. Can be "euclidean", "l1",
         "l2", "manhattan", or "cosine". If connectivity is "knn" only
         "euclidean" is accepted.
+
+        .. deprecated:: 24.06
+            `affinity` was deprecated in version 24.06 and will be renamed to
+            `metric` in 25.08.
+
+    metric : str, default=None
+        Metric used to compute the linkage. Can be "euclidean", "l1",
+        "l2", "manhattan", or "cosine". If set to `None` then "euclidean"
+        is used. If connectivity is "knn" only "euclidean" is accepted.
+        .. versionadded:: 24.06
+
     linkage : {"single"}, default="single"
         Which linkage criterion to use. The linkage criterion determines
         which distance to use between sets of observations. The algorithm
@@ -136,9 +149,9 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
     labels_ = CumlArrayDescriptor()
     children_ = CumlArrayDescriptor()
 
-    def __init__(self, *, n_clusters=2, affinity="euclidean", linkage="single",
-                 handle=None, verbose=False, connectivity='knn',
-                 n_neighbors=10, output_type=None):
+    def __init__(self, *, n_clusters=2, affinity="deprecated", metric=None,
+                 linkage="single", handle=None, verbose=False,
+                 connectivity='knn', n_neighbors=10, output_type=None):
 
         super().__init__(handle=handle,
                          verbose=verbose,
@@ -159,11 +172,12 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
             raise ValueError("'n_neighbors' must be a positive number "
                              "between 2 and 1023")
 
-        if affinity not in _metrics_mapping:
-            raise ValueError("'affinity' %s is not supported." % affinity)
+        if metric is not None and metric not in _metrics_mapping:
+            raise ValueError("Metric '%s' is not supported." % affinity)
 
         self.n_clusters = n_clusters
         self.affinity = affinity
+        self.metric = metric
         self.linkage = linkage
         self.n_neighbors = n_neighbors
         self.connectivity = connectivity
@@ -178,6 +192,26 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
         """
         Fit the hierarchical clustering from features.
         """
+        if self.affinity != "deprecated":
+            if self.metric is not None:
+                raise ValueError(
+                    "Both `affinity` and `metric` attributes were set. Attribute"
+                    " `affinity` was deprecated in version 24.06 and will be removed in"
+                    " 25.08. To avoid this error, only set the `metric` attribute."
+                )
+            warnings.warn(
+                (
+                    "Attribute `affinity` was deprecated in version 24.06 and will be"
+                    " removed in 25.08. Use `metric` instead."
+                ),
+                FutureWarning,
+            )
+            metric_name = self.affinity
+        else:
+            if self.metric is None:
+                metric_name = "euclidean"
+            else:
+                metric_name = self.metric
 
         X_m, n_rows, n_cols, self.dtype = \
             input_to_cuml_array(X, order='C',
@@ -209,10 +243,10 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
         linkage_output.labels = <int*>labels_ptr
 
         cdef DistanceType metric
-        if self.affinity in _metrics_mapping:
-            metric = _metrics_mapping[self.affinity]
+        if metric_name in _metrics_mapping:
+            metric = _metrics_mapping[metric_name]
         else:
-            raise ValueError("'affinity' %s not supported." % self.affinity)
+            raise ValueError("Metric '%s' not supported." % metric_name)
 
         if self.connectivity == 'knn':
             single_linkage_neighbors(
@@ -249,6 +283,7 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
         return super().get_param_names() + [
             "n_clusters",
             "affinity",
+            "metric",
             "linkage",
             "connectivity",
             "n_neighbors"
diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx
index eb71f0c78d..2442757c75 100644
--- a/python/cuml/ensemble/randomforest_common.pyx
+++ b/python/cuml/ensemble/randomforest_common.pyx
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -68,7 +68,7 @@ class BaseRandomForestModel(Base):
     classes_ = CumlArrayDescriptor()
 
     def __init__(self, *, split_criterion, n_streams=4, n_estimators=100,
-                 max_depth=16, handle=None, max_features='auto', n_bins=128,
+                 max_depth=16, handle=None, max_features='sqrt', n_bins=128,
                  bootstrap=True,
                  verbose=False, min_samples_leaf=1, min_samples_split=2,
                  max_samples=1.0, max_leaves=-1, accuracy_metric=None,
@@ -166,8 +166,22 @@ class BaseRandomForestModel(Base):
             return math.log2(self.n_cols)/self.n_cols
         elif self.max_features == 'auto':
             if self.RF_type == CLASSIFICATION:
+                warnings.warn(
+                    "`max_features='auto'` has been deprecated in 24.06 "
+                    "and will be removed in 25.08. To keep the past behaviour "
+                    "and silence this warning, explicitly set "
+                    "`max_features='sqrt'`.",
+                    FutureWarning
+                )
                 return 1/np.sqrt(self.n_cols)
             else:
+                warnings.warn(
+                    "`max_features='auto'` has been deprecated in 24.06 "
+                    "and will be removed in 25.08. To keep the past behaviour "
+                    "and silence this warning, explicitly set "
+                    "`max_features=1.0`.",
+                    FutureWarning
+                )
                 return 1.0
         else:
             raise ValueError(
diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx
index 23a1bae940..45bc4ce2e8 100644
--- a/python/cuml/ensemble/randomforestclassifier.pyx
+++ b/python/cuml/ensemble/randomforestclassifier.pyx
@@ -172,15 +172,18 @@ class RandomForestClassifier(BaseRandomForestModel,
     max_leaves : int (default = -1)
         Maximum leaf nodes per tree. Soft constraint. Unlimited,
         If ``-1``.
-    max_features : int, float, or string (default = 'auto')
+    max_features : int, float, or string (default = 'sqrt')
         Ratio of number of features (columns) to consider per node
         split.\n
          * If type ``int`` then ``max_features`` is the absolute count of
            features to be used
          * If type ``float`` then ``max_features`` is used as a fraction.
-         * If ``'auto'`` then ``max_features=1/sqrt(n_features)``.
          * If ``'sqrt'`` then ``max_features=1/sqrt(n_features)``.
          * If ``'log2'`` then ``max_features=log2(n_features)/n_features``.
+
+        .. versionchanged:: 24.06
+           The default of `max_features` changed from `"auto"` to `"sqrt"`.
+
     n_bins : int (default = 128)
         Maximum number of bins used by the split algorithm per feature.
         For large problems, particularly those with highly-skewed input data,
diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx
index bfa35cdccc..96a197e5c5 100644
--- a/python/cuml/ensemble/randomforestregressor.pyx
+++ b/python/cuml/ensemble/randomforestregressor.pyx
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2019-2023, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -165,18 +165,22 @@ class RandomForestRegressor(BaseRandomForestModel,
         is not supported.\n
         .. note:: This default differs from scikit-learn's
           random forest, which defaults to unlimited depth.
+
     max_leaves : int (default = -1)
         Maximum leaf nodes per tree. Soft constraint. Unlimited,
         If ``-1``.
-    max_features : int, float, or string (default = 'auto')
+    max_features : int, float, or string (default = 1.0)
         Ratio of number of features (columns) to consider
         per node split.\n
          * If type ``int`` then ``max_features`` is the absolute count of
            features to be used.
          * If type ``float`` then ``max_features`` is used as a fraction.
-         * If ``'auto'`` then ``max_features=1.0``.
          * If ``'sqrt'`` then ``max_features=1/sqrt(n_features)``.
          * If ``'log2'`` then ``max_features=log2(n_features)/n_features``.
+
+        .. versionchanged:: 24.06
+          The default of `max_features` changed from `"auto"` to 1.0.
+
     n_bins : int (default = 128)
         Maximum number of bins used by the split algorithm per feature.
         For large problems, particularly those with highly-skewed input data,
diff --git a/python/cuml/experimental/linear_model/lars.pyx b/python/cuml/experimental/linear_model/lars.pyx
index 25a2ead0ac..9f2da7ea3b 100644
--- a/python/cuml/experimental/linear_model/lars.pyx
+++ b/python/cuml/experimental/linear_model/lars.pyx
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -85,11 +85,15 @@ class Lars(Base, RegressorMixin):
     fit_intercept : boolean (default = True)
         If True, Lars tries to correct for the global mean of y.
         If False, the model expects that you have centered the data.
-    normalize : boolean (default = True)
+    normalize : boolean (default = False)
         This parameter is ignored when `fit_intercept` is set to False.
         If True, the predictors in X will be normalized by removing its mean
         and dividing by it's variance. If False, then the solver expects that
         the data is already normalized.
+
+        .. versionchanged:: 24.06
+            The default of `normalize` changed from `True` to `False`.
+
     copy_X : boolean (default = True)
         The solver permutes the columns of X. Set `copy_X` to True to prevent
         changing the input data.
diff --git a/python/cuml/linear_model/logistic_regression.pyx b/python/cuml/linear_model/logistic_regression.pyx
index 92c42c849d..164821a5bd 100644
--- a/python/cuml/linear_model/logistic_regression.pyx
+++ b/python/cuml/linear_model/logistic_regression.pyx
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2019-2022, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
 
 # distutils: language = c++
 
+import warnings
+
 from cuml.internals.safe_imports import cpu_only_import
 from cuml.internals.safe_imports import gpu_only_import
 import pprint
@@ -36,7 +38,7 @@ cp = gpu_only_import('cupy')
 np = cpu_only_import('numpy')
 
 
-supported_penalties = ["l1", "l2", "none", "elasticnet"]
+supported_penalties = ["l1", "l2", None, "none", "elasticnet"]
 
 supported_solvers = ["qn"]
 
@@ -210,7 +212,7 @@ class LogisticRegression(UniversalBase,
                          output_type=output_type)
 
         if penalty not in supported_penalties:
-            raise ValueError("`penalty` " + str(penalty) + "not supported.")
+            raise ValueError("`penalty` " + str(penalty) + " not supported.")
 
         if solver not in supported_solvers:
             raise ValueError("Only quasi-newton `qn` solver is "
@@ -218,7 +220,16 @@ class LogisticRegression(UniversalBase,
         self.solver = solver
 
         self.C = C
+
+        if penalty == "none":
+            warnings.warn(
+                "The 'none' option was deprecated in version 24.06, and will "
+                "be removed in 25.08. Use None instead.",
+                FutureWarning
+            )
+            penalty = None
         self.penalty = penalty
+
         self.tol = tol
         self.fit_intercept = fit_intercept
         self.max_iter = max_iter
@@ -452,7 +463,7 @@ class LogisticRegression(UniversalBase,
         return proba
 
     def _get_qn_params(self):
-        if self.penalty == "none":
+        if self.penalty is None:
             l1_strength = 0.0
             l2_strength = 0.0
 
diff --git a/python/cuml/preprocessing/encoders.py b/python/cuml/preprocessing/encoders.py
index 46500b766a..01264572e7 100644
--- a/python/cuml/preprocessing/encoders.py
+++ b/python/cuml/preprocessing/encoders.py
@@ -203,10 +203,21 @@ class OneHotEncoder(BaseEncoder):
         - dict/list : ``drop[col]`` is the category in feature col that
           should be dropped.
 
-    sparse : bool, default=True
+    sparse_output : bool, default=True
         This feature is not fully supported by cupy
         yet, causing incorrect values when computing one hot encodings.
         See https://github.com/cupy/cupy/issues/3223
+
+        .. versionadded:: 24.06
+           `sparse` was renamed to `sparse_output`
+
+    sparse : bool, default=True
+        Will return sparse matrix if set True else will return an array.
+
+        .. deprecated:: 24.06
+           `sparse` is deprecated in 24.06 and will be removed in 25.08. Use
+           `sparse_output` instead.
+
     dtype : number type, default=np.float
         Desired datatype of transform's output.
     handle_unknown : {'error', 'ignore'}, default='error'
@@ -246,7 +257,8 @@ def __init__(
         *,
         categories="auto",
         drop=None,
-        sparse=True,
+        sparse="deprecated",
+        sparse_output=True,
         dtype=np.float32,
         handle_unknown="error",
         handle=None,
@@ -257,7 +269,9 @@ def __init__(
             handle=handle, verbose=verbose, output_type=output_type
         )
         self.categories = categories
+        # TODO(24.08): Remove self.sparse
         self.sparse = sparse
+        self.sparse_output = sparse_output
         self.dtype = dtype
         self.handle_unknown = handle_unknown
         self.drop = drop
@@ -266,10 +280,14 @@ def __init__(
         self._features = None
         self._encoders = None
         self.input_type = None
-        if sparse and np.dtype(dtype) not in ["f", "d", "F", "D"]:
+        # This parameter validation should be performed in `fit` instead
+        # of in the constructor. Hence the awkwark `if` clause
+        if ((sparse != "deprecated" and sparse) or sparse_output) and np.dtype(
+            dtype
+        ) not in ["f", "d", "F", "D"]:
             raise ValueError(
                 "Only float32, float64, complex64 and complex128 "
-                "are supported when using sparse"
+                "are supported when using sparse_output"
             )
 
     def _validate_keywords(self):
@@ -289,6 +307,17 @@ def _validate_keywords(self):
                 "zero."
             )
 
+        if self.sparse != "deprecated":
+            warnings.warn(
+                (
+                    "`sparse` was renamed to `sparse_output` in version 24.06"
+                    " and will be removed in 25.08. `sparse_output` is ignored"
+                    " unless you leave `sparse` set to its default value."
+                ),
+                FutureWarning,
+            )
+            self.sparse_output = self.sparse
+
     def _check_is_fitted(self):
         if not self._fitted:
             msg = (
@@ -440,7 +469,7 @@ def transform(self, X):
                 (val, (rows, cols)), shape=(len(X), j), dtype=self.dtype
             )
 
-            if not self.sparse:
+            if not self.sparse_output:
                 ohe = ohe.toarray()
 
             return ohe
@@ -578,6 +607,7 @@ def get_param_names(self):
             "categories",
             "drop",
             "sparse",
+            "sparse_output",
             "dtype",
             "handle_unknown",
         ]
diff --git a/python/cuml/tests/dask/test_dask_logistic_regression.py b/python/cuml/tests/dask/test_dask_logistic_regression.py
index 9d46fa0147..89814365e1 100644
--- a/python/cuml/tests/dask/test_dask_logistic_regression.py
+++ b/python/cuml/tests/dask/test_dask_logistic_regression.py
@@ -907,7 +907,7 @@ def to_dask_data(X_train, X_test, y_train, y_test):
 @pytest.mark.parametrize(
     "reg_dtype",
     [
-        (("none", 1.0, None), np.float64),
+        ((None, 1.0, None), np.float64),
         (("l2", 2.0, None), np.float64),
         (("l1", 2.0, None), np.float32),
         (("elasticnet", 2.0, 0.2), np.float32),
@@ -1005,7 +1005,7 @@ def test_standardization_example(fit_intercept, reg_dtype, client):
 @pytest.mark.parametrize(
     "reg_dtype",
     [
-        (("none", 1.0, None), np.float64),
+        ((None, 1.0, None), np.float64),
         (("l2", 2.0, None), np.float32),
         (("l1", 2.0, None), np.float64),
         (("elasticnet", 2.0, 0.2), np.float32),
diff --git a/python/cuml/tests/dask/test_dask_one_hot_encoder.py b/python/cuml/tests/dask/test_dask_one_hot_encoder.py
index e4d76bd470..64ba9715bc 100644
--- a/python/cuml/tests/dask/test_dask_one_hot_encoder.py
+++ b/python/cuml/tests/dask/test_dask_one_hot_encoder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -41,8 +41,8 @@ def test_onehot_vs_skonehot(client):
     skX = from_df_to_numpy(X)
     X = dask_cudf.from_cudf(X, npartitions=2)
 
-    enc = OneHotEncoder(sparse=False)
-    skohe = SkOneHotEncoder(sparse=False)
+    enc = OneHotEncoder(sparse_output=False)
+    skohe = SkOneHotEncoder(sparse_output=False)
 
     ohe = enc.fit_transform(X)
     ref = skohe.fit_transform(skX)
@@ -71,7 +71,7 @@ def test_onehot_categories(client):
     X = DataFrame({"chars": ["a", "b"], "int": [0, 2]})
     X = dask_cudf.from_cudf(X, npartitions=2)
     cats = DataFrame({"chars": ["a", "b", "c"], "int": [0, 1, 2]})
-    enc = OneHotEncoder(categories=cats, sparse=False)
+    enc = OneHotEncoder(categories=cats, sparse_output=False)
     ref = cp.array(
         [[1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0, 1.0]]
     )
@@ -100,12 +100,12 @@ def test_onehot_transform_handle_unknown(client):
     X = dask_cudf.from_cudf(X, npartitions=2)
     Y = dask_cudf.from_cudf(Y, npartitions=2)
 
-    enc = OneHotEncoder(handle_unknown="error", sparse=False)
+    enc = OneHotEncoder(handle_unknown="error", sparse_output=False)
     enc = enc.fit(X)
     with pytest.raises(KeyError):
         enc.transform(Y).compute()
 
-    enc = OneHotEncoder(handle_unknown="ignore", sparse=False)
+    enc = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
     enc = enc.fit(X)
     ohe = enc.transform(Y)
     ref = cp.array([[0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]])
@@ -140,8 +140,10 @@ def test_onehot_random_inputs(client, drop, as_array, sparse, n_samples):
     else:
         dX = dask_cudf.from_cudf(X, npartitions=1)
 
-    enc = OneHotEncoder(sparse=sparse, drop=drop, categories="auto")
-    sk_enc = SkOneHotEncoder(sparse=sparse, drop=drop, categories="auto")
+    enc = OneHotEncoder(sparse_output=sparse, drop=drop, categories="auto")
+    sk_enc = SkOneHotEncoder(
+        sparse_output=sparse, drop=drop, categories="auto"
+    )
     ohe = enc.fit_transform(dX)
     ref = sk_enc.fit_transform(ary)
     if sparse:
@@ -159,8 +161,8 @@ def test_onehot_drop_idx_first(client):
     X = DataFrame({"chars": ["c", "b"], "int": [2, 2], "letters": ["a", "b"]})
     ddf = dask_cudf.from_cudf(X, npartitions=2)
 
-    enc = OneHotEncoder(sparse=False, drop="first")
-    sk_enc = SkOneHotEncoder(sparse=False, drop="first")
+    enc = OneHotEncoder(sparse_output=False, drop="first")
+    sk_enc = SkOneHotEncoder(sparse_output=False, drop="first")
     ohe = enc.fit_transform(ddf)
     ref = sk_enc.fit_transform(X_ary)
     cp.testing.assert_array_equal(ohe.compute(), ref)
@@ -177,8 +179,8 @@ def test_onehot_drop_one_of_each(client):
     ddf = dask_cudf.from_cudf(X, npartitions=2)
 
     drop = dict({"chars": "b", "int": 2, "letters": "b"})
-    enc = OneHotEncoder(sparse=False, drop=drop)
-    sk_enc = SkOneHotEncoder(sparse=False, drop=["b", 2, "b"])
+    enc = OneHotEncoder(sparse_output=False, drop=drop)
+    sk_enc = SkOneHotEncoder(sparse_output=False, drop=["b", 2, "b"])
     ohe = enc.fit_transform(ddf)
     ref = sk_enc.fit_transform(X_ary)
     cp.testing.assert_array_equal(ohe.compute(), ref)
@@ -212,7 +214,7 @@ def test_onehot_drop_exceptions(client, drop, pattern):
     X = dask_cudf.from_cudf(X, npartitions=2)
 
     with pytest.raises(ValueError, match=pattern):
-        OneHotEncoder(sparse=False, drop=drop).fit(X)
+        OneHotEncoder(sparse_output=False, drop=drop).fit(X)
 
 
 @pytest.mark.mg
diff --git a/python/cuml/tests/dask/test_dask_random_forest.py b/python/cuml/tests/dask/test_dask_random_forest.py
index c35f5ab21e..38596b2e69 100644
--- a/python/cuml/tests/dask/test_dask_random_forest.py
+++ b/python/cuml/tests/dask/test_dask_random_forest.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2019-2023, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
 #
 
 
-# Copyright (c) 2019-2022, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -170,7 +170,7 @@ def test_rf_regression_dask_fil(partitions_per_worker, dtype, client):
     cuml_mod_predict = cuml_mod.predict(X_test_df)
     cuml_mod_predict = cp.asnumpy(cp.array(cuml_mod_predict.compute()))
 
-    acc_score = r2_score(cuml_mod_predict, y_test)
+    acc_score = r2_score(y_test, cuml_mod_predict)
 
     assert acc_score >= 0.59
 
@@ -256,7 +256,7 @@ def test_rf_regression_dask_cpu(partitions_per_worker, client):
 
     cuml_mod_predict = cuml_mod.predict(X_test, predict_model="CPU")
 
-    acc_score = r2_score(cuml_mod_predict, y_test)
+    acc_score = r2_score(y_test, cuml_mod_predict)
 
     assert acc_score >= 0.67
 
@@ -711,7 +711,7 @@ def test_rf_broadcast(model_type, fit_broadcast, transform_broadcast, client):
 
         cuml_mod_predict = cuml_mod_predict.compute()
         cuml_mod_predict = cp.asnumpy(cuml_mod_predict)
-        acc_score = r2_score(cuml_mod_predict, y_test)
+        acc_score = r2_score(y_test, cuml_mod_predict)
         assert acc_score >= 0.72
 
     if transform_broadcast:
diff --git a/python/cuml/tests/explainer/test_explainer_kernel_shap.py b/python/cuml/tests/explainer/test_explainer_kernel_shap.py
index 74c985989f..3f20b7d8a5 100644
--- a/python/cuml/tests/explainer/test_explainer_kernel_shap.py
+++ b/python/cuml/tests/explainer/test_explainer_kernel_shap.py
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -522,24 +522,24 @@ def test_typeerror_input():
 housing_regression_result = np.array(
     [
         [
-            -0.73860609,
-            0.00557072,
-            -0.05829297,
-            -0.01582018,
-            -0.01010366,
-            -0.23167623,
-            -0.470639,
-            -0.07584473,
+            -0.00182223,
+            -0.01232004,
+            -0.4782278,
+            0.04781425,
+            -0.01337761,
+            -0.34830606,
+            -0.4682865,
+            -0.20812261,
         ],
         [
-            -0.6410764,
-            0.01369913,
-            -0.09492759,
-            0.02654463,
-            -0.00911134,
-            -0.05953105,
-            -0.51266433,
-            -0.0853608,
+            -0.0013606,
+            0.0110372,
+            -0.445176,
+            -0.08268094,
+            0.00406259,
+            -0.02185595,
+            -0.47673094,
+            -0.13557231,
         ],
     ],
     dtype=np.float32,
diff --git a/python/cuml/tests/test_agglomerative.py b/python/cuml/tests/test_agglomerative.py
index c415a6b3c2..7c71be02ec 100644
--- a/python/cuml/tests/test_agglomerative.py
+++ b/python/cuml/tests/test_agglomerative.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2019-2023, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -33,14 +33,14 @@ def test_duplicate_distances(connectivity):
 
     cuml_agg = AgglomerativeClustering(
         n_clusters=2,
-        affinity="euclidean",
+        metric="euclidean",
         linkage="single",
         n_neighbors=3,
         connectivity=connectivity,
     )
 
     sk_agg = cluster.AgglomerativeClustering(
-        n_clusters=2, affinity="euclidean", linkage="single"
+        n_clusters=2, metric="euclidean", linkage="single"
     )
 
     cuml_agg.fit(X)
@@ -64,7 +64,7 @@ def test_single_linkage_sklearn_compare(
 
     cuml_agg = AgglomerativeClustering(
         n_clusters=nclusters,
-        affinity="euclidean",
+        metric="euclidean",
         linkage="single",
         n_neighbors=k,
         connectivity=connectivity,
@@ -73,7 +73,7 @@ def test_single_linkage_sklearn_compare(
     cuml_agg.fit(X)
 
     sk_agg = cluster.AgglomerativeClustering(
-        n_clusters=nclusters, affinity="euclidean", linkage="single"
+        n_clusters=nclusters, metric="euclidean", linkage="single"
     )
     sk_agg.fit(cp.asnumpy(X))
 
@@ -87,9 +87,9 @@ def test_single_linkage_sklearn_compare(
 
 def test_invalid_inputs():
 
-    # Test bad affinity
+    # Test bad metric
     with pytest.raises(ValueError):
-        AgglomerativeClustering(affinity="doesntexist")
+        AgglomerativeClustering(metric="doesntexist")
 
     with pytest.raises(ValueError):
         AgglomerativeClustering(linkage="doesntexist")
@@ -108,3 +108,23 @@ def test_invalid_inputs():
 
     with pytest.raises(ValueError):
         AgglomerativeClustering(n_clusters=500).fit(cp.ones((2, 5)))
+
+
+def test_affinity_deprecation():
+    X = cp.array([[1.0, 2], [3, 4]])
+    y = cp.array([1, 0])
+
+    agg = AgglomerativeClustering(affinity="euclidean")
+    with pytest.warns(
+        FutureWarning,
+        match="Attribute `affinity` was deprecated in version 24.06",
+    ):
+        agg.fit(X, y)
+
+    # don't provide both
+    agg = AgglomerativeClustering(affinity="euclidean", metric="euclidean")
+    with pytest.raises(
+        ValueError,
+        match="Both `affinity` and `metric` attributes were set",
+    ):
+        agg.fit(X, y)
diff --git a/python/cuml/tests/test_device_selection.py b/python/cuml/tests/test_device_selection.py
index 96d776909f..e5c2d9ce1a 100644
--- a/python/cuml/tests/test_device_selection.py
+++ b/python/cuml/tests/test_device_selection.py
@@ -236,7 +236,7 @@ def linreg_test_data(request):
     **fixture_generation_helper(
         {
             "input_type": ["numpy", "dataframe", "cupy", "cudf", "numba"],
-            "penalty": ["none", "l2"],
+            "penalty": [None, "l2"],
             "fit_intercept": [False, True],
         }
     )
diff --git a/python/cuml/tests/test_kmeans.py b/python/cuml/tests/test_kmeans.py
index 83c2e4db6a..b05a762177 100644
--- a/python/cuml/tests/test_kmeans.py
+++ b/python/cuml/tests/test_kmeans.py
@@ -236,7 +236,9 @@ def test_kmeans_sklearn_comparison(name, nrows, random_state):
     cu_y_pred = cuml_kmeans.fit_predict(X)
     cu_score = adjusted_rand_score(cu_y_pred, y)
     kmeans = cluster.KMeans(
-        random_state=random_state, n_clusters=params["n_clusters"]
+        random_state=random_state,
+        n_clusters=params["n_clusters"],
+        n_init=10,
     )
     sk_y_pred = kmeans.fit_predict(X)
     sk_score = adjusted_rand_score(sk_y_pred, y)
@@ -278,7 +280,9 @@ def test_kmeans_sklearn_comparison_default(name, nrows, random_state):
     cu_y_pred = cuml_kmeans.fit_predict(X)
     cu_score = adjusted_rand_score(cu_y_pred, y)
     kmeans = cluster.KMeans(
-        random_state=random_state, n_clusters=params["n_clusters"]
+        random_state=random_state,
+        n_clusters=params["n_clusters"],
+        n_init=10,
     )
     sk_y_pred = kmeans.fit_predict(X)
     sk_score = adjusted_rand_score(sk_y_pred, y)
diff --git a/python/cuml/tests/test_lars.py b/python/cuml/tests/test_lars.py
index 274d8f4199..5064ae674b 100644
--- a/python/cuml/tests/test_lars.py
+++ b/python/cuml/tests/test_lars.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -61,15 +61,14 @@ def normalize_data(X, y):
         stress_param([1000, 500]),
     ],
 )
-@pytest.mark.parametrize("normalize", [True, False])
 @pytest.mark.parametrize("precompute", [True, False, "precompute"])
-def test_lars_model(datatype, nrows, column_info, precompute, normalize):
+def test_lars_model(datatype, nrows, column_info, precompute):
     ncols, n_info = column_info
     X_train, X_test, y_train, y_test = make_regression_dataset(
         datatype, nrows, ncols, n_info
     )
 
-    if precompute == "precompute" or not normalize:
+    if precompute == "precompute":
         # Apply normalization manually, because the solver expects normalized
         # input data
         X_train, y_train, x_mean, x_scale, y_mean = normalize_data(
@@ -81,7 +80,7 @@ def test_lars_model(datatype, nrows, column_info, precompute, normalize):
     if precompute == "precompute":
         precompute = np.dot(X_train.T, X_train)
 
-    params = {"precompute": precompute, "normalize": normalize}
+    params = {"precompute": precompute}
 
     # Initialization of cuML's LARS
     culars = cuLars(**params)
diff --git a/python/cuml/tests/test_linear_model.py b/python/cuml/tests/test_linear_model.py
index 365c749c79..74395c15c9 100644
--- a/python/cuml/tests/test_linear_model.py
+++ b/python/cuml/tests/test_linear_model.py
@@ -144,6 +144,15 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None):
 algorithms = st.sampled_from(_ALGORITHMS)
 
 
+# TODO(24.08): remove this test
+def test_logreg_penalty_deprecation():
+    with pytest.warns(
+        FutureWarning,
+        match="The 'none' option was deprecated in version 24.06",
+    ):
+        cuLog(penalty="none")
+
+
 @pytest.mark.parametrize("ntargets", [1, 2])
 @pytest.mark.parametrize("datatype", [np.float32, np.float64])
 @pytest.mark.parametrize("algorithm", ["eig", "svd"])
@@ -457,11 +466,11 @@ def test_weighted_ridge(datatype, algorithm, fit_intercept, distribution):
     "num_classes, dtype, penalty, l1_ratio, fit_intercept, C, tol",
     [
         # L-BFGS Solver
-        (2, np.float32, "none", 1.0, True, 1.0, 1e-3),
+        (2, np.float32, None, 1.0, True, 1.0, 1e-3),
         (2, np.float64, "l2", 1.0, True, 1.0, 1e-8),
         (10, np.float32, "elasticnet", 0.0, True, 1.0, 1e-3),
-        (10, np.float32, "none", 1.0, False, 1.0, 1e-8),
-        (10, np.float32, "none", 1.0, False, 2.0, 1e-3),
+        (10, np.float32, None, 1.0, False, 1.0, 1e-8),
+        (10, np.float32, None, 1.0, False, 2.0, 1e-3),
         # OWL-QN Solver
         (2, np.float32, "l1", 1.0, True, 1.0, 1e-3),
         (2, np.float64, "elasticnet", 1.0, True, 1.0, 1e-8),
@@ -567,7 +576,7 @@ def test_logistic_regression(
 
 @given(
     dtype=floating_dtypes(sizes=(32, 64)),
-    penalty=st.sampled_from(("none", "l1", "l2", "elasticnet")),
+    penalty=st.sampled_from((None, "l1", "l2", "elasticnet")),
     l1_ratio=st.one_of(st.none(), st.floats(min_value=0.0, max_value=1.0)),
 )
 def test_logistic_regression_unscaled(dtype, penalty, l1_ratio):
@@ -624,7 +633,7 @@ def test_logistic_regression_model_default(dtype):
     order=st.sampled_from(("C", "F")),
     sparse_input=st.booleans(),
     fit_intercept=st.booleans(),
-    penalty=st.sampled_from(("none", "l1", "l2")),
+    penalty=st.sampled_from((None, "l1", "l2")),
 )
 def test_logistic_regression_model_digits(
     dtype, order, sparse_input, fit_intercept, penalty
@@ -927,8 +936,8 @@ def test_linear_models_set_params(algo):
     coef_before = model.coef_
 
     if algo == cuLog:
-        params = {"penalty": "none", "C": 1, "max_iter": 30}
-        model = algo(penalty="none", C=1, max_iter=30)
+        params = {"penalty": None, "C": 1, "max_iter": 30}
+        model = algo(penalty=None, C=1, max_iter=30)
     else:
         model = algo(solver="svd", alpha=0.1)
         params = {"solver": "svd", "alpha": 0.1}
diff --git a/python/cuml/tests/test_metrics.py b/python/cuml/tests/test_metrics.py
index 8748463e6e..6e92535cf7 100644
--- a/python/cuml/tests/test_metrics.py
+++ b/python/cuml/tests/test_metrics.py
@@ -199,7 +199,7 @@ def test_sklearn_search():
     gdf_train = cudf.DataFrame(dict(train=y_train))
 
     sk_cu_grid.fit(gdf_data, gdf_train.train)
-    assert sk_cu_grid.best_params_ == {"alpha": 0.1}
+    assert_almost_equal(sk_cu_grid.best_params_["alpha"], 0.1)
 
 
 @pytest.mark.parametrize(
@@ -960,9 +960,12 @@ def test_log_loss_random(n_samples, dtype):
         lambda rng: rng.randint(0, 10, n_samples).astype(dtype)
     )
 
-    y_pred, _, _, _ = generate_random_labels(
+    _, _, y_pred, _ = generate_random_labels(
         lambda rng: rng.rand(n_samples, 10)
     )
+    # Make sure the probabilities sum to 1 per sample
+    y_pred /= y_pred.sum(axis=1)[:, None]
+    y_pred = cuda.to_device(y_pred)
 
     assert_almost_equal(
         log_loss(y_true, y_pred), sklearn_log_loss(y_true, y_pred)
@@ -1497,8 +1500,8 @@ def test_sparse_pairwise_distances_sklearn_comparison(
         matrix_size[0], matrix_size[1], cp.float64, density, metric
     )
 
-    # For fp64, compare at 9 decimals, (6 places less than the ~15 max)
-    compare_precision = 9
+    # For fp64, compare at 7 decimals, (8 places less than the ~15 max)
+    compare_precision = 7
 
     # Compare to sklearn, fp64
     S = sparse_pairwise_distances(X, Y, metric=metric)
diff --git a/python/cuml/tests/test_one_hot_encoder.py b/python/cuml/tests/test_one_hot_encoder.py
index 991d42ddb1..9f3b1d2c34 100644
--- a/python/cuml/tests/test_one_hot_encoder.py
+++ b/python/cuml/tests/test_one_hot_encoder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -57,8 +57,8 @@ def test_onehot_vs_skonehot(as_array):
         X = _from_df_to_cupy(X)
         skX = cp.asnumpy(X)
 
-    enc = OneHotEncoder(sparse=True)
-    skohe = SkOneHotEncoder(sparse=True)
+    enc = OneHotEncoder(sparse_output=True)
+    skohe = SkOneHotEncoder(sparse_output=True)
 
     ohe = enc.fit_transform(X)
     ref = skohe.fit_transform(skX)
@@ -89,7 +89,7 @@ def test_onehot_categories(as_array):
         X = _from_df_to_cupy(X)
         categories = _from_df_to_cupy(categories).transpose()
 
-    enc = OneHotEncoder(categories=categories, sparse=False)
+    enc = OneHotEncoder(categories=categories, sparse_output=False)
     ref = cp.array(
         [[1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0, 1.0]]
     )
@@ -124,12 +124,12 @@ def test_onehot_transform_handle_unknown(as_array):
         X = _from_df_to_cupy(X)
         Y = _from_df_to_cupy(Y)
 
-    enc = OneHotEncoder(handle_unknown="error", sparse=False)
+    enc = OneHotEncoder(handle_unknown="error", sparse_output=False)
     enc = enc.fit(X)
     with pytest.raises(KeyError):
         enc.transform(Y)
 
-    enc = OneHotEncoder(handle_unknown="ignore", sparse=False)
+    enc = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
     enc = enc.fit(X)
     ohe = enc.transform(Y)
     ref = cp.array([[0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]])
@@ -163,8 +163,10 @@ def test_onehot_random_inputs(drop, sparse, n_samples, as_array):
         n_samples=n_samples, as_array=as_array
     )
 
-    enc = OneHotEncoder(sparse=sparse, drop=drop, categories="auto")
-    sk_enc = SkOneHotEncoder(sparse=sparse, drop=drop, categories="auto")
+    enc = OneHotEncoder(sparse_output=sparse, drop=drop, categories="auto")
+    sk_enc = SkOneHotEncoder(
+        sparse_output=sparse, drop=drop, categories="auto"
+    )
     ohe = enc.fit_transform(X)
     ref = sk_enc.fit_transform(ary)
     if sparse:
@@ -183,8 +185,10 @@ def test_onehot_drop_idx_first(as_array):
         X = _from_df_to_cupy(X)
         X_ary = cp.asnumpy(X)
 
-    enc = OneHotEncoder(sparse=False, drop="first", categories="auto")
-    sk_enc = SkOneHotEncoder(sparse=False, drop="first", categories="auto")
+    enc = OneHotEncoder(sparse_output=False, drop="first", categories="auto")
+    sk_enc = SkOneHotEncoder(
+        sparse_output=False, drop="first", categories="auto"
+    )
     ohe = enc.fit_transform(X)
     ref = sk_enc.fit_transform(X_ary)
     cp.testing.assert_array_equal(ohe, ref)
@@ -203,11 +207,11 @@ def test_onehot_drop_one_of_each(as_array):
         X_ary = cp.asnumpy(X)
         drop = drop_ary = _convert_drop(drop)
 
-    enc = OneHotEncoder(sparse=False, drop=drop, categories="auto")
+    enc = OneHotEncoder(sparse_output=False, drop=drop, categories="auto")
     ohe = enc.fit_transform(X)
     print(ohe.dtype)
     ref = SkOneHotEncoder(
-        sparse=False, drop=drop_ary, categories="auto"
+        sparse_output=False, drop=drop_ary, categories="auto"
     ).fit_transform(X_ary)
     cp.testing.assert_array_equal(ohe, ref)
     inv = enc.inverse_transform(ohe)
@@ -240,7 +244,7 @@ def test_onehot_drop_exceptions(drop, pattern, as_array):
         drop = _convert_drop(drop) if not isinstance(drop, DataFrame) else drop
 
     with pytest.raises(ValueError, match=pattern):
-        OneHotEncoder(sparse=False, drop=drop).fit(X)
+        OneHotEncoder(sparse_output=False, drop=drop).fit(X)
 
 
 @pytest.mark.parametrize("as_array", [True, False], ids=["cupy", "cudf"])
@@ -270,8 +274,10 @@ def test_onehot_sparse_drop(as_array):
         ary = cp.asnumpy(X)
         drop = drop_ary = _convert_drop(drop)
 
-    enc = OneHotEncoder(sparse=True, drop=drop, categories="auto")
-    sk_enc = SkOneHotEncoder(sparse=True, drop=drop_ary, categories="auto")
+    enc = OneHotEncoder(sparse_output=True, drop=drop, categories="auto")
+    sk_enc = SkOneHotEncoder(
+        sparse_output=True, drop=drop_ary, categories="auto"
+    )
     ohe = enc.fit_transform(X)
     ref = sk_enc.fit_transform(ary)
     cp.testing.assert_array_equal(ohe.toarray(), ref.toarray())
@@ -286,21 +292,21 @@ def test_onehot_categories_shape_mismatch(as_array):
         categories = _from_df_to_cupy(categories).transpose()
 
     with pytest.raises(ValueError):
-        OneHotEncoder(categories=categories, sparse=False).fit(X)
+        OneHotEncoder(categories=categories, sparse_output=False).fit(X)
 
 
 def test_onehot_category_specific_cases():
     # See this for reasoning: https://github.com/rapidsai/cuml/issues/2690
 
-    # All of these cases use sparse=False, where
-    # test_onehot_category_class_count uses sparse=True
+    # All of these cases use sparse_output=False, where
+    # test_onehot_category_class_count uses sparse_output=True
 
     # ==== 2 Rows (Low before High) ====
     example_df = DataFrame()
     example_df["low_cardinality_column"] = ["A"] * 200 + ["B"] * 56
     example_df["high_cardinality_column"] = cp.linspace(0, 255, 256)
 
-    encoder = OneHotEncoder(handle_unknown="ignore", sparse=False)
+    encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
     encoder.fit_transform(example_df)
 
     # ==== 2 Rows (High before Low, used to fail) ====
@@ -308,7 +314,7 @@ def test_onehot_category_specific_cases():
     example_df["high_cardinality_column"] = cp.linspace(0, 255, 256)
     example_df["low_cardinality_column"] = ["A"] * 200 + ["B"] * 56
 
-    encoder = OneHotEncoder(handle_unknown="ignore", sparse=False)
+    encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
     encoder.fit_transform(example_df)
 
 
@@ -319,9 +325,9 @@ def test_onehot_category_specific_cases():
 )
 def test_onehot_category_class_count(total_classes: int):
     # See this for reasoning: https://github.com/rapidsai/cuml/issues/2690
-    # All tests use sparse=True to avoid memory errors
+    # All tests use sparse_output=True to avoid memory errors
 
-    encoder = OneHotEncoder(handle_unknown="ignore", sparse=True)
+    encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=True)
 
     # ==== 2 Rows ====
     example_df = DataFrame()
@@ -388,3 +394,14 @@ def test_onehot_get_feature_names(as_array):
     ]
     feature_names = enc.get_feature_names(["fruit", "size"])
     assert np.array_equal(feature_names, feature_names_ref)
+
+
+# TODO(24.08): remove this test
+def test_sparse_deprecation():
+    X = cp.array([[33, 1], [34, 3], [34, 2]])
+    oh = OneHotEncoder(sparse=True)
+
+    with pytest.warns(
+        FutureWarning, match="`sparse` was renamed to `sparse_output`"
+    ):
+        oh.fit(X)
diff --git a/python/cuml/tests/test_random_forest.py b/python/cuml/tests/test_random_forest.py
index b18d6ec8ab..640c22fd67 100644
--- a/python/cuml/tests/test_random_forest.py
+++ b/python/cuml/tests/test_random_forest.py
@@ -275,7 +275,7 @@ def test_tweedie_convergence(max_depth, split_criterion):
     "max_samples", [unit_param(1.0), quality_param(0.90), stress_param(0.95)]
 )
 @pytest.mark.parametrize("datatype", [np.float32, np.float64])
-@pytest.mark.parametrize("max_features", [1.0, "auto", "log2", "sqrt"])
+@pytest.mark.parametrize("max_features", [1.0, "log2", "sqrt"])
 def test_rf_classification(small_clf, datatype, max_samples, max_features):
     use_handle = True
 
@@ -399,7 +399,6 @@ def test_rf_classification_unorder(
     [
         (1.0, 16),
         (1.0, 11),
-        ("auto", 128),
         ("log2", 100),
         ("sqrt", 100),
         (1.0, 17),
@@ -682,7 +681,7 @@ def test_rf_classification_multi_class(mclass_clf, datatype, array_type):
 
 @pytest.mark.parametrize("datatype", [(np.float32, np.float64)])
 @pytest.mark.parametrize("max_samples", [unit_param(1.0), stress_param(0.95)])
-@pytest.mark.parametrize("max_features", [1.0, "auto", "log2", "sqrt"])
+@pytest.mark.parametrize("max_features", [1.0, "log2", "sqrt"])
 def test_rf_classification_proba(
     small_clf, datatype, max_samples, max_features
 ):
@@ -862,7 +861,7 @@ def test_rf_regression_sparse(special_reg, datatype, fil_sparse_format, algo):
             sk_model.fit(X_train, y_train)
             sk_preds = sk_model.predict(X_test)
             sk_r2 = r2_score(y_test, sk_preds, convert_dtype=datatype)
-            assert fil_r2 >= (sk_r2 - 0.07)
+            assert fil_r2 >= (sk_r2 - 0.08)
 
 
 @pytest.mark.xfail(reason="Need rapidsai/rmm#415 to detect memleak robustly")
@@ -915,7 +914,7 @@ def test_for_memory_leak():
         test_for_memory_leak()
 
 
-@pytest.mark.parametrize("max_features", [1.0, "auto", "log2", "sqrt"])
+@pytest.mark.parametrize("max_features", [1.0, "log2", "sqrt"])
 @pytest.mark.parametrize("max_depth", [10, 13, 16])
 @pytest.mark.parametrize("n_estimators", [10, 20, 100])
 @pytest.mark.parametrize("n_bins", [8, 9, 10])
@@ -1384,6 +1383,25 @@ def test_rf_min_samples_split_with_small_float(estimator, make_data):
     clf.fit(X, y)
 
 
+# TODO: Remove in v24.08
+@pytest.mark.parametrize(
+    "Estimator",
+    [
+        curfr,
+        curfc,
+    ],
+)
+def test_random_forest_max_features_deprecation(Estimator):
+    X = np.array([[1.0, 2], [3, 4]])
+    y = np.array([1, 0])
+    est = Estimator(max_features="auto")
+
+    error_msg = "`max_features='auto'` has been deprecated in 24.06 "
+
+    with pytest.warns(FutureWarning, match=error_msg):
+        est.fit(X, y)
+
+
 def test_rf_predict_returns_int():
 
     X, y = make_classification()
diff --git a/python/cuml/tests/test_thirdparty.py b/python/cuml/tests/test_thirdparty.py
index 70b4a21ad6..ed23db76fb 100644
--- a/python/cuml/tests/test_thirdparty.py
+++ b/python/cuml/tests/test_thirdparty.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2023, NVIDIA CORPORATION.
+# Copyright (c) 2021-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -88,8 +88,8 @@ def test_check_X_y():
 def test_row_norms(failure_logger, sparse_random_dataset, square):
     X_np, X, X_sparse_np, X_sparse = sparse_random_dataset
 
-    cu_norms = cu_row_norms(X_np, squared=square)
-    sk_norms = sk_row_norms(X, squared=square)
+    cu_norms = cu_row_norms(X, squared=square)
+    sk_norms = sk_row_norms(X_np, squared=square)
     assert_allclose(cu_norms, sk_norms)
 
     cu_norms = cu_row_norms(X_sparse, squared=square)
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 9c72a43a9f..a7dd8d8e6a 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -122,7 +122,7 @@ test = [
     "pytest-cov",
     "pytest-xdist",
     "pytest==7.*",
-    "scikit-learn==1.2",
+    "scikit-learn==1.5",
     "seaborn",
     "statsmodels",
     "umap-learn==0.5.3",

From 92f5830d2c5544681f07edd981b33ac1ed37ed94 Mon Sep 17 00:00:00 2001
From: Paul Taylor <178183+trxcllnt@users.noreply.github.com>
Date: Wed, 29 May 2024 09:10:11 -0700
Subject: [PATCH 4/4] Fix building cuml with CCCL main (#5886)

Similar to https://github.com/rapidsai/cudf/pull/15552, we are testing [building RAPIDS with CCCL's main branch](https://github.com/NVIDIA/cccl/pull/1667) to get ahead of any breaking changes.

Authors:
  - Paul Taylor (https://github.com/trxcllnt)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Kyle Edwards (https://github.com/KyleFromNVIDIA)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: https://github.com/rapidsai/cuml/pull/5886
---
 .devcontainer/Dockerfile                       |  5 +++++
 .devcontainer/cuda11.8-conda/devcontainer.json |  2 +-
 .devcontainer/cuda11.8-pip/devcontainer.json   |  8 ++------
 .devcontainer/cuda12.2-conda/devcontainer.json |  2 +-
 .devcontainer/cuda12.2-pip/devcontainer.json   |  8 ++------
 cpp/bench/CMakeLists.txt                       |  6 +++++-
 cpp/cmake/modules/ConfigureCUDA.cmake          |  6 +++---
 cpp/src/arima/batched_arima.cu                 |  2 +-
 cpp/src/hdbscan/condensed_hierarchy.cu         |  3 ++-
 cpp/src/hdbscan/detail/select.cuh              | 15 ++++++++-------
 cpp/src/hdbscan/detail/utils.h                 |  2 +-
 cpp/src/tsne/distances.cuh                     |  4 +++-
 cpp/src/tsne/utils.cuh                         |  1 +
 cpp/test/CMakeLists.txt                        |  4 ++++
 14 files changed, 39 insertions(+), 29 deletions(-)

diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
index 9d35e3f97f..b50414b08e 100644
--- a/.devcontainer/Dockerfile
+++ b/.devcontainer/Dockerfile
@@ -7,6 +7,11 @@ FROM ${BASE} as pip-base
 
 ENV DEFAULT_VIRTUAL_ENV=rapids
 
+RUN apt update -y \
+ && DEBIAN_FRONTEND=noninteractive apt install -y \
+    libblas-dev liblapack-dev \
+ && rm -rf /tmp/* /var/tmp/* /var/cache/apt/* /var/lib/apt/lists/*;
+
 FROM ${BASE} as conda-base
 
 ENV DEFAULT_CONDA_ENV=rapids
diff --git a/.devcontainer/cuda11.8-conda/devcontainer.json b/.devcontainer/cuda11.8-conda/devcontainer.json
index ee050cc5fc..822b27f3fe 100644
--- a/.devcontainer/cuda11.8-conda/devcontainer.json
+++ b/.devcontainer/cuda11.8-conda/devcontainer.json
@@ -11,7 +11,7 @@
   "runArgs": [
     "--rm",
     "--name",
-    "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-conda"
+    "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-conda"
   ],
   "hostRequirements": {"gpu": "optional"},
   "features": {
diff --git a/.devcontainer/cuda11.8-pip/devcontainer.json b/.devcontainer/cuda11.8-pip/devcontainer.json
index e0fb0b22e0..d2bf9e6dc9 100644
--- a/.devcontainer/cuda11.8-pip/devcontainer.json
+++ b/.devcontainer/cuda11.8-pip/devcontainer.json
@@ -5,19 +5,16 @@
     "args": {
       "CUDA": "11.8",
       "PYTHON_PACKAGE_MANAGER": "pip",
-      "BASE": "rapidsai/devcontainers:24.06-cpp-cuda11.8-ubuntu22.04"
+      "BASE": "rapidsai/devcontainers:24.06-cpp-cuda11.8-ucx1.15.0-openmpi-ubuntu22.04"
     }
   },
   "runArgs": [
     "--rm",
     "--name",
-    "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-pip"
+    "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-pip"
   ],
   "hostRequirements": {"gpu": "optional"},
   "features": {
-    "ghcr.io/rapidsai/devcontainers/features/ucx:24.6": {
-      "version": "1.15.0"
-    },
     "ghcr.io/rapidsai/devcontainers/features/cuda:24.6": {
       "version": "11.8",
       "installcuBLAS": true,
@@ -28,7 +25,6 @@
     "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.6": {}
   },
   "overrideFeatureInstallOrder": [
-    "ghcr.io/rapidsai/devcontainers/features/ucx",
     "ghcr.io/rapidsai/devcontainers/features/cuda",
     "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils"
   ],
diff --git a/.devcontainer/cuda12.2-conda/devcontainer.json b/.devcontainer/cuda12.2-conda/devcontainer.json
index 17ce672f3c..9a0fa0e594 100644
--- a/.devcontainer/cuda12.2-conda/devcontainer.json
+++ b/.devcontainer/cuda12.2-conda/devcontainer.json
@@ -11,7 +11,7 @@
   "runArgs": [
     "--rm",
     "--name",
-    "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-conda"
+    "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-conda"
   ],
   "hostRequirements": {"gpu": "optional"},
   "features": {
diff --git a/.devcontainer/cuda12.2-pip/devcontainer.json b/.devcontainer/cuda12.2-pip/devcontainer.json
index 627f725a2b..4cd630f1c2 100644
--- a/.devcontainer/cuda12.2-pip/devcontainer.json
+++ b/.devcontainer/cuda12.2-pip/devcontainer.json
@@ -5,19 +5,16 @@
     "args": {
       "CUDA": "12.2",
       "PYTHON_PACKAGE_MANAGER": "pip",
-      "BASE": "rapidsai/devcontainers:24.06-cpp-cuda12.2-ubuntu22.04"
+      "BASE": "rapidsai/devcontainers:24.06-cpp-cuda12.2-ucx1.15.0-openmpi-ubuntu22.04"
     }
   },
   "runArgs": [
     "--rm",
     "--name",
-    "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-pip"
+    "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-pip"
   ],
   "hostRequirements": {"gpu": "optional"},
   "features": {
-    "ghcr.io/rapidsai/devcontainers/features/ucx:24.6": {
-      "version": "1.15.0"
-    },
     "ghcr.io/rapidsai/devcontainers/features/cuda:24.6": {
       "version": "12.2",
       "installcuBLAS": true,
@@ -28,7 +25,6 @@
     "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.6": {}
   },
   "overrideFeatureInstallOrder": [
-    "ghcr.io/rapidsai/devcontainers/features/ucx",
     "ghcr.io/rapidsai/devcontainers/features/cuda",
     "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils"
   ],
diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt
index 1eccd65ba4..4f8c312717 100644
--- a/cpp/bench/CMakeLists.txt
+++ b/cpp/bench/CMakeLists.txt
@@ -1,5 +1,5 @@
 #=============================================================================
-# Copyright (c) 2019-2023, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -61,6 +61,10 @@ if(BUILD_CUML_BENCH)
   set_target_properties(
     ${CUML_CPP_BENCH_TARGET}
     PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib"
+               CXX_STANDARD                      17
+               CXX_STANDARD_REQUIRED             ON
+               CUDA_STANDARD                     17
+               CUDA_STANDARD_REQUIRED            ON
   )
 
   install(
diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake
index b6c49bb2c2..60cc5dae15 100644
--- a/cpp/cmake/modules/ConfigureCUDA.cmake
+++ b/cpp/cmake/modules/ConfigureCUDA.cmake
@@ -1,5 +1,5 @@
 #=============================================================================
-# Copyright (c) 2018-2022, NVIDIA CORPORATION.
+# Copyright (c) 2018-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -31,8 +31,8 @@ endif()
 list(APPEND CUML_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations,-Wno-error=sign-compare)
 
 if(DISABLE_DEPRECATION_WARNINGS)
-    list(APPEND CUML_CXX_FLAGS -Wno-deprecated-declarations)
-    list(APPEND CUML_CUDA_FLAGS -Wno-deprecated-declarations -Xcompiler=-Wno-deprecated-declarations)
+    list(APPEND CUML_CXX_FLAGS -Wno-deprecated-declarations -DRAFT_HIDE_DEPRECATION_WARNINGS)
+    list(APPEND CUML_CUDA_FLAGS -Wno-deprecated-declarations -Xcompiler=-Wno-deprecated-declarations -DRAFT_HIDE_DEPRECATION_WARNINGS)
 endif()
 
 # make sure we produce smallest binary size
diff --git a/cpp/src/arima/batched_arima.cu b/cpp/src/arima/batched_arima.cu
index 187fc7923b..2b262df412 100644
--- a/cpp/src/arima/batched_arima.cu
+++ b/cpp/src/arima/batched_arima.cu
@@ -84,7 +84,7 @@ struct is_missing {
   typedef T argument_type;
   typedef T result_type;
 
-  __thrust_exec_check_disable__ __device__ const T operator()(const T& x) const { return isnan(x); }
+  __device__ const T operator()(const T& x) const { return isnan(x); }
 };  // end is_missing
 
 bool detect_missing(raft::handle_t& handle, const double* d_y, int n_elem)
diff --git a/cpp/src/hdbscan/condensed_hierarchy.cu b/cpp/src/hdbscan/condensed_hierarchy.cu
index 20f155b012..76f1a19cf8 100644
--- a/cpp/src/hdbscan/condensed_hierarchy.cu
+++ b/cpp/src/hdbscan/condensed_hierarchy.cu
@@ -26,6 +26,7 @@
 #include <rmm/exec_policy.hpp>
 
 #include <cub/cub.cuh>
+#include <cuda/functional>
 #include <thrust/copy.h>
 #include <thrust/device_ptr.h>
 #include <thrust/execution_policy.h>
@@ -156,7 +157,7 @@ void CondensedHierarchy<value_idx, value_t>::condense(value_idx* full_parents,
     thrust::cuda::par.on(stream),
     full_sizes,
     full_sizes + size,
-    [=] __device__(value_idx a) { return a != -1; },
+    cuda::proclaim_return_type<bool>([=] __device__(value_idx a) -> bool { return a != -1; }),
     0,
     thrust::plus<value_idx>());
 
diff --git a/cpp/src/hdbscan/detail/select.cuh b/cpp/src/hdbscan/detail/select.cuh
index 3bf17c437f..36e674e40b 100644
--- a/cpp/src/hdbscan/detail/select.cuh
+++ b/cpp/src/hdbscan/detail/select.cuh
@@ -216,13 +216,14 @@ void excess_of_mass(const raft::handle_t& handle,
     value_t subtree_stability = 0.0;
 
     if (indptr_h[node + 1] - indptr_h[node] > 0) {
-      subtree_stability = thrust::transform_reduce(
-        exec_policy,
-        children + indptr_h[node],
-        children + indptr_h[node + 1],
-        [=] __device__(value_idx a) { return stability[a]; },
-        0.0,
-        thrust::plus<value_t>());
+      subtree_stability =
+        thrust::transform_reduce(exec_policy,
+                                 children + indptr_h[node],
+                                 children + indptr_h[node + 1],
+                                 cuda::proclaim_return_type<value_t>(
+                                   [=] __device__(value_idx a) -> value_t { return stability[a]; }),
+                                 0.0,
+                                 thrust::plus<value_t>());
     }
 
     if (subtree_stability > node_stability || cluster_sizes_h[node] > max_cluster_size) {
diff --git a/cpp/src/hdbscan/detail/utils.h b/cpp/src/hdbscan/detail/utils.h
index 092dc2e673..b151628429 100644
--- a/cpp/src/hdbscan/detail/utils.h
+++ b/cpp/src/hdbscan/detail/utils.h
@@ -114,7 +114,7 @@ Common::CondensedHierarchy<value_idx, value_t> make_cluster_tree(
     thrust_policy,
     sizes,
     sizes + condensed_tree.get_n_edges(),
-    [=] __device__(value_idx a) { return a > 1; },
+    cuda::proclaim_return_type<bool>([=] __device__(value_idx a) -> bool { return a > 1; }),
     0,
     thrust::plus<value_idx>());
 
diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh
index d9e831cc34..a221d70820 100644
--- a/cpp/src/tsne/distances.cuh
+++ b/cpp/src/tsne/distances.cuh
@@ -33,6 +33,7 @@
 #include <rmm/device_uvector.hpp>
 #include <rmm/exec_policy.hpp>
 
+#include <cuda/functional>
 #include <thrust/functional.h>
 #include <thrust/transform_reduce.h>
 
@@ -162,7 +163,8 @@ void get_distances(const raft::handle_t& handle,
 template <typename value_t>
 void normalize_distances(value_t* distances, const size_t total_nn, cudaStream_t stream)
 {
-  auto abs_f      = [] __device__(const value_t& x) { return abs(x); };
+  auto abs_f = cuda::proclaim_return_type<value_t>(
+    [] __device__(const value_t& x) -> value_t { return abs(x); });
   value_t maxNorm = thrust::transform_reduce(rmm::exec_policy(stream),
                                              distances,
                                              distances + total_nn,
diff --git a/cpp/src/tsne/utils.cuh b/cpp/src/tsne/utils.cuh
index 5b3d008c25..895fe412d2 100644
--- a/cpp/src/tsne/utils.cuh
+++ b/cpp/src/tsne/utils.cuh
@@ -39,6 +39,7 @@
 #include <sys/time.h>
 #include <unistd.h>
 
+#include <cfloat>
 #include <chrono>
 #include <iostream>
 
diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt
index 0033c844ae..2a04100cdf 100644
--- a/cpp/test/CMakeLists.txt
+++ b/cpp/test/CMakeLists.txt
@@ -85,6 +85,10 @@ function(ConfigureTest)
   set_target_properties(
     ${_CUML_TEST_NAME}
     PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib"
+               CXX_STANDARD                      17
+               CXX_STANDARD_REQUIRED             ON
+               CUDA_STANDARD                     17
+               CUDA_STANDARD_REQUIRED            ON
   )
 
   set(_CUML_TEST_COMPONENT_NAME testing)