Skip to content

Commit

Permalink
Using pinned host memory for Random Forest and DBSCAN (rapidsai#4215)
Browse files Browse the repository at this point in the history
Benchmarks show that RF performs consistently better with pinned host memory, while DBSCAN sometimes better and sometimes not (within the margin of error), so using pinned host memory by default for both these algorithms.

Ignoring KMeans and LARS for now as both show slightly better perf with pinned host memory but only with increasing number of columns. Since this would need more analysis and deciding if a heuristic is needed for selecting memory, deferring it to 21.12.

Here are the raw numbers:
1. LARS
Normal memory:
```{'lars': {(100000, 10): 0.12429666519165039, (100000, 100): 0.015396833419799805, (100000, 250): 0.015408039093017578, (250000, 10): 0.00986933708190918, (250000, 100): 0.023822546005249023, (250000, 250): 0.03715157508850098, (500000, 10): 0.013423442840576172, (500000, 100): 0.044762372970581055, (500000, 250): 0.07782578468322754}```
Pinned memory:
```{'lars': {(100000, 10): 0.12958097457885742, (100000, 100): 0.01501011848449707, (100000, 250): 0.016597509384155273, (250000, 10): 0.01801013946533203, (250000, 100): 0.022644996643066406, (250000, 250): 0.037090301513671875, (500000, 10): 0.020437955856323242, (500000, 100): 0.044635772705078125, (500000, 250): 0.07696056365966797}```
2. RFR
Normal memory:
```'rfr': {(100000, 10): 1.1951744556427002, (100000, 100): 5.099738359451294, (100000, 250): 11.32804536819458, (250000, 10): 2.0097765922546387, (250000, 100): 9.109776496887207, (250000, 250): 21.058837890625, (500000, 10): 3.3387184143066406, (500000, 100): 15.802990436553955, (500000, 250): 36.80855870246887}```
Pinned memory:
```'rfr': {(100000, 10): 1.1727137565612793, (100000, 100): 4.804195880889893, (100000, 250): 11.621357917785645, (250000, 10): 1.8899295330047607, (250000, 100): 9.16961407661438, (250000, 250): 21.12194561958313, (500000, 10): 3.2937560081481934, (500000, 100): 15.66197681427002, (500000, 250): 36.6080117225647}```
3. KMeans
Normal memory:
```{(100000, 10): 0.11008882522583008, (100000, 100): 0.15475797653198242, (100000, 250): 0.15683507919311523, (250000, 10): 0.18775177001953125, (250000, 100): 0.25696277618408203, (250000, 250): 0.40389132499694824, (500000, 10): 0.4578282833099365, (500000, 100): 0.3917391300201416, (500000, 250): 0.6426849365234375}```
Pinned memory:
```'kmeans': {(100000, 10): 0.11982870101928711, (100000, 100): 0.16992664337158203, (100000, 250): 0.1021108627319336, (250000, 10): 0.16021251678466797, (250000, 100): 0.31025242805480957, (250000, 250): 0.298201322555542, (500000, 10): 0.21084189414978027, (500000, 100): 0.50473952293396, (500000, 250): 0.6191830635070801}```
4. DBSCAN
Normal memory:
```'dbscan': {(100000, 10): 0.4957292079925537, (100000, 100): 0.8680248260498047, (100000, 250): 1.585218906402588, (250000, 10): 4.52524995803833, (250000, 100): 7.175846099853516, (250000, 250): 12.135416269302368, (500000, 10): 26.427770853042603, (500000, 100): 37.57275915145874, (500000, 250): 57.98261737823486}}```
Pinned memory:
```'dbscan': {(100000, 10): 0.49578166007995605, (100000, 100): 0.8678708076477051, (100000, 250): 1.5854766368865967, (250000, 10): 4.526952505111694, (250000, 100): 7.172863006591797, (250000, 250): 12.145166397094727, (500000, 10): 26.422622680664062, (500000, 100): 37.56665277481079, (500000, 250): 58.02563738822937}}```

Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Rory Mitchell (https://github.com/RAMitchell)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4215
  • Loading branch information
divyegala authored Sep 23, 2021
1 parent 70f97c5 commit bccebf9
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 5 deletions.
64 changes: 64 additions & 0 deletions cpp/include/cuml/common/pinned_host_vector.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <rmm/mr/host/pinned_memory_resource.hpp>

namespace ML {

template <typename T>
class pinned_host_vector {
public:
pinned_host_vector() = default;

explicit pinned_host_vector(std::size_t n)
: size_{n}, data_{static_cast<T*>(pinned_mr.allocate(n * sizeof(T)))}
{
std::uninitialized_fill(data_, data_ + n, static_cast<T>(0));
}
~pinned_host_vector() { pinned_mr.deallocate(data_, size_ * sizeof(T)); }

pinned_host_vector(pinned_host_vector const&) = delete;
pinned_host_vector(pinned_host_vector&&) = delete;
pinned_host_vector& operator=(pinned_host_vector const&) = delete;
pinned_host_vector& operator=(pinned_host_vector&&) = delete;

void resize(std::size_t n)
{
size_ = n;
data_ = static_cast<T*>(pinned_mr.allocate(n * sizeof(T)));
std::uninitialized_fill(data_, data_ + n, static_cast<T>(0));
}

T* data() { return data_; }

T* begin() { return data_; }

T* end() { return data_ + size_; }

std::size_t size() { return size_; }

T operator[](std::size_t idx) const { return *(data_ + idx); }
T& operator[](std::size_t idx) { return *(data_ + idx); }

private:
rmm::mr::pinned_memory_resource pinned_mr{};
T* data_;
std::size_t size_;
};

} // namespace ML
9 changes: 5 additions & 4 deletions cpp/src/dbscan/adjgraph/naive.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/cudart_utils.h>
#include <cuml/common/pinned_host_vector.hpp>
#include <raft/cuda_utils.cuh>
#include <vector>
#include "../common.cuh"
Expand All @@ -35,14 +36,14 @@ void launcher(const raft::handle_t& handle,
{
Index_ k = 0;
Index_ N = data.N;
std::vector<Index_> host_vd(batch_size + 1);
std::vector<char> host_adj(((batch_size * N) / 8) + 1);
std::vector<Index_> host_ex_scan(batch_size);
ML::pinned_host_vector<Index_> host_vd(batch_size + 1);
ML::pinned_host_vector<char> host_adj(((batch_size * N) / 8) + 1);
ML::pinned_host_vector<Index_> host_ex_scan(batch_size);
raft::update_host((bool*)host_adj.data(), data.adj, batch_size * N, stream);
raft::update_host(host_vd.data(), data.vd, batch_size + 1, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
size_t adjgraph_size = size_t(host_vd[batch_size]);
std::vector<Index_> host_adj_graph(adjgraph_size);
ML::pinned_host_vector<Index_> host_adj_graph(adjgraph_size);
for (Index_ i = 0; i < batch_size; i++) {
for (Index_ j = 0; j < N; j++) {
/// TODO: change layout or remove; cf #3414
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/decisiontree/batched-levelalgo/builder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <common/grid_sync.cuh>
#include <common/nvtx.hpp>
#include <cuml/common/logger.hpp>
#include <cuml/common/pinned_host_vector.hpp>
#include <cuml/tree/decisiontree.hpp>
#include <raft/cuda_utils.cuh>
#include "input.cuh"
Expand Down Expand Up @@ -190,7 +191,7 @@ struct Builder {
const size_t alignValue = 512;

rmm::device_uvector<char> d_buff;
std::vector<char> h_buff;
ML::pinned_host_vector<char> h_buff;

Builder(const raft::handle_t& handle,
IdxT treeid,
Expand Down
1 change: 1 addition & 0 deletions cpp/src/kmeans/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include <fstream>
#include <numeric>
#include <random>
#include <vector>

namespace ML {

Expand Down
1 change: 1 addition & 0 deletions cpp/src/solver/lars_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <iostream>
#include <limits>
#include <numeric>
#include <vector>

#include <raft/cudart_utils.h>
#include <raft/linalg/cublas_wrappers.h>
Expand Down

0 comments on commit bccebf9

Please sign in to comment.