diff --git a/dask_cuda/benchmarks/local_cudf_shuffle.py b/dask_cuda/benchmarks/local_cudf_shuffle.py index cf5fec8a4..f329aa92b 100644 --- a/dask_cuda/benchmarks/local_cudf_shuffle.py +++ b/dask_cuda/benchmarks/local_cudf_shuffle.py @@ -22,28 +22,16 @@ from dask_cuda.utils import all_to_all -def shuffle_dask(args, df, write_profile): - if write_profile is None: - ctx = contextlib.nullcontext() - else: - ctx = performance_report(filename=args.profile) - - # Execute the operations to benchmark - with ctx: - t1 = clock() - wait(shuffle(df, index="data", shuffle="tasks").persist()) - return clock() - t1 +def shuffle_dask(df): + wait(shuffle(df, index="data", shuffle="tasks").persist()) -def shuffle_explicit_comms(args, df): - t1 = clock() +def shuffle_explicit_comms(df): wait( dask_cuda.explicit_comms.dataframe.shuffle.shuffle( df, column_names="data" ).persist() ) - took = clock() - t1 - return took def run(client, args, n_workers, write_profile=None): @@ -63,12 +51,20 @@ def run(client, args, n_workers, write_profile=None): wait(df) data_processed = len(df) * sum([t.itemsize for t in df.dtypes]) - if args.backend == "dask": - took = shuffle_dask(args, df, write_profile) + if write_profile is None: + ctx = contextlib.nullcontext() else: - took = shuffle_explicit_comms(args, df) + ctx = performance_report(filename=args.profile) + + with ctx: + t1 = clock() + if args.backend == "dask": + shuffle_dask(df) + else: + shuffle_explicit_comms(df) + t2 = clock() - return (data_processed, took) + return (data_processed, t2 - t1) def main(args):