Skip to content

Commit

Permalink
feat: rpc response message chunking
Browse files Browse the repository at this point in the history
Adds "chunking" protocol for large RPC responses.
  • Loading branch information
sdbondi committed Sep 12, 2021
1 parent a268406 commit 312aa8d
Show file tree
Hide file tree
Showing 17 changed files with 533 additions and 213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1003,11 +1003,7 @@ impl tari_rpc::base_node_server::BaseNode for BaseNodeGrpcServer {
.state_info
.get_block_sync_info()
.map(|info| {
let node_ids = info
.sync_peers
.iter()
.map(|x| x.to_string().as_bytes().to_vec())
.collect();
let node_ids = info.sync_peers.iter().map(|x| x.to_string().into_bytes()).collect();
tari_rpc::SyncInfoResponse {
tip_height: info.tip_height,
local_height: info.local_height,
Expand Down
1 change: 0 additions & 1 deletion applications/tari_base_node/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![recursion_limit = "1024"]
// Copyright 2019. The Tari Project
//
// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ impl wallet_server::Wallet for WalletGrpcServer {
async fn identify(&self, _: Request<GetIdentityRequest>) -> Result<Response<GetIdentityResponse>, Status> {
let identity = self.wallet.comms.node_identity();
Ok(Response::new(GetIdentityResponse {
public_key: identity.public_key().to_string().as_bytes().to_vec(),
public_key: identity.public_key().to_string().into_bytes(),
public_address: identity.public_address().to_string(),
node_id: identity.node_id().to_string().as_bytes().to_vec(),
node_id: identity.node_id().to_string().into_bytes(),
}))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1546,8 +1546,7 @@ struct KeyManagerStateUpdateSql {
impl Encryptable<Aes256Gcm> for KeyManagerStateSql {
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), Error> {
let encrypted_master_key = encrypt_bytes_integral_nonce(&cipher, self.master_key.clone())?;
let encrypted_branch_seed =
encrypt_bytes_integral_nonce(&cipher, self.branch_seed.clone().as_bytes().to_vec())?;
let encrypted_branch_seed = encrypt_bytes_integral_nonce(&cipher, self.branch_seed.clone().into_bytes())?;
self.master_key = encrypted_master_key;
self.branch_seed = encrypted_branch_seed.to_hex();
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion base_layer/wallet/src/storage/sqlite_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ impl ClientKeyValueSql {
impl Encryptable<Aes256Gcm> for ClientKeyValueSql {
#[allow(unused_assignments)]
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> {
let encrypted_value = encrypt_bytes_integral_nonce(&cipher, self.clone().value.as_bytes().to_vec())?;
let encrypted_value = encrypt_bytes_integral_nonce(&cipher, self.value.as_bytes().to_vec())?;
self.value = encrypted_value.to_hex();
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1028,8 +1028,7 @@ impl InboundTransactionSql {

impl Encryptable<Aes256Gcm> for InboundTransactionSql {
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> {
let encrypted_protocol =
encrypt_bytes_integral_nonce(&cipher, self.receiver_protocol.clone().as_bytes().to_vec())?;
let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.receiver_protocol.as_bytes().to_vec())?;
self.receiver_protocol = encrypted_protocol.to_hex();
Ok(())
}
Expand Down Expand Up @@ -1211,8 +1210,7 @@ impl OutboundTransactionSql {

impl Encryptable<Aes256Gcm> for OutboundTransactionSql {
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> {
let encrypted_protocol =
encrypt_bytes_integral_nonce(&cipher, self.sender_protocol.clone().as_bytes().to_vec())?;
let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.sender_protocol.as_bytes().to_vec())?;
self.sender_protocol = encrypted_protocol.to_hex();
Ok(())
}
Expand Down Expand Up @@ -1534,8 +1532,7 @@ impl CompletedTransactionSql {

impl Encryptable<Aes256Gcm> for CompletedTransactionSql {
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> {
let encrypted_protocol =
encrypt_bytes_integral_nonce(&cipher, self.transaction_protocol.clone().as_bytes().to_vec())?;
let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.transaction_protocol.as_bytes().to_vec())?;
self.transaction_protocol = encrypted_protocol.to_hex();
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/store_forward/saf_handler/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ mod test {
dht_header: DhtMessageHeader,
stored_at: NaiveDateTime,
) -> StoredMessage {
let body = message.as_bytes().to_vec();
let body = message.into_bytes();
let body_hash = hex::to_hex(&Challenge::new().chain(body.clone()).finalize());
StoredMessage {
id: 1,
Expand Down
4 changes: 2 additions & 2 deletions comms/src/proto/rpc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ message RpcRequest {
uint64 deadline = 4;

// The message payload
bytes message = 10;
bytes payload = 10;
}

// Message type for all RPC responses
Expand All @@ -29,7 +29,7 @@ message RpcResponse {
uint32 flags = 3;

// The message payload. If the status is non-zero, this contains additional error details.
bytes message = 10;
bytes payload = 10;
}

// Message sent by the client when negotiating an RPC session. A server may close the substream if it does
Expand Down
11 changes: 7 additions & 4 deletions comms/src/protocol/rpc/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ impl BodyBytes {
pub fn into_vec(self) -> Vec<u8> {
self.0.map(|bytes| bytes.to_vec()).unwrap_or_else(Vec::new)
}

pub fn into_bytes(self) -> Option<Bytes> {
self.0
}
}

#[allow(clippy::from_over_into)]
Expand All @@ -186,10 +190,9 @@ impl Into<Bytes> for BodyBytes {
}
}

#[allow(clippy::from_over_into)]
impl Into<Vec<u8>> for BodyBytes {
fn into(self) -> Vec<u8> {
self.into_vec()
impl From<BodyBytes> for Vec<u8> {
fn from(body: BodyBytes) -> Self {
body.into_vec()
}
}

Expand Down
160 changes: 111 additions & 49 deletions comms/src/protocol/rpc/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::{
Response,
RpcError,
RpcStatus,
RPC_CHUNKING_MAX_CHUNKS,
},
ProtocolId,
},
Expand Down Expand Up @@ -239,7 +240,7 @@ where TClient: From<RpcClient> + NamedProtocolService
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub struct RpcClientConfig {
pub deadline: Option<Duration>,
pub deadline_grace_period: Duration,
Expand Down Expand Up @@ -489,7 +490,8 @@ impl RpcClientWorker {
self.protocol_name(),
start.elapsed()
);
let resp = match self.read_reply().await {
let mut reader = RpcResponseReader::new(&mut self.framed, self.config, 0);
let resp = match reader.read_ack().await {
Ok(resp) => resp,
Err(RpcError::ReplyTimeout) => {
debug!(
Expand Down Expand Up @@ -529,7 +531,7 @@ impl RpcClientWorker {
Ok(())
}

#[tracing::instrument(name = "rpc_do_request_response", skip(self, reply))]
#[tracing::instrument(name = "rpc_do_request_response", skip(self, reply, request), fields(request_method = ?request.method, request_size = request.message.len()))]
async fn do_request_response(
&mut self,
request: BaseRequest<Bytes>,
Expand All @@ -542,7 +544,7 @@ impl RpcClientWorker {
method,
deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0),
flags: 0,
message: request.message.to_vec(),
payload: request.message.to_vec(),
};

debug!(target: LOG_TARGET, "Sending request: {}", req);
Expand Down Expand Up @@ -575,14 +577,14 @@ impl RpcClientWorker {
}

loop {
let resp = match self.read_reply().await {
let resp = match self.read_response(request_id).await {
Ok(resp) => {
let latency = start.elapsed();
event!(Level::TRACE, "Message received");
trace!(
target: LOG_TARGET,
"Received response ({} byte(s)) from request #{} (protocol = {}, method={}) in {:.0?}",
resp.message.len(),
resp.payload.len(),
request_id,
self.protocol_name(),
method,
Expand Down Expand Up @@ -617,12 +619,19 @@ impl RpcClientWorker {
break;
},
Err(err) => {
event!(Level::ERROR, "Errored:{}", err);
event!(
Level::WARN,
"Request {} (method={}) returned an error after {:.0?}: {}",
request_id,
method,
start.elapsed(),
err
);
return Err(err);
},
};

match Self::convert_to_result(resp, request_id) {
match Self::convert_to_result(resp) {
Ok(Ok(resp)) => {
// The consumer may drop the receiver before all responses are received.
// We just ignore that as we still want obey the protocol and receive messages until the FIN flag or
Expand Down Expand Up @@ -665,27 +674,10 @@ impl RpcClientWorker {
Ok(())
}

async fn read_reply(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
// Wait until the timeout, allowing an extra grace period to account for latency
let next_msg_fut = match self.config.timeout_with_grace_period() {
Some(timeout) => Either::Left(time::timeout(timeout, self.framed.next())),
None => Either::Right(self.framed.next().map(Ok)),
};

let result = tokio::select! {
biased;
_ = &mut self.shutdown_signal => {
return Err(RpcError::ClientClosed);
}
result = next_msg_fut => result,
};

match result {
Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?),
Ok(Some(Err(err))) => Err(err.into()),
Ok(None) => Err(RpcError::ServerClosedRequest),
Err(_) => Err(RpcError::ReplyTimeout),
}
async fn read_response(&mut self, request_id: u16) -> Result<proto::rpc::RpcResponse, RpcError> {
let mut reader = RpcResponseReader::new(&mut self.framed, self.config, request_id);
let resp = reader.read_response().await?;
Ok(resp)
}

fn next_request_id(&mut self) -> u16 {
Expand All @@ -695,33 +687,15 @@ impl RpcClientWorker {
next_id
}

fn convert_to_result(
resp: proto::rpc::RpcResponse,
request_id: u16,
) -> Result<Result<Response<Bytes>, RpcStatus>, RpcError> {
let resp_id = u16::try_from(resp.request_id)
.map_err(|_| RpcStatus::protocol_error(format!("invalid request_id: must be less than {}", u16::MAX)))?;

let flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8);
if flags.contains(RpcMessageFlags::ACK) {
return Err(RpcError::UnexpectedAckResponse);
}

if resp_id != request_id {
return Err(RpcError::ResponseIdDidNotMatchRequest {
expected: request_id,
actual: resp.request_id as u16,
});
}

fn convert_to_result(resp: proto::rpc::RpcResponse) -> Result<Result<Response<Bytes>, RpcStatus>, RpcError> {
let status = RpcStatus::from(&resp);
if !status.is_ok() {
return Ok(Err(status));
}

let resp = Response {
flags: resp.flags(),
message: resp.message.into(),
payload: resp.payload.into(),
};

Ok(Ok(resp))
Expand All @@ -736,3 +710,91 @@ pub enum ClientRequest {
GetLastRequestLatency(oneshot::Sender<Option<Duration>>),
SendPing(oneshot::Sender<Result<Duration, RpcStatus>>),
}

struct RpcResponseReader<'a> {
framed: &'a mut CanonicalFraming<Substream>,
config: RpcClientConfig,
request_id: u16,
}
impl<'a> RpcResponseReader<'a> {
pub fn new(framed: &'a mut CanonicalFraming<Substream>, config: RpcClientConfig, request_id: u16) -> Self {
Self {
framed,
config,
request_id,
}
}

pub async fn read_response(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
let mut resp = self.next().await?;
self.check_response(&resp)?;
let mut chunk_count = 1;
let mut last_chunk_flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8);
let mut last_chunk_size = resp.payload.len();
loop {
trace!(
target: LOG_TARGET,
"Chunk {} received (flags={:?}, {} bytes, {} total)",
chunk_count,
last_chunk_flags,
last_chunk_size,
resp.payload.len()
);
if !last_chunk_flags.is_more() {
return Ok(resp);
}

if chunk_count >= RPC_CHUNKING_MAX_CHUNKS {
return Err(RpcError::ExceededMaxChunkCount {
expected: RPC_CHUNKING_MAX_CHUNKS,
});
}

let msg = self.next().await?;
last_chunk_flags = RpcMessageFlags::from_bits_truncate(msg.flags as u8);
last_chunk_size = msg.payload.len();
self.check_response(&resp)?;
resp.payload.extend(msg.payload);
chunk_count += 1;
}
}

pub async fn read_ack(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
let resp = self.next().await?;
Ok(resp)
}

fn check_response(&self, resp: &proto::rpc::RpcResponse) -> Result<(), RpcError> {
let resp_id = u16::try_from(resp.request_id)
.map_err(|_| RpcStatus::protocol_error(format!("invalid request_id: must be less than {}", u16::MAX)))?;

let flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8);
if flags.contains(RpcMessageFlags::ACK) {
return Err(RpcError::UnexpectedAckResponse);
}

if resp_id != self.request_id {
return Err(RpcError::ResponseIdDidNotMatchRequest {
expected: self.request_id,
actual: resp.request_id as u16,
});
}

Ok(())
}

async fn next(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
// Wait until the timeout, allowing an extra grace period to account for latency
let next_msg_fut = match self.config.timeout_with_grace_period() {
Some(timeout) => Either::Left(time::timeout(timeout, self.framed.next())),
None => Either::Right(self.framed.next().map(Ok)),
};

match next_msg_fut.await {
Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?),
Ok(Some(Err(err))) => Err(err.into()),
Ok(None) => Err(RpcError::ServerClosedRequest),
Err(_) => Err(RpcError::ReplyTimeout),
}
}
}
2 changes: 2 additions & 0 deletions comms/src/protocol/rpc/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub enum RpcError {
InvalidPingResponse,
#[error("Unexpected ACK response. This is likely because of a previous ACK timeout")]
UnexpectedAckResponse,
#[error("Attempted to send more than {expected} payload chunks")]
ExceededMaxChunkCount { expected: usize },
#[error(transparent)]
UnknownError(#[from] anyhow::Error),
}
Expand Down
Loading

0 comments on commit 312aa8d

Please sign in to comment.