Skip to content

Commit

Permalink
Test column sampler with column-wise data split (#9609)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou authored Sep 26, 2023
1 parent 1167e6c commit 290b17f
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions tests/cpp/test_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,4 +718,70 @@ INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit,
[](const ::testing::TestParamInfo<TestColumnSplit::ParamType>& info) {
return ObjTestNameGenerator(info);
});

namespace {
void VerifyColumnSplitColumnSampler(std::string const& tree_method, bool use_gpu,
Json const& expected_model) {
Json model{Object{}};
{
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const objective = "reg:logistic";
auto p_fmat = MakeFmatForObjTest(objective);
std::shared_ptr<DMatrix> sliced{p_fmat->SliceCol(world_size, rank)};
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
learner->SetParam("tree_method", tree_method);
if (use_gpu) {
auto gpu_id = common::AllVisibleGPUs() == 1 ? 0 : rank;
learner->SetParam("device", "cuda:" + std::to_string(gpu_id));
}
learner->SetParam("objective", objective);
learner->SetParam("colsample_bytree", "0.5");
learner->SetParam("colsample_bylevel", "0.6");
learner->SetParam("colsample_bynode", "0.7");
learner->UpdateOneIter(0, sliced);
learner->SaveModel(&model);
}
ASSERT_EQ(model, expected_model);
}

void TestColumnSplitColumnSampler(std::string const& tree_method, bool use_gpu) {
Json model{Object{}};
{
auto objective = "reg:logistic";
auto p_fmat = MakeFmatForObjTest(objective);
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
learner->SetParam("tree_method", tree_method);
if (use_gpu) {
learner->SetParam("device", "cuda:0");
}
learner->SetParam("objective", objective);
learner->SetParam("colsample_bytree", "0.5");
learner->SetParam("colsample_bylevel", "0.6");
learner->SetParam("colsample_bynode", "0.7");
learner->UpdateOneIter(0, p_fmat);
learner->SaveModel(&model);
}
auto world_size{3};
if (use_gpu) {
world_size = common::AllVisibleGPUs();
// Simulate MPU on a single GPU.
if (world_size == 1) {
world_size = 3;
}
}
RunWithInMemoryCommunicator(world_size, VerifyColumnSplitColumnSampler, tree_method, use_gpu,
model);
}
} // anonymous namespace

TEST(ColumnSplitColumnSampler, Approx) { TestColumnSplitColumnSampler("approx", false); }

TEST(ColumnSplitColumnSampler, Hist) { TestColumnSplitColumnSampler("hist", false); }

#if defined(XGBOOST_USE_CUDA)
TEST(ColumnSplitColumnSampler, GPUApprox) { TestColumnSplitColumnSampler("approx", true); }

TEST(ColumnSplitColumnSampler, GPUHist) { TestColumnSplitColumnSampler("hist", true); }
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost

0 comments on commit 290b17f

Please sign in to comment.