diff --git a/Cargo.lock b/Cargo.lock index 72fd624a05..996ea4e87c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4613,8 +4613,10 @@ name = "sos-ipc" version = "0.16.0" dependencies = [ "async-trait", + "bytes", "futures-util", "http", + "http-body-util", "hyper", "interprocess", "matchit", diff --git a/crates/integration_tests/tests/ipc/app_info.rs b/crates/integration_tests/tests/ipc/app_info.rs index 2d17eb6f02..be44a72286 100644 --- a/crates/integration_tests/tests/ipc/app_info.rs +++ b/crates/integration_tests/tests/ipc/app_info.rs @@ -1,7 +1,7 @@ use anyhow::Result; use sos_ipc::{ - remove_socket_file, AppIntegration, Error, IpcService, IpcServiceOptions, - ServiceAppInfo, SocketClient, SocketServer, + remove_socket_file, AppIntegration, Error, ServiceAppInfo, + ServiceOptions, SocketClient, SocketServer, }; use sos_net::sdk::{prelude::LocalAccountSwitcher, Paths}; use sos_test_utils::teardown; @@ -35,21 +35,25 @@ async fn integration_ipc_app_info() -> Result<()> { let build_number = 1u32; // Start the IPC service + /* let service = Arc::new(RwLock::new(IpcService::new( ipc_accounts, - IpcServiceOptions { - app_info: Some(ServiceAppInfo { - name: name.to_string(), - version: version.to_string(), - build_number, - }), - ..Default::default() - }, ))); + */ + + let options = ServiceOptions { + app_info: Some(ServiceAppInfo { + name: name.to_string(), + version: version.to_string(), + build_number, + }), + ..Default::default() + }; let server_socket_name = socket_name.clone(); tokio::task::spawn(async move { - SocketServer::listen(&server_socket_name, service).await?; + SocketServer::listen(&server_socket_name, ipc_accounts, options) + .await?; Ok::<(), Error>(()) }); diff --git a/crates/integration_tests/tests/ipc/list_accounts.rs b/crates/integration_tests/tests/ipc/list_accounts.rs index 935d4fa869..35225ba3ed 100644 --- a/crates/integration_tests/tests/ipc/list_accounts.rs +++ b/crates/integration_tests/tests/ipc/list_accounts.rs @@ -1,7 +1,6 @@ use anyhow::Result; use sos_ipc::{ - remove_socket_file, AppIntegration, Error, IpcService, SocketClient, - SocketServer, + remove_socket_file, AppIntegration, Error, SocketClient, SocketServer, }; use sos_net::sdk::{ crypto::AccessKey, @@ -64,15 +63,16 @@ async fn integration_ipc_list_accounts() -> Result<()> { accounts.add_account(auth_account); accounts.add_account(unauth_account); - // Start the IPC service - let service = Arc::new(RwLock::new(IpcService::new( - Arc::new(RwLock::new(accounts)), - Default::default(), - ))); + let ipc_accounts = Arc::new(RwLock::new(accounts)); let server_socket_name = socket_name.clone(); tokio::task::spawn(async move { - SocketServer::listen(&server_socket_name, service).await?; + SocketServer::listen( + &server_socket_name, + ipc_accounts, + Default::default(), + ) + .await?; Ok::<(), Error>(()) }); diff --git a/crates/integration_tests/tests/ipc/local_sync.rs b/crates/integration_tests/tests/ipc/local_sync.rs index 674d276bde..8a543314c1 100644 --- a/crates/integration_tests/tests/ipc/local_sync.rs +++ b/crates/integration_tests/tests/ipc/local_sync.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use sos_ipc::{remove_socket_file, Error, IpcService, SocketServer}; +use sos_ipc::{remove_socket_file, Error, SocketServer}; use sos_net::{ protocol::{ integration::{LinkedAccount, LocalClient, LocalIntegration}, @@ -64,15 +64,16 @@ async fn integration_ipc_local_sync() -> Result<()> { local_accounts.switch_account(&address); let local_accounts = Arc::new(RwLock::new(local_accounts)); - // Start the IPC service - let service = Arc::new(RwLock::new(IpcService::new( - local_accounts.clone(), - Default::default(), - ))); + let ipc_accounts = local_accounts.clone(); let server_socket_name = socket_name.clone(); tokio::task::spawn(async move { - SocketServer::listen(&server_socket_name, service).await?; + SocketServer::listen( + &server_socket_name, + ipc_accounts, + Default::default(), + ) + .await?; Ok::<(), Error>(()) }); diff --git a/crates/ipc/Cargo.toml b/crates/ipc/Cargo.toml index d10ee878bc..16f15e317a 100644 --- a/crates/ipc/Cargo.toml +++ b/crates/ipc/Cargo.toml @@ -42,6 +42,8 @@ http.workspace = true tower = { version = "0.5", features = ["util"]} hyper = { version = "1" } matchit = "0.7" +bytes.workspace = true +http-body-util = "0.1" [build-dependencies] rustc_version = "0.4.1" diff --git a/crates/ipc/src/io.rs b/crates/ipc/src/io.rs index 1aae060576..6acd84f412 100644 --- a/crates/ipc/src/io.rs +++ b/crates/ipc/src/io.rs @@ -1,36 +1,31 @@ use pin_project_lite::pin_project; -use std::io::{Read, Write}; use std::pin::Pin; use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; pin_project! { #[derive(Debug)] - pub struct TokioIo { + pub struct TokioAdapter { #[pin] inner: T, } } -impl TokioIo { +impl TokioAdapter { pub fn new(inner: T) -> Self { Self { inner } } - - pub fn inner(self) -> T { - self.inner - } } -impl hyper::rt::Read for TokioIo +impl hyper::rt::Read for TokioAdapter where - T: Read, + T: AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, mut buf: hyper::rt::ReadBufCursor<'_>, ) -> Poll> { - /* let n = unsafe { let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); match tokio::io::AsyncRead::poll_read( @@ -47,43 +42,37 @@ where buf.advance(n); } Poll::Ready(Ok(())) - */ - todo!(); } } -impl hyper::rt::Write for TokioIo +impl hyper::rt::Write for TokioAdapter where - T: Write, + T: AsyncWrite, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - // tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) - todo!(); + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) } fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - // tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) - todo!(); + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - // tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) - todo!(); + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) } fn is_write_vectored(&self) -> bool { - // tokio::io::AsyncWrite::is_write_vectored(&self.inner) - todo!(); + tokio::io::AsyncWrite::is_write_vectored(&self.inner) } fn poll_write_vectored( @@ -91,13 +80,10 @@ where cx: &mut Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll> { - /* tokio::io::AsyncWrite::poll_write_vectored( self.project().inner, cx, bufs, ) - */ - todo!(); } } diff --git a/crates/ipc/src/lib.rs b/crates/ipc/src/lib.rs index 769de09f0a..ffa8ac448b 100644 --- a/crates/ipc/src/lib.rs +++ b/crates/ipc/src/lib.rs @@ -1,5 +1,4 @@ #![deny(missing_docs)] -#![forbid(unsafe_code)] #![cfg_attr(all(doc, CHANNEL_NIGHTLY), feature(doc_auto_cfg))] //! Inter-process communcation library for [Save Our Secrets](https://saveoursecrets.com/). //! @@ -25,10 +24,12 @@ pub mod native_bridge; #[cfg(any(feature = "client", feature = "server"))] pub(crate) mod io; +#[cfg(feature = "server")] +mod local_server; #[cfg(feature = "server")] mod server; #[cfg(feature = "server")] -mod service; +pub(crate) use local_server::LocalServer; pub use error::Error; @@ -41,9 +42,7 @@ pub type Result = std::result::Result; pub use client::{AppIntegration, SocketClient}; #[cfg(feature = "server")] -pub use server::SocketServer; -#[cfg(feature = "server")] -pub use service::{IpcService, IpcServiceOptions}; +pub use server::{ServiceOptions, SocketServer}; /// Information about the service. #[typeshare::typeshare] diff --git a/crates/ipc/src/service/local_server/account.rs b/crates/ipc/src/local_server/account.rs similarity index 94% rename from crates/ipc/src/service/local_server/account.rs rename to crates/ipc/src/local_server/account.rs index fd4a71d084..cd6ed6b0c7 100644 --- a/crates/ipc/src/service/local_server/account.rs +++ b/crates/ipc/src/local_server/account.rs @@ -1,4 +1,5 @@ use http::{Request, Response, StatusCode}; +use http_body_util::Full; use hyper::body::Bytes; use sos_protocol::{ server_helpers, Merge, SyncPacket, SyncStorage, WireEncodeDecode, @@ -9,7 +10,8 @@ use tokio::sync::RwLock; use super::{ bad_request, forbidden, internal_server_error, json, not_found, ok, - parse_account_id, protobuf, protobuf_compress, Body, Incoming, + parse_account_id, protobuf, protobuf_compress, read_bytes, Body, + Incoming, }; /// List of account public identities. @@ -133,7 +135,7 @@ where return internal_server_error("fetch_account::encode"); }; - protobuf_compress(buffer) + protobuf_compress(Full::new(Bytes::from(buffer))) } Err(e) => internal_server_error(e), } @@ -190,7 +192,7 @@ where let Ok(buffer) = status.encode().await else { return internal_server_error("sync_status::encode"); }; - protobuf(buffer) + protobuf(Full::new(Bytes::from(buffer))) } Err(e) => internal_server_error(e), } @@ -224,7 +226,7 @@ where if let Some(account) = accounts.iter_mut().find(|a| a.address() == &account_id) { - let buf: Bytes = req.into_body().into(); + let buf: Bytes = read_bytes(req).await?; let Ok(packet) = SyncPacket::decode(buf).await else { return bad_request(); }; @@ -235,7 +237,7 @@ where return internal_server_error("sync_account::encode"); }; - protobuf(response) + protobuf(Full::new(Bytes::from(response))) } Err(e) => internal_server_error(e), } diff --git a/crates/ipc/src/service/local_server/common.rs b/crates/ipc/src/local_server/common.rs similarity index 91% rename from crates/ipc/src/service/local_server/common.rs rename to crates/ipc/src/local_server/common.rs index 2b79f6abcc..0bdc860eeb 100644 --- a/crates/ipc/src/service/local_server/common.rs +++ b/crates/ipc/src/local_server/common.rs @@ -1,7 +1,9 @@ +use bytes::Bytes; use http::{ header::{CONTENT_ENCODING, CONTENT_TYPE}, Request, Response, StatusCode, }; +use http_body_util::{BodyExt, Full}; use serde::Serialize; use sos_protocol::constants::{ ENCODING_ZLIB, ENCODING_ZSTD, MIME_TYPE_JSON, MIME_TYPE_PROTOBUF, @@ -17,6 +19,10 @@ struct ErrorReply { message: String, } +pub async fn read_bytes(req: Request) -> hyper::Result { + Ok(req.collect().await?.to_bytes()) +} + pub fn parse_account_id(req: &Request) -> Option
{ let Some(Ok(account_id)) = req.headers().get(X_SOS_ACCOUNT_ID).map(|v| v.to_str()) @@ -78,7 +84,7 @@ pub fn json( let response = Response::builder() .status(status) .header(CONTENT_TYPE, MIME_TYPE_JSON) - .body(body) + .body(Full::new(Bytes::from(body))) .unwrap(); Ok(response) } @@ -93,15 +99,17 @@ pub fn protobuf(body: Body) -> hyper::Result> { } pub fn protobuf_compress(body: Body) -> hyper::Result> { + /* use sos_protocol::compression::zlib; let Ok(buf) = zlib::encode_all(body.as_slice()) else { return internal_server_error("zlib::compress"); }; + */ Ok(Response::builder() .status(StatusCode::OK) - .header(CONTENT_ENCODING, ENCODING_ZLIB) + // .header(CONTENT_ENCODING, ENCODING_ZLIB) .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF) - .body(buf) + .body(body) .unwrap()) } diff --git a/crates/ipc/src/service/local_server/events.rs b/crates/ipc/src/local_server/events.rs similarity index 90% rename from crates/ipc/src/service/local_server/events.rs rename to crates/ipc/src/local_server/events.rs index 334d2190a2..d11c9e5e71 100644 --- a/crates/ipc/src/service/local_server/events.rs +++ b/crates/ipc/src/local_server/events.rs @@ -1,5 +1,6 @@ +use bytes::Bytes; use http::{Request, Response}; -use hyper::body::Bytes; +use http_body_util::Full; use sos_protocol::{ server_helpers, DiffRequest, Merge, PatchRequest, ScanRequest, SyncStorage, WireEncodeDecode, @@ -10,7 +11,7 @@ use tokio::sync::RwLock; use super::{ bad_request, internal_server_error, not_found, parse_account_id, - protobuf, Body, Incoming, + protobuf, read_bytes, Body, Incoming, }; pub async fn event_scan( @@ -37,7 +38,7 @@ where if let Some(account) = accounts.iter().find(|a| a.address() == &account_id) { - let buf: Bytes = req.into_body().into(); + let buf: Bytes = read_bytes(req).await?; let Ok(packet) = ScanRequest::decode(buf).await else { return bad_request(); }; @@ -53,7 +54,7 @@ where return internal_server_error("event_scan::encode"); }; - protobuf(buffer) + protobuf(Full::new(Bytes::from(buffer))) } Err(e) => internal_server_error(e), } @@ -86,7 +87,7 @@ where if let Some(account) = accounts.iter().find(|a| a.address() == &account_id) { - let buf: Bytes = req.into_body().into(); + let buf: Bytes = read_bytes(req).await?; let Ok(packet) = DiffRequest::decode(buf).await else { return bad_request(); }; @@ -97,7 +98,7 @@ where return internal_server_error("event_diff::encode"); }; - protobuf(buffer) + protobuf(Full::new(Bytes::from(buffer))) } Err(e) => internal_server_error(e), } @@ -131,7 +132,7 @@ where if let Some(account) = accounts.iter_mut().find(|a| a.address() == &account_id) { - let buf: Bytes = req.into_body().into(); + let buf: Bytes = read_bytes(req).await?; let Ok(packet) = PatchRequest::decode(buf).await else { return bad_request(); }; @@ -141,7 +142,7 @@ where let Ok(buffer) = response.encode().await else { return internal_server_error("event_patch::encode"); }; - protobuf(buffer) + protobuf(Full::new(Bytes::from(buffer))) } Err(e) => internal_server_error(e), } diff --git a/crates/ipc/src/service/local_server/mod.rs b/crates/ipc/src/local_server/mod.rs similarity index 89% rename from crates/ipc/src/service/local_server/mod.rs rename to crates/ipc/src/local_server/mod.rs index 599870e578..9c6c35661b 100644 --- a/crates/ipc/src/service/local_server/mod.rs +++ b/crates/ipc/src/local_server/mod.rs @@ -1,14 +1,17 @@ +use bytes::Bytes; use http::{Method, Request, Response, StatusCode}; +use http_body_util::Full; +use hyper::body::Incoming; +use hyper::service::Service as HyperService; use parking_lot::Mutex; use sos_protocol::{ constants::routes::v1::{ ACCOUNTS_LIST, SYNC_ACCOUNT, SYNC_ACCOUNT_EVENTS, SYNC_ACCOUNT_STATUS, }, - local_transport::{LocalRequest, LocalResponse}, Merge, SyncStorage, }; use sos_sdk::prelude::{Account, AccountSwitcher}; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; use tokio::sync::RwLock; use tower::service_fn; use tower::util::BoxCloneService; @@ -16,8 +19,8 @@ use tower::Service as _; use crate::ServiceAppInfo; -type Incoming = Vec; -type Body = Vec; +// type Incoming = Vec; +type Body = Full; // Need the Mutex as BoxCloneService does not implement Sync type Service = @@ -44,6 +47,7 @@ async fn index( /// /// We avoid using axum directly as we need the `Sync` bound /// but `axum::Body` is `!Sync`. +#[derive(Clone)] pub(crate) struct LocalServer { /// Service router. router: Arc, @@ -253,6 +257,7 @@ impl LocalServer { } } + /* pub async fn handle(&self, req: LocalRequest) -> LocalResponse { let res = match req.try_into() { Ok(req) => self.call(req).await, @@ -263,13 +268,17 @@ impl LocalServer { }; res.into() } + */ - async fn call(&self, req: Request) -> Response { + pub(crate) async fn call( + &self, + req: Request, + ) -> hyper::Result> { let router = self.router.clone(); - match Self::route(router, req).await { + Ok(match Self::route(router, req).await { Ok(result) => result, Err(e) => internal_server_error(e).unwrap(), - } + }) } async fn route( @@ -296,3 +305,20 @@ impl LocalServer { } } } + +impl HyperService> for LocalServer { + type Response = Response>; + type Error = hyper::Error; + type Future = Pin< + Box< + dyn Future< + Output = std::result::Result, + > + Send, + >, + >; + + fn call(&self, req: Request) -> Self::Future { + let router = self.router.clone(); + Box::pin(async move { LocalServer::route(router, req).await }) + } +} diff --git a/crates/ipc/src/server.rs b/crates/ipc/src/server.rs index 73ea5e5f08..65e08cb992 100644 --- a/crates/ipc/src/server.rs +++ b/crates/ipc/src/server.rs @@ -1,41 +1,42 @@ -use http::StatusCode; +use crate::{io::TokioAdapter, LocalServer, Result, ServiceAppInfo}; +use hyper::server::conn::http1::Builder; use interprocess::local_socket::{ tokio::prelude::*, GenericNamespaced, ListenerOptions, }; -use sos_protocol::local_transport::{LocalRequest, LocalResponse}; -use std::{pin::Pin, sync::Arc}; +use sos_protocol::{Merge, SyncStorage}; +use sos_sdk::prelude::{Account, AccountSwitcher}; +use std::sync::Arc; +use tokio::sync::RwLock; -use futures_util::sink::SinkExt; -use hyper::{ - body::Incoming, server::conn::http1::Builder, service::HttpService, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::RwLock, - time::timeout, -}; -use tokio_stream::StreamExt; -use tokio_util::{ - bytes::BytesMut, - codec::{Framed, LengthDelimitedCodec}, -}; - -use crate::{ - codec, decode_proto, encode_proto, io::TokioIo, io_err, IpcService, - WireLocalRequest, WireLocalResponse, -}; - -use crate::Result; +/// Options for an IPC service. +#[derive(Default, Clone)] +pub struct ServiceOptions { + /// Application info. + pub app_info: Option, +} /// Socket server for inter-process communication. pub struct SocketServer; impl SocketServer { /// Listen on a bind address. - pub async fn listen(socket_name: &str, service: S) -> Result<()> + pub async fn listen( + socket_name: &str, + accounts: Arc>>, + options: ServiceOptions, + ) -> Result<()> where - S: HttpService + Send + 'static, - // B: Incoming + Send + Sync + 'static, + A: Account + + SyncStorage + + Merge + + Sync + + Send + + 'static, + R: 'static, + E: std::fmt::Debug + + From + + From + + 'static, { let name = socket_name.to_ns_name::()?; let opts = ListenerOptions::new().name(name); @@ -49,118 +50,21 @@ impl SocketServer { x => x?, }; + let service = + LocalServer::new(options.app_info.unwrap_or_default(), accounts); + let svc = Arc::new(service); + loop { let socket = listener.accept().await?; - // let service = service.clone(); - // - // socket.foo(); - + let svc = svc.clone(); tokio::spawn(async move { - let socket = TokioIo::new(socket); - // handle_conn(service, socket).await; - + let socket = TokioAdapter::new(socket); let http = Builder::new(); - let conn = http.serve_connection(socket, service); - - /* + let conn = http.serve_connection(socket, svc); if let Err(e) = conn.await { - eprintln!("server connection error: {}", e); + tracing::error!(error = %e); } - */ - - todo!(); }); } } } - -/* -async fn handle_conn(service: Arc>, socket: T) -where - T: AsyncRead + AsyncWrite + Sized, -{ - let io = Box::pin(socket); - let mut framed = codec::framed(io); - - while let Some(message) = framed.next().await { - match message { - Ok(bytes) => { - tracing::debug!( - len = bytes.len(), - "socket_server::socket_recv" - ); - - if let Err(err) = - handle_request(service.clone(), &mut framed, bytes).await - { - // Internal error, try to send a response and close - // the connection if we error here - let response = LocalResponse::new_internal_error(err); - let response: WireLocalResponse = response.into(); - match encode_proto(&response) { - Ok(buffer) => { - match framed.send(buffer.into()).await { - Err(err) => { - tracing::error!( - error = ?err, - "socket_server::internal_error::close_connection" - ); - break; - } - _ => {} - } - } - Err(err) => { - tracing::error!( - error = ?err, - "socket_server::internal_error::close_connection" - ); - break; - } - } - } - } - Err(err) => { - tracing::error!( - error = ?err, - "socket_server::socket_error", - ); - } - } - } - tracing::debug!("socket_server::socket_closed"); -} - -async fn handle_request( - service: Arc>, - channel: &mut Framed>, LengthDelimitedCodec>, - bytes: BytesMut, -) -> Result<()> -where - T: AsyncRead + AsyncWrite + Sized, -{ - let request: WireLocalRequest = decode_proto(&bytes).map_err(io_err)?; - let request: LocalRequest = request.try_into().map_err(io_err)?; - tracing::debug!( - request = ?request, - "socket_server::socket_request" - ); - let message_id = request.request_id(); - let handler = service.read().await; - let duration = request.timeout_duration(); - let response = match timeout(duration, handler.handle(request)).await { - Ok(res) => res, - Err(_) => { - tracing::debug!( - duration = ?duration, - "socket_server::request_timeout"); - LocalResponse::with_id(StatusCode::REQUEST_TIMEOUT, message_id) - } - }; - - let response: WireLocalResponse = response.into(); - let buffer = encode_proto(&response).map_err(io_err)?; - channel.send(buffer.into()).await?; - Ok(()) -} -*/ diff --git a/crates/ipc/src/service/mod.rs b/crates/ipc/src/service/mod.rs deleted file mode 100644 index 797ad2a038..0000000000 --- a/crates/ipc/src/service/mod.rs +++ /dev/null @@ -1,62 +0,0 @@ -use crate::ServiceAppInfo; -use sos_sdk::account::{Account, AccountSwitcher}; - -use sos_protocol::{ - local_transport::{LocalRequest, LocalResponse}, - Merge, SyncStorage, -}; -use std::sync::Arc; -use tokio::sync::RwLock; - -// mod delegate; -mod local_server; - -// pub use delegate::*; -use local_server::LocalServer; - -/// Options for an IPC service. -#[derive(Default)] -pub struct IpcServiceOptions { - /// Application info. - pub app_info: Option, -} - -/// Handler for IPC requests. -pub struct IpcService { - options: IpcServiceOptions, - server: LocalServer, -} - -impl IpcService { - /// Create a new service handler. - pub fn new( - accounts: Arc>>, - options: IpcServiceOptions, - ) -> Self - where - A: Account - + SyncStorage - + Merge - + Sync - + Send - + 'static, - R: 'static, - E: std::error::Error - + std::fmt::Debug - + From - + From - + 'static, - { - let app_info = options.app_info.clone().unwrap_or_default(); - - Self { - server: LocalServer::new(app_info, accounts), - options, - } - } - - /// Handle an incoming request. - pub async fn handle(&self, request: LocalRequest) -> LocalResponse { - self.server.handle(request).await - } -}