-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a6d9a06
commit db9ea1e
Showing
4 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/*! | ||
* Copyright 2020 XGBoost contributors | ||
*/ | ||
#include "proxy_dmatrix.h" | ||
#include "device_adapter.cuh" | ||
|
||
namespace xgboost { | ||
namespace data { | ||
|
||
void DMatrixProxy::FromCudaColumnar(std::string interface_str) { | ||
std::shared_ptr<data::CudfAdapter> adapter {new data::CudfAdapter(interface_str)}; | ||
auto const& value = adapter->Value(); | ||
this->batch_ = adapter; | ||
device_ = adapter->DeviceIdx(); | ||
} | ||
|
||
void DMatrixProxy::FromCudaArray(std::string interface_str) { | ||
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str)); | ||
this->batch_ = adapter; | ||
device_ = adapter->DeviceIdx(); | ||
} | ||
|
||
} // namespace data | ||
} // namespace xgboost |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
/*! | ||
* Copyright 2020 XGBoost contributors | ||
*/ | ||
#ifndef XGBOOST_DATA_PROXY_DMATRIX_H_ | ||
#define XGBOOST_DATA_PROXY_DMATRIX_H_ | ||
|
||
#include <dmlc/any.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
|
||
#include "xgboost/data.h" | ||
#include "xgboost/generic_parameters.h" | ||
#include "xgboost/c_api.h" | ||
#include "adapter.h" | ||
|
||
namespace xgboost { | ||
namespace data { | ||
/* | ||
* \brief A proxy to external iterator. | ||
*/ | ||
template <typename ResetFn, typename NextFn> | ||
class DataIterProxy { | ||
DataIterHandle iter_; | ||
ResetFn* reset_; | ||
NextFn* next_; | ||
|
||
public: | ||
DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next) : | ||
iter_{iter}, | ||
reset_{reset}, next_{next} {} | ||
|
||
bool Next() { | ||
return next_(iter_); | ||
} | ||
void Reset() { | ||
reset_(iter_); | ||
} | ||
}; | ||
|
||
/* | ||
* \brief A proxy of DMatrix used by external iterator. | ||
*/ | ||
class DMatrixProxy : public DMatrix { | ||
MetaInfo info_; | ||
dmlc::any batch_; | ||
int32_t device_ { xgboost::GenericParameter::kCpuId }; | ||
|
||
public: | ||
void SetInfo(char const *key, void const *info, DataType type, | ||
size_t len) override { | ||
this->Info().SetInfo(key, info, type, len); | ||
} | ||
void SetInfo(const char* key, std::string const& interface_str) override { | ||
this->Info().SetInfo(key, interface_str); | ||
} | ||
|
||
int DeviceIdx() const { return device_; } | ||
|
||
#if defined(XGBOOST_USE_CUDA) | ||
void FromCudaColumnar(std::string interface_str); | ||
void FromCudaArray(std::string interface_str); | ||
#endif // defined(XGBOOST_USE_CUDA) | ||
|
||
void SetData(char const* c_interface) { | ||
common::AssertGPUSupport(); | ||
#if defined(XGBOOST_USE_CUDA) | ||
std::string interface_str = c_interface; | ||
Json json_array_interface = | ||
Json::Load({interface_str.c_str(), interface_str.size()}); | ||
if (IsA<Array>(json_array_interface)) { | ||
this->FromCudaColumnar(interface_str); | ||
} else { | ||
this->FromCudaArray(interface_str); | ||
} | ||
#endif // defined(XGBOOST_USE_CUDA) | ||
} | ||
|
||
MetaInfo& Info() override { return info_; } | ||
MetaInfo const& Info() const override { return info_; } | ||
bool SingleColBlock() const override { return true; } | ||
bool EllpackExists() const override { return true; } | ||
bool SparsePageExists() const override { return false; } | ||
DMatrix *Slice(common::Span<int32_t const> ridxs) override { | ||
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix."; | ||
return nullptr; | ||
} | ||
BatchSet<SparsePage> GetRowBatches() override { | ||
LOG(FATAL) << "Not implemented."; | ||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr)); | ||
} | ||
BatchSet<CSCPage> GetColumnBatches() override { | ||
LOG(FATAL) << "Not implemented."; | ||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr)); | ||
} | ||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override { | ||
LOG(FATAL) << "Not implemented."; | ||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr)); | ||
} | ||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override { | ||
LOG(FATAL) << "Not implemented."; | ||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(nullptr)); | ||
} | ||
|
||
dmlc::any Adapter() const { | ||
return batch_; | ||
} | ||
}; | ||
} // namespace data | ||
} // namespace xgboost | ||
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#include <gtest/gtest.h> | ||
#include <xgboost/host_device_vector.h> | ||
#include <memory> | ||
#include "../helpers.h" | ||
#include "../../../src/data/device_adapter.cuh" | ||
#include "../../../src/data/proxy_dmatrix.h" | ||
|
||
namespace xgboost { | ||
namespace data { | ||
TEST(ProxyDMatrix, Basic) { | ||
constexpr size_t kRows{100}, kCols{100}; | ||
HostDeviceVector<float> storage; | ||
auto data = RandomDataGenerator(kRows, kCols, 0.5) | ||
.Device(0) | ||
.GenerateArrayInterface(&storage); | ||
std::vector<HostDeviceVector<float>> label_storage(1); | ||
auto labels = RandomDataGenerator(kRows, 1, 0) | ||
.Device(0) | ||
.GenerateColumnarArrayInterface(&label_storage); | ||
|
||
DMatrixProxy proxy; | ||
proxy.FromCudaArray(data); | ||
proxy.SetInfo("label", labels.c_str()); | ||
|
||
ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr<CupyAdapter>)); | ||
ASSERT_EQ(proxy.Info().labels_.Size(), kRows); | ||
ASSERT_EQ(dmlc::get<std::shared_ptr<CupyAdapter>>(proxy.Adapter())->NumRows(), | ||
kRows); | ||
ASSERT_EQ( | ||
dmlc::get<std::shared_ptr<CupyAdapter>>(proxy.Adapter())->NumColumns(), | ||
kCols); | ||
|
||
std::vector<HostDeviceVector<float>> columnar_storage(kCols); | ||
data = RandomDataGenerator(kRows, kCols, 0) | ||
.Device(0) | ||
.GenerateColumnarArrayInterface(&columnar_storage); | ||
proxy.FromCudaColumnar(data); | ||
ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr<CudfAdapter>)); | ||
ASSERT_EQ(dmlc::get<std::shared_ptr<CudfAdapter>>(proxy.Adapter())->NumRows(), | ||
kRows); | ||
ASSERT_EQ( | ||
dmlc::get<std::shared_ptr<CudfAdapter>>(proxy.Adapter())->NumColumns(), | ||
kCols); | ||
} | ||
} // namespace data | ||
} // namespace xgboost |