Skip to content

Commit

Permalink
Reduce device synchronisation
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed May 5, 2020
1 parent 8de7f19 commit 29b2819
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 116 deletions.
121 changes: 121 additions & 0 deletions src/tree/gpu_hist/driver.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#ifndef DRIVER_CUH_
#define DRIVER_CUH_
#include <xgboost/span.h>
#include <queue>
#include "../param.h"
#include "evaluate_splits.cuh"

namespace xgboost {
namespace tree {
struct ExpandEntry {
int nid;
int depth;
DeviceSplitCandidate split;
ExpandEntry() = default;
ExpandEntry(int nid, int depth, DeviceSplitCandidate split
)
: nid(nid), depth(depth), split(std::move(split)){}
bool IsValid(const TrainParam& param, int num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false;
}
if (split.loss_chg < param.min_split_loss) {
return false;
}
if (param.max_depth > 0 && depth == param.max_depth) {
return false;
}
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
return false;
}
return true;
}

static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
if (param.max_depth > 0 && depth >= param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
return true;
}

friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
os << "ExpandEntry: \n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
return os;
}
};

inline bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
return lhs.depth > rhs.depth; // favor small depth
}

inline bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.split.loss_chg == rhs.split.loss_chg) {
return lhs.nid > rhs.nid; // favor small timestamp
} else {
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
}
}

// Drives execution of tree building on device
class Driver {
using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;

public:
explicit Driver(TrainParam::TreeGrowPolicy policy)
: policy(policy),
queue(policy == TrainParam::kDepthWise ? DepthWise : LossGuide) {}
template <typename EntryIterT>
void Push(EntryIterT begin,EntryIterT end) {
for (auto it = begin; it != end; ++it) {
const ExpandEntry& e = *it;
if (e.split.loss_chg > kRtEps) {
queue.push(e);
}
}
}
void Push(const std::vector<ExpandEntry> &entries) {
this->Push(entries.begin(), entries.end());
}
// Return the set of nodes to be expanded
// This set has no dependencies between entries so they may be expanded in
// parallel or asynchronously
std::vector<ExpandEntry> Pop() {
if (queue.empty()) return {};
// Return a single entry for loss guided mode
if (policy == TrainParam::kLossGuide) {
ExpandEntry e = queue.top();
queue.pop();
return {e};
}
// Return nodes on same level for depth wise
std::vector<ExpandEntry> result;
ExpandEntry e = queue.top();
int level = e.depth;
while (e.depth == level && !queue.empty()) {
queue.pop();
result.emplace_back(e);
if (!queue.empty()) {
e = queue.top();
}
}
return result;
}

private:
TrainParam::TreeGrowPolicy policy;
ExpandQueue queue;
};
} // namespace tree
} // namespace xgboost

#endif // DRIVER_CUH_
7 changes: 4 additions & 3 deletions src/tree/gpu_hist/row_partitioner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class RowPartitioner {
dh::caching_device_vector<int64_t>
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
std::vector<cudaStream_t> streams_;
dh::PinnedMemory pinned_;

public:
RowPartitioner(int device_idx, size_t num_rows);
Expand Down Expand Up @@ -129,12 +130,12 @@ class RowPartitioner {
d_position[idx] = new_position;
});
// Overlap device to host memory copy (left_count) with sort
int64_t left_count;
int64_t &left_count = pinned_.GetSpan<int64_t>(1)[0];
dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t),
cudaMemcpyDeviceToHost, streams_[0]));

SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count,
streams_[1]);
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, streams_[1]
);

dh::safe_cuda(cudaStreamSynchronize(streams_[0]));
CHECK_LE(left_count, segment.Size());
Expand Down
170 changes: 61 additions & 109 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "gpu_hist/row_partitioner.cuh"
#include "gpu_hist/histogram.cuh"
#include "gpu_hist/evaluate_splits.cuh"
#include "gpu_hist/driver.cuh"

namespace xgboost {
namespace tree {
Expand Down Expand Up @@ -57,58 +58,6 @@ struct GPUHistMakerTrainParam
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
#endif // !defined(GTEST_TEST)

struct ExpandEntry {
int nid;
int depth;
DeviceSplitCandidate split;
uint64_t timestamp;
ExpandEntry() = default;
ExpandEntry(int nid, int depth, DeviceSplitCandidate split,
uint64_t timestamp)
: nid(nid), depth(depth), split(std::move(split)), timestamp(timestamp) {}
bool IsValid(const TrainParam& param, int num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false;
}
if (split.loss_chg < param.min_split_loss) { return false; }
if (param.max_depth > 0 && depth == param.max_depth) {return false; }
if (param.max_leaves > 0 && num_leaves == param.max_leaves) { return false; }
return true;
}

static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
if (param.max_depth > 0 && depth >= param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
return true;
}

friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
os << "ExpandEntry: \n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
return os;
}
};

inline static bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.depth == rhs.depth) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.depth > rhs.depth; // favor small depth
}
}
inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.split.loss_chg == rhs.split.loss_chg) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
}
}

/**
* \struct DeviceHistogram
*
Expand Down Expand Up @@ -243,18 +192,15 @@ struct GPUHistMakerDevice {

GradientSumT histogram_rounding;

dh::PinnedMemory pinned;

std::vector<cudaStream_t> streams{};

common::Monitor monitor;
std::vector<ValueConstraint> node_value_constraints;
common::ColumnSampler column_sampler;
FeatureInteractionConstraintDevice interaction_constraints;

using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;

std::unique_ptr<GradientBasedSampler> sampler;

GPUHistMakerDevice(int _device_id,
Expand Down Expand Up @@ -314,11 +260,6 @@ struct GPUHistMakerDevice {
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand.reset(new ExpandQueue(LossGuide));
} else {
qexpand.reset(new ExpandQueue(DepthWise));
}
this->column_sampler.Init(num_columns, param.colsample_bynode,
param.colsample_bylevel, param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
Expand Down Expand Up @@ -370,9 +311,9 @@ struct GPUHistMakerDevice {
return result.front();
}

std::vector<DeviceSplitCandidate> EvaluateLeftRightSplits(
void EvaluateLeftRightSplits(
ExpandEntry candidate, int left_nidx, int right_nidx,
const RegTree& tree) {
const RegTree& tree,common::Span<ExpandEntry> pinned_candidates_out) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
GPUTrainingParam gpu_param(param);
auto left_sampled_features =
Expand Down Expand Up @@ -412,12 +353,20 @@ struct GPUHistMakerDevice {
hist.GetNodeHistogram(right_nidx),
node_value_constraints[right_nidx],
dh::ToSpan(monotone_constraints)};
EvaluateSplits(dh::ToSpan(splits_out), left, right);
std::vector<DeviceSplitCandidate> result(2);
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
sizeof(DeviceSplitCandidate) * splits_out.size(),
auto d_splits_out = dh::ToSpan(splits_out);
EvaluateSplits(d_splits_out, left, right);
dh::TemporaryArray<ExpandEntry> entries(2);
auto d_entries = entries.data().get();
dh::LaunchN(device_id, 1, [=]__device__(size_t idx)
{
d_entries[0] = ExpandEntry(left_nidx, candidate.depth + 1, d_splits_out[0]);
d_entries[1] =
ExpandEntry(right_nidx, candidate.depth + 1, d_splits_out[1]);

});
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(),
sizeof(ExpandEntry) * entries.size(),
cudaMemcpyDeviceToHost));
return result;
}

void BuildHist(int nidx) {
Expand Down Expand Up @@ -637,7 +586,7 @@ struct GPUHistMakerDevice {
tree[candidate.nid].RightChild());
}

void InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
ExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
GradientPair root_sum = thrust::reduce(
Expand All @@ -662,61 +611,64 @@ struct GPUHistMakerDevice {

// Generate first split
auto split = this->EvaluateRootSplit(root_sum);
qexpand->push(
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0));
return ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split);
}

void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
RegTree* p_tree, dh::AllReducer* reducer) {
auto& tree = *p_tree;
Driver driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));

monitor.StartCuda("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
monitor.StopCuda("Reset");

monitor.StartCuda("InitRoot");
this->InitRoot(p_tree, reducer);
driver.Push({ this->InitRoot(p_tree, reducer) });
monitor.StopCuda("InitRoot");

auto timestamp = qexpand->size();
auto num_leaves = 1;

while (!qexpand->empty()) {
ExpandEntry candidate = qexpand->top();
qexpand->pop();
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
this->ApplySplit(candidate, p_tree);

num_leaves++;

int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.StartCuda("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.StopCuda("UpdatePosition");

monitor.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.StopCuda("BuildHist");

monitor.StartCuda("EvaluateSplits");
auto splits = this->EvaluateLeftRightSplits(candidate, left_child_nidx,
right_child_nidx,
*p_tree);
monitor.StopCuda("EvaluateSplits");

qexpand->push(ExpandEntry(left_child_nidx,
tree.GetDepth(left_child_nidx), splits.at(0),
timestamp++));
qexpand->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx),
splits.at(1), timestamp++));
// The set of leaves that can be expanded asynchronously
auto expand_set = driver.Pop();
while (!expand_set.empty()) {
auto new_candidates = pinned.GetSpan<ExpandEntry>(expand_set.size() * 2);
for (auto i = 0ull; i < expand_set.size(); i++) {
auto candidate = expand_set.at(i);
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
this->ApplySplit(candidate, p_tree);

num_leaves++;

int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.StartCuda("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.StopCuda("UpdatePosition");

monitor.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.StopCuda("BuildHist");

monitor.StartCuda("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, left_child_nidx,
right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2));
monitor.StopCuda("EvaluateSplits");
} else {
// Set default
new_candidates[i * 2] = ExpandEntry();
new_candidates[i * 2 + 1] = ExpandEntry();
}
}
dh::safe_cuda(cudaDeviceSynchronize());
driver.Push(new_candidates.begin(),new_candidates.end());
expand_set = driver.Pop();
}

monitor.StartCuda("FinalisePosition");
Expand Down
Loading

0 comments on commit 29b2819

Please sign in to comment.