Skip to content

Commit

Permalink
pw_rpc: Guard the on_error callback
Browse files Browse the repository at this point in the history
- Move the on_error callback to a local variable while the lock is held,
  then invoke it.
- Have decoding errors terminate a call. They call the on_error
  callback currently, which should only be called at most once per call.
- Use the call's active() flag instead of a separate active_ variable in
  client call tests.

Bug: b/234876851
Change-Id: I785dd00813ebb0ab4743c7498ec3bcfaecc2075d
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/125250
Pigweed-Auto-Submit: Wyatt Hepler <[email protected]>
Commit-Queue: Auto-Submit <[email protected]>
Reviewed-by: Alexei Frolov <[email protected]>
  • Loading branch information
255 authored and CQ Bot Account committed Jan 4, 2023
1 parent 72b5d42 commit ca6e845
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 79 deletions.
41 changes: 11 additions & 30 deletions pw_rpc/nanopb/client_call_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ TEST_F(UnaryClientCall, OnlyReceivesOneResponse) {

class ServerStreamingClientCall : public ::testing::Test {
protected:
bool active_ = true;
std::optional<Status> stream_status_;
std::optional<Status> rpc_error_;
int responses_received_ = 0;
Expand Down Expand Up @@ -290,26 +289,23 @@ TEST_F(ServerStreamingClientCall, InvokesCallbackOnValidResponse) {
++responses_received_;
last_response_number_ = response.number;
},
[this](Status status) {
active_ = false;
stream_status_ = status;
});
[this](Status status) { stream_status_ = status; });

PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());
EXPECT_EQ(responses_received_, 1);
EXPECT_EQ(last_response_number_, 11);

PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());
EXPECT_EQ(responses_received_, 2);
EXPECT_EQ(last_response_number_, 22);

PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());
EXPECT_EQ(responses_received_, 3);
EXPECT_EQ(last_response_number_, 33);
}
Expand All @@ -325,30 +321,27 @@ TEST_F(ServerStreamingClientCall, InvokesStreamEndOnFinish) {
++responses_received_;
last_response_number_ = response.number;
},
[this](Status status) {
active_ = false;
stream_status_ = status;
});
[this](Status status) { stream_status_ = status; });

PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());

PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());

// Close the stream.
EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound()));

PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
EXPECT_FALSE(active_);
EXPECT_FALSE(call.active());

EXPECT_EQ(responses_received_, 2);
}

TEST_F(ServerStreamingClientCall, InvokesErrorCallbackOnInvalidResponses) {
TEST_F(ServerStreamingClientCall, ParseErrorTerminatesCallWithDataLoss) {
ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;

auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
Expand All @@ -364,28 +357,16 @@ TEST_F(ServerStreamingClientCall, InvokesErrorCallbackOnInvalidResponses) {

PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());
EXPECT_EQ(responses_received_, 1);
EXPECT_EQ(last_response_number_, 11);

constexpr std::byte bad_payload[]{
std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
EXPECT_EQ(OkStatus(), context.SendServerStream(bad_payload));
EXPECT_FALSE(call.active());
EXPECT_EQ(responses_received_, 1);
ASSERT_TRUE(rpc_error_.has_value());
EXPECT_EQ(rpc_error_, Status::DataLoss());

PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
EXPECT_TRUE(active_);
EXPECT_EQ(responses_received_, 2);
EXPECT_EQ(last_response_number_, 22);

EXPECT_EQ(OkStatus(),
context.SendPacket(internal::pwpb::PacketType::SERVER_ERROR,
Status::NotFound()));
EXPECT_EQ(responses_received_, 2);
EXPECT_EQ(rpc_error_, Status::NotFound());
}

} // namespace
Expand Down
6 changes: 2 additions & 4 deletions pw_rpc/nanopb/public/pw_rpc/nanopb/client_reader_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,8 @@ class NanopbUnaryResponseClientCall : public UnaryResponseClientCall {
if (serde_->DecodeResponse(payload, &response_struct)) {
nanopb_on_completed_local(response_struct, status);
} else {
// TODO(hepler): This should send a DATA_LOSS error and call the
// error callback.
rpc_lock().lock();
CallOnError(Status::DataLoss());
HandleError(Status::DataLoss());
}
}
});
Expand Down Expand Up @@ -211,7 +209,7 @@ class NanopbStreamResponseClientCall : public StreamResponseClientCall {
// TODO(hepler): This should send a DATA_LOSS error and call the
// error callback.
rpc_lock().lock();
CallOnError(Status::DataLoss());
HandleError(Status::DataLoss());
}
}
});
Expand Down
26 changes: 12 additions & 14 deletions pw_rpc/public/pw_rpc/internal/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,19 +259,6 @@ class Call : public IntrusiveList<Call>::Item {
on_error_ = std::move(on_error);
}

// Calls the on_error callback without closing the RPC. This is used when the
// call has already completed.
void CallOnError(Status error) PW_UNLOCK_FUNCTION(rpc_lock()) {
const bool invoke = on_error_ != nullptr;

// TODO(b/234876851): Ensure on_error_ is properly guarded.

rpc_lock().unlock();
if (invoke) {
on_error_(error);
}
}

void MarkClientStreamCompleted() PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
client_stream_state_ = kClientStreamInactive;
}
Expand Down Expand Up @@ -329,6 +316,17 @@ class Call : public IntrusiveList<Call>::Item {
client_stream_state_ = kClientStreamInactive;
}

// Calls the on_error callback without closing the RPC. This is used when the
// call has already completed.
void CallOnError(Status error) PW_UNLOCK_FUNCTION(rpc_lock()) {
auto on_error_local = std::move(on_error_);

rpc_lock().unlock();
if (on_error_local) {
on_error_local(error);
}
}

// Sends a payload with the specified type. The payload may either be in a
// previously acquired buffer or in a standalone buffer.
//
Expand Down Expand Up @@ -358,7 +356,7 @@ class Call : public IntrusiveList<Call>::Item {
} client_stream_state_ PW_GUARDED_BY(rpc_lock());

// Called when the RPC is terminated due to an error.
Function<void(Status error)> on_error_;
Function<void(Status error)> on_error_ PW_GUARDED_BY(rpc_lock());

// Called when a request is received. Only used for RPCs with client streams.
// The raw payload buffer is passed to the callback.
Expand Down
40 changes: 11 additions & 29 deletions pw_rpc/pwpb/client_call_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ TEST_F(UnaryClientCall, OnlyReceivesOneResponse) {

class ServerStreamingClientCall : public ::testing::Test {
protected:
bool active_ = true;
std::optional<Status> stream_status_;
std::optional<Status> rpc_error_;
int responses_received_ = 0;
Expand Down Expand Up @@ -291,26 +290,23 @@ TEST_F(ServerStreamingClientCall, InvokesCallbackOnValidResponse) {
++responses_received_;
last_response_number_ = response.number;
},
[this](Status status) {
active_ = false;
stream_status_ = status;
});
[this](Status status) { stream_status_ = status; });

PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());
EXPECT_EQ(responses_received_, 1);
EXPECT_EQ(last_response_number_, 11);

PW_ENCODE_PB(TestStreamResponse, r2, .chunk = {}, .number = 22u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());
EXPECT_EQ(responses_received_, 2);
EXPECT_EQ(last_response_number_, 22);

PW_ENCODE_PB(TestStreamResponse, r3, .chunk = {}, .number = 33u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());
EXPECT_EQ(responses_received_, 3);
EXPECT_EQ(last_response_number_, 33);
}
Expand All @@ -326,30 +322,27 @@ TEST_F(ServerStreamingClientCall, InvokesStreamEndOnFinish) {
++responses_received_;
last_response_number_ = response.number;
},
[this](Status status) {
active_ = false;
stream_status_ = status;
});
[this](Status status) { stream_status_ = status; });

PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());

PW_ENCODE_PB(TestStreamResponse, r2, .chunk = {}, .number = 22u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());

// Close the stream.
EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound()));

PW_ENCODE_PB(TestStreamResponse, r3, .chunk = {}, .number = 33u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
EXPECT_FALSE(active_);
EXPECT_FALSE(call.active());

EXPECT_EQ(responses_received_, 2);
}

TEST_F(ServerStreamingClientCall, InvokesErrorCallbackOnInvalidResponses) {
TEST_F(ServerStreamingClientCall, ParseErrorTerminatesCallWithDataLoss) {
ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;

auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
Expand All @@ -365,27 +358,16 @@ TEST_F(ServerStreamingClientCall, InvokesErrorCallbackOnInvalidResponses) {

PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
EXPECT_TRUE(active_);
EXPECT_TRUE(call.active());
EXPECT_EQ(responses_received_, 1);
EXPECT_EQ(last_response_number_, 11);

constexpr std::byte bad_payload[]{
std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
EXPECT_EQ(OkStatus(), context.SendServerStream(bad_payload));
EXPECT_FALSE(call.active());
EXPECT_EQ(responses_received_, 1);
ASSERT_TRUE(rpc_error_.has_value());
EXPECT_EQ(rpc_error_, Status::DataLoss());

PW_ENCODE_PB(TestStreamResponse, r2, .chunk = {}, .number = 22u);
EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
EXPECT_TRUE(active_);
EXPECT_EQ(responses_received_, 2);
EXPECT_EQ(last_response_number_, 22);

EXPECT_EQ(OkStatus(),
context.SendPacket(PacketType::SERVER_ERROR, Status::NotFound()));
EXPECT_EQ(responses_received_, 2);
EXPECT_EQ(rpc_error_, Status::NotFound());
}

} // namespace
Expand Down
4 changes: 2 additions & 2 deletions pw_rpc/pwpb/public/pw_rpc/pwpb/client_reader_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class PwpbUnaryResponseClientCall : public UnaryResponseClientCall {
pwpb_on_completed_local(response, status);
} else {
rpc_lock().lock();
CallOnError(Status::DataLoss());
HandleError(Status::DataLoss());
}
}
});
Expand Down Expand Up @@ -274,7 +274,7 @@ class PwpbStreamResponseClientCall : public StreamResponseClientCall {
pwpb_on_next_(response);
} else {
rpc_lock().lock();
CallOnError(Status::DataLoss());
HandleError(Status::DataLoss());
}
}
});
Expand Down

0 comments on commit ca6e845

Please sign in to comment.