diff --git a/Cargo.lock b/Cargo.lock index 46d3e9610..cdb31bafa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1250,6 +1250,7 @@ dependencies = [ "pyo3-asyncio", "socket2", "tokio", + "uuid", ] [[package]] @@ -1527,6 +1528,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "uuid" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" +dependencies = [ + "getrandom", + "serde", +] + [[package]] name = "version_check" version = "0.9.4" diff --git a/Cargo.toml b/Cargo.toml index 1a5b3aed4..043fc243f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ actix-http = "3.0.0-beta.8" socket2 = { version = "0.4.1", features = ["all"] } actix = "0.12.0" actix-web-actors = "4.0.0-beta.1" +uuid = { version = "0.8", features = ["serde", "v4"] } [package.metadata.maturin] name = "robyn" diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index c2ecda18d..e3cfd6610 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -13,7 +13,8 @@ @websocket.on("message") -async def connect(): +async def connect(websocket_id): + print(websocket_id) global i i += 1 if i == 0: diff --git a/src/web_socket_connection.rs b/src/web_socket_connection.rs index bb1845729..227c838bf 100644 --- a/src/web_socket_connection.rs +++ b/src/web_socket_connection.rs @@ -6,6 +6,7 @@ use actix_web::{web, Error, HttpRequest, HttpResponse}; use actix_web_actors::ws; use actix_web_actors::ws::WebsocketContext; use pyo3::prelude::*; +use uuid::Uuid; use std::collections::HashMap; use std::sync::Arc; @@ -13,6 +14,7 @@ use std::sync::Arc; /// Define HTTP actor #[derive(Clone)] struct MyWs { + id: Uuid, router: HashMap, // can probably try removing arc from here // and use clone_ref() @@ -21,22 +23,35 @@ struct MyWs { fn execute_ws_functionn( handler_function: &PyFunction, + number_of_params: u8, event_loop: Arc, ctx: &mut ws::WebsocketContext, ws: &MyWs, + // add number of params here ) { match handler_function { PyFunction::SyncFunction(handler) => Python::with_gil(|py| { let handler = handler.as_ref(py); // call execute function - let op = handler.call0().unwrap(); - let op: &str = op.extract().unwrap(); + let op: PyResult<&PyAny> = match number_of_params { + 0 => handler.call0(), + 1 => handler.call1((ws.id.to_string(),)), + // this is done to accomodate any future params + 2_u8..=u8::MAX => handler.call1((ws.id.to_string(),)), + }; + + let op: &str = op.unwrap().extract().unwrap(); ctx.text(op); }), PyFunction::CoRoutine(handler) => { let fut = Python::with_gil(|py| { let handler = handler.as_ref(py); - let coro = handler.call0().unwrap(); + let coro = match number_of_params { + 0 => handler.call0().unwrap(), + 1 => handler.call1((ws.id.to_string(),)).unwrap(), + // this is done to accomodate any future params + 2_u8..=u8::MAX => handler.call1((ws.id.to_string(),)).unwrap(), + }; pyo3_asyncio::into_future_with_loop((*(event_loop.clone())).as_ref(py), coro) .unwrap() }); @@ -61,16 +76,28 @@ impl Actor for MyWs { fn started(&mut self, ctx: &mut WebsocketContext) { let handler_function = &self.router.get("connect").unwrap().0; - let _number_of_params = &self.router.get("connect").unwrap().1; - execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self); + let number_of_params = &self.router.get("connect").unwrap().1; + execute_ws_functionn( + handler_function, + *number_of_params, + self.event_loop.clone(), + ctx, + self, + ); println!("Actor is alive"); } fn stopped(&mut self, ctx: &mut WebsocketContext) { let handler_function = &self.router.get("close").expect("No close function").0; - let _number_of_params = &self.router.get("close").unwrap().1; - execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self); + let number_of_params = &self.router.get("close").unwrap().1; + execute_ws_functionn( + handler_function, + *number_of_params, + self.event_loop.clone(), + ctx, + self, + ); println!("Actor is dead"); } @@ -87,9 +114,15 @@ impl StreamHandler> for MyWs { Ok(ws::Message::Ping(msg)) => { println!("Ping message {:?}", msg); let handler_function = &self.router.get("connect").unwrap().0; - let _number_of_params = &self.router.get("connect").unwrap().1; + let number_of_params = &self.router.get("connect").unwrap().1; println!("{:?}", handler_function); - execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self); + execute_ws_functionn( + handler_function, + *number_of_params, + self.event_loop.clone(), + ctx, + self, + ); ctx.pong(&msg) } @@ -101,16 +134,28 @@ impl StreamHandler> for MyWs { Ok(ws::Message::Text(_text)) => { // need to also passs this text as a param let handler_function = &self.router.get("message").unwrap().0; - let _number_of_params = &self.router.get("message").unwrap().1; - execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self); + let number_of_params = &self.router.get("message").unwrap().1; + execute_ws_functionn( + handler_function, + *number_of_params, + self.event_loop.clone(), + ctx, + self, + ); } Ok(ws::Message::Binary(bin)) => ctx.binary(bin), Ok(ws::Message::Close(_close_reason)) => { println!("Socket was closed"); let handler_function = &self.router.get("close").expect("No close function").0; - let _number_of_params = &self.router.get("close").unwrap().1; - execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self); + let number_of_params = &self.router.get("close").unwrap().1; + execute_ws_functionn( + handler_function, + *number_of_params, + self.event_loop.clone(), + ctx, + self, + ); } _ => (), } @@ -128,10 +173,10 @@ pub async fn start_web_socket( MyWs { router, event_loop, + id: Uuid::new_v4(), }, &req, stream, ); - println!("{:?}", resp); resp }