From 0cb75a4913facc53f22613c82be4b326eb5e7eb4 Mon Sep 17 00:00:00 2001
From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com>
Date: Mon, 11 Apr 2022 14:42:49 -0400
Subject: [PATCH] Choose cupy dlpack method based on version (#10631)
This PR adds a CuPy version check to the 10 minute CuPy notebook, using this to decide which DLPack input method to use - this should allow us to support older CuPy versions that don't yet have `from_dlpack` while also supporting newer versions which will eventually deprecate `fromDlpack`.
Closes #10612
Authors:
- Charles Blackmon-Luca (https://github.com/charlesbluca)
Approvers:
- GALI PREM SAGAR (https://github.com/galipremsagar)
---
.../source/user_guide/10min-cudf-cupy.ipynb | 247 +++++++++---------
1 file changed, 123 insertions(+), 124 deletions(-)
diff --git a/docs/cudf/source/user_guide/10min-cudf-cupy.ipynb b/docs/cudf/source/user_guide/10min-cudf-cupy.ipynb
index b34bbd3f193..1bcb9335256 100644
--- a/docs/cudf/source/user_guide/10min-cudf-cupy.ipynb
+++ b/docs/cudf/source/user_guide/10min-cudf-cupy.ipynb
@@ -16,9 +16,15 @@
"outputs": [],
"source": [
"import timeit\n",
+ "from packaging import version\n",
"\n",
"import cupy as cp\n",
- "import cudf"
+ "import cudf\n",
+ "\n",
+ "if version.parse(cp.__version__) >= version.parse(\"10.0.0\"):\n",
+ " cupy_from_dlpack = cp.from_dlpack\n",
+ "else:\n",
+ " cupy_from_dlpack = cp.fromDlpack"
]
},
{
@@ -45,9 +51,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "167 µs ± 789 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
- "497 µs ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
- "502 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
+ "183 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
+ "553 µs ± 6.25 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
+ "546 µs ± 2.25 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
@@ -58,7 +64,7 @@
" 'c':range(1000, nelem + 1000)}\n",
" )\n",
"\n",
- "%timeit arr_cupy = cp.from_dlpack(df.to_dlpack())\n",
+ "%timeit arr_cupy = cupy_from_dlpack(df.to_dlpack())\n",
"%timeit arr_cupy = df.values\n",
"%timeit arr_cupy = df.to_cupy()"
]
@@ -86,7 +92,7 @@
}
],
"source": [
- "arr_cupy = cp.from_dlpack(df.to_dlpack())\n",
+ "arr_cupy = cupy_from_dlpack(df.to_dlpack())\n",
"arr_cupy"
]
},
@@ -117,9 +123,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "75.2 µs ± 117 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
- "185 µs ± 630 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
- "169 µs ± 1.24 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
+ "76.8 µs ± 636 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
+ "198 µs ± 2.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
+ "181 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
@@ -127,7 +133,7 @@
"col = 'a'\n",
"\n",
"%timeit cola_cupy = cp.asarray(df[col])\n",
- "%timeit cola_cupy = cp.from_dlpack(df[col].to_dlpack())\n",
+ "%timeit cola_cupy = cupy_from_dlpack(df[col].to_dlpack())\n",
"%timeit cola_cupy = df[col].values"
]
},
@@ -256,7 +262,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "22 ms ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
+ "23.9 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
@@ -510,7 +516,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "8.6 ms ± 33.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+ "9.15 ms ± 131 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
@@ -530,7 +536,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "5.56 ms ± 37.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+ "5.74 ms ± 29.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
@@ -1023,7 +1029,7 @@
}
],
"source": [
- "new_arr = cp.from_dlpack(reshaped_df.to_dlpack())\n",
+ "new_arr = cupy_from_dlpack(reshaped_df.to_dlpack())\n",
"new_arr.sum(axis=1)"
]
},
@@ -1137,116 +1143,116 @@
"
\n",
" \n",
" 0 | \n",
+ " 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
- " 0.0 | \n",
- " 3.380014 | \n",
- " 0.0 | \n",
" 0.000000 | \n",
- " 11.030136 | \n",
" 0.0 | \n",
+ " 9.37476 | \n",
" 0.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
" 0.000000 | \n",
+ " 6.237859 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
" 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.00000 | \n",
- " 5.726806 | \n",
" 0.0 | \n",
" 0.0 | \n",
+ " 0.00000 | \n",
+ " 0.0 | \n",
" 0.0 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 1 | \n",
- " 0.0 | \n",
+ " 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.000000 | \n",
" 0.0 | \n",
- " 0.000000 | \n",
+ " 0.00000 | \n",
" 0.000000 | \n",
" 0.0 | \n",
+ " 0.0 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 5.917846 | \n",
- " 0.000000 | \n",
- " 5.90886 | \n",
- " 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
+ " 0.065878 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 12.35705 | \n",
+ " 0.0 | \n",
" 0.0 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 2 | \n",
- " 0.0 | \n",
+ " 3.232751 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.000000 | \n",
" 0.0 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
+ " 0.00000 | \n",
+ " 8.341915 | \n",
+ " 0.0 | \n",
" 0.0 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 6.646564 | \n",
- " 0.00000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
" 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
+ " 0.00000 | \n",
" 0.0 | \n",
- " 3.399164 | \n",
+ " 0.0 | \n",
+ " 3.110362 | \n",
"
\n",
" \n",
" 3 | \n",
- " 0.0 | \n",
+ " 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.000000 | \n",
" 0.0 | \n",
- " 0.000000 | \n",
+ " 0.00000 | \n",
" 0.000000 | \n",
" 0.0 | \n",
- " 14.092100 | \n",
- " 0.000000 | \n",
- " 0.378781 | \n",
- " 0.420953 | \n",
+ " 0.0 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
- " 0.00000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
" 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
+ " 0.00000 | \n",
+ " 0.0 | \n",
" 0.0 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 4 | \n",
+ " 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
+ " 7.743024 | \n",
" 0.0 | \n",
+ " 0.00000 | \n",
" 0.000000 | \n",
" 0.0 | \n",
- " 0.109242 | \n",
- " 2.541798 | \n",
" 0.0 | \n",
- " 0.071563 | \n",
- " 8.223387 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
+ " 5.987098 | \n",
" 0.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
" 0.000000 | \n",
- " 0.00000 | \n",
- " 10.744624 | \n",
" 0.0 | \n",
" 0.0 | \n",
+ " 0.00000 | \n",
+ " 0.0 | \n",
" 0.0 | \n",
" 0.000000 | \n",
"
\n",
@@ -1255,26 +1261,19 @@
""
],
"text/plain": [
- " a0 a1 a2 a3 a4 a5 a6 a7 a8 \\\n",
- "0 0.0 0.0 0.0 3.380014 0.0 0.000000 11.030136 0.0 0.000000 \n",
- "1 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n",
- "2 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n",
- "3 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 14.092100 \n",
- "4 0.0 0.0 0.0 0.000000 0.0 0.109242 2.541798 0.0 0.071563 \n",
- "\n",
- " a9 a10 a11 a12 a13 a14 a15 a16 \\\n",
- "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.00000 5.726806 0.0 \n",
- "1 0.000000 0.000000 0.000000 5.917846 0.000000 5.90886 0.000000 0.0 \n",
- "2 0.000000 0.000000 0.000000 0.000000 6.646564 0.00000 0.000000 0.0 \n",
- "3 0.000000 0.378781 0.420953 0.000000 0.000000 0.00000 0.000000 0.0 \n",
- "4 8.223387 0.000000 0.000000 0.000000 0.000000 0.00000 10.744624 0.0 \n",
+ " a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 \\\n",
+ "0 0.000000 0.0 0.0 0.000000 0.0 9.37476 0.000000 0.0 0.0 0.000000 \n",
+ "1 0.000000 0.0 0.0 0.000000 0.0 0.00000 0.000000 0.0 0.0 0.000000 \n",
+ "2 3.232751 0.0 0.0 0.000000 0.0 0.00000 8.341915 0.0 0.0 0.000000 \n",
+ "3 0.000000 0.0 0.0 0.000000 0.0 0.00000 0.000000 0.0 0.0 0.000000 \n",
+ "4 0.000000 0.0 0.0 7.743024 0.0 0.00000 0.000000 0.0 0.0 5.987098 \n",
"\n",
- " a17 a18 a19 \n",
- "0 0.0 0.0 0.000000 \n",
- "1 0.0 0.0 0.000000 \n",
- "2 0.0 0.0 3.399164 \n",
- "3 0.0 0.0 0.000000 \n",
- "4 0.0 0.0 0.000000 "
+ " a10 a11 a12 a13 a14 a15 a16 a17 a18 a19 \n",
+ "0 6.237859 0.0 0.0 0.000000 0.0 0.0 0.00000 0.0 0.0 0.000000 \n",
+ "1 0.000000 0.0 0.0 0.065878 0.0 0.0 12.35705 0.0 0.0 0.000000 \n",
+ "2 0.000000 0.0 0.0 0.000000 0.0 0.0 0.00000 0.0 0.0 3.110362 \n",
+ "3 0.000000 0.0 0.0 0.000000 0.0 0.0 0.00000 0.0 0.0 0.000000 \n",
+ "4 0.000000 0.0 0.0 0.000000 0.0 0.0 0.00000 0.0 0.0 0.000000 "
]
},
"execution_count": 20,
@@ -1295,57 +1294,57 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " (896, 0)\t0.7194778152522069\n",
- " (385, 0)\t5.061243119202521\n",
- " (899, 0)\t8.032932656540671\n",
- " (1028, 0)\t10.072155866140903\n",
- " (133, 0)\t13.27741318265092\n",
- " (518, 0)\t2.242099518010387\n",
- " (647, 0)\t6.487369007371155\n",
- " (776, 0)\t5.621989952370181\n",
- " (9, 0)\t8.833796529523534\n",
- " (521, 0)\t7.719749292928572\n",
- " (777, 0)\t7.4610987015782975\n",
- " (394, 0)\t10.09026095476732\n",
- " (140, 0)\t2.974228870142501\n",
- " (653, 0)\t4.520704347545524\n",
- " (1037, 0)\t4.53896886415556\n",
- " (400, 0)\t4.0198547103826705\n",
- " (401, 0)\t-0.2557920447399875\n",
- " (1041, 0)\t1.8627471984893114\n",
- " (146, 0)\t9.834516073722536\n",
- " (1042, 0)\t7.850006814937681\n",
- " (275, 0)\t1.5747512513374389\n",
- " (662, 0)\t6.717038670488377\n",
- " (25, 0)\t7.311464380885098\n",
- " (281, 0)\t3.5147599499072024\n",
- " (409, 0)\t1.121874214291239\n",
+ " (2, 0)\t3.2327506467190874\n",
+ " (259, 0)\t10.723428115951062\n",
+ " (643, 0)\t0.47763624588488707\n",
+ " (899, 0)\t8.857065309921685\n",
+ " (516, 0)\t8.792407143276648\n",
+ " (262, 0)\t2.1900894573805396\n",
+ " (390, 0)\t5.007630701229646\n",
+ " (646, 0)\t6.630703075588639\n",
+ " (392, 0)\t5.573713453854357\n",
+ " (776, 0)\t10.501281989515688\n",
+ " (904, 0)\t8.261890175181366\n",
+ " (1033, 0)\t-0.41106824704220446\n",
+ " (522, 0)\t12.619952511457068\n",
+ " (139, 0)\t12.753348070606792\n",
+ " (141, 0)\t4.936902335394504\n",
+ " (270, 0)\t-1.7695949916946174\n",
+ " (782, 0)\t4.378746787324408\n",
+ " (15, 0)\t8.554141682891935\n",
+ " (527, 0)\t5.1994882136423\n",
+ " (912, 0)\t2.6101212854793125\n",
+ " (401, 0)\t5.614628764689268\n",
+ " (403, 0)\t9.999468341523317\n",
+ " (787, 0)\t7.6170790481600985\n",
+ " (404, 0)\t5.105328903336744\n",
+ " (916, 0)\t1.395526391114967\n",
" :\t:\n",
- " (8290, 19)\t19.23532976720017\n",
- " (8679, 19)\t3.9092712623274224\n",
- " (8935, 19)\t0.8411008847310036\n",
- " (9063, 19)\t12.010953214709328\n",
- " (9319, 19)\t3.470064419440258\n",
- " (8683, 19)\t14.397876149427695\n",
- " (8300, 19)\t10.524275022546979\n",
- " (8301, 19)\t0.6266917401191829\n",
- " (8557, 19)\t-0.4554588974911311\n",
- " (9197, 19)\t12.379896820812874\n",
- " (8304, 19)\t1.3276250981825033\n",
- " (8563, 19)\t-1.579631321204169\n",
- " (8442, 19)\t6.881252269650868\n",
- " (8315, 19)\t0.5811637925849389\n",
- " (8575, 19)\t15.52855242553137\n",
- " (9343, 19)\t-0.12679091919544638\n",
- " (9569, 19)\t9.316119794827424\n",
- " (9570, 19)\t10.791371431930969\n",
- " (9443, 19)\t4.7035396189880645\n",
- " (9452, 19)\t-0.9924476181662789\n",
- " (9713, 19)\t-3.2038209275781346\n",
- " (9719, 19)\t0.6578276176100656\n",
- " (9847, 19)\t9.57555910183088\n",
- " (9724, 19)\t0.990362915454171\n",
- " (9855, 19)\t1.153449284622368\n"
+ " (9328, 19)\t5.938629381103238\n",
+ " (9457, 19)\t4.463547879031807\n",
+ " (9458, 19)\t-0.8034946631917106\n",
+ " (8051, 19)\t-1.904327616912268\n",
+ " (8819, 19)\t8.314944347687199\n",
+ " (7543, 19)\t1.4303204025224376\n",
+ " (8824, 19)\t5.1559713157589\n",
+ " (7673, 19)\t7.478681299798863\n",
+ " (7802, 19)\t0.502526238006068\n",
+ " (8186, 19)\t-3.824944685072472\n",
+ " (8570, 19)\t8.442324394481236\n",
+ " (8571, 19)\t6.204199957873215\n",
+ " (7420, 19)\t0.297737356585836\n",
+ " (9212, 19)\t3.934797966994188\n",
+ " (7421, 19)\t14.26161925450462\n",
+ " (8574, 19)\t5.826108027573207\n",
+ " (9214, 19)\t7.209975861932724\n",
+ " (9825, 19)\t11.155342644729613\n",
+ " (9702, 19)\t3.55144040779287\n",
+ " (9578, 19)\t12.638681362546228\n",
+ " (9712, 19)\t2.3542852760656348\n",
+ " (9969, 19)\t-2.645175092587592\n",
+ " (9973, 19)\t-2.2666402312025213\n",
+ " (9851, 19)\t-4.293381721466055\n",
+ " (9596, 19)\t6.6580506888430415\n"
]
}
],