diff --git a/cpp/bench/prims/distance_common.cuh b/cpp/bench/prims/distance_common.cuh index 0889ab48c9..dcc8292f82 100644 --- a/cpp/bench/prims/distance_common.cuh +++ b/cpp/bench/prims/distance_common.cuh @@ -49,6 +49,13 @@ struct Distance : public Fixture { workspace.resize(worksize, stream); } + void deallocateBuffers(const ::benchmark::State& state) override + { + x.release(); + y.release(); + out.release(); + workspace.release(); + } void runBenchmark(::benchmark::State& state) override { loopOnState(state, [this]() { diff --git a/cpp/bench/prims/gram_matrix.cu b/cpp/bench/prims/gram_matrix.cu index d4a83a8e31..be1e180bc8 100644 --- a/cpp/bench/prims/gram_matrix.cu +++ b/cpp/bench/prims/gram_matrix.cu @@ -68,7 +68,12 @@ struct GramMatrix : public Fixture { r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream); r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream); } - + void deallocateBuffers(const ::benchmark::State& state) override + { + A.release(); + B.release(); + C.release(); + } void runBenchmark(::benchmark::State& state) override { if (!this->kernel) { state.SkipWithError("Kernel matrix is not initialized"); } diff --git a/cpp/bench/prims/make_blobs.cu b/cpp/bench/prims/make_blobs.cu index 4b4e4932dd..963679d77c 100644 --- a/cpp/bench/prims/make_blobs.cu +++ b/cpp/bench/prims/make_blobs.cu @@ -40,6 +40,11 @@ struct MakeBlobs : public Fixture { labels.resize(params.rows, stream); } + void deallocateBuffers(const ::benchmark::State& state) override + { + data.release(); + labels.release(); + } void runBenchmark(::benchmark::State& state) override { loopOnState(state, [this]() { diff --git a/cpp/bench/sg/benchmark.cuh b/cpp/bench/sg/benchmark.cuh index e02ad12112..ae2c408755 100644 --- a/cpp/bench/sg/benchmark.cuh +++ b/cpp/bench/sg/benchmark.cuh @@ -38,7 +38,7 @@ class Fixture : public MLCommon::Bench::Fixture { void SetUp(const ::benchmark::State& state) override { auto stream_pool = std::make_shared(NumStreams); - handle.reset(new raft::handle_t{stream, stream_pool}); + handle.reset(new raft::handle_t{rmm::cuda_stream_per_thread, stream_pool}); MLCommon::Bench::Fixture::SetUp(state); }