Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

alts: Fix TsiSocket doWrite on short writes #15962

Merged
merged 7 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "common/buffer/buffer_impl.h"
#include "common/common/assert.h"

#include "grpc/slice_buffer.h"
#include "src/core/tsi/transport_security_grpc.h"
#include "src/core/tsi/transport_security_interface.h"

Expand All @@ -15,16 +14,14 @@ namespace Alts {
TsiFrameProtector::TsiFrameProtector(CFrameProtectorPtr&& frame_protector)
: frame_protector_(std::move(frame_protector)) {}

tsi_result TsiFrameProtector::protect(Buffer::Instance& input, Buffer::Instance& output) {
tsi_result TsiFrameProtector::protect(const grpc_slice& input_slice, Buffer::Instance& output) {
ASSERT(frame_protector_);

if (input.length() == 0) {
if (GRPC_SLICE_LENGTH(input_slice) == 0) {
return TSI_OK;
}

grpc_core::ExecCtx exec_ctx;
grpc_slice input_slice = grpc_slice_from_copied_buffer(
reinterpret_cast<char*>(input.linearize(input.length())), input.length());

grpc_slice_buffer message_buffer;
grpc_slice_buffer_init(&message_buffer);
Expand Down Expand Up @@ -58,7 +55,6 @@ tsi_result TsiFrameProtector::protect(Buffer::Instance& input, Buffer::Instance&
});

output.addBufferFragment(*fragment);
input.drain(input.length());

grpc_slice_buffer_destroy(&message_buffer);
grpc_slice_buffer_destroy(&protected_buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "extensions/transport_sockets/alts/grpc_tsi.h"

#include "grpc/slice_buffer.h"

namespace Envoy {
namespace Extensions {
namespace TransportSockets {
Expand All @@ -20,11 +22,12 @@ class TsiFrameProtector final {

/**
* Wrapper for tsi_frame_protector_protect
* @param input supplies the input data to protect, the method will drain it when it is processed.
* @param input_slice supplies the input data to protect. Its ownership will
* be transferred.
* @param output supplies the buffer where the protected data will be stored.
* @return tsi_result the status.
*/
tsi_result protect(Buffer::Instance& input, Buffer::Instance& output);
tsi_result protect(const grpc_slice& input_slice, Buffer::Instance& output);

/**
* Wrapper for tsi_frame_protector_unprotect
Expand Down
92 changes: 76 additions & 16 deletions source/extensions/transport_sockets/alts/tsi_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result

// Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack.
if (raw_write_buffer_.length() > 0) {
return raw_buffer_socket_->doWrite(raw_write_buffer_, false).action_;
Network::IoResult result = raw_buffer_socket_->doWrite(raw_write_buffer_, false);
if (handshake_complete_ && raw_write_buffer_.length() > 0) {
write_buffer_contains_handshake_bytes_ = true;
}
return result.action_;
}

return Network::PostIoAction::KeepOpen;
Expand Down Expand Up @@ -259,28 +263,84 @@ Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) {
return repeatReadAndUnprotect(buffer, result);
}

Network::IoResult TsiSocket::repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream) {
uint64_t total_bytes_written = 0;
Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};

ASSERT(!write_buffer_contains_handshake_bytes_);
while (true) {
uint64_t bytes_to_drain_this_iteration =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an odd case that merits special consideration:

prev_bytes_to_drain_ == 0 && raw_write_buffer_.length() > 0

I think this can happen in the case where doHandshake adds bytes to raw_write_buffer_ but also completes the handshake. When this happens, the protect call below is skipped, but bytes_to_drain_this_iteration ends up being > 0, which could result in bytes in the input buffer being discarded without being sent.

Ways to detect:
ASSERT((prev_bytes_to_drain_ == 0) == (raw_write_buffer_.length() == 0);
ASSERT(prev_bytes_to_drain_ >= buffer.length());
ASSERT(buffer.length() >= bytes_to_drain_this_iteration) before the call to buffer.drain() further down
Possibly adding an ASSERT to OwnedImpl::drainImpl to detect attempts to drain more bytes than are in the buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per this code, I think we have a guarantee that raw_write_buffer_.length() is 0 when entering doWrite() for the first time after handshake completes. Also, it does not make sense that a peer wants to send non-handshake data without first confirming if the handshake completes successfully. In other words, during handshake, peer A will send whatever it receives from peer B to the handshake service in order to get the bytes to send to peer B. Here, peer A will not concatenate any non-handshake data to the data received from peer B, and send them to the handshake service because peer A has not received any confirmation from the handshake service that handshake completes successfully.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raw_write_buffer_ will be non-empty after the doWrite just after handshake completes if that doWrite results in a partial write. I know this may be really unlikely but it is possible for raw_write_buffer_ to be non-empty after handshake completes. I think it also true that the peer that completes the handshake first will need to do a write after handshake completes locally to provide the remote peer the information it needs to complete its own handshake.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible solution: branch on raw_write_buffer_.length() > 0 instead of prev_bytes_to_drain_.

When raw_write_buffer_.length() > 0 then bytes_to_drain_this_iteration = prev_bytes_to_drain_ and we should attempt a write even if bytes_to_drain_this_iteration is 0 which would happen if the bytes in the buffer are handshake bytes.

When raw_write_buffer_.length() > 0 then bytes_to_drain_this_iteration = std::min(buffer.length(), max_unprotected_frame_size_). If >0, attempt to protect and write those bytes, else break.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a slightly different approach by introducing a new field - prev_handshake_bytes_to_drain_ that indicates if we need to drain handshake data before doing regular protect+write operations. PTAL.

prev_bytes_to_drain_ > 0
? prev_bytes_to_drain_
: std::min(buffer.length(), actual_frame_size_to_use_ - frame_overhead_size_);
// Consumed all data. Exit.
if (bytes_to_drain_this_iteration == 0) {
break;
}
// Short write did not occur previously.
if (raw_write_buffer_.length() == 0) {
ASSERT(frame_protector_);
ASSERT(prev_bytes_to_drain_ == 0);

// Do protect.
ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(),
bytes_to_drain_this_iteration);
tsi_result status = frame_protector_->protect(
grpc_slice_from_static_buffer(buffer.linearize(bytes_to_drain_this_iteration),
bytes_to_drain_this_iteration),
raw_write_buffer_);
ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(),
bytes_to_drain_this_iteration, tsi_result_to_string(status));
}

// Write raw_write_buffer_ to network.
ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
raw_write_buffer_.length(), end_stream);
result = raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));

// Short write. Exit.
if (raw_write_buffer_.length() > 0) {
prev_bytes_to_drain_ = bytes_to_drain_this_iteration;
break;
} else {
buffer.drain(bytes_to_drain_this_iteration);
prev_bytes_to_drain_ = 0;
total_bytes_written += bytes_to_drain_this_iteration;
}
}

return {result.action_, total_bytes_written, false};
}

Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) {
if (!handshake_complete_) {
Network::PostIoAction action = doHandshake();
// Envoy ALTS implements asynchronous tsi_handshaker_next() interface
// which returns immediately after scheduling a handshake request to
// the handshake service. The handshake response will be handled by a
// dedicated thread in a seperate API within which handshake_complete_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failing format checks due to spell error.

seperate -> separate

// will be set to true if the handshake completes.
ASSERT(!handshake_complete_);
ASSERT(action == Network::PostIoAction::KeepOpen);
// TODO(lizan): Handle synchronous handshake when TsiHandshaker supports it.
}

if (handshake_complete_) {
return {Network::PostIoAction::KeepOpen, 0, false};
} else {
ASSERT(frame_protector_);
ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(),
buffer.length());
tsi_result status = frame_protector_->protect(buffer, raw_write_buffer_);
ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(),
buffer.length(), tsi_result_to_string(status));
}

if (raw_write_buffer_.length() > 0) {
ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
raw_write_buffer_.length(), end_stream);
return raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
// Check if we need to flush outstanding handshake bytes.
if (write_buffer_contains_handshake_bytes_) {
ASSERT(raw_write_buffer_.length() > 0);
ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
raw_write_buffer_.length(), end_stream);
Network::IoResult result =
raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
// Check if short write occurred.
if (raw_write_buffer_.length() > 0) {
return {result.action_, 0, false};
}
write_buffer_contains_handshake_bytes_ = false;
}
return repeatProtectAndWrite(buffer, end_stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the raw_write_buffer_.length() > 0 code below this early return?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need it to send handshake data to its peer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but think that it could be moved to the "if (!handshake_complete_) {" branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
return {Network::PostIoAction::KeepOpen, 0, false};
}

void TsiSocket::closeSocket(Network::ConnectionEvent) {
Expand Down
13 changes: 13 additions & 0 deletions source/extensions/transport_sockets/alts/tsi_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class TsiSocket : public Network::TransportSocket,

// This API should be called only after ALTS handshake finishes successfully.
size_t actualFrameSizeToUse() { return actual_frame_size_to_use_; }
// Set actual_frame_size_to_use_. Exposed for testing purpose.
void setActualFrameSizeToUse(size_t frame_size) { actual_frame_size_to_use_ = frame_size; }
// Set frame_overhead_size_. Exposed for testing purpose.
void setFrameOverheadSize(size_t overhead_size) { frame_overhead_size_ = overhead_size; }

private:
Network::PostIoAction doHandshake();
Expand All @@ -82,6 +86,8 @@ class TsiSocket : public Network::TransportSocket,

// Helper function to perform repeated read and unprotect operations.
Network::IoResult repeatReadAndUnprotect(Buffer::Instance& buffer, Network::IoResult prev_result);
// Helper function to perform repeated protect and write operations.
Network::IoResult repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream);
// Helper function to read from a raw socket and update status.
Network::IoResult readFromRawSocket();

Expand All @@ -97,6 +103,11 @@ class TsiSocket : public Network::TransportSocket,
// actual_frame_size_to_use_ is the actual frame size used by
// frame protector, which is the result of frame size negotiation.
size_t actual_frame_size_to_use_{0};
// frame_overhead_size_ includes 4 bytes frame message type and 16 bytes tag length.
// It is consistent with gRPC ALTS zero copy frame protector implementation.
// The maximum size of data that can be protected for each frame is equal to
// actual_frame_size_to_use_ - frame_overhead_size_.
size_t frame_overhead_size_{20};

Envoy::Network::TransportSocketCallbacks* callbacks_{};
std::unique_ptr<TsiTransportSocketCallbacks> tsi_callbacks_;
Expand All @@ -107,6 +118,8 @@ class TsiSocket : public Network::TransportSocket,
bool handshake_complete_{};
bool end_stream_read_{};
bool read_error_{};
bool write_buffer_contains_handshake_bytes_{};
uint64_t prev_bytes_to_drain_{};
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class AltsIntegrationTestBase : public Event::TestUsingSimulatedTime,
Network::Address::InstanceConstSharedPtr address = getAddress(version_, lookupPort("http"));
auto client_transport_socket = client_alts_->createTransportSocket(nullptr);
client_tsi_socket_ = dynamic_cast<TsiSocket*>(client_transport_socket.get());
client_tsi_socket_->setActualFrameSizeToUse(16384);
client_tsi_socket_->setFrameOverheadSize(4);
return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(),
std::move(client_transport_socket), nullptr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,28 @@ class TsiFrameProtectorTest : public testing::Test {

TEST_F(TsiFrameProtectorTest, Protect) {
{
Buffer::OwnedImpl input, encrypted;
input.add("foo");
Buffer::OwnedImpl encrypted;

EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted));
EXPECT_EQ(TSI_OK, frame_protector_.protect(grpc_slice_from_static_string("foo"), encrypted));
EXPECT_EQ("\x07\0\0\0foo"s, encrypted.toString());
}

{
Buffer::OwnedImpl input, encrypted;
input.add("foo");
Buffer::OwnedImpl encrypted;

EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted));
EXPECT_EQ(TSI_OK, frame_protector_.protect(grpc_slice_from_static_string("foo"), encrypted));
EXPECT_EQ("\x07\0\0\0foo"s, encrypted.toString());

input.add("bar");
EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted));
EXPECT_EQ(TSI_OK, frame_protector_.protect(grpc_slice_from_static_string("bar"), encrypted));
EXPECT_EQ("\x07\0\0\0foo\x07\0\0\0bar"s, encrypted.toString());
}

{
Buffer::OwnedImpl input, encrypted;
input.add(std::string(20000, 'a'));
Buffer::OwnedImpl encrypted;

EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted));
EXPECT_EQ(TSI_OK,
frame_protector_.protect(
grpc_slice_from_static_string(std::string(20000, 'a').c_str()), encrypted));

// fake frame protector will split long buffer to 2 "encrypted" frames with length 16K.
std::string expected =
Expand All @@ -71,10 +69,10 @@ TEST_F(TsiFrameProtectorTest, ProtectError) {
};
raw_frame_protector_->vtable = &mock_vtable;

Buffer::OwnedImpl input, encrypted;
input.add("foo");
Buffer::OwnedImpl encrypted;

EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.protect(input, encrypted));
EXPECT_EQ(TSI_INTERNAL_ERROR,
frame_protector_.protect(grpc_slice_from_static_string("foo"), encrypted));

raw_frame_protector_->vtable = vtable;
}
Expand Down
Loading