diff --git a/kuksa_databroker/databroker/src/viss/server.rs b/kuksa_databroker/databroker/src/viss/server.rs index 6b2c640d..09e05f07 100644 --- a/kuksa_databroker/databroker/src/viss/server.rs +++ b/kuksa_databroker/databroker/src/viss/server.rs @@ -24,8 +24,8 @@ use tracing::{debug, error, info}; use futures::{channel::mpsc, Sink}; use futures::{stream::StreamExt, Stream}; -use crate::broker; use crate::viss::v2; +use crate::{broker, viss::v2::VissService}; pub async fn serve( addr: impl Into, @@ -35,7 +35,7 @@ pub async fn serve( // signal: F ) -> Result<(), Box> { let app = Router::new() - .route("/", get(handle_upgrade::)) + .route("/", get(handle_upgrade)) .with_state(broker); let addr = addr.into(); @@ -52,20 +52,17 @@ pub async fn serve( } // Handle upgrade request -async fn handle_upgrade( +async fn handle_upgrade( ws: WebSocketUpgrade, axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo, - axum::extract::State(state): axum::extract::State, -) -> impl IntoResponse -where - T: v2::VissService, -{ + axum::extract::State(state): axum::extract::State, +) -> impl IntoResponse { debug!("Received websocket upgrade request"); ws.on_upgrade(move |socket| handle_websocket(socket, addr, state)) } // Handle websocket (one per connection) -async fn handle_websocket(socket: WebSocket, addr: SocketAddr, state: impl v2::VissService) { +async fn handle_websocket(socket: WebSocket, addr: SocketAddr, broker: broker::DataBroker) { let valid_subprotocol = match socket.protocol() { Some(subprotocol) => match subprotocol.to_str() { Ok("VISSv2") => true, @@ -92,14 +89,14 @@ async fn handle_websocket(socket: WebSocket, addr: SocketAddr, state: impl v2::V let (write, read) = socket.split(); - handle_viss_v2(write, read, addr, state).await; + handle_viss_v2(write, read, addr, broker).await; } async fn handle_viss_v2( write: W, mut read: R, client_addr: SocketAddr, - viss: impl v2::VissService, + broker: broker::DataBroker, ) where W: Sink + Unpin + Send + 'static, >::Error: Send, @@ -109,6 +106,8 @@ async fn handle_viss_v2( // single consumer will write to the socket. let (sender, receiver) = mpsc::channel::(10); + let server = v2::Server::new(broker); + let mut write_task = tokio::spawn(async move { let _ = receiver.map(Ok).forward(write).await; }); @@ -125,24 +124,25 @@ async fn handle_viss_v2( Ok(request) => match request { v2::Request::Get(request) => { debug!("Get request parsed successfully"); - match viss.get(request).await { + match server.get(request).await { Ok(response) => serde_json::to_string(&response), Err(error_response) => serde_json::to_string(&error_response), } } v2::Request::Set(request) => { - debug!("Set request parsed successfully"); - match viss.set(request).await { + debug!("Set request successfully parsed"); + match server.set(request).await { Ok(response) => serde_json::to_string(&response), Err(error_response) => serde_json::to_string(&error_response), } } v2::Request::Subscribe(request) => { - debug!("Subscribe request parsed successfully"); - match viss.subscribe(request).await { + debug!("Subscribe request successfully parsed"); + match server.subscribe(request).await { Ok((response, stream)) => { // Setup background stream let mut subscription_sender = sender.clone(); + tokio::spawn(async move { let mut stream = Box::pin(stream); while let Some(item) = stream.next().await { @@ -154,14 +154,15 @@ async fn handle_viss_v2( }; if let Ok(serialized) = serialized { + debug!("Sending notification: {}", serialized); match subscription_sender .try_send(Message::Text(serialized)) { Ok(_) => { - debug!("Successfully sent response") + debug!("Successfully sent notification") } Err(err) => { - debug!("Failed to send response: {err}") + debug!("Failed to send notification: {err}") } }; } @@ -174,12 +175,20 @@ async fn handle_viss_v2( Err(error_response) => serde_json::to_string(&error_response), } } + v2::Request::Unsubscribe(request) => { + debug!("Unsubscribe request successfully parsed"); + match server.unsubscribe(request).await { + Ok(response) => serde_json::to_string(&response), + Err(error_response) => serde_json::to_string(&error_response), + } + } }, Err(err) => serde_json::to_string(&err), }; // Send it if let Ok(serialized) = serialized { + debug!("Sending response: {}", serialized); let mut sender = sender; match sender.try_send(Message::Text(serialized)) { Ok(_) => debug!("Successfully sent response"), diff --git a/kuksa_databroker/databroker/src/viss/v2/server.rs b/kuksa_databroker/databroker/src/viss/v2/server.rs index ed5a0d93..2cd56f40 100644 --- a/kuksa_databroker/databroker/src/viss/v2/server.rs +++ b/kuksa_databroker/databroker/src/viss/v2/server.rs @@ -15,20 +15,20 @@ use std::{ collections::{HashMap, HashSet}, convert::{TryFrom, TryInto}, pin::Pin, + sync::Arc, time::SystemTime, }; -use futures::{Stream, StreamExt}; +use futures::{ + stream::{AbortHandle, Abortable}, + Stream, StreamExt, +}; +use tokio::sync::RwLock; use crate::{broker, permissions}; use super::types::*; -pub fn parse_request(msg: &str) -> Result { - let request: Request = serde_json::from_str(msg).map_err(|_err| Error::BadRequest)?; - Ok(request) -} - #[tonic::async_trait] pub(crate) trait VissService: Send + Sync + 'static { async fn get(&self, request: GetRequest) -> Result; @@ -42,52 +42,37 @@ pub(crate) trait VissService: Send + Sync + 'static { &self, request: SubscribeRequest, ) -> Result<(SubscribeSuccessResponse, Self::SubscribeStream), SubscribeErrorResponse>; + + async fn unsubscribe( + &self, + request: UnsubscribeRequest, + ) -> Result; } -fn convert_to_viss_stream( - subscription_id: SubscriptionId, - input: impl Stream, -) -> impl Stream> { - input.map(move |item| { - let ts = SystemTime::now().into(); - let subscription_id = subscription_id.clone(); - match item.updates.get(0) { - Some(item) => match (&item.update.path, &item.update.datapoint) { - (Some(path), Some(datapoint)) => match datapoint.clone().try_into() { - Ok(dp) => Ok(SubscriptionEvent { - subscription_id, - data: Data::Object(DataObject { - path: path.clone().into(), - dp, - }), - ts, - }), - Err(error) => Err(SubscriptionErrorEvent { - subscription_id, - error, - ts, - }), - }, - (_, _) => Err(SubscriptionErrorEvent { - subscription_id, - error: Error::InternalServerError, - ts, - }), - }, - None => Err(SubscriptionErrorEvent { - subscription_id, - error: Error::InternalServerError, - ts, - }), +pub struct Server { + broker: broker::DataBroker, + subscriptions: Arc>>, +} + +impl Server { + pub fn new(broker: broker::DataBroker) -> Self { + Self { + broker, + subscriptions: Arc::new(RwLock::new(HashMap::new())), } - }) + } +} + +pub fn parse_request(msg: &str) -> Result { + let request: Request = serde_json::from_str(msg).map_err(|_err| Error::BadRequest)?; + Ok(request) } #[tonic::async_trait] -impl VissService for broker::DataBroker { +impl VissService for Server { async fn get(&self, request: GetRequest) -> Result { let permissions = &permissions::ALLOW_ALL; - let broker = self.authorized_access(permissions); + let broker = self.broker.authorized_access(permissions); match broker.get_datapoint_by_path(request.path.as_ref()).await { Ok(datapoint) => match datapoint.try_into() { @@ -118,7 +103,7 @@ impl VissService for broker::DataBroker { async fn set(&self, request: SetRequest) -> Result { let permissions = &permissions::ALLOW_ALL; - let broker = self.authorized_access(permissions); + let broker = self.broker.authorized_access(permissions); match broker.get_metadata_by_path(request.path.as_ref()).await { Some(metadata) => { @@ -235,12 +220,36 @@ impl VissService for broker::DataBroker { >, >; + async fn unsubscribe( + &self, + request: UnsubscribeRequest, + ) -> Result { + let subscription_id = request.subscription_id; + let request_id = request.request_id; + match self.subscriptions.read().await.get(&subscription_id) { + Some(abort_handle) => { + abort_handle.abort(); + Ok(UnsubscribeSuccessResponse { + request_id, + subscription_id, + ts: SystemTime::now().into(), + }) + } + None => Err(UnsubscribeErrorResponse { + request_id, + subscription_id, + error: Error::BadRequest, + ts: SystemTime::now().into(), + }), + } + } + async fn subscribe( &self, request: SubscribeRequest, ) -> Result<(SubscribeSuccessResponse, Self::SubscribeStream), SubscribeErrorResponse> { let permissions = &permissions::ALLOW_ALL; - let broker = self.authorized_access(permissions); + let broker = self.broker.authorized_access(permissions); let entries = HashMap::from([( Into::::into(request.path), @@ -248,15 +257,28 @@ impl VissService for broker::DataBroker { )]); match broker.subscribe(entries).await { Ok(stream) => { - let subscription_id = SubscriptionId::from(request.request_id.as_ref()); + let subscription_id = SubscriptionId::from(request.request_id.as_ref().to_owned()); + + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + + // Make the stream abortable + let stream = Abortable::new(stream, abort_registration); + + // Register abort handle + self.subscriptions + .write() + .await + .insert(subscription_id.clone(), abort_handle); + + let stream = convert_to_viss_stream(subscription_id.clone(), stream); Ok(( SubscribeSuccessResponse { request_id: request.request_id, - subscription_id: subscription_id.clone(), + subscription_id, ts: SystemTime::now().into(), }, - Box::pin(convert_to_viss_stream(subscription_id, stream)), + Box::pin(stream), )) } Err(err) => Err(SubscribeErrorResponse { @@ -271,3 +293,42 @@ impl VissService for broker::DataBroker { } } } + +fn convert_to_viss_stream( + subscription_id: SubscriptionId, + stream: impl Stream, +) -> impl Stream> { + stream.map(move |item| { + let ts = SystemTime::now().into(); + let subscription_id = subscription_id.clone(); + match item.updates.get(0) { + Some(item) => match (&item.update.path, &item.update.datapoint) { + (Some(path), Some(datapoint)) => match datapoint.clone().try_into() { + Ok(dp) => Ok(SubscriptionEvent { + subscription_id, + data: Data::Object(DataObject { + path: path.clone().into(), + dp, + }), + ts, + }), + Err(error) => Err(SubscriptionErrorEvent { + subscription_id, + error, + ts, + }), + }, + (_, _) => Err(SubscriptionErrorEvent { + subscription_id, + error: Error::InternalServerError, + ts, + }), + }, + None => Err(SubscriptionErrorEvent { + subscription_id, + error: Error::InternalServerError, + ts, + }), + } + }) +} diff --git a/kuksa_databroker/databroker/src/viss/v2/types.rs b/kuksa_databroker/databroker/src/viss/v2/types.rs index af169993..408a7da1 100644 --- a/kuksa_databroker/databroker/src/viss/v2/types.rs +++ b/kuksa_databroker/databroker/src/viss/v2/types.rs @@ -24,6 +24,8 @@ pub enum Request { Set(SetRequest), #[serde(rename = "subscribe")] Subscribe(SubscribeRequest), + #[serde(rename = "unsubscribe")] + Unsubscribe(UnsubscribeRequest), } #[derive(Deserialize)] @@ -115,6 +117,31 @@ pub struct SubscriptionErrorEvent { pub ts: Timestamp, } +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UnsubscribeRequest { + pub request_id: RequestId, + pub subscription_id: SubscriptionId, + // authorization: Option, +} + +#[derive(Serialize)] +#[serde(tag = "action", rename = "unsubscribe", rename_all = "camelCase")] +pub struct UnsubscribeSuccessResponse { + pub request_id: RequestId, + pub subscription_id: SubscriptionId, + pub ts: Timestamp, +} + +#[derive(Serialize)] +#[serde(tag = "action", rename = "unsubscribe", rename_all = "camelCase")] +pub struct UnsubscribeErrorResponse { + pub request_id: RequestId, + pub subscription_id: SubscriptionId, + pub error: Error, + pub ts: Timestamp, +} + // Unique id value specified by the client. Returned by the server in the // response and used by the client to link the request and response messages. // The value MAY be an integer or a Universally Unique Identifier (UUID). @@ -137,7 +164,7 @@ pub struct RequestId(String); pub struct Path(String); // A handle identifying a subscription session. -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Eq, Hash, PartialEq)] #[serde(transparent)] pub struct SubscriptionId(String); @@ -297,12 +324,15 @@ impl From for Path { } } -impl From for SubscriptionId -where - S: Into, -{ - fn from(value: S) -> Self { - SubscriptionId(value.into()) +impl From for SubscriptionId { + fn from(value: String) -> Self { + SubscriptionId(value) + } +} + +impl From for String { + fn from(value: SubscriptionId) -> Self { + value.0 } }