Skip to content

Commit

Permalink
Implement S/G IO for batched sends and eliminate another frame copy (L…
Browse files Browse the repository at this point in the history
  • Loading branch information
cgutman authored and KuleRucket committed Oct 9, 2024
1 parent ed2b1f9 commit a8c5067
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 56 deletions.
15 changes: 9 additions & 6 deletions src/crypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ namespace crypto {
* The resulting ciphertext and the GCM tag are written into the tagged_cipher buffer.
*/
int
gcm_t::encrypt(const std::string_view &plaintext, std::uint8_t *tagged_cipher, aes_t *iv) {
gcm_t::encrypt(const std::string_view &plaintext, std::uint8_t *tag, std::uint8_t *ciphertext, aes_t *iv) {
if (!encrypt_ctx && init_encrypt_gcm(encrypt_ctx, &key, iv, padding)) {
return -1;
}
Expand All @@ -196,18 +196,15 @@ namespace crypto {
return -1;
}

auto tag = tagged_cipher;
auto cipher = tag + tag_size;

int update_outlen, final_outlen;

// Encrypt into the caller's buffer
if (EVP_EncryptUpdate(encrypt_ctx.get(), cipher, &update_outlen, (const std::uint8_t *) plaintext.data(), plaintext.size()) != 1) {
if (EVP_EncryptUpdate(encrypt_ctx.get(), ciphertext, &update_outlen, (const std::uint8_t *) plaintext.data(), plaintext.size()) != 1) {
return -1;
}

// GCM encryption won't ever fill ciphertext here but we have to call it anyway
if (EVP_EncryptFinal_ex(encrypt_ctx.get(), cipher + update_outlen, &final_outlen) != 1) {
if (EVP_EncryptFinal_ex(encrypt_ctx.get(), ciphertext + update_outlen, &final_outlen) != 1) {
return -1;
}

Expand All @@ -218,6 +215,12 @@ namespace crypto {
return update_outlen + final_outlen;
}

int
gcm_t::encrypt(const std::string_view &plaintext, std::uint8_t *tagged_cipher, aes_t *iv) {
// This overload handles the common case of [GCM tag][cipher text] buffer layout
return encrypt(plaintext, tagged_cipher, tagged_cipher + tag_size, iv);
}

int
ecb_t::decrypt(const std::string_view &cipher, std::vector<std::uint8_t> &plaintext) {
auto fg = util::fail_guard([this]() {
Expand Down
11 changes: 11 additions & 0 deletions src/crypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ namespace crypto {

gcm_t(const crypto::aes_t &key, bool padding = true);

/**
* @brief Encrypts the plaintext using AES GCM mode.
* @param plaintext The plaintext data to be encrypted.
* @param tag The buffer where the GCM tag will be written.
* @param ciphertext The buffer where the resulting ciphertext will be written.
* @param iv The initialization vector to be used for the encryption.
* @return The total length of the ciphertext and GCM tag. Returns -1 in case of an error.
*/
int
encrypt(const std::string_view &plaintext, std::uint8_t *tag, std::uint8_t *ciphertext, aes_t *iv);

/**
* @brief Encrypts the plaintext using AES GCM mode.
* length of cipher must be at least: round_to_pkcs7_padded(plaintext.size()) + crypto::cipher::tag_size
Expand Down
49 changes: 47 additions & 2 deletions src/platform/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,15 +606,60 @@ namespace platf {
void
restart();

struct batched_send_info_t {
struct buffer_descriptor_t {
const char *buffer;
size_t block_size;
size_t size;

// Constructors required for emplace_back() prior to C++20
buffer_descriptor_t(const char *buffer, size_t size):
buffer(buffer), size(size) {}
buffer_descriptor_t():
buffer(nullptr), size(0) {}
};

struct batched_send_info_t {
// Optional headers to be prepended to each packet
const char *headers;
size_t header_size;

// One or more data buffers to use for the payloads
//
// NB: Data buffers must be aligned to payload size!
std::vector<buffer_descriptor_t> &payload_buffers;
size_t payload_size;

// The offset (in header+payload message blocks) in the header and payload
// buffers to begin sending messages from
size_t block_offset;

// The number of header+payload message blocks to send
size_t block_count;

std::uintptr_t native_socket;
boost::asio::ip::address &target_address;
uint16_t target_port;
boost::asio::ip::address &source_address;

/**
* @brief Returns a payload buffer descriptor for the given payload offset.
* @param offset The offset in the total payload data (bytes).
* @return Buffer descriptor describing the region at the given offset.
*/
buffer_descriptor_t
buffer_for_payload_offset(ptrdiff_t offset) {
for (const auto &desc : payload_buffers) {
if (offset < desc.size) {
return {
desc.buffer + offset,
desc.size - offset,
};
}
else {
offset -= desc.size;
}
}
return {};
}
};
bool
send_batch(batched_send_info_t &send_info);
Expand Down
70 changes: 52 additions & 18 deletions src/platform/linux/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,30 +433,56 @@ namespace platf {
memcpy(CMSG_DATA(pktinfo_cm), &pktInfo, sizeof(pktInfo));
}

auto const max_iovs_per_msg = send_info.payload_buffers.size() + (send_info.headers ? 1 : 0);

#ifdef UDP_SEGMENT
{
struct iovec iov = {};

msg.msg_iov = &iov;
msg.msg_iovlen = 1;

// UDP GSO on Linux currently only supports sending 64K or 64 segments at a time
size_t seg_index = 0;
const size_t seg_max = 65536 / 1500;
struct iovec iovs[(send_info.headers ? std::min(seg_max, send_info.block_count) : 1) * max_iovs_per_msg] = {};
auto msg_size = send_info.header_size + send_info.payload_size;
while (seg_index < send_info.block_count) {
iov.iov_base = (void *) &send_info.buffer[seg_index * send_info.block_size];
iov.iov_len = send_info.block_size * std::min(send_info.block_count - seg_index, seg_max);
int iovlen = 0;
auto segs_in_batch = std::min(send_info.block_count - seg_index, seg_max);
if (send_info.headers) {
// Interleave iovs for headers and payloads
for (auto i = 0; i < segs_in_batch; i++) {
iovs[iovlen].iov_base = (void *) &send_info.headers[(send_info.block_offset + seg_index + i) * send_info.header_size];
iovs[iovlen].iov_len = send_info.header_size;
iovlen++;
auto payload_desc = send_info.buffer_for_payload_offset((send_info.block_offset + seg_index + i) * send_info.payload_size);
iovs[iovlen].iov_base = (void *) payload_desc.buffer;
iovs[iovlen].iov_len = send_info.payload_size;
iovlen++;
}
}
else {
// Translate buffer descriptors into iovs
auto payload_offset = (send_info.block_offset + seg_index) * send_info.payload_size;
auto payload_length = payload_offset + (segs_in_batch * send_info.payload_size);
while (payload_offset < payload_length) {
auto payload_desc = send_info.buffer_for_payload_offset(payload_offset);
iovs[iovlen].iov_base = (void *) payload_desc.buffer;
iovs[iovlen].iov_len = std::min(payload_desc.size, payload_length - payload_offset);
payload_offset += iovs[iovlen].iov_len;
iovlen++;
}
}

msg.msg_iov = iovs;
msg.msg_iovlen = iovlen;

// We should not use GSO if the data is <= one full block size
if (iov.iov_len > send_info.block_size) {
if (segs_in_batch > 1) {
msg.msg_controllen = cmbuflen + CMSG_SPACE(sizeof(uint16_t));

// Enable GSO to perform segmentation of our buffer for us
auto cm = CMSG_NXTHDR(&msg, pktinfo_cm);
cm->cmsg_level = SOL_UDP;
cm->cmsg_type = UDP_SEGMENT;
cm->cmsg_len = CMSG_LEN(sizeof(uint16_t));
*((uint16_t *) CMSG_DATA(cm)) = send_info.block_size;
*((uint16_t *) CMSG_DATA(cm)) = msg_size;
}
else {
msg.msg_controllen = cmbuflen;
Expand All @@ -483,10 +509,11 @@ namespace platf {
continue;
}

BOOST_LOG(verbose) << "sendmsg() failed: "sv << errno;
break;
}

seg_index += bytes_sent / send_info.block_size;
seg_index += bytes_sent / msg_size;
}

// If we sent something, return the status and don't fall back to the non-GSO path.
Expand All @@ -498,18 +525,25 @@ namespace platf {

{
// If GSO is not supported, use sendmmsg() instead.
struct mmsghdr msgs[send_info.block_count];
struct iovec iovs[send_info.block_count];
struct mmsghdr msgs[send_info.block_count] = {};
struct iovec iovs[send_info.block_count * (send_info.headers ? 2 : 1)] = {};
int iov_idx = 0;
for (size_t i = 0; i < send_info.block_count; i++) {
iovs[i] = {};
iovs[i].iov_base = (void *) &send_info.buffer[i * send_info.block_size];
iovs[i].iov_len = send_info.block_size;
msgs[i].msg_hdr.msg_iov = &iovs[iov_idx];
msgs[i].msg_hdr.msg_iovlen = send_info.headers ? 2 : 1;

if (send_info.headers) {
iovs[iov_idx].iov_base = (void *) &send_info.headers[(send_info.block_offset + i) * send_info.header_size];
iovs[iov_idx].iov_len = send_info.header_size;
iov_idx++;
}
auto payload_desc = send_info.buffer_for_payload_offset((send_info.block_offset + i) * send_info.payload_size);
iovs[iov_idx].iov_base = (void *) payload_desc.buffer;
iovs[iov_idx].iov_len = send_info.payload_size;
iov_idx++;

msgs[i] = {};
msgs[i].msg_hdr.msg_name = msg.msg_name;
msgs[i].msg_hdr.msg_namelen = msg.msg_namelen;
msgs[i].msg_hdr.msg_iov = &iovs[i];
msgs[i].msg_hdr.msg_iovlen = 1;
msgs[i].msg_hdr.msg_control = cmbuf.buf;
msgs[i].msg_hdr.msg_controllen = cmbuflen;
}
Expand Down
37 changes: 31 additions & 6 deletions src/platform/windows/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,12 +1452,37 @@ namespace platf {
msg.namelen = sizeof(taddr_v4);
}

WSABUF buf;
buf.buf = (char *) send_info.buffer;
buf.len = send_info.block_size * send_info.block_count;
auto const max_bufs_per_msg = send_info.payload_buffers.size() + (send_info.headers ? 1 : 0);

msg.lpBuffers = &buf;
msg.dwBufferCount = 1;
WSABUF bufs[(send_info.headers ? send_info.block_count : 1) * max_bufs_per_msg];
DWORD bufcount = 0;
if (send_info.headers) {
// Interleave buffers for headers and payloads
for (auto i = 0; i < send_info.block_count; i++) {
bufs[bufcount].buf = (char *) &send_info.headers[(send_info.block_offset + i) * send_info.header_size];
bufs[bufcount].len = send_info.header_size;
bufcount++;
auto payload_desc = send_info.buffer_for_payload_offset((send_info.block_offset + i) * send_info.payload_size);
bufs[bufcount].buf = (char *) payload_desc.buffer;
bufs[bufcount].len = send_info.payload_size;
bufcount++;
}
}
else {
// Translate buffer descriptors into WSABUFs
auto payload_offset = send_info.block_offset * send_info.payload_size;
auto payload_length = payload_offset + (send_info.block_count * send_info.payload_size);
while (payload_offset < payload_length) {
auto payload_desc = send_info.buffer_for_payload_offset(payload_offset);
bufs[bufcount].buf = (char *) payload_desc.buffer;
bufs[bufcount].len = std::min(payload_desc.size, payload_length - payload_offset);
payload_offset += bufs[bufcount].len;
bufcount++;
}
}

msg.lpBuffers = bufs;
msg.dwBufferCount = bufcount;
msg.dwFlags = 0;

// At most, one DWORD option and one PKTINFO option
Expand Down Expand Up @@ -1505,7 +1530,7 @@ namespace platf {
cm->cmsg_level = IPPROTO_UDP;
cm->cmsg_type = UDP_SEND_MSG_SIZE;
cm->cmsg_len = WSA_CMSG_LEN(sizeof(DWORD));
*((DWORD *) WSA_CMSG_DATA(cm)) = send_info.block_size;
*((DWORD *) WSA_CMSG_DATA(cm)) = send_info.header_size + send_info.payload_size;
}

msg.Control.len = cmbuflen;
Expand Down
Loading

0 comments on commit a8c5067

Please sign in to comment.