Skip to content

Commit

Permalink
Improve client implementation.
Browse files Browse the repository at this point in the history
So that the LocalSocketStream is re-used between requests rather than
re-connecting to the named pipe each time. Reconnecting was a regression
introduced when we switched to hyper for the transport.
  • Loading branch information
tmpfs committed Dec 6, 2024
1 parent 1530f6d commit eefa2ff
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 53 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/integration_tests/tests/ipc/app_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async fn integration_ipc_app_info() -> Result<()> {

tokio::time::sleep(Duration::from_millis(250)).await;

let client = LocalSocketClient::connect(&socket_name).await?;
let mut client = LocalSocketClient::connect(&socket_name).await?;
let info = client.info().await?;
assert_eq!(name, &info.name);
assert_eq!(version, &info.version);
Expand Down
2 changes: 1 addition & 1 deletion crates/integration_tests/tests/ipc/list_accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async fn integration_ipc_list_accounts() -> Result<()> {
tokio::time::sleep(Duration::from_millis(250)).await;

// Create a client and list accounts
let client = LocalSocketClient::connect(&socket_name).await?;
let mut client = LocalSocketClient::connect(&socket_name).await?;
let accounts = client.list_accounts().await?;
assert_eq!(2, accounts.len());

Expand Down
4 changes: 2 additions & 2 deletions crates/ipc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ contacts = ["sos-sdk/contacts", "sos-protocol/contacts"]
files = ["sos-sdk/files", "sos-protocol/files"]
migrate = ["sos-sdk/migrate", "sos-protocol/migrate"]
search = ["sos-sdk/search", "sos-protocol/search"]
native-bridge-server = ["client", "open", "tokio/io-std"]
native-bridge-server = ["client", "open", "tokio/io-std", "once_cell"]
native-bridge-client = ["local-transport", "tokio/process", "tokio/io-std", "futures"]

[dependencies]
Expand All @@ -64,7 +64,7 @@ parking_lot.workspace = true
http.workspace = true
bytes.workspace = true
futures = { workspace = true, optional = true }

once_cell = { workspace = true, optional = true }
indexmap = { workspace = true, optional = true }
async-trait = { workspace = true, optional = true }
serde_with = { workspace = true, optional = true }
Expand Down
87 changes: 43 additions & 44 deletions crates/ipc/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::{Result, ServiceAppInfo};
use bytes::Bytes;
use futures::pin_mut;
use http::{Request, Response};
use http_body_util::{BodyExt, Full};
use hyper::client::conn::http1::handshake;
Expand All @@ -12,64 +13,62 @@ use hyper_util::rt::tokio::TokioIo;
use sos_protocol::{constants::routes::v1::ACCOUNTS_LIST, NetworkError};
use sos_sdk::prelude::PublicIdentity;

/// Send a local request.
pub async fn send_local(
socket_name: impl Into<String>,
request: LocalRequest,
) -> Result<LocalResponse> {
let request: Request<Vec<u8>> = request.try_into()?;
let (header, body) = request.into_parts();
let request = Request::from_parts(header, Full::new(Bytes::from(body)));
let response = send_http(socket_name, request).await?;
let (header, body) = response.into_parts();
let bytes = body.collect().await.unwrap().to_bytes();
let response = Response::from_parts(header, bytes.to_vec());
Ok(response.into())
}

/// Send a HTTP request.
pub async fn send_http(
socket_name: impl Into<String>,
request: Request<Full<Bytes>>,
) -> Result<Response<Full<Bytes>>> {
let name = socket_name.into().to_ns_name::<GenericNamespaced>()?;
let io = LocalSocketStream::connect(name).await?;
let socket = TokioIo::new(io);
let (mut sender, conn) = handshake(socket).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
tracing::error!(error = %err, "ipc::client::connection");
}
});
let response = sender.send_request(request).await?;
let (header, body) = response.into_parts();
let bytes = body.collect().await.unwrap().to_bytes();
Ok(Response::from_parts(header, Full::new(bytes)))
}

/// Socket client for inter-process communication.
pub struct LocalSocketClient {
socket_name: String,
socket: TokioIo<LocalSocketStream>,
}

impl LocalSocketClient {
/// Create a client and connect the server.
/// Create a client and connect to the named pipe.
pub async fn connect(socket_name: impl Into<String>) -> Result<Self> {
Ok(Self {
socket_name: socket_name.into(),
})
let name = socket_name.into().to_ns_name::<GenericNamespaced>()?;
let io = LocalSocketStream::connect(name).await?;
let socket = TokioIo::new(io);
Ok(Self { socket })
}

/// Send a local request.
pub async fn send_request(
&self,
&mut self,
request: LocalRequest,
) -> Result<LocalResponse> {
send_local(self.socket_name.clone(), request).await
let request: Request<Vec<u8>> = request.try_into()?;
let (header, body) = request.into_parts();
let request =
Request::from_parts(header, Full::new(Bytes::from(body)));
let response = self.send_http(request).await?;
let (header, body) = response.into_parts();
let bytes = body.collect().await.unwrap().to_bytes();
let response = Response::from_parts(header, bytes.to_vec());
Ok(response.into())
}

/// Send a HTTP request.
pub async fn send_http(
&mut self,
request: Request<Full<Bytes>>,
) -> Result<Response<Full<Bytes>>> {
let (mut sender, conn) = handshake(&mut self.socket).await?;

let conn = Box::pin(async move { conn.await });
let req = Box::pin(async move { sender.send_request(request).await });
pin_mut!(conn);
pin_mut!(req);

let (conn, response) = futures::future::join(conn, req).await;
if let Err(err) = conn {
tracing::error!(error = %err, "ipc::client::connection");
}

let (header, body) = response?.into_parts();
let bytes = body.collect().await.unwrap().to_bytes();
let response = Response::from_parts(header, Full::new(bytes));

Ok(response)
}

/// Get application information.
pub async fn info(&self) -> Result<ServiceAppInfo> {
pub async fn info(&mut self) -> Result<ServiceAppInfo> {
let response = self.send_request(Default::default()).await?;
let status = response.status()?;
if status.is_success() {
Expand All @@ -82,7 +81,7 @@ impl LocalSocketClient {
}

/// List accounts.
pub async fn list_accounts(&self) -> Result<Vec<PublicIdentity>> {
pub async fn list_accounts(&mut self) -> Result<Vec<PublicIdentity>> {
let request = LocalRequest {
uri: ACCOUNTS_LIST.parse()?,
..Default::default()
Expand Down
70 changes: 65 additions & 5 deletions crates/ipc/src/native_bridge/server.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
//! Server for the native messaging API bridge.
use crate::{
client::send_local,
client::LocalSocketClient,
local_transport::{HttpMessage, LocalRequest, LocalResponse},
Result,
Error, Result,
};
use futures_util::{SinkExt, StreamExt};
use http::StatusCode;
use once_cell::sync::Lazy;
use sos_sdk::{logs::Logger, prelude::IPC_GUI_SOCKET_NAME, url::Url, Paths};
use tokio::sync::mpsc;
use std::io::ErrorKind;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex};
use tokio::time::sleep;
use tokio_util::codec::LengthDelimitedCodec;

const LIMIT: usize = 1024 * 1024;

static CONN: Lazy<Arc<Mutex<Option<LocalSocketClient>>>> =
Lazy::new(|| Arc::new(Mutex::new(None)));

/// Options for a native bridge.
#[derive(Debug, Default)]
pub struct NativeBridgeOptions {
Expand Down Expand Up @@ -114,8 +122,7 @@ impl NativeBridgeServer {
)
.await
} else {
send_local(
sock_name.clone(), request).await
try_send_request(&sock_name, request).await
};

let result = match response {
Expand Down Expand Up @@ -177,6 +184,59 @@ impl NativeBridgeServer {
}
}

async fn connect(socket_name: &str) -> Arc<Mutex<Option<LocalSocketClient>>> {
let mut conn = CONN.lock().await;
if conn.is_some() {
return Arc::clone(&*CONN);
}
let socket_client = try_connect(socket_name).await;
*conn = Some(socket_client);
return Arc::clone(&*CONN);
}

async fn try_connect(socket_name: &str) -> LocalSocketClient {
let retry_delay = Duration::from_secs(1);
loop {
match LocalSocketClient::connect(socket_name).await {
Ok(client) => return client,
Err(e) => {
tracing::trace!(
error = %e,
"native_bridge::connect",
);
sleep(retry_delay).await;
}
}
}
}

/// Send an IPC request and reconnect for certain types of IO error.
async fn try_send_request(
socket_name: &str,
request: LocalRequest,
) -> Result<LocalResponse> {
loop {
let conn = connect(socket_name).await;
let mut lock = conn.lock().await;
let client = lock.as_mut().unwrap();
match client.send_request(request.clone()).await {
Ok(response) => return Ok(response),
Err(e) => match e {
Error::Io(io_err) => match io_err.kind() {
ErrorKind::BrokenPipe => {
// Move the broken client out
// so the next attempt to connect
// will create a new client
lock.take();
}
_ => return Err(Error::Io(io_err)),
},
_ => return Err(e),
},
}
}
}

/// Native requests are those handled by this native bridge.
fn is_native_request(request: &LocalRequest) -> bool {
match request.uri.path() {
Expand Down

0 comments on commit eefa2ff

Please sign in to comment.