Skip to content

Commit

Permalink
GH-36512: [C++][FlightRPC] Add async GetFlightInfo client call (#36517)
Browse files Browse the repository at this point in the history
### Rationale for this change

Async is a long-requested feature.

### What changes are included in this PR?

Just the C++ implementation of async GetFlightInfo for the client.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

Yes, new APIs.

* Closes: #36512

Authored-by: David Li <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
lidavidm authored Aug 9, 2023
1 parent 1a00fec commit 9f183fc
Show file tree
Hide file tree
Showing 19 changed files with 1,024 additions and 52 deletions.
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ else()
add_definitions(-DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc_impl::experimental)
endif()

# Was in a different namespace, or simply not supported, prior to this
if(ARROW_GRPC_VERSION VERSION_GREATER_EQUAL "1.40")
add_definitions(-DGRPC_ENABLE_ASYNC)
endif()

# </KLUDGE> Restore the CXXFLAGS that were modified above
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}")

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
#include "arrow/flight/server_middleware.h"
#include "arrow/flight/server_tracing_middleware.h"
#include "arrow/flight/types.h"
#include "arrow/flight/types_async.h"
58 changes: 58 additions & 0 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,56 @@
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/util/future.h"
#include "arrow/util/logging.h"

#include "arrow/flight/client_auth.h"
#include "arrow/flight/serialization_internal.h"
#include "arrow/flight/transport.h"
#include "arrow/flight/transport/grpc/grpc_client.h"
#include "arrow/flight/types.h"
#include "arrow/flight/types_async.h"

namespace arrow {

namespace flight {

namespace {
template <typename T>
class UnaryUnaryAsyncListener : public AsyncListener<T> {
public:
UnaryUnaryAsyncListener() : future_(arrow::Future<T>::Make()) {}

void OnNext(T result) override {
DCHECK(!result_.ok());
result_ = std::move(result);
}

void OnFinish(Status status) override {
if (status.ok()) {
DCHECK(result_.ok());
} else {
// Default-initialized result is not ok
DCHECK(!result_.ok());
result_ = std::move(status);
}
future_.MarkFinished(std::move(result_));
}

static std::pair<std::shared_ptr<AsyncListener<T>>, arrow::Future<T>> Make() {
auto self = std::make_shared<UnaryUnaryAsyncListener<T>>();
// Keep the listener alive by stashing it in the future
self->future_.AddCallback([self](const arrow::Result<T>&) {});
auto future = self->future_;
return std::make_pair(std::move(self), std::move(future));
}

private:
arrow::Result<T> result_;
arrow::Future<T> future_;
};
} // namespace

const char* kWriteSizeDetailTypeId = "flight::FlightWriteSizeStatusDetail";

FlightCallOptions::FlightCallOptions()
Expand Down Expand Up @@ -584,6 +622,24 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightClient::GetFlightInfo(
return info;
}

void FlightClient::GetFlightInfoAsync(
const FlightCallOptions& options, const FlightDescriptor& descriptor,
std::shared_ptr<AsyncListener<FlightInfo>> listener) {
if (auto status = CheckOpen(); !status.ok()) {
listener->OnFinish(std::move(status));
return;
}
transport_->GetFlightInfoAsync(options, descriptor, std::move(listener));
}

arrow::Future<FlightInfo> FlightClient::GetFlightInfoAsync(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
RETURN_NOT_OK(CheckOpen());
auto [listener, future] = UnaryUnaryAsyncListener<FlightInfo>::Make();
transport_->GetFlightInfoAsync(options, descriptor, std::move(listener));
return future;
}

arrow::Result<std::unique_ptr<SchemaResult>> FlightClient::GetSchema(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
RETURN_NOT_OK(CheckOpen());
Expand Down Expand Up @@ -658,6 +714,8 @@ Status FlightClient::Close() {
return Status::OK();
}

bool FlightClient::supports_async() const { return transport_->supports_async(); }

Status FlightClient::CheckOpen() const {
if (closed_) {
return Status::Invalid("FlightClient is closed");
Expand Down
28 changes: 28 additions & 0 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,31 @@ class ARROW_FLIGHT_EXPORT FlightClient {
return GetFlightInfo({}, descriptor);
}

/// \brief Asynchronous GetFlightInfo.
/// \param[in] options Per-RPC options
/// \param[in] descriptor the dataset request
/// \param[in] listener Callbacks for response and RPC completion
///
/// This API is EXPERIMENTAL.
void GetFlightInfoAsync(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::shared_ptr<AsyncListener<FlightInfo>> listener);
void GetFlightInfoAsync(const FlightDescriptor& descriptor,
std::shared_ptr<AsyncListener<FlightInfo>> listener) {
return GetFlightInfoAsync({}, descriptor, std::move(listener));
}

/// \brief Asynchronous GetFlightInfo returning a Future.
/// \param[in] options Per-RPC options
/// \param[in] descriptor the dataset request
///
/// This API is EXPERIMENTAL.
arrow::Future<FlightInfo> GetFlightInfoAsync(const FlightCallOptions& options,
const FlightDescriptor& descriptor);
arrow::Future<FlightInfo> GetFlightInfoAsync(const FlightDescriptor& descriptor) {
return GetFlightInfoAsync({}, descriptor);
}

/// \brief Request schema for a single flight, which may be an existing
/// dataset or a command to be executed
/// \param[in] options Per-RPC options
Expand Down Expand Up @@ -355,6 +380,9 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \since 8.0.0
Status Close();

/// \brief Whether this client supports asynchronous methods.
bool supports_async() const;

private:
FlightClient();
Status CheckOpen() const;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/flight/flight_internals_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ void TestRoundtrip(const std::vector<FlightType>& values,
ASSERT_OK(internal::ToProto(values[i], &pb_value));

if constexpr (std::is_same_v<FlightType, FlightInfo>) {
FlightInfo::Data data;
ASSERT_OK(internal::FromProto(pb_value, &data));
FlightInfo value(std::move(data));
ASSERT_OK_AND_ASSIGN(FlightInfo value, internal::FromProto(pb_value));
EXPECT_EQ(values[i], value);
} else if constexpr (std::is_same_v<FlightType, SchemaResult>) {
std::string data;
Expand Down Expand Up @@ -742,5 +740,7 @@ TEST(TransportErrorHandling, ReconstructStatus) {
ASSERT_EQ(detail->extra_info(), "Binary error details");
}

// TODO: test TransportStatusDetail

} // namespace flight
} // namespace arrow
25 changes: 24 additions & 1 deletion cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/util.h"
#include "arrow/util/base64.h"
#include "arrow/util/future.h"
#include "arrow/util/logging.h"

#ifdef GRPCPP_GRPCPP_H
Expand Down Expand Up @@ -91,9 +92,16 @@ const char kAuthHeader[] = "authorization";
//------------------------------------------------------------
// Common transport tests

#ifdef GRPC_ENABLE_ASYNC
constexpr bool kGrpcSupportsAsync = true;
#else
constexpr bool kGrpcSupportsAsync = false;
#endif

class GrpcConnectivityTest : public ConnectivityTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -102,6 +110,7 @@ ARROW_FLIGHT_TEST_CONNECTIVITY(GrpcConnectivityTest);
class GrpcDataTest : public DataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -110,6 +119,7 @@ ARROW_FLIGHT_TEST_DATA(GrpcDataTest);
class GrpcDoPutTest : public DoPutTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -118,6 +128,7 @@ ARROW_FLIGHT_TEST_DO_PUT(GrpcDoPutTest);
class GrpcAppMetadataTest : public AppMetadataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -126,6 +137,7 @@ ARROW_FLIGHT_TEST_APP_METADATA(GrpcAppMetadataTest);
class GrpcIpcOptionsTest : public IpcOptionsTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -134,6 +146,7 @@ ARROW_FLIGHT_TEST_IPC_OPTIONS(GrpcIpcOptionsTest);
class GrpcCudaDataTest : public CudaDataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -142,11 +155,21 @@ ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest);
class GrpcErrorHandlingTest : public ErrorHandlingTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
ARROW_FLIGHT_TEST_ERROR_HANDLING(GrpcErrorHandlingTest);

class GrpcAsyncClientTest : public AsyncClientTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
ARROW_FLIGHT_TEST_ASYNC_CLIENT(GrpcAsyncClientTest);

//------------------------------------------------------------
// Ad-hoc gRPC-specific tests

Expand Down Expand Up @@ -443,7 +466,7 @@ class TestTls : public ::testing::Test {
Location location_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<FlightServerBase> server_;
bool server_is_initialized_;
bool server_is_initialized_ = false;
};

// A server middleware that rejects all calls.
Expand Down
24 changes: 12 additions & 12 deletions cpp/src/arrow/flight/serialization_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,21 @@ Status ToProto(const FlightDescriptor& descriptor, pb::FlightDescriptor* pb_desc

// FlightInfo

Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) {
RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info->descriptor));
arrow::Result<FlightInfo> FromProto(const pb::FlightInfo& pb_info) {
FlightInfo::Data info;
RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info.descriptor));

info->schema = pb_info.schema();
info.schema = pb_info.schema();

info->endpoints.resize(pb_info.endpoint_size());
info.endpoints.resize(pb_info.endpoint_size());
for (int i = 0; i < pb_info.endpoint_size(); ++i) {
RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info->endpoints[i]));
RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info.endpoints[i]));
}

info->total_records = pb_info.total_records();
info->total_bytes = pb_info.total_bytes();
info->ordered = pb_info.ordered();
return Status::OK();
info.total_records = pb_info.total_records();
info.total_bytes = pb_info.total_bytes();
info.ordered = pb_info.ordered();
return FlightInfo(std::move(info));
}

Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* basic_auth) {
Expand Down Expand Up @@ -291,9 +292,8 @@ Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) {

Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request) {
FlightInfo::Data data;
RETURN_NOT_OK(FromProto(pb_request.info(), &data));
request->info = std::make_unique<FlightInfo>(std::move(data));
ARROW_ASSIGN_OR_RAISE(FlightInfo info, FromProto(pb_request.info()));
request->info = std::make_unique<FlightInfo>(std::move(info));
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/serialization_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr);
Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint);
Status FromProto(const pb::RenewFlightEndpointRequest& pb_request,
RenewFlightEndpointRequest* request);
Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info);
arrow::Result<FlightInfo> FromProto(const pb::FlightInfo& pb_info);
Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request);
Status FromProto(const pb::SchemaResult& pb_result, std::string* result);
Expand Down
Loading

0 comments on commit 9f183fc

Please sign in to comment.