From dec134aed25ae182098687904194a7259f96d31c Mon Sep 17 00:00:00 2001 From: Micka <9810050+lowener@users.noreply.github.com> Date: Fri, 5 Feb 2021 05:08:56 +0100 Subject: [PATCH] Replacing sklearn functions with cuml in RF MNMG notebook (#3408) Closes #2864. This PR will replace `datasets.make_blobs` and `metrics.accuracy_score` from sklearn to cuml. Authors: - Micka (@lowener) Approvers: - John Zedlewski (@JohnZed) URL: https://github.com/rapidsai/cuml/pull/3408 --- notebooks/random_forest_mnmg_demo.ipynb | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/notebooks/random_forest_mnmg_demo.ipynb b/notebooks/random_forest_mnmg_demo.ipynb index 0c5be5ec03..6797707b95 100755 --- a/notebooks/random_forest_mnmg_demo.ipynb +++ b/notebooks/random_forest_mnmg_demo.ipynb @@ -30,9 +30,10 @@ "import cudf\n", "import cuml\n", "\n", - "from sklearn.metrics import accuracy_score\n", - "from sklearn import model_selection, datasets\n", + "from sklearn import model_selection\n", "\n", + "from cuml import datasets\n", + "from cuml.metrics import accuracy_score\n", "from cuml.dask.common import utils as dask_utils\n", "from dask.distributed import Client, wait\n", "from dask_cuda import LocalCUDACluster\n", @@ -132,7 +133,7 @@ "\n", "def distribute(X, y):\n", " # First convert to cudf (with real data, you would likely load in cuDF format to start)\n", - " X_cudf = cudf.DataFrame.from_pandas(pd.DataFrame(X))\n", + " X_cudf = cudf.DataFrame(X)\n", " y_cudf = cudf.Series(y)\n", "\n", " # Partition with Dask\n", @@ -169,7 +170,7 @@ "\n", "# Use all avilable CPU cores\n", "skl_model = sklRF(max_depth=max_depth, n_estimators=n_trees, n_jobs=-1)\n", - "skl_model.fit(X_train, y_train)" + "skl_model.fit(X_train.get(), y_train.get())" ] }, { @@ -206,7 +207,7 @@ "metadata": {}, "outputs": [], "source": [ - "skl_y_pred = skl_model.predict(X_test)\n", + "skl_y_pred = skl_model.predict(X_test.get())\n", "cuml_y_pred = cuml_model.predict(X_test_dask).compute().to_array()\n", "\n", "# Due to randomness in the algorithm, you may see slight variation in accuracies\n",