Skip to content

Commit

Permalink
Replacing sklearn functions with cuml in RF MNMG notebook (#3408)
Browse files Browse the repository at this point in the history
Closes #2864.

This PR will replace `datasets.make_blobs` and `metrics.accuracy_score` from sklearn to cuml.

Authors:
  - Micka (@lowener)

Approvers:
  - John Zedlewski (@JohnZed)

URL: #3408
  • Loading branch information
lowener authored Feb 5, 2021
1 parent 39c7262 commit dec134a
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions notebooks/random_forest_mnmg_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
"import cudf\n",
"import cuml\n",
"\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn import model_selection, datasets\n",
"from sklearn import model_selection\n",
"\n",
"from cuml import datasets\n",
"from cuml.metrics import accuracy_score\n",
"from cuml.dask.common import utils as dask_utils\n",
"from dask.distributed import Client, wait\n",
"from dask_cuda import LocalCUDACluster\n",
Expand Down Expand Up @@ -132,7 +133,7 @@
"\n",
"def distribute(X, y):\n",
" # First convert to cudf (with real data, you would likely load in cuDF format to start)\n",
" X_cudf = cudf.DataFrame.from_pandas(pd.DataFrame(X))\n",
" X_cudf = cudf.DataFrame(X)\n",
" y_cudf = cudf.Series(y)\n",
"\n",
" # Partition with Dask\n",
Expand Down Expand Up @@ -169,7 +170,7 @@
"\n",
"# Use all avilable CPU cores\n",
"skl_model = sklRF(max_depth=max_depth, n_estimators=n_trees, n_jobs=-1)\n",
"skl_model.fit(X_train, y_train)"
"skl_model.fit(X_train.get(), y_train.get())"
]
},
{
Expand Down Expand Up @@ -206,7 +207,7 @@
"metadata": {},
"outputs": [],
"source": [
"skl_y_pred = skl_model.predict(X_test)\n",
"skl_y_pred = skl_model.predict(X_test.get())\n",
"cuml_y_pred = cuml_model.predict(X_test_dask).compute().to_array()\n",
"\n",
"# Due to randomness in the algorithm, you may see slight variation in accuracies\n",
Expand Down

0 comments on commit dec134a

Please sign in to comment.