Skip to content

Commit

Permalink
dnsdist: Avoid a few more allocations in the DoQ code
Browse files Browse the repository at this point in the history
  • Loading branch information
rgacogne committed Dec 26, 2023
1 parent ddc643a commit 4cfae5d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 42 deletions.
34 changes: 19 additions & 15 deletions pdns/dnsdistdist/doh3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ static void flushStalledResponses(H3Connection& conn)
}
}

static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event)
static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event)
{
auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
DEBUGLOG(msg);
Expand Down Expand Up @@ -719,7 +719,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
handleImmediateError("Unsupported HTTP method");
}

static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event)
static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event, PacketBuffer& buffer)
{
auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
DEBUGLOG(msg);
Expand All @@ -739,14 +739,14 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
return;
}

PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
buffer.resize(std::numeric_limits<uint16_t>::max());
auto& streamBuffer = conn.d_streamBuffers[streamID];

while (true) {
buffer.resize(std::numeric_limits<uint16_t>::max());
ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
conn.d_conn.get(), streamID,
buffer.data(), buffer.capacity());
buffer.data(), buffer.size());

if (len <= 0) {
break;
Expand All @@ -771,7 +771,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
conn.d_streamBuffers.erase(streamID);
}

static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID)
static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, PacketBuffer& buffer)
{
std::map<std::string, std::string> headers;
while (true) {
Expand All @@ -791,7 +791,7 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3
break;
}
case QUICHE_H3_EVENT_DATA: {
processH3DataEvent(clientState, frontend, conn, client, serverConnID, headers, streamID, event);
processH3DataEvent(clientState, frontend, conn, client, serverConnID, headers, streamID, event, buffer);
break;
}
case QUICHE_H3_EVENT_FINISHED:
Expand All @@ -807,6 +807,11 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3

static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientState, Socket& sock, PacketBuffer& buffer)
{
// destination connection ID, will have to be sent as original destination connection ID
PacketBuffer serverConnID;
// source connection ID, will have to be sent as destination connection ID
PacketBuffer clientConnID;
PacketBuffer tokenBuf;
while (true) {
ComboAddress client;
buffer.resize(4096);
Expand Down Expand Up @@ -834,29 +839,28 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
continue;
}

// destination connection ID, will have to be sent as original destination connection ID
PacketBuffer serverConnID(dcid.begin(), dcid.begin() + dcid_len);
serverConnID.assign(dcid.begin(), dcid.begin() + dcid_len);
// source connection ID, will have to be sent as destination connection ID
PacketBuffer clientConnID(scid.begin(), scid.begin() + scid_len);
clientConnID.assign(scid.begin(), scid.begin() + scid_len);
auto conn = getConnection(frontend.d_server_config->d_connections, serverConnID);

if (!conn) {
DEBUGLOG("Connection not found");
if (!quiche_version_is_supported(version)) {
DEBUGLOG("Unsupported version");
++frontend.d_doh3UnsupportedVersionErrors;
handleVersionNegociation(sock, clientConnID, serverConnID, client);
handleVersionNegociation(sock, clientConnID, serverConnID, client, buffer);
continue;
}

if (token_len == 0) {
/* stateless retry */
DEBUGLOG("No token received");
handleStatelessRetry(sock, clientConnID, serverConnID, client, version);
handleStatelessRetry(sock, clientConnID, serverConnID, client, version, buffer);
continue;
}

PacketBuffer tokenBuf(token.begin(), token.begin() + token_len);
tokenBuf.assign(token.begin(), token.begin() + token_len);
auto originalDestinationID = validateToken(tokenBuf, client);
if (!originalDestinationID) {
++frontend.d_doh3InvalidTokensReceived;
Expand Down Expand Up @@ -897,9 +901,9 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
DEBUGLOG("Successfully created HTTP/3 connection");
}

processH3Events(clientState, frontend, conn->get(), client, serverConnID);
processH3Events(clientState, frontend, conn->get(), client, serverConnID, buffer);

flushEgress(sock, conn->get().d_conn, client);
flushEgress(sock, conn->get().d_conn, client, buffer);
}
else {
DEBUGLOG("Connection not established");
Expand Down Expand Up @@ -944,7 +948,7 @@ void doh3Thread(ClientState* clientState)
for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) {
quiche_conn_on_timeout(conn->second.d_conn.get());

flushEgress(sock, conn->second.d_conn, conn->second.d_peer);
flushEgress(sock, conn->second.d_conn, conn->second.d_peer, buffer);

if (quiche_conn_is_closed(conn->second.d_conn.get())) {
#ifdef DEBUGLOG_ENABLED
Expand Down
33 changes: 18 additions & 15 deletions pdns/dnsdistdist/doq-common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer)
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
const auto encryptedToken = sodEncryptSym(std::string_view(reinterpret_cast<const char*>(plainTextToken.data()), plainTextToken.size()), s_quicRetryTokenKey, nonce, false);
// a bit sad, let's see if we can do better later
auto encryptedTokenPacket = PacketBuffer(encryptedToken.begin(), encryptedToken.end());
PacketBuffer encryptedTokenPacket;
encryptedTokenPacket.reserve(encryptedToken.size() + nonce.value.size());
encryptedTokenPacket.insert(encryptedTokenPacket.begin(), encryptedToken.begin(), encryptedToken.end());
encryptedTokenPacket.insert(encryptedTokenPacket.begin(), nonce.value.begin(), nonce.value.end());
return encryptedTokenPacket;
}
Expand Down Expand Up @@ -98,7 +100,7 @@ std::optional<PacketBuffer> validateToken(const PacketBuffer& token, const Combo

memcpy(nonce.value.data(), token.data(), nonce.value.size());

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
auto cipher = std::string_view(reinterpret_cast<const char*>(&token.at(nonce.value.size())), token.size() - nonce.value.size());
auto plainText = sodDecryptSym(cipher, s_quicRetryTokenKey, nonce, false);

Expand All @@ -124,7 +126,7 @@ std::optional<PacketBuffer> validateToken(const PacketBuffer& token, const Combo
}
}

void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version)
void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version, PacketBuffer& buffer)
{
auto newServerConnID = getCID();
if (!newServerConnID) {
Expand All @@ -133,46 +135,46 @@ void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const

auto token = mintToken(serverConnID, peer);

PacketBuffer out(MAX_DATAGRAM_SIZE);
buffer.resize(MAX_DATAGRAM_SIZE);
auto written = quiche_retry(clientConnID.data(), clientConnID.size(),
serverConnID.data(), serverConnID.size(),
newServerConnID->data(), newServerConnID->size(),
token.data(), token.size(),
version,
out.data(), out.size());
buffer.data(), buffer.size());

if (written < 0) {
DEBUGLOG("failed to create retry packet " << written);
return;
}

out.resize(written);
sock.sendTo(std::string(out.begin(), out.end()), peer);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
sock.sendTo(reinterpret_cast<const char*>(buffer.data()), static_cast<size_t>(written), peer);
}

void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer)
void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, PacketBuffer& buffer)
{
PacketBuffer out(MAX_DATAGRAM_SIZE);
buffer.resize(MAX_DATAGRAM_SIZE);

auto written = quiche_negotiate_version(clientConnID.data(), clientConnID.size(),
serverConnID.data(), serverConnID.size(),
out.data(), out.size());
buffer.data(), buffer.size());

if (written < 0) {
DEBUGLOG("failed to create vneg packet " << written);
return;
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
sock.sendTo(reinterpret_cast<const char*>(out.data()), written, peer);
sock.sendTo(reinterpret_cast<const char*>(buffer.data()), static_cast<size_t>(written), peer);
}

void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer)
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, PacketBuffer& buffer)
{
std::array<uint8_t, MAX_DATAGRAM_SIZE> out{};
buffer.resize(MAX_DATAGRAM_SIZE);
quiche_send_info send_info;

while (true) {
auto written = quiche_conn_send(conn.get(), out.data(), out.size(), &send_info);
auto written = quiche_conn_send(conn.get(), buffer.data(), buffer.size(), &send_info);
if (written == QUICHE_ERR_DONE) {
return;
}
Expand All @@ -182,7 +184,7 @@ void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer)
}
// FIXME pacing (as send_info.at should tell us when to send the packet) ?
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
sock.sendTo(reinterpret_cast<const char*>(out.data()), written, peer);
sock.sendTo(reinterpret_cast<const char*>(buffer.data()), static_cast<size_t>(written), peer);
}
}

Expand All @@ -203,6 +205,7 @@ void configureQuiche(QuicheConfig& config, const QuicheParams& params)

{
auto res = quiche_config_set_application_protos(config.get(),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
reinterpret_cast<const uint8_t*>(params.d_alpn.data()),
params.d_alpn.size());
if (res != 0) {
Expand Down
6 changes: 3 additions & 3 deletions pdns/dnsdistdist/doq-common.hh
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ void fillRandom(PacketBuffer& buffer, size_t size);
std::optional<PacketBuffer> getCID();
PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer);
std::optional<PacketBuffer> validateToken(const PacketBuffer& token, const ComboAddress& peer);
void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version);
void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer);
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer);
void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version, PacketBuffer& buffer);
void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, PacketBuffer& buffer);
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, PacketBuffer& buffer);
void configureQuiche(QuicheConfig& config, const QuicheParams& params);

};
Expand Down
21 changes: 12 additions & 9 deletions pdns/dnsdistdist/doq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,11 @@ static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState

static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState, Socket& sock, PacketBuffer& buffer)
{
// destination connection ID, will have to be sent as original destination connection ID
PacketBuffer serverConnID;
// source connection ID, will have to be sent as destination connection ID
PacketBuffer clientConnID;
PacketBuffer tokenBuf;
while (true) {
ComboAddress client;
buffer.resize(4096);
Expand Down Expand Up @@ -655,29 +660,27 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState
continue;
}

// destination connection ID, will have to be sent as original destination connection ID
PacketBuffer serverConnID(dcid.begin(), dcid.begin() + dcid_len);
// source connection ID, will have to be sent as destination connection ID
PacketBuffer clientConnID(scid.begin(), scid.begin() + scid_len);
serverConnID.assign(dcid.begin(), dcid.begin() + dcid_len);
clientConnID.assign(scid.begin(), scid.begin() + scid_len);
auto conn = getConnection(frontend.d_server_config->d_connections, serverConnID);

if (!conn) {
DEBUGLOG("Connection not found");
if (!quiche_version_is_supported(version)) {
DEBUGLOG("Unsupported version");
++frontend.d_doqUnsupportedVersionErrors;
handleVersionNegociation(sock, clientConnID, serverConnID, client);
handleVersionNegociation(sock, clientConnID, serverConnID, client, buffer);
continue;
}

if (token_len == 0) {
/* stateless retry */
DEBUGLOG("No token received");
handleStatelessRetry(sock, clientConnID, serverConnID, client, version);
handleStatelessRetry(sock, clientConnID, serverConnID, client, version, buffer);
continue;
}

PacketBuffer tokenBuf(token.begin(), token.begin() + token_len);
tokenBuf.assign(token.begin(), token.begin() + token_len);
auto originalDestinationID = validateToken(tokenBuf, client);
if (!originalDestinationID) {
++frontend.d_doqInvalidTokensReceived;
Expand Down Expand Up @@ -714,7 +717,7 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState
handleReadableStream(frontend, clientState, *conn, streamID, client, serverConnID);
}

flushEgress(sock, conn->get().d_conn, client);
flushEgress(sock, conn->get().d_conn, client, buffer);
}
else {
DEBUGLOG("Connection not established");
Expand Down Expand Up @@ -759,7 +762,7 @@ void doqThread(ClientState* clientState)
for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) {
quiche_conn_on_timeout(conn->second.d_conn.get());

flushEgress(sock, conn->second.d_conn, conn->second.d_peer);
flushEgress(sock, conn->second.d_conn, conn->second.d_peer, buffer);

if (quiche_conn_is_closed(conn->second.d_conn.get())) {
#ifdef DEBUGLOG_ENABLED
Expand Down

0 comments on commit 4cfae5d

Please sign in to comment.