diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index ee499b4d1a0e..12222cf9d747 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -32,23 +32,23 @@ namespace collective { * @param function The function used to calculate the results. * @param args Arguments to the function. */ -template -void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function, +template +void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& function, Args&&... args) { if (info.IsVerticalFederated()) { // We assume labels are only available on worker 0, so the calculation is done there and result // broadcast to other workers. - std::vector message(1024); + std::string message; if (collective::GetRank() == 0) { try { std::forward(function)(std::forward(args)...); } catch (dmlc::Error& e) { - strncpy(&message[0], e.what(), message.size()); - message.back() = '\0'; + message = e.what(); } } - collective::Broadcast(&message[0], message.size(), 0); - if (strlen(&message[0]) == 0) { + + collective::Broadcast(&message, 0); + if (message.empty()) { collective::Broadcast(buffer, size, 0); } else { LOG(FATAL) << &message[0]; @@ -57,6 +57,5 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& std::forward(function)(std::forward(args)...); } } - } // namespace collective } // namespace xgboost