diff --git a/crates/integration_tests/tests/ipc/native_bridge_list_accounts.rs b/crates/integration_tests/tests/ipc/native_bridge_list_accounts.rs index ead1a665b3..c78f251d3b 100644 --- a/crates/integration_tests/tests/ipc/native_bridge_list_accounts.rs +++ b/crates/integration_tests/tests/ipc/native_bridge_list_accounts.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use http::Method; use http::StatusCode; use sos_ipc::{ local_transport::{HttpMessage, LocalRequest}, @@ -83,16 +82,12 @@ async fn integration_ipc_native_bridge_list_accounts() -> Result<()> { tokio::time::sleep(Duration::from_millis(50)).await; - let mut request = LocalRequest { - method: Method::GET, - uri: ACCOUNTS_LIST.parse().unwrap(), - ..Default::default() - }; + let mut request = LocalRequest::get(ACCOUNTS_LIST.parse().unwrap()); request.set_request_id(1); let (command, arguments) = super::native_bridge_cmd(SOCKET_NAME); let mut client = NativeBridgeClient::new(command, arguments).await?; - let response = client.send(&request).await?; + let response = client.send(request).await?; assert_eq!(StatusCode::OK, response.status().unwrap()); client.kill().await?; diff --git a/crates/integration_tests/tests/ipc/native_bridge_probe.rs b/crates/integration_tests/tests/ipc/native_bridge_probe.rs index 74e375994f..7875df43a3 100644 --- a/crates/integration_tests/tests/ipc/native_bridge_probe.rs +++ b/crates/integration_tests/tests/ipc/native_bridge_probe.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use http::Method; use http::StatusCode; use sos_ipc::{ local_transport::{HttpMessage, LocalRequest}, @@ -15,16 +14,12 @@ const SOCKET_NAME: &str = "ipc_native_bridge_probe.sock"; async fn integration_ipc_native_bridge_probe() -> Result<()> { // crate::test_utils::init_tracing(); - let mut request = LocalRequest { - method: Method::GET, - uri: "/probe".parse().unwrap(), - ..Default::default() - }; + let mut request = LocalRequest::get("/probe".parse().unwrap()); request.set_request_id(1); let (command, arguments) = super::native_bridge_cmd(SOCKET_NAME); let mut client = NativeBridgeClient::new(command, arguments).await?; - let response = client.send(&request).await?; + let response = client.send(request).await?; assert_eq!(StatusCode::OK, response.status().unwrap()); client.kill().await?; diff --git a/crates/ipc/src/client.rs b/crates/ipc/src/client.rs index 868e9bae67..f9ddf41a2d 100644 --- a/crates/ipc/src/client.rs +++ b/crates/ipc/src/client.rs @@ -82,10 +82,7 @@ impl LocalSocketClient { /// List accounts. pub async fn list_accounts(&mut self) -> Result> { - let request = LocalRequest { - uri: ACCOUNTS_LIST.parse()?, - ..Default::default() - }; + let request = LocalRequest::get(ACCOUNTS_LIST.parse()?); let response = self.send_request(request).await?; let status = response.status()?; diff --git a/crates/ipc/src/local_transport.rs b/crates/ipc/src/local_transport.rs index 9c4021bf1b..35c4d528b1 100644 --- a/crates/ipc/src/local_transport.rs +++ b/crates/ipc/src/local_transport.rs @@ -1,4 +1,9 @@ -//! Types used for communicating between apps on the same device. +//! Types used for HTTP communication between apps +//! on the same device. +//! +//! Wraps the `http` request and response types so we can +//! serialize and deserialize from JSON for transfer via +//! the browser native messaging API. use crate::{Error, Result}; @@ -39,9 +44,32 @@ pub trait HttpMessage { /// Message body. fn body(&self) -> &[u8]; + /// Mutable message body. + fn body_mut(&mut self) -> &mut Vec; + /// Consume the message body. fn into_body(self) -> Vec; + /// Number of chunks. + fn chunks_len(&self) -> u32; + + /// Zero-based chunk index of this message. + fn chunk_index(&self) -> u32; + + /// Convert this message into a collection of chunks. + /// + /// If the size of the body is less than limit then + /// only this message is included. + /// + /// Conversion is performed on the number of bytes in the + /// body but the native messaging API restricts the serialized + /// JSON to 1MB so it's wise to choose a value smaller + /// than the 1MB limit so there is some headroom for the JSON + /// serialization overhead. + fn into_chunks(self, limit: usize, chunk_size: usize) -> Vec + where + Self: Sized; + /// Extract a request id. /// /// If no header is present or the value is invalid zero @@ -125,6 +153,16 @@ pub trait HttpMessage { } */ + /// Convert the message into parts. + fn into_parts(mut self) -> (Headers, Vec) + where + Self: Sized, + { + let headers = + std::mem::replace(self.headers_mut(), Default::default()); + (headers, self.into_body()) + } + /// Convert the body to bytes. fn bytes(self) -> Bytes where @@ -132,6 +170,25 @@ pub trait HttpMessage { { self.into_body().into() } + + /// Convert from a collection of chunks into a response. + /// + /// # Panics + /// + /// If chunks is empty. + fn from_chunks(mut chunks: Vec) -> Self + where + Self: Sized, + { + chunks.sort_by(|a, b| a.chunk_index().cmp(&b.chunk_index())); + let mut it = chunks.into_iter(); + let mut message = it.next().expect("to have one chunk"); + for chunk in it { + let mut body = chunk.into_body(); + message.body_mut().append(&mut body); + } + message + } } /// Request that can be sent to a local data source. @@ -158,6 +215,8 @@ pub struct LocalRequest { /// Request body. #[serde(default, skip_serializing_if = "Vec::is_empty")] pub body: Vec, + /// Chunk information; length and then index. + chunks: (u32, u32), } impl Default for LocalRequest { @@ -167,11 +226,23 @@ impl Default for LocalRequest { uri: Uri::builder().path_and_query("/").build().unwrap(), headers: Default::default(), body: Default::default(), + chunks: (1, 0), } } } impl LocalRequest { + /// Create a GET request from a URI. + pub fn get(uri: Uri) -> Self { + Self { + method: Method::GET, + uri, + headers: Default::default(), + body: Default::default(), + chunks: (1, 0), + } + } + /// Duration allowed for a request. pub fn timeout_duration(&self) -> Duration { Duration::from_secs(15) @@ -191,9 +262,54 @@ impl HttpMessage for LocalRequest { self.body.as_slice() } + fn body_mut(&mut self) -> &mut Vec { + &mut self.body + } + fn into_body(self) -> Vec { self.body } + + fn chunks_len(&self) -> u32 { + self.chunks.0 + } + + fn chunk_index(&self) -> u32 { + self.chunks.1 + } + + fn into_chunks(self, limit: usize, chunk_size: usize) -> Vec { + if self.body.len() < limit { + vec![self] + } else { + let mut messages = Vec::new(); + let uri = self.uri.clone(); + let method = self.method.clone(); + let (headers, body) = self.into_parts(); + let len = if body.len() > chunk_size { + let mut len = body.len() / chunk_size; + if body.len() % chunk_size != 0 { + len += 1; + } + len + } else { + 1 + }; + for (index, window) in + body.as_slice().chunks(chunk_size).enumerate() + { + let message = Self { + uri: uri.clone(), + method: method.clone(), + body: window.to_owned(), + headers: headers.clone(), + chunks: (len as u32, index as u32), + }; + messages.push(message); + } + messages + } + } } impl fmt::Debug for LocalRequest { @@ -222,6 +338,7 @@ impl From>> for LocalRequest { uri: parts.uri, headers, body, + chunks: (1, 0), } } } @@ -261,6 +378,8 @@ pub struct LocalResponse { /// Response body. #[serde(default, skip_serializing_if = "Vec::is_empty")] pub body: Vec, + /// Chunk information; length and then index. + chunks: (u32, u32), } impl Default for LocalResponse { @@ -269,6 +388,7 @@ impl Default for LocalResponse { status: StatusCode::OK.into(), headers: Default::default(), body: Default::default(), + chunks: (1, 0), } } } @@ -297,6 +417,7 @@ impl From>> for LocalResponse { status: parts.status.into(), headers, body, + chunks: (1, 0), } } } @@ -317,6 +438,7 @@ impl LocalResponse { status: status.into(), headers: Default::default(), body: Default::default(), + chunks: (1, 0), }; res.set_request_id(id); res @@ -350,7 +472,51 @@ impl HttpMessage for LocalResponse { self.body.as_slice() } + fn body_mut(&mut self) -> &mut Vec { + &mut self.body + } + fn into_body(self) -> Vec { self.body } + + fn chunks_len(&self) -> u32 { + self.chunks.0 + } + + fn chunk_index(&self) -> u32 { + self.chunks.1 + } + + fn into_chunks(self, limit: usize, chunk_size: usize) -> Vec { + if self.body.len() < limit { + vec![self] + } else { + let mut messages = Vec::new(); + let status = self.status.clone(); + let (headers, body) = self.into_parts(); + let len = if body.len() > chunk_size { + let mut len = body.len() / chunk_size; + if body.len() % chunk_size != 0 { + len += 1; + } + len + } else { + 1 + }; + for (index, window) in + body.as_slice().chunks(chunk_size).enumerate() + { + let message = Self { + status, + headers: headers.clone(), + body: window.to_owned(), + chunks: (len as u32, index as u32), + }; + messages.push(message); + } + println!("split into chunks: {}", messages.len()); + messages + } + } } diff --git a/crates/ipc/src/native_bridge/client.rs b/crates/ipc/src/native_bridge/client.rs index 7992ffa02d..d277e02c5b 100644 --- a/crates/ipc/src/native_bridge/client.rs +++ b/crates/ipc/src/native_bridge/client.rs @@ -3,14 +3,16 @@ //! //! Used to test the browser native messaging API integration. -use crate::local_transport::{LocalRequest, LocalResponse}; +use crate::local_transport::{HttpMessage, LocalRequest, LocalResponse}; use crate::Result; use futures_util::{SinkExt, StreamExt}; -use http::StatusCode; use std::process::Stdio; +use std::sync::atomic::{AtomicU64, Ordering}; use tokio::process::{Child, Command}; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; +use super::{CHUNK_LIMIT, CHUNK_SIZE}; + /// Client that spawns a native bridge and sends /// and receives messages from the spawned executable. /// @@ -19,6 +21,7 @@ pub struct NativeBridgeClient { child: Child, stdin: FramedWrite, stdout: FramedRead, + id: AtomicU64, } impl NativeBridgeClient { @@ -50,27 +53,35 @@ impl NativeBridgeClient { child, stdin, stdout, + id: AtomicU64::new(1), }) } /// Send a request to the spawned native bridge. pub async fn send( &mut self, - request: &LocalRequest, + mut request: LocalRequest, ) -> Result { - let message = serde_json::to_vec(request)?; - self.stdin.send(message.into()).await?; - - let mut res: LocalResponse = StatusCode::IM_A_TEAPOT.into(); + let message_id = self.id.fetch_add(1, Ordering::SeqCst); + request.set_request_id(message_id); + let chunks = request.into_chunks(CHUNK_LIMIT, CHUNK_SIZE); + for request in chunks { + let message = serde_json::to_vec(&request)?; + self.stdin.send(message.into()).await?; + } + let mut chunks: Vec = Vec::new(); while let Some(response) = self.stdout.next().await { let response = response?; let response: LocalResponse = serde_json::from_slice(&response)?; - res = response; - break; + let chunks_len = response.chunks_len(); + chunks.push(response); + if chunks.len() == chunks_len as usize { + break; + } } - Ok(res) + Ok(LocalResponse::from_chunks(chunks)) } /// Kill the child process. diff --git a/crates/ipc/src/native_bridge/mod.rs b/crates/ipc/src/native_bridge/mod.rs index 0dbd5eedba..57cfecb88c 100644 --- a/crates/ipc/src/native_bridge/mod.rs +++ b/crates/ipc/src/native_bridge/mod.rs @@ -4,6 +4,11 @@ //! Used to support the native messaging API provided //! by browser extensions. +/// Body size limit before breaking into chunks. +pub const CHUNK_LIMIT: usize = 256 * 1024; +/// Size of each chunk. +pub const CHUNK_SIZE: usize = 128 * 1024; + #[cfg(feature = "native-bridge-client")] pub mod client; #[cfg(feature = "native-bridge-server")] diff --git a/crates/ipc/src/native_bridge/server.rs b/crates/ipc/src/native_bridge/server.rs index 44fd12794b..a1edf6ce9f 100644 --- a/crates/ipc/src/native_bridge/server.rs +++ b/crates/ipc/src/native_bridge/server.rs @@ -14,9 +14,11 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::{mpsc, Mutex}; use tokio::time::sleep; -use tokio_util::codec::LengthDelimitedCodec; +use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; -const LIMIT: usize = 1024 * 1024; +use super::{CHUNK_LIMIT, CHUNK_SIZE}; + +const HARD_LIMIT: usize = 1024 * 1024; static CONN: Lazy>>> = Lazy::new(|| Arc::new(Mutex::new(None))); @@ -98,49 +100,88 @@ impl NativeBridgeServer { let (tx, mut rx) = mpsc::unbounded_channel::(); + // Read request chunks into a single request + async fn read_chunked_request( + stdin: &mut FramedRead, + ) -> Result { + let mut chunks: Vec = Vec::new(); + while let Some(Ok(buffer)) = stdin.next().await { + let req = serde_json::from_slice::(&buffer)?; + let chunks_len = req.chunks_len(); + chunks.push(req); + if chunks.len() == chunks_len as usize { + break; + } + } + Ok(LocalRequest::from_chunks(chunks)) + } + loop { let channel = tx.clone(); let sock_name = socket_name.clone(); tokio::select! { - Some(Ok(buffer)) = stdin.next() => { + result = read_chunked_request(&mut stdin) => { + let Ok(request) = result else { + let response = LocalResponse::with_id( + StatusCode::BAD_REQUEST, + 0, + ); + let tx = channel.clone(); + if let Err(e) = tx.send(response.into()) { + tracing::warn!( + error = %e, + "native_bridge::response_channel"); + } + continue; + }; + tokio::task::spawn(async move { + let tx = channel.clone(); + + tracing::trace!( + request = ?request, + "sos_native_bridge::request", + ); + + let message_id = request.request_id(); + + // Is this a command we handle internally? + let response = if is_native_request(&request) { + handle_native_request( + request, + ) + .await + } else { + try_send_request(&sock_name, request).await + }; + + let result = match response { + Ok(response) => response, + Err(_) => { + LocalResponse::with_id( + StatusCode::SERVICE_UNAVAILABLE, + message_id, + ) + } + }; + + // Send response in chunks to avoid the 1MB + // hard limit + let chunks = result.into_chunks( + CHUNK_LIMIT, + CHUNK_SIZE, + ); + for chunk in chunks { + if let Err(e) = tx.send(chunk) { + tracing::warn!( + error = %e, + "native_bridge::response_channel"); + } + } + }); + + /* match serde_json::from_slice::(&buffer) { Ok(request) => { - tokio::task::spawn(async move { - let tx = channel.clone(); - - tracing::trace!( - request = ?request, - "sos_native_bridge::request", - ); - - let message_id = request.request_id(); - - // Is this a command we handle internally? - let response = if is_native_request(&request) { - handle_native_request( - request, - ) - .await - } else { - try_send_request(&sock_name, request).await - }; - - let result = match response { - Ok(response) => response, - Err(_) => { - LocalResponse::with_id( - StatusCode::SERVICE_UNAVAILABLE, - message_id, - ) - } - }; - - if let Err(e) = tx.send(result) { - tracing::warn!( - error = %e, - "native_bridge::response_channel"); - } - }); } Err(_) => { let response = LocalResponse::with_id(StatusCode::BAD_REQUEST, 0); @@ -152,6 +193,7 @@ impl NativeBridgeServer { } } } + */ } Some(response) = rx.recv() => { tracing::trace!( @@ -165,7 +207,7 @@ impl NativeBridgeServer { len = %output.len(), "native_bridge::stdout", ); - if output.len() > LIMIT { + if output.len() > HARD_LIMIT { tracing::error!("native_bridge::exceeds_limit"); } if let Err(e) = stdout.send(output.into()).await { @@ -178,6 +220,7 @@ impl NativeBridgeServer { std::process::exit(1); } } + } } }