Skip to content

Commit

Permalink
apacheGH-43677: [C++][FlightRPC] Move the FlightTestServer to its own…
Browse files Browse the repository at this point in the history
… .cc and .h files (apache#43678)

### Rationale for this change

One way of learning about a codebase is reading the tests. As it is now, it's hard to see the minimal `FlightServerBase` sub-class in `flight/test_util.cc`, so I moved it to its own file.

### What changes are included in this PR?

 - Renaming `FlightTestServer` to `TestFlightServer`
 - Moving the class to `test_flight_server.{h,cc}`
 - Bonus: Moving the server and client auth handlers to `test_auth_handlers.{h,cc}`

### Are these changes tested?

By existing tests.

### Are there any user-facing changes?

`ExampleTestServer` is removed from the testing library in favor of `FlightTestServer::Make`.
* GitHub Issue: apache#43677

Authored-by: Felipe Oliveira Carvalho <[email protected]>
Signed-off-by: Felipe Oliveira Carvalho <[email protected]>
  • Loading branch information
felipecrv authored Aug 14, 2024
1 parent 01fd7fc commit 69bce8f
Show file tree
Hide file tree
Showing 11 changed files with 759 additions and 560 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ if(ARROW_TESTING)
OUTPUTS
ARROW_FLIGHT_TESTING_LIBRARIES
SOURCES
test_auth_handlers.cc
test_definitions.cc
test_flight_server.cc
test_util.cc
DEPENDENCIES
flight_grpc_gen
Expand Down
8 changes: 5 additions & 3 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
// Include before test_util.h (boost), contains Windows fixes
#include "arrow/flight/platform.h"
#include "arrow/flight/serialization_internal.h"
#include "arrow/flight/test_auth_handlers.h"
#include "arrow/flight/test_definitions.h"
#include "arrow/flight/test_flight_server.h"
#include "arrow/flight/test_util.h"
// OTel includes must come after any gRPC includes, and
// client_header_internal.h includes gRPC. See:
Expand Down Expand Up @@ -247,7 +249,7 @@ TEST(TestFlight, ConnectUriUnix) {

// CI environments don't have an IPv6 interface configured
TEST(TestFlight, DISABLED_IpV6Port) {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("[::1]", 0));
FlightServerOptions options(location);
Expand All @@ -261,7 +263,7 @@ TEST(TestFlight, DISABLED_IpV6Port) {
}

TEST(TestFlight, ServerCallContextIncomingHeaders) {
auto server = ExampleTestServer();
auto server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
Expand Down Expand Up @@ -290,7 +292,7 @@ TEST(TestFlight, ServerCallContextIncomingHeaders) {
class TestFlightClient : public ::testing::Test {
public:
void SetUp() {
server_ = ExampleTestServer();
server_ = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
FlightServerOptions options(location);
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "arrow/flight/sql/server.h"
#include "arrow/flight/sql/server_session_middleware.h"
#include "arrow/flight/sql/types.h"
#include "arrow/flight/test_auth_handlers.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/types.h"
#include "arrow/ipc/dictionary.h"
Expand Down
141 changes: 141 additions & 0 deletions cpp/src/arrow/flight/test_auth_handlers.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <string>

#include "arrow/flight/client_auth.h"
#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/test_auth_handlers.h"
#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
#include "arrow/status.h"

namespace arrow::flight {

// TestServerAuthHandler

TestServerAuthHandler::TestServerAuthHandler(const std::string& username,
const std::string& password)
: username_(username), password_(password) {}

TestServerAuthHandler::~TestServerAuthHandler() {}

Status TestServerAuthHandler::Authenticate(const ServerCallContext& context,
ServerAuthSender* outgoing,
ServerAuthReader* incoming) {
std::string token;
RETURN_NOT_OK(incoming->Read(&token));
if (token != password_) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
RETURN_NOT_OK(outgoing->Write(username_));
return Status::OK();
}

Status TestServerAuthHandler::IsValid(const ServerCallContext& context,
const std::string& token,
std::string* peer_identity) {
if (token != password_) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
*peer_identity = username_;
return Status::OK();
}

// TestServerBasicAuthHandler

TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username,
const std::string& password) {
basic_auth_.username = username;
basic_auth_.password = password;
}

TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {}

Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext& context,
ServerAuthSender* outgoing,
ServerAuthReader* incoming) {
std::string token;
RETURN_NOT_OK(incoming->Read(&token));
ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth, BasicAuth::Deserialize(token));
if (incoming_auth.username != basic_auth_.username ||
incoming_auth.password != basic_auth_.password) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
RETURN_NOT_OK(outgoing->Write(basic_auth_.username));
return Status::OK();
}

Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context,
const std::string& token,
std::string* peer_identity) {
if (token != basic_auth_.username) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
*peer_identity = basic_auth_.username;
return Status::OK();
}

// TestClientAuthHandler

TestClientAuthHandler::TestClientAuthHandler(const std::string& username,
const std::string& password)
: username_(username), password_(password) {}

TestClientAuthHandler::~TestClientAuthHandler() {}

Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing,
ClientAuthReader* incoming) {
RETURN_NOT_OK(outgoing->Write(password_));
std::string username;
RETURN_NOT_OK(incoming->Read(&username));
if (username != username_) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
return Status::OK();
}

Status TestClientAuthHandler::GetToken(std::string* token) {
*token = password_;
return Status::OK();
}

// TestClientBasicAuthHandler

TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username,
const std::string& password) {
basic_auth_.username = username;
basic_auth_.password = password;
}

TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {}

Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing,
ClientAuthReader* incoming) {
ARROW_ASSIGN_OR_RAISE(std::string pb_result, basic_auth_.SerializeToString());
RETURN_NOT_OK(outgoing->Write(pb_result));
RETURN_NOT_OK(incoming->Read(&token_));
return Status::OK();
}

Status TestClientBasicAuthHandler::GetToken(std::string* token) {
*token = token_;
return Status::OK();
}

} // namespace arrow::flight
89 changes: 89 additions & 0 deletions cpp/src/arrow/flight/test_auth_handlers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <string>

#include "arrow/flight/client_auth.h"
#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
#include "arrow/status.h"

// A pair of authentication handlers that check for a predefined password
// and set the peer identity to a predefined username.

namespace arrow::flight {

class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler {
public:
explicit TestServerAuthHandler(const std::string& username,
const std::string& password);
~TestServerAuthHandler() override;
Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing,
ServerAuthReader* incoming) override;
Status IsValid(const ServerCallContext& context, const std::string& token,
std::string* peer_identity) override;

private:
std::string username_;
std::string password_;
};

class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler {
public:
explicit TestServerBasicAuthHandler(const std::string& username,
const std::string& password);
~TestServerBasicAuthHandler() override;
Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing,
ServerAuthReader* incoming) override;
Status IsValid(const ServerCallContext& context, const std::string& token,
std::string* peer_identity) override;

private:
BasicAuth basic_auth_;
};

class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler {
public:
explicit TestClientAuthHandler(const std::string& username,
const std::string& password);
~TestClientAuthHandler() override;
Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override;
Status GetToken(std::string* token) override;

private:
std::string username_;
std::string password_;
};

class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler {
public:
explicit TestClientBasicAuthHandler(const std::string& username,
const std::string& password);
~TestClientBasicAuthHandler() override;
Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override;
Status GetToken(std::string* token) override;

private:
BasicAuth basic_auth_;
std::string token_;
};

} // namespace arrow::flight
15 changes: 8 additions & 7 deletions cpp/src/arrow/flight/test_definitions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "arrow/array/util.h"
#include "arrow/flight/api.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/test_flight_server.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/types.h"
#include "arrow/flight/types_async.h"
Expand All @@ -53,15 +54,15 @@ using arrow::internal::checked_cast;
// Tests of initialization/shutdown

void ConnectivityTest::TestGetPort() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
ASSERT_GT(server->port(), 0);
}
void ConnectivityTest::TestBuilderHook() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
Expand All @@ -80,7 +81,7 @@ void ConnectivityTest::TestShutdown() {
constexpr int kIterations = 10;
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
for (int i = 0; i < kIterations; i++) {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
Expand All @@ -92,7 +93,7 @@ void ConnectivityTest::TestShutdown() {
}
}
void ConnectivityTest::TestShutdownWithDeadline() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
Expand All @@ -105,7 +106,7 @@ void ConnectivityTest::TestShutdownWithDeadline() {
ASSERT_OK(server->Wait());
}
void ConnectivityTest::TestBrokenConnection() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
Expand Down Expand Up @@ -151,7 +152,7 @@ class GetFlightInfoListener : public AsyncListener<FlightInfo> {
} // namespace

void DataTest::SetUpTest() {
server_ = ExampleTestServer();
server_ = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
Expand Down Expand Up @@ -1822,7 +1823,7 @@ void AsyncClientTest::SetUpTest() {

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));

server_ = ExampleTestServer();
server_ = TestFlightServer::Make();
FlightServerOptions server_options(location);
ASSERT_OK(server_->Init(server_options));

Expand Down
Loading

0 comments on commit 69bce8f

Please sign in to comment.