Skip to content

Commit

Permalink
fix(server): fix leak in FuturesUnordered (#1204)
Browse files Browse the repository at this point in the history
* fix: remove needless clone in ws background task

* fix(server): fix leak in FuturesUnordered

The tokio::spawn handles were never removed from `FutursUnordered`
which this commit fixes.

Reduces the memory usage signficantly but still slightly worse than v0.16.x

* Update server/src/transport/ws.rs

* cargo fmt

* wording
  • Loading branch information
niklasad1 authored Sep 15, 2023
1 parent 5be9866 commit 30c0fbb
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::PingConfig;

use futures_util::future::{self, Either, Fuse};
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::{FuturesOrdered, FuturesUnordered};
use futures_util::stream::FuturesOrdered;
use futures_util::{Future, FutureExt, StreamExt};
use hyper::upgrade::Upgraded;
use jsonrpsee_core::server::helpers::{
Expand Down Expand Up @@ -248,7 +248,11 @@ pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Rec
let (conn_tx, conn_rx) = oneshot::channel();
let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection);
let pending_calls = FuturesUnordered::new();

// On each method call the `pending_calls` is cloned
// then when all pending_calls are dropped
// a graceful shutdown can has occur.
let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1);

// Spawn another task that sends out the responses on the Websocket.
let send_task_handle = tokio::spawn(send_task(rx, sender, ping_config.ping_interval(), conn_rx));
Expand Down Expand Up @@ -320,13 +324,14 @@ pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Rec
}
};

pending_calls.push(tokio::spawn(execute_unchecked_call(params.clone(), std::mem::take(&mut data))));
tokio::spawn(execute_unchecked_call(params.clone(), std::mem::take(&mut data), pending_calls.clone()));
};

// Drive all running methods to completion.
// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
// proper drop behaviour.
graceful_shutdown(result, pending_calls, receiver, data, conn_tx, send_task_handle).await;
drop(pending_calls);
graceful_shutdown(result, pending_calls_completed, receiver, data, conn_tx, send_task_handle).await;

logger.on_disconnect(remote_addr, TransportProtocol::WebSocket);
drop(conn);
Expand Down Expand Up @@ -492,7 +497,11 @@ struct ExecuteCallParams<L: Logger> {
bounded_subscriptions: BoundedSubscriptions,
}

async fn execute_unchecked_call<L: Logger>(params: Arc<ExecuteCallParams<L>>, data: Vec<u8>) {
async fn execute_unchecked_call<L: Logger>(
params: Arc<ExecuteCallParams<L>>,
data: Vec<u8>,
drop_on_completion: mpsc::Sender<()>,
) {
let request_start = params.logger.on_request(TransportProtocol::WebSocket);
let first_non_whitespace = data.iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace());

Expand Down Expand Up @@ -550,6 +559,10 @@ async fn execute_unchecked_call<L: Logger>(params: Arc<ExecuteCallParams<L>>, da
_ = params.sink.send_error(Id::Null, ErrorCode::ParseError.into()).await;
}
};

// NOTE: This channel is only used to indicate that a method call was completed
// thus the drop here tells the main task that method call was completed.
drop(drop_on_completion);
}

#[derive(Debug, Copy, Clone)]
Expand All @@ -561,14 +574,16 @@ pub(crate) enum Shutdown {
/// Enforce a graceful shutdown.
///
/// This will return once the connection has been terminated or all pending calls have been executed.
async fn graceful_shutdown<F: Future>(
async fn graceful_shutdown(
result: Result<Shutdown, SokettoError>,
pending_calls: FuturesUnordered<F>,
pending_calls: mpsc::Receiver<()>,
receiver: Receiver,
data: Vec<u8>,
mut conn_tx: oneshot::Sender<()>,
send_task_handle: tokio::task::JoinHandle<()>,
) {
let pending_calls = ReceiverStream::new(pending_calls);

match result {
Ok(Shutdown::ConnectionClosed) | Err(SokettoError::Closed) => (),
Ok(Shutdown::Stopped) | Err(_) => {
Expand Down

0 comments on commit 30c0fbb

Please sign in to comment.