Skip to content

Commit

Permalink
Make websocket id accessible (#173)
Browse files Browse the repository at this point in the history
* Make websocket id accessible

* Socket ids are now accessible
  • Loading branch information
sansyrox authored Mar 24, 2022
1 parent 020cfd0 commit f970d25
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 15 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ actix-http = "3.0.0-beta.8"
socket2 = { version = "0.4.1", features = ["all"] }
actix = "0.12.0"
actix-web-actors = "4.0.0-beta.1"
uuid = { version = "0.8", features = ["serde", "v4"] }

[package.metadata.maturin]
name = "robyn"
3 changes: 2 additions & 1 deletion integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


@websocket.on("message")
async def connect():
async def connect(websocket_id):
print(websocket_id)
global i
i += 1
if i == 0:
Expand Down
73 changes: 59 additions & 14 deletions src/web_socket_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ use actix_web::{web, Error, HttpRequest, HttpResponse};
use actix_web_actors::ws;
use actix_web_actors::ws::WebsocketContext;
use pyo3::prelude::*;
use uuid::Uuid;

use std::collections::HashMap;
use std::sync::Arc;

/// Define HTTP actor
#[derive(Clone)]
struct MyWs {
id: Uuid,
router: HashMap<String, (PyFunction, u8)>,
// can probably try removing arc from here
// and use clone_ref()
Expand All @@ -21,22 +23,35 @@ struct MyWs {

fn execute_ws_functionn(
handler_function: &PyFunction,
number_of_params: u8,
event_loop: Arc<PyObject>,
ctx: &mut ws::WebsocketContext<MyWs>,
ws: &MyWs,
// add number of params here
) {
match handler_function {
PyFunction::SyncFunction(handler) => Python::with_gil(|py| {
let handler = handler.as_ref(py);
// call execute function
let op = handler.call0().unwrap();
let op: &str = op.extract().unwrap();
let op: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
1 => handler.call1((ws.id.to_string(),)),
// this is done to accomodate any future params
2_u8..=u8::MAX => handler.call1((ws.id.to_string(),)),
};

let op: &str = op.unwrap().extract().unwrap();
ctx.text(op);
}),
PyFunction::CoRoutine(handler) => {
let fut = Python::with_gil(|py| {
let handler = handler.as_ref(py);
let coro = handler.call0().unwrap();
let coro = match number_of_params {
0 => handler.call0().unwrap(),
1 => handler.call1((ws.id.to_string(),)).unwrap(),
// this is done to accomodate any future params
2_u8..=u8::MAX => handler.call1((ws.id.to_string(),)).unwrap(),
};
pyo3_asyncio::into_future_with_loop((*(event_loop.clone())).as_ref(py), coro)
.unwrap()
});
Expand All @@ -61,16 +76,28 @@ impl Actor for MyWs {

fn started(&mut self, ctx: &mut WebsocketContext<Self>) {
let handler_function = &self.router.get("connect").unwrap().0;
let _number_of_params = &self.router.get("connect").unwrap().1;
execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self);
let number_of_params = &self.router.get("connect").unwrap().1;
execute_ws_functionn(
handler_function,
*number_of_params,
self.event_loop.clone(),
ctx,
self,
);

println!("Actor is alive");
}

fn stopped(&mut self, ctx: &mut WebsocketContext<Self>) {
let handler_function = &self.router.get("close").expect("No close function").0;
let _number_of_params = &self.router.get("close").unwrap().1;
execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self);
let number_of_params = &self.router.get("close").unwrap().1;
execute_ws_functionn(
handler_function,
*number_of_params,
self.event_loop.clone(),
ctx,
self,
);

println!("Actor is dead");
}
Expand All @@ -87,9 +114,15 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for MyWs {
Ok(ws::Message::Ping(msg)) => {
println!("Ping message {:?}", msg);
let handler_function = &self.router.get("connect").unwrap().0;
let _number_of_params = &self.router.get("connect").unwrap().1;
let number_of_params = &self.router.get("connect").unwrap().1;
println!("{:?}", handler_function);
execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self);
execute_ws_functionn(
handler_function,
*number_of_params,
self.event_loop.clone(),
ctx,
self,
);
ctx.pong(&msg)
}

Expand All @@ -101,16 +134,28 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for MyWs {
Ok(ws::Message::Text(_text)) => {
// need to also passs this text as a param
let handler_function = &self.router.get("message").unwrap().0;
let _number_of_params = &self.router.get("message").unwrap().1;
execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self);
let number_of_params = &self.router.get("message").unwrap().1;
execute_ws_functionn(
handler_function,
*number_of_params,
self.event_loop.clone(),
ctx,
self,
);
}

Ok(ws::Message::Binary(bin)) => ctx.binary(bin),
Ok(ws::Message::Close(_close_reason)) => {
println!("Socket was closed");
let handler_function = &self.router.get("close").expect("No close function").0;
let _number_of_params = &self.router.get("close").unwrap().1;
execute_ws_functionn(handler_function, self.event_loop.clone(), ctx, self);
let number_of_params = &self.router.get("close").unwrap().1;
execute_ws_functionn(
handler_function,
*number_of_params,
self.event_loop.clone(),
ctx,
self,
);
}
_ => (),
}
Expand All @@ -128,10 +173,10 @@ pub async fn start_web_socket(
MyWs {
router,
event_loop,
id: Uuid::new_v4(),
},
&req,
stream,
);
println!("{:?}", resp);
resp
}

0 comments on commit f970d25

Please sign in to comment.