diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index a1311d489eb..1a5073b2f24 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -243,6 +243,17 @@ cd "$WORKSPACE/python/custreamz" gpuci_logger "Python py.test for cuStreamz" py.test -n 8 --cache-clear --basetemp="$WORKSPACE/custreamz-cuda-tmp" --junitxml="$WORKSPACE/junit-custreamz.xml" -v --cov-config=.coveragerc --cov=custreamz --cov-report=xml:"$WORKSPACE/python/custreamz/custreamz-coverage.xml" --cov-report term custreamz +# Run benchmarks with both cudf and pandas to ensure compatibility is maintained. +# Benchmarks are run in DEBUG_ONLY mode, meaning that only small data sizes are used. +# Therefore, these runs only verify that benchmarks are valid. +# They do not generate meaningful performance measurements. +cd "$WORKSPACE/python/cudf" +gpuci_logger "Python pytest for cuDF benchmarks" +CUDF_BENCHMARKS_DEBUG_ONLY=ON pytest -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" -v --dist=loadscope benchmarks + +gpuci_logger "Python pytest for cuDF benchmarks using pandas" +CUDF_BENCHMARKS_USE_PANDAS=ON CUDF_BENCHMARKS_DEBUG_ONLY=ON pytest -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" -v --dist=loadscope benchmarks + gpuci_logger "Test notebooks" "$WORKSPACE/ci/gpu/test-notebooks.sh" 2>&1 | tee nbtest.log python "$WORKSPACE/ci/utils/nbtestlog2junitxml.py" nbtest.log diff --git a/conda/environments/cudf_dev_cuda11.5.yml b/conda/environments/cudf_dev_cuda11.5.yml index 7d04a7e5758..87dd36776f7 100644 --- a/conda/environments/cudf_dev_cuda11.5.yml +++ b/conda/environments/cudf_dev_cuda11.5.yml @@ -29,6 +29,7 @@ dependencies: - fsspec>=0.6.0 - pytest - pytest-benchmark + - pytest-cases - pytest-xdist - sphinx - sphinxcontrib-websupport diff --git a/python/cudf/benchmarks/API/bench_dataframe.py b/python/cudf/benchmarks/API/bench_dataframe.py new file mode 100644 index 00000000000..d5e175ff85c --- /dev/null +++ b/python/cudf/benchmarks/API/bench_dataframe.py @@ -0,0 +1,117 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of DataFrame methods.""" + +import string + +import numpy +import pytest +from config import cudf, cupy +from utils import benchmark_with_object + + +@pytest.mark.parametrize("N", [100, 1_000_000]) +def bench_construction(benchmark, N): + benchmark(cudf.DataFrame, {None: cupy.random.rand(N)}) + + +@benchmark_with_object(cls="dataframe", dtype="float", cols=6) +@pytest.mark.parametrize( + "expr", ["a+b", "a+b+c+d+e", "a / (sin(a) + cos(b)) * tanh(d*e*f)"] +) +def bench_eval_func(benchmark, expr, dataframe): + benchmark(dataframe.eval, expr) + + +@benchmark_with_object(cls="dataframe", dtype="int", nulls=False, cols=6) +@pytest.mark.parametrize( + "num_key_cols", + [2, 3, 4], +) +def bench_merge(benchmark, dataframe, num_key_cols): + benchmark( + dataframe.merge, dataframe, on=list(dataframe.columns[:num_key_cols]) + ) + + +# TODO: Some of these cases could be generalized to an IndexedFrame benchmark +# instead of a DataFrame benchmark. +@benchmark_with_object(cls="dataframe", dtype="int") +@pytest.mark.parametrize( + "values", + [ + range(1000), + {f"key{i}": range(1000) for i in range(10)}, + cudf.DataFrame({f"key{i}": range(1000) for i in range(10)}), + cudf.Series(range(1000)), + ], +) +def bench_isin(benchmark, dataframe, values): + benchmark(dataframe.isin, values) + + +@pytest.fixture( + params=[0, numpy.random.RandomState, cupy.random.RandomState], + ids=["Seed", "NumpyRandomState", "CupyRandomState"], +) +def random_state(request): + rs = request.param + return rs if isinstance(rs, int) else rs(seed=42) + + +@benchmark_with_object(cls="dataframe", dtype="int") +@pytest.mark.parametrize("frac", [0.5]) +def bench_sample(benchmark, dataframe, axis, frac, random_state): + if axis == 1 and isinstance(random_state, cupy.random.RandomState): + pytest.skip("Unsupported params.") + benchmark( + dataframe.sample, frac=frac, axis=axis, random_state=random_state + ) + + +@benchmark_with_object(cls="dataframe", dtype="int", nulls=False, cols=6) +@pytest.mark.parametrize( + "num_key_cols", + [2, 3, 4], +) +def bench_groupby(benchmark, dataframe, num_key_cols): + benchmark(dataframe.groupby, by=list(dataframe.columns[:num_key_cols])) + + +@benchmark_with_object(cls="dataframe", dtype="int", nulls=False, cols=6) +@pytest.mark.parametrize( + "agg", + [ + "sum", + ["sum", "mean"], + { + f"{string.ascii_lowercase[i]}": ["sum", "mean", "count"] + for i in range(6) + }, + ], +) +@pytest.mark.parametrize( + "num_key_cols", + [2, 3, 4], +) +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize("sort", [True, False]) +def bench_groupby_agg(benchmark, dataframe, agg, num_key_cols, as_index, sort): + by = list(dataframe.columns[:num_key_cols]) + benchmark(dataframe.groupby(by=by, as_index=as_index, sort=sort).agg, agg) + + +@benchmark_with_object(cls="dataframe", dtype="int") +@pytest.mark.parametrize("num_cols_to_sort", [1]) +def bench_sort_values(benchmark, dataframe, num_cols_to_sort): + benchmark( + dataframe.sort_values, list(dataframe.columns[:num_cols_to_sort]) + ) + + +@benchmark_with_object(cls="dataframe", dtype="int") +@pytest.mark.parametrize("num_cols_to_sort", [1]) +@pytest.mark.parametrize("n", [10]) +def bench_nsmallest(benchmark, dataframe, num_cols_to_sort, n): + by = list(dataframe.columns[:num_cols_to_sort]) + benchmark(dataframe.nsmallest, n, by) diff --git a/python/cudf/benchmarks/API/bench_frame_or_index.py b/python/cudf/benchmarks/API/bench_frame_or_index.py new file mode 100644 index 00000000000..14b29e8ef75 --- /dev/null +++ b/python/cudf/benchmarks/API/bench_frame_or_index.py @@ -0,0 +1,88 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of methods that exist for both Frame and BaseIndex.""" + +import operator + +import numpy as np +import pytest +from utils import benchmark_with_object, make_gather_map + + +@benchmark_with_object(cls="frame_or_index", dtype="int") +@pytest.mark.parametrize("gather_how", ["sequence", "reverse", "random"]) +@pytest.mark.parametrize("fraction", [0.4]) +def bench_take(benchmark, gather_how, fraction, frame_or_index): + nr = len(frame_or_index) + gather_map = make_gather_map(nr * fraction, nr, gather_how) + benchmark(frame_or_index.take, gather_map) + + +@pytest.mark.pandas_incompatible # Series/Index work, but not DataFrame +@benchmark_with_object(cls="frame_or_index", dtype="int") +def bench_argsort(benchmark, frame_or_index): + benchmark(frame_or_index.argsort) + + +@benchmark_with_object(cls="frame_or_index", dtype="int") +def bench_min(benchmark, frame_or_index): + benchmark(frame_or_index.min) + + +@benchmark_with_object(cls="frame_or_index", dtype="int") +def bench_where(benchmark, frame_or_index): + cond = frame_or_index % 2 == 0 + benchmark(frame_or_index.where, cond, 0) + + +@benchmark_with_object(cls="frame_or_index", dtype="int", nulls=False) +@pytest.mark.pandas_incompatible +def bench_values_host(benchmark, frame_or_index): + benchmark(lambda: frame_or_index.values_host) + + +@benchmark_with_object(cls="frame_or_index", dtype="int", nulls=False) +def bench_values(benchmark, frame_or_index): + benchmark(lambda: frame_or_index.values) + + +@benchmark_with_object(cls="frame_or_index", dtype="int") +def bench_nunique(benchmark, frame_or_index): + benchmark(frame_or_index.nunique) + + +@benchmark_with_object(cls="frame_or_index", dtype="int", nulls=False) +def bench_to_numpy(benchmark, frame_or_index): + benchmark(frame_or_index.to_numpy) + + +@benchmark_with_object(cls="frame_or_index", dtype="int", nulls=False) +@pytest.mark.pandas_incompatible +def bench_to_cupy(benchmark, frame_or_index): + benchmark(frame_or_index.to_cupy) + + +@benchmark_with_object(cls="frame_or_index", dtype="int") +@pytest.mark.pandas_incompatible +def bench_to_arrow(benchmark, frame_or_index): + benchmark(frame_or_index.to_arrow) + + +@benchmark_with_object(cls="frame_or_index", dtype="int") +def bench_astype(benchmark, frame_or_index): + benchmark(frame_or_index.astype, float) + + +@pytest.mark.parametrize("ufunc", [np.add, np.logical_and]) +@benchmark_with_object(cls="frame_or_index", dtype="int") +def bench_ufunc_series_binary(benchmark, frame_or_index, ufunc): + benchmark(ufunc, frame_or_index, frame_or_index) + + +@pytest.mark.parametrize( + "op", + [operator.add, operator.mul, operator.eq], +) +@benchmark_with_object(cls="frame_or_index", dtype="int") +def bench_binops(benchmark, op, frame_or_index): + benchmark(lambda: op(frame_or_index, frame_or_index)) diff --git a/python/cudf/benchmarks/API/bench_functions.py b/python/cudf/benchmarks/API/bench_functions.py new file mode 100644 index 00000000000..a166317a46b --- /dev/null +++ b/python/cudf/benchmarks/API/bench_functions.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of free functions that accept cudf objects.""" + +import pytest +import pytest_cases +from config import cudf, cupy + + +@pytest_cases.parametrize_with_cases("objs", prefix="concat") +@pytest.mark.parametrize( + "axis", + [ + 1, + ], +) +@pytest.mark.parametrize("join", ["inner", "outer"]) +@pytest.mark.parametrize("ignore_index", [True, False]) +def bench_concat_axis_1(benchmark, objs, axis, join, ignore_index): + benchmark( + cudf.concat, objs=objs, axis=axis, join=join, ignore_index=ignore_index + ) + + +@pytest.mark.parametrize("size", [10_000, 100_000]) +@pytest.mark.parametrize("cardinality", [10, 100, 1000]) +@pytest.mark.parametrize("dtype", [cupy.bool_, cupy.float64]) +def bench_get_dummies_high_cardinality(benchmark, size, cardinality, dtype): + """Benchmark when the cardinality of column to encode is high.""" + df = cudf.DataFrame( + { + "col": cudf.Series( + cupy.random.randint(low=0, high=cardinality, size=size) + ).astype("category") + } + ) + benchmark(cudf.get_dummies, df, columns=["col"], dtype=dtype) + + +@pytest.mark.parametrize("prefix", [None, "pre"]) +def bench_get_dummies_simple(benchmark, prefix): + """Benchmark with small input to test the efficiency of the API itself.""" + df = cudf.DataFrame( + { + "col1": list(range(10)), + "col2": list("abcdefghij"), + "col3": cudf.Series(list(range(100, 110)), dtype="category"), + } + ) + benchmark( + cudf.get_dummies, df, columns=["col1", "col2", "col3"], prefix=prefix + ) diff --git a/python/cudf/benchmarks/API/bench_functions_cases.py b/python/cudf/benchmarks/API/bench_functions_cases.py new file mode 100644 index 00000000000..c81f8f20f80 --- /dev/null +++ b/python/cudf/benchmarks/API/bench_functions_cases.py @@ -0,0 +1,148 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Test cases for benchmarks in bench_functions.py.""" + +import pytest_cases +from config import NUM_ROWS, cudf, cupy + + +@pytest_cases.parametrize("nr", NUM_ROWS) +def concat_case_default_index(nr): + return [ + cudf.DataFrame({"a": cupy.tile([1, 2, 3], nr)}), + cudf.DataFrame({"b": cupy.tile([4, 5, 7], nr)}), + ] + + +@pytest_cases.parametrize("nr", NUM_ROWS) +def concat_case_contiguous_indexes(nr): + return [ + cudf.DataFrame({"a": cupy.tile([1, 2, 3], nr)}), + cudf.DataFrame( + {"b": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3), + ), + ] + + +@pytest_cases.parametrize("nr", NUM_ROWS) +def concat_case_contiguous_indexes_different_cols(nr): + return [ + cudf.DataFrame( + {"a": cupy.tile([1, 2, 3], nr), "b": cupy.tile([4, 5, 7], nr)} + ), + cudf.DataFrame( + {"c": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3), + ), + ] + + +@pytest_cases.parametrize("nr", NUM_ROWS) +def concat_case_string_index(nr): + return [ + cudf.DataFrame( + {"a": cupy.tile([1, 2, 3], nr), "b": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=0, stop=nr * 3).astype("str"), + ), + cudf.DataFrame( + {"c": [4, 5, 7] * nr}, + index=cudf.RangeIndex(start=0, stop=nr * 3).astype("str"), + ), + ] + + +@pytest_cases.parametrize("nr", NUM_ROWS) +def concat_case_contiguous_string_index_different_col(nr): + return [ + cudf.DataFrame( + {"a": cupy.tile([1, 2, 3], nr), "b": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=0, stop=nr * 3).astype("str"), + ), + cudf.DataFrame( + {"c": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3).astype("str"), + ), + ] + + +@pytest_cases.parametrize("nr", NUM_ROWS) +def concat_case_complex_string_index(nr): + return [ + cudf.DataFrame( + {"a": cupy.tile([1, 2, 3], nr), "b": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=0, stop=nr * 3).astype("str"), + ), + cudf.DataFrame( + {"c": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3).astype("str"), + ), + cudf.DataFrame( + {"d": cupy.tile([1, 2, 3], nr), "e": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=0, stop=nr * 3).astype("str"), + ), + cudf.DataFrame( + {"f": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3).astype("str"), + ), + cudf.DataFrame( + {"g": cupy.tile([1, 2, 3], nr), "h": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=0, stop=nr * 3).astype("str"), + ), + cudf.DataFrame( + {"i": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3).astype("str"), + ), + ] + + +@pytest_cases.parametrize("nr", NUM_ROWS) +def concat_case_unique_columns(nr): + # To avoid any edge case bugs, always use at least 10 rows per DataFrame. + nr_actual = max(10, nr // 20) + return [ + cudf.DataFrame({"a": cupy.tile([1, 2, 3], nr_actual)}), + cudf.DataFrame({"b": cupy.tile([4, 5, 7], nr_actual)}), + cudf.DataFrame({"c": cupy.tile([1, 2, 3], nr_actual)}), + cudf.DataFrame({"d": cupy.tile([4, 5, 7], nr_actual)}), + cudf.DataFrame({"e": cupy.tile([1, 2, 3], nr_actual)}), + cudf.DataFrame({"f": cupy.tile([4, 5, 7], nr_actual)}), + cudf.DataFrame({"g": cupy.tile([1, 2, 3], nr_actual)}), + cudf.DataFrame({"h": cupy.tile([4, 5, 7], nr_actual)}), + cudf.DataFrame({"i": cupy.tile([1, 2, 3], nr_actual)}), + cudf.DataFrame({"j": cupy.tile([4, 5, 7], nr_actual)}), + ] + + +@pytest_cases.parametrize("nr", NUM_ROWS) +def concat_case_unique_columns_with_different_range_index(nr): + return [ + cudf.DataFrame( + {"a": cupy.tile([1, 2, 3], nr), "b": cupy.tile([4, 5, 7], nr)} + ), + cudf.DataFrame( + {"c": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3), + ), + cudf.DataFrame( + {"d": cupy.tile([1, 2, 3], nr), "e": cupy.tile([4, 5, 7], nr)} + ), + cudf.DataFrame( + {"f": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3), + ), + cudf.DataFrame( + {"g": cupy.tile([1, 2, 3], nr), "h": cupy.tile([4, 5, 7], nr)} + ), + cudf.DataFrame( + {"i": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3), + ), + cudf.DataFrame( + {"j": cupy.tile([1, 2, 3], nr), "k": cupy.tile([4, 5, 7], nr)} + ), + cudf.DataFrame( + {"l": cupy.tile([4, 5, 7], nr)}, + index=cudf.RangeIndex(start=nr * 3, stop=nr * 2 * 3), + ), + ] diff --git a/python/cudf/benchmarks/API/bench_index.py b/python/cudf/benchmarks/API/bench_index.py new file mode 100644 index 00000000000..53e617141b6 --- /dev/null +++ b/python/cudf/benchmarks/API/bench_index.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of Index methods.""" + +import pytest +from config import cudf, cupy +from utils import benchmark_with_object + + +@pytest.mark.parametrize("N", [100, 1_000_000]) +def bench_construction(benchmark, N): + benchmark(cudf.Index, cupy.random.rand(N)) + + +@benchmark_with_object(cls="index", dtype="int", nulls=False) +def bench_sort_values(benchmark, index): + benchmark(index.sort_values) diff --git a/python/cudf/benchmarks/API/bench_indexed_frame.py b/python/cudf/benchmarks/API/bench_indexed_frame.py new file mode 100644 index 00000000000..6969121b0da --- /dev/null +++ b/python/cudf/benchmarks/API/bench_indexed_frame.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of IndexedFrame methods.""" + +import pytest +from utils import benchmark_with_object + + +@benchmark_with_object(cls="indexedframe", dtype="int") +@pytest.mark.parametrize("op", ["cumsum", "cumprod", "cummax"]) +def bench_scans(benchmark, op, indexedframe): + benchmark(getattr(indexedframe, op)) + + +@benchmark_with_object(cls="indexedframe", dtype="int") +@pytest.mark.parametrize("op", ["sum", "product", "mean"]) +def bench_reductions(benchmark, op, indexedframe): + benchmark(getattr(indexedframe, op)) + + +@benchmark_with_object(cls="indexedframe", dtype="int") +def bench_drop_duplicates(benchmark, indexedframe): + benchmark(indexedframe.drop_duplicates) + + +@benchmark_with_object(cls="indexedframe", dtype="int") +def bench_rangeindex_replace(benchmark, indexedframe): + # TODO: Consider adding more DataFrame-specific benchmarks for different + # types of valid inputs (dicts, etc). + benchmark(indexedframe.replace, 0, 2) diff --git a/python/cudf/benchmarks/API/bench_multiindex.py b/python/cudf/benchmarks/API/bench_multiindex.py new file mode 100644 index 00000000000..6268bcc4267 --- /dev/null +++ b/python/cudf/benchmarks/API/bench_multiindex.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of MultiIndex methods.""" + +import numpy as np +import pandas as pd +import pytest +from config import cudf + + +@pytest.fixture +def pidx(): + num_elements = int(1e3) + a = np.random.randint(0, num_elements // 10, num_elements) + b = np.random.randint(0, num_elements // 10, num_elements) + return pd.MultiIndex.from_arrays([a, b], names=("a", "b")) + + +@pytest.fixture +def midx(pidx): + num_elements = int(1e3) + a = np.random.randint(0, num_elements // 10, num_elements) + b = np.random.randint(0, num_elements // 10, num_elements) + df = cudf.DataFrame({"a": a, "b": b}) + return cudf.MultiIndex.from_frame(df) + + +@pytest.mark.pandas_incompatible +def bench_from_pandas(benchmark, pidx): + benchmark(cudf.MultiIndex.from_pandas, pidx) + + +def bench_constructor(benchmark, midx): + benchmark( + cudf.MultiIndex, codes=midx.codes, levels=midx.levels, names=midx.names + ) + + +def bench_from_frame(benchmark, midx): + benchmark(cudf.MultiIndex.from_frame, midx.to_frame(index=False)) + + +def bench_copy(benchmark, midx): + benchmark(midx.copy, deep=False) diff --git a/python/cudf/benchmarks/API/bench_series.py b/python/cudf/benchmarks/API/bench_series.py new file mode 100644 index 00000000000..92032da4a2e --- /dev/null +++ b/python/cudf/benchmarks/API/bench_series.py @@ -0,0 +1,23 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of Series methods.""" + +import pytest +from config import cudf, cupy +from utils import benchmark_with_object + + +@pytest.mark.parametrize("N", [100, 1_000_000]) +def bench_construction(benchmark, N): + benchmark(cudf.Series, cupy.random.rand(N)) + + +@benchmark_with_object(cls="series", dtype="int") +def bench_sort_values(benchmark, series): + benchmark(series.sort_values) + + +@benchmark_with_object(cls="series", dtype="int") +@pytest.mark.parametrize("n", [10]) +def bench_series_nsmallest(benchmark, series, n): + benchmark(series.nsmallest, n) diff --git a/python/cudf/benchmarks/common/config.py b/python/cudf/benchmarks/common/config.py new file mode 100644 index 00000000000..305a21d0a29 --- /dev/null +++ b/python/cudf/benchmarks/common/config.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Module used for global configuration of benchmarks. + +This file contains global definitions that are important for configuring all +benchmarks such as fixture sizes. In addition, this file supports the following +features: + - Defining the CUDF_BENCHMARKS_USE_PANDAS environment variable will change + all benchmarks to run with pandas instead of cudf (and numpy instead of + cupy). This feature enables easy comparisons of benchmarks between cudf + and pandas. All common modules (cudf, cupy) should be imported from here + by benchmark modules to allow configuration if needed. + - Defining CUDF_BENCHMARKS_DEBUG_ONLY will set global configuration + variables to avoid running large benchmarks, instead using minimal values + to simply ensure that benchmarks are functional. + +This file is also where standard pytest hooks should be overridden. While these +definitions typically belong in conftest.py, since any of the above environment +variables could affect test collection or other properties, we must define them +in this file and import them in conftest.py to ensure that they are handled +appropriately. +""" +import os +import sys + +# Environment variable-based configuration of benchmarking pandas or cudf. +collect_ignore = [] +if "CUDF_BENCHMARKS_USE_PANDAS" in os.environ: + import numpy as cupy + import pandas as cudf + + # cudf internals offer no pandas compatibility guarantees, and we also + # never need to compare those benchmarks to pandas. + collect_ignore.append("internal/") + + # Also filter out benchmarks of APIs that are not compatible with pandas. + def is_pandas_compatible(item): + return all(m.name != "pandas_incompatible" for m in item.own_markers) + + def pytest_collection_modifyitems(session, config, items): + items[:] = list(filter(is_pandas_compatible, items)) + +else: + import cupy # noqa: W0611, F401 + + import cudf # noqa: W0611, F401 + + def pytest_collection_modifyitems(session, config, items): + pass + + +def pytest_sessionstart(session): + """Add the common files to the path for all tests to import.""" + sys.path.insert(0, os.path.join(os.getcwd(), "common")) + + +def pytest_sessionfinish(session, exitstatus): + """Clean up sys.path after exit.""" + if "common" in sys.path[0]: + del sys.path[0] + + +# Constants used to define benchmarking standards. +if "CUDF_BENCHMARKS_DEBUG_ONLY" in os.environ: + NUM_ROWS = [10, 20] + NUM_COLS = [1, 6] +else: + NUM_ROWS = [100, 10_000, 1_000_000] + NUM_COLS = [1, 6] diff --git a/python/cudf/benchmarks/common/utils.py b/python/cudf/benchmarks/common/utils.py new file mode 100644 index 00000000000..363316f0930 --- /dev/null +++ b/python/cudf/benchmarks/common/utils.py @@ -0,0 +1,257 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Common utilities for fixture creation and benchmarking.""" + +import inspect +import re +import textwrap +from collections.abc import MutableSet +from itertools import groupby +from numbers import Real + +import pytest_cases +from config import NUM_COLS, NUM_ROWS, cudf, cupy + + +def make_gather_map(len_gather_map: Real, len_column: Real, how: str): + """Create a gather map based on "how" you'd like to gather from input. + - sequence: gather the first `len_gather_map` rows, the first thread + collects the first element + - reverse: gather the last `len_gather_map` rows, the first thread + collects the last element + - random: create a pseudorandom gather map + + `len_gather_map`, `len_column` gets rounded to integer. + """ + len_gather_map = round(len_gather_map) + len_column = round(len_column) + + rstate = cupy.random.RandomState(seed=0) + if how == "sequence": + return cudf.Series(cupy.arange(0, len_gather_map)) + elif how == "reverse": + return cudf.Series( + cupy.arange(len_column - 1, len_column - len_gather_map - 1, -1) + ) + elif how == "random": + return cudf.Series(rstate.randint(0, len_column, len_gather_map)) + + +def make_boolean_mask_column(size): + rstate = cupy.random.RandomState(seed=0) + return cudf.core.column.as_column(rstate.randint(0, 2, size).astype(bool)) + + +def benchmark_with_object( + cls, *, dtype="int", nulls=None, cols=None, rows=None +): + """Pass "standard" cudf fixtures to functions without renaming parameters. + + The fixture generation logic in conftest.py provides a plethora of useful + fixtures to allow developers to easily select an appropriate cross-section + of the space of objects to apply a particular benchmark to. However, the + usage of these fixtures is cumbersome because creating them in a principled + fashion results in long names and very specific naming schemes. This + decorator abstracts that naming logic away from the developer, allowing + them to instead focus on defining the fixture semantically by describing + its properties. + + Parameters + ---------- + cls : Union[str, Type] + The class of object to test. May either be specified as the type + itself, or using the name (as a string). If a string, the case is + irrelevant as the string will be converted to all lowercase. + dtype : Union[str, Iterable[str]], default 'int' + The dtype or set of dtypes to use. + nulls : Optional[bool], default None + Whether to test nullable or non-nullable data. If None, both nullable + and non-nullable data are included. + cols : Optional[int], None + The number of columns. Only valid if cls == 'dataframe'. If None, use + all possible numbers of columns. Specifying multiple values is + unsupported. + rows : Optional[int], None + The number of rows. If None, use all possible numbers of rows. + Specifying multiple values is unsupported. + + Raises + ------ + AssertionError + If any of the parameters do not correspond to extant fixtures. + + Examples + -------- + # Note: As an internal function, this example is not meant for doctesting. + + @benchmark_with_object("dataframe", dtype="int", nulls=False) + def bench_columns(benchmark, df): + benchmark(df.columns) + """ + if inspect.isclass(cls): + cls = cls.__name__ + cls = cls.lower() + + supported_classes = ( + "column", + "series", + "index", + "dataframe", + "indexedframe", + "frame_or_index", + ) + assert cls in supported_classes, ( + f"cls {cls} is invalid, choose from " f"{', '.join(supported_classes)}" + ) + + if not isinstance(dtype, list): + dtype = [dtype] + assert all(dt in column_generators for dt in dtype), ( + f"The only supported dtypes are " f"{', '.join(column_generators)}" + ) + + dtype_str = "_dtype_" + "_or_".join(dtype) + + null_str = "" + if nulls is not None: + null_str = f"_nulls_{nulls}".lower() + + col_str = "" + if cols is not None: + assert cols in NUM_COLS, ( + f"You have requested a DataFrame with {cols} columns but fixtures " + f"only exist for the values {', '.join(NUM_COLS)}" + ) + col_str = f"_cols_{cols}" + + row_str = "" + if rows is not None: + assert rows in NUM_ROWS, ( + f"You have requested a {cls} with {rows} rows but fixtures " + f"only exist for the values {', '.join(NUM_ROWS)}" + ) + row_str = f"_rows_{rows}" + + fixture_name = f"{cls}{dtype_str}{null_str}{col_str}{row_str}" + + def deco(bm): + # pytest's test collection process relies on parsing the globals dict + # to find test functions and identify their parameters for the purpose + # of fixtures and parameters. Therefore, the primary purpose of this + # decorator is to define a new benchmark function with a signature + # identical to that of the decorated benchmark except with the user's + # fixture name replaced by the true fixture name based on the arguments + # to benchmark_with_object. + parameters = inspect.signature(bm).parameters + + # Note: This logic assumes that any benchmark using this fixture has at + # least two parameters since they must be using both the + # pytest-benchmark `benchmark` fixture and the cudf object. + params_str = ", ".join(f"{p}" for p in parameters if p != cls) + arg_str = ", ".join(f"{p}={p}" for p in parameters if p != cls) + + if params_str: + params_str += ", " + if arg_str: + arg_str += ", " + + params_str += f"{fixture_name}" + arg_str += f"{cls}={fixture_name}" + + src = textwrap.dedent( + f""" + import makefun + @makefun.wraps( + bm, + remove_args=("{cls}",), + prepend_args=("{fixture_name}",) + ) + def wrapped_bm({params_str}): + return bm({arg_str}) + """ + ) + globals_ = {"bm": bm} + exec(src, globals_) + wrapped_bm = globals_["wrapped_bm"] + # In case marks were applied to the original benchmark, copy them over. + if marks := getattr(bm, "pytestmark", None): + wrapped_bm.pytestmark = marks + wrapped_bm.place_as = bm + return wrapped_bm + + return deco + + +class OrderedSet(MutableSet): + """A minimal OrderedSet implementation built on a dict. + + This implementation exploits the fact that dicts are ordered as of Python + 3.7. It is not intended to be performant, so only the minimal set of + methods are implemented. We need this class to ensure that fixture names + are constructed deterministically, otherwise pytest-xdist will complain if + different threads have seemingly different tests. + """ + + def __init__(self, args=None): + args = args or [] + self._data = {value: None for value in args} + + def __contains__(self, key): + return key in self._data + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __repr__(self): + # Helpful for debugging. + data = ", ".join(str(i) for i in self._data) + return f"{self.__class__.__name__}({data})" + + def add(self, value): + self._data[value] = None + + def discard(self, value): + self._data.pop(value, None) + + +def make_fixture(name, func, globals_, fixtures): + """Create a named fixture in `globals_` and save its name in `fixtures`. + + https://github.com/pytest-dev/pytest/issues/2424#issuecomment-333387206 + explains why this hack is necessary. Essentially, dynamically generated + fixtures must exist in globals() to be found by pytest. + """ + globals_[name] = pytest_cases.fixture(name=name)(func) + fixtures.add(name) + + +def collapse_fixtures(fixtures, pattern, repl, globals_, idfunc=None): + """Create unions of fixtures based on specific name mappings. + + `fixtures` are grouped into unions according the regex replacement + `re.sub(pattern, repl)` and placed into `new_fixtures`. + """ + + def collapser(n): + return re.sub(pattern, repl, n) + + # Note: sorted creates a new list, not a view, so it's OK to modify the + # list of fixtures while iterating over the sorted result. + for name, group in groupby(sorted(fixtures, key=collapser), key=collapser): + group = list(group) + if len(group) > 1 and name not in fixtures: + pytest_cases.fixture_union(name=name, fixtures=group, ids=idfunc) + # Need to assign back to the parent scope's globals. + globals_[name] = globals()[name] + fixtures.add(name) + + +# A dictionary of callables that create a column of a specified length +random_state = cupy.random.RandomState(42) +column_generators = { + "int": (lambda nr: random_state.randint(low=0, high=100, size=nr)), + "float": (lambda nr: random_state.rand(nr)), +} diff --git a/python/cudf/benchmarks/conftest.py b/python/cudf/benchmarks/conftest.py new file mode 100644 index 00000000000..4f2bb96061f --- /dev/null +++ b/python/cudf/benchmarks/conftest.py @@ -0,0 +1,234 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Defines pytest fixtures for all benchmarks. + +Most fixtures defined in this file represent one of the primary classes in the +cuDF ecosystem such as DataFrame, Series, or Index. These fixtures may in turn +be broken up into two categories: base fixtures and fixture unions. Each base +fixture represents a specific type of object as well as certain of its +properties crucial for benchmarking. Specifically, fixtures must account for +the following different parameters: + - Class of object (DataFrame, Series, Index) + - Dtype + - Nullability + - Size (rows for all, rows/columns for DataFrame) + +One such fixture is a series of nullable integer data. Given that we generally +want data across different sizes, we parametrize all fixtures across different +numbers of rows rather than generating separate fixtures for each different +possible number of rows. The number of columns is only relevant for DataFrame. + +While this core set of fixtures means that any benchmark can be run for any +combination of these parameters, it also means that we would effectively have +to parametrize our benchmarks with many fixtures. Not only is parametrizing +tests with fixtures in this manner unsupported by pytest, it is also an +inelegant solution leading to cumbersome parameter lists that must be +maintained across all tests. Instead we make use of the +`pytest_cases _` pytest plugin, +which supports the creation of fixture unions: fixtures that result from +combining other fixtures together. The result is a set of well-defined fixtures +that allow us to write benchmarks that naturally express the set of objects for +which they are valid, e.g. `def bench_sort_values(frame_or_index)`. + +The generated fixtures are named according to the following convention: +`{classname}_dtype_{dtype}[_nulls_{true|false}][_cols_{num_cols}][_rows_{num_rows}]` +where classname is one of the following: index, series, dataframe, +indexedframe, frame, frame_or_index. Note that in the case of indexes, to match +Series/DataFrame we simply set `classname=index` and rely on the +`dtype_{dtype}` component to delineate which index class is actually in use. + +In addition to the above fixtures, we also provide the following more +specialized fixtures: + - rangeindex: Since RangeIndex always holds int64 data we cannot conflate + it with index_dtype_int64 (a true Int64Index), and it cannot hold nulls. + As a result, it is provided as a separate fixture. +""" + +import os +import string +import sys + +import pytest_cases + +# TODO: Rather than doing this path hacking (including the sessionstart and +# sessionfinish hooks), we could just make the benchmarks a (sub)package to +# enable relative imports. A minor change to consider when these are ported +# into the main repo. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "common")) + +from config import cudf # noqa: W0611, E402, F401 +from utils import ( # noqa: E402 + OrderedSet, + collapse_fixtures, + column_generators, + make_fixture, +) + +# Turn off isort until we upgrade to 5.8.0 +# https://github.com/pycqa/isort/issues/1594 +# isort: off +from config import ( # noqa: W0611, E402, F401 + NUM_COLS, + NUM_ROWS, + collect_ignore, + pytest_collection_modifyitems, + pytest_sessionfinish, + pytest_sessionstart, +) + +# isort: on + + +@pytest_cases.fixture(params=[0, 1], ids=["AxisIndex", "AxisColumn"]) +def axis(request): + return request.param + + +# First generate all the base fixtures. +fixtures = OrderedSet() +for dtype, column_generator in column_generators.items(): + + def make_dataframe(nr, nc, column_generator=column_generator): + assert nc <= len( + string.ascii_lowercase + ), "make_dataframe only supports a maximum of 26 columns" + return cudf.DataFrame( + { + f"{string.ascii_lowercase[i]}": column_generator(nr) + for i in range(nc) + } + ) + + for nr in NUM_ROWS: + # TODO: pytest_cases.fixture doesn't appear to support lambdas where + # pytest does. https://github.com/smarie/python-pytest-cases/issues/278 + # Once that is fixed we could use lambdas here. + # TODO: pytest_cases has a bug where the first argument being a + # defaulted kwarg e.g. (nr=nr, nc=nc) raises errors. + # https://github.com/smarie/python-pytest-cases/issues/278 + # Once that is fixed we could remove all the extraneous `request` + # fixtures in these fixtures. + def series_nulls_false( + request, nr=nr, column_generator=column_generator + ): + return cudf.Series(column_generator(nr)) + + make_fixture( + f"series_dtype_{dtype}_nulls_false_rows_{nr}", + series_nulls_false, + globals(), + fixtures, + ) + + def series_nulls_true( + request, nr=nr, column_generator=column_generator + ): + s = cudf.Series(column_generator(nr)) + s.iloc[::2] = None + return s + + make_fixture( + f"series_dtype_{dtype}_nulls_true_rows_{nr}", + series_nulls_true, + globals(), + fixtures, + ) + + # For now, not bothering to include a nullable index fixture. + def index_nulls_false( + request, nr=nr, column_generator=column_generator + ): + return cudf.Index(column_generator(nr)) + + make_fixture( + f"index_dtype_{dtype}_nulls_false_rows_{nr}", + index_nulls_false, + globals(), + fixtures, + ) + + for nc in NUM_COLS: + + def dataframe_nulls_false( + request, nr=nr, nc=nc, make_dataframe=make_dataframe + ): + return make_dataframe(nr, nc) + + make_fixture( + f"dataframe_dtype_{dtype}_nulls_false_cols_{nc}_rows_{nr}", + dataframe_nulls_false, + globals(), + fixtures, + ) + + def dataframe_nulls_true( + request, nr=nr, nc=nc, make_dataframe=make_dataframe + ): + df = make_dataframe(nr, nc) + df.iloc[::2, :] = None + return df + + make_fixture( + f"dataframe_dtype_{dtype}_nulls_true_cols_{nc}_rows_{nr}", + dataframe_nulls_true, + globals(), + fixtures, + ) + + +# We define some custom naming functions for use in the creation of fixture +# unions to create more readable test function names that don't contain the +# entire union, which quickly becomes intractably long. +def unique_union_id(val): + return val.alternative_name + + +def default_union_id(val): + return f"alt{val.get_alternative_idx()}" + + +# Label the first level differently from others since there's no redundancy. +idfunc = unique_union_id +num_new_fixtures = len(fixtures) + +# Keep trying to merge existing fixtures until no new fixtures are added. +while num_new_fixtures > 0: + num_fixtures = len(fixtures) + + # Note: If we start also introducing unions across dtypes, most likely + # those will take the form `*int_and_float*` or similar since we won't want + # to union _all_ dtypes. In that case, the regexes will need to use + # suitable lookaheads etc to avoid infinite loops here. + for pat, repl in [ + ("_nulls_(true|false)", ""), + ("series|dataframe", "indexedframe"), + ("indexedframe|index", "frame_or_index"), + (r"_rows_\d+", ""), + (r"_cols_\d+", ""), + ]: + + collapse_fixtures(fixtures, pat, repl, globals(), idfunc) + + num_new_fixtures = len(fixtures) - num_fixtures + # All subsequent levels get the same (collapsed) labels. + idfunc = default_union_id + + +for dtype, column_generator in column_generators.items(): + # We have to manually add this one because we aren't including nullable + # indexes but we want to be able to run some benchmarks on Series/DataFrame + # that may or may not be nullable as well as Index objects. + pytest_cases.fixture_union( + name=f"frame_or_index_dtype_{dtype}", + fixtures=( + f"indexedframe_dtype_{dtype}", + f"index_dtype_{dtype}_nulls_false", + ), + ids=["", f"index_dtype_{dtype}_nulls_false"], + ) + + +# TODO: Decide where to incorporate RangeIndex and MultiIndex fixtures. +@pytest_cases.fixture(params=NUM_ROWS) +def rangeindex(request): + return cudf.RangeIndex(request.param) diff --git a/python/cudf/benchmarks/internal/bench_column.py b/python/cudf/benchmarks/internal/bench_column.py new file mode 100644 index 00000000000..d4969b39f7f --- /dev/null +++ b/python/cudf/benchmarks/internal/bench_column.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of Column methods.""" + +import pytest +import pytest_cases +from utils import ( + benchmark_with_object, + make_boolean_mask_column, + make_gather_map, +) + + +@benchmark_with_object(cls="column", dtype="float") +def bench_apply_boolean_mask(benchmark, column): + mask = make_boolean_mask_column(column.size) + benchmark(column.apply_boolean_mask, mask) + + +@benchmark_with_object(cls="column", dtype="float") +@pytest.mark.parametrize("dropnan", [True, False]) +def bench_dropna(benchmark, column, dropnan): + benchmark(column.dropna, drop_nan=dropnan) + + +@benchmark_with_object(cls="column", dtype="float") +def bench_unique_single_column(benchmark, column): + benchmark(column.unique) + + +@benchmark_with_object(cls="column", dtype="float") +@pytest.mark.parametrize("nullify", [True, False]) +@pytest.mark.parametrize("gather_how", ["sequence", "reverse", "random"]) +def bench_take(benchmark, column, gather_how, nullify): + gather_map = make_gather_map( + column.size * 0.4, column.size, gather_how + )._column + benchmark(column.take, gather_map, nullify=nullify) + + +@benchmark_with_object(cls="column", dtype="int", nulls=False) +def setitem_case_stride_1_slice_scalar(column): + return column, slice(None, None, 1), 42 + + +@benchmark_with_object(cls="column", dtype="int", nulls=False) +def setitem_case_stride_2_slice_scalar(column): + return column, slice(None, None, 2), 42 + + +@benchmark_with_object(cls="column", dtype="int", nulls=False) +def setitem_case_boolean_column_scalar(column): + column = column + return column, [True, False] * (len(column) // 2), 42 + + +@benchmark_with_object(cls="column", dtype="int", nulls=False) +def setitem_case_int_column_scalar(column): + column = column + return column, list(range(len(column))), 42 + + +@benchmark_with_object(cls="column", dtype="int", nulls=False) +def setitem_case_stride_1_slice_align_to_key_size( + column, +): + column = column + key = slice(None, None, 1) + start, stop, stride = key.indices(len(column)) + materialized_key_size = len(column.slice(start, stop, stride)) + return column, key, [42] * materialized_key_size + + +@benchmark_with_object(cls="column", dtype="int", nulls=False) +def setitem_case_stride_2_slice_align_to_key_size( + column, +): + column = column + key = slice(None, None, 2) + start, stop, stride = key.indices(len(column)) + materialized_key_size = len(column.slice(start, stop, stride)) + return column, key, [42] * materialized_key_size + + +@benchmark_with_object(cls="column", dtype="int", nulls=False) +def setitem_case_boolean_column_align_to_col_size( + column, +): + column = column + size = len(column) + return column, [True, False] * (size // 2), [42] * size + + +@benchmark_with_object(cls="column", dtype="int", nulls=False) +def setitem_case_int_column_align_to_col_size(column): + column = column + size = len(column) + return column, list(range(size)), [42] * size + + +# Benchmark Grid +# key: slice == 1 (fill or copy_range shortcut), +# slice != 1 (scatter), +# column(bool) (boolean_mask_scatter), +# column(int) (scatter) +# value: scalar, +# column (len(val) == len(key)), +# column (len(val) != len(key) and len == num_true) + + +@pytest_cases.parametrize_with_cases( + "column,key,value", cases=".", prefix="setitem" +) +def bench_setitem(benchmark, column, key, value): + benchmark(column.__setitem__, key, value) diff --git a/python/cudf/benchmarks/internal/bench_dataframe_internal.py b/python/cudf/benchmarks/internal/bench_dataframe_internal.py new file mode 100644 index 00000000000..5204d4fb65d --- /dev/null +++ b/python/cudf/benchmarks/internal/bench_dataframe_internal.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of internal DataFrame methods.""" + +from utils import benchmark_with_object, make_boolean_mask_column + + +@benchmark_with_object(cls="dataframe", dtype="int") +def bench_apply_boolean_mask(benchmark, dataframe): + mask = make_boolean_mask_column(len(dataframe)) + benchmark(dataframe._apply_boolean_mask, mask) diff --git a/python/cudf/benchmarks/internal/bench_rangeindex_internal.py b/python/cudf/benchmarks/internal/bench_rangeindex_internal.py new file mode 100644 index 00000000000..c4cf1de8ab9 --- /dev/null +++ b/python/cudf/benchmarks/internal/bench_rangeindex_internal.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Benchmarks of internal RangeIndex methods.""" + + +def bench_column(benchmark, rangeindex): + benchmark(lambda: rangeindex._column) + + +def bench_columns(benchmark, rangeindex): + benchmark(lambda: rangeindex._columns) diff --git a/python/cudf/benchmarks/internal/conftest.py b/python/cudf/benchmarks/internal/conftest.py new file mode 100644 index 00000000000..7351f1d1427 --- /dev/null +++ b/python/cudf/benchmarks/internal/conftest.py @@ -0,0 +1,56 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +"""Defines pytest fixtures for internal benchmarks.""" + +from config import NUM_ROWS, cudf +from utils import ( + OrderedSet, + collapse_fixtures, + column_generators, + make_fixture, +) + +fixtures = OrderedSet() +for dtype, column_generator in column_generators.items(): + for nr in NUM_ROWS: + + def column_nulls_false(request, nr=nr): + return cudf.core.column.as_column(column_generator(nr)) + + make_fixture( + f"column_dtype_{dtype}_nulls_false_rows_{nr}", + column_nulls_false, + globals(), + fixtures, + ) + + def column_nulls_true(request, nr=nr): + c = cudf.core.column.as_column(column_generator(nr)) + c[::2] = None + return c + + make_fixture( + f"column_dtype_{dtype}_nulls_true_rows_{nr}", + column_nulls_true, + globals(), + fixtures, + ) + +num_new_fixtures = len(fixtures) + +# Keep trying to merge existing fixtures until no new fixtures are added. +while num_new_fixtures > 0: + num_fixtures = len(fixtures) + + # Note: If we start also introducing unions across dtypes, most likely + # those will take the form `*int_and_float*` or similar since we won't want + # to union _all_ dtypes. In that case, the regexes will need to use + # suitable lookaheads etc to avoid infinite loops here. + for pat, repl in [ + ("_nulls_(true|false)", ""), + (r"_rows_\d+", ""), + ]: + + collapse_fixtures(fixtures, pat, repl, globals()) + + num_new_fixtures = len(fixtures) - num_fixtures diff --git a/python/cudf/benchmarks/pytest.ini b/python/cudf/benchmarks/pytest.ini new file mode 100644 index 00000000000..db24415ef9e --- /dev/null +++ b/python/cudf/benchmarks/pytest.ini @@ -0,0 +1,8 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +[pytest] +python_files = bench_*.py +python_classes = Bench +python_functions = bench_* +markers = + pandas_incompatible: mark a benchmark that cannot be run with pandas