Skip to content

Commit

Permalink
Pass headers in websocket connection (#1299)
Browse files Browse the repository at this point in the history
Co-authored-by: ruthvik125 <[email protected]>
  • Loading branch information
twitu and ruthvik125 authored Oct 23, 2023
1 parent 1ac4da2 commit 773b9d7
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 67 deletions.
3 changes: 2 additions & 1 deletion nautilus_core/network/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use http::{HttpClient, HttpMethod, HttpResponse};
use pyo3::prelude::*;
use ratelimiter::quota::Quota;
use socket::{SocketClient, SocketConfig};
use websocket::WebSocketClient;
use websocket::{WebSocketClient, WebSocketConfig};

/// Loaded as nautilus_pyo3.network
#[pymodule]
Expand All @@ -33,6 +33,7 @@ pub fn network(_: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<Quota>()?;
m.add_class::<HttpResponse>()?;
m.add_class::<WebSocketClient>()?;
m.add_class::<WebSocketConfig>()?;
m.add_class::<SocketClient>()?;
m.add_class::<SocketConfig>()?;
Ok(())
Expand Down
170 changes: 124 additions & 46 deletions nautilus_core/network/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
// limitations under the License.
// -------------------------------------------------------------------------------------------------

use std::{sync::Arc, time::Duration};
use std::{str::FromStr, sync::Arc, time::Duration};

use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use hyper::header::HeaderName;
use nautilus_core::python::to_pyruntime_err;
use pyo3::{exceptions::PyException, prelude::*, types::PyBytes, PyObject, Python};
use tokio::{net::TcpStream, sync::Mutex, task, time::sleep};
use tokio_tungstenite::{
connect_async,
tungstenite::{Error, Message},
tungstenite::{client::IntoClientRequest, http::HeaderValue, Error, Message},
MaybeTlsStream, WebSocketStream,
};
use tracing::{debug, error};
Expand All @@ -34,6 +35,36 @@ type SharedMessageWriter =
Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>;
type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;

#[derive(Debug, Clone)]
#[cfg_attr(
feature = "python",
pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
)]
pub struct WebSocketConfig {
url: String,
handler: PyObject,
headers: Vec<(String, String)>,
heartbeat: Option<u64>,
}

#[pymethods]
impl WebSocketConfig {
#[new]
fn new(
url: String,
handler: PyObject,
headers: Vec<(String, String)>,
heartbeat: Option<u64>,
) -> Self {
Self {
url,
handler,
headers,
heartbeat,
}
}
}

/// `WebSocketClient` connects to a websocket server to read and send messages.
///
/// The client is opinionated about how messages are read and written. It
Expand All @@ -50,44 +81,53 @@ type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
/// It's preferable to set the duration slightly lower - heartbeat more
/// frequently - than the required amount.
struct WebSocketClientInner {
config: WebSocketConfig,
read_task: task::JoinHandle<()>,
heartbeat_task: Option<task::JoinHandle<()>>,
writer: SharedMessageWriter,
url: String,
handler: PyObject,
heartbeat: Option<u64>,
}

impl WebSocketClientInner {
/// Create an inner websocket client.
pub async fn connect_url(
url: &str,
handler: PyObject,
heartbeat: Option<u64>,
) -> Result<Self, Error> {
let (writer, reader) = Self::connect_with_server(url).await?;
pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
let WebSocketConfig {
url,
handler,
heartbeat,
headers,
} = &config;
let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
let writer = Arc::new(Mutex::new(writer));
let handler_clone = handler.clone();

// Keep receiving messages from socket and pass them as arguments to handler
let read_task = Self::spawn_read_task(reader, handler);
let read_task = Self::spawn_read_task(reader, handler.clone());

let heartbeat_task = Self::spawn_heartbeat_task(heartbeat, writer.clone());
let heartbeat_task = Self::spawn_heartbeat_task(*heartbeat, writer.clone());

Ok(Self {
config,
read_task,
heartbeat_task,
writer,
url: url.to_string(),
handler: handler_clone,
heartbeat,
})
}

/// Connects with the server creating a tokio-tungstenite websocket stream.
#[inline]
pub async fn connect_with_server(url: &str) -> Result<(MessageWriter, MessageReader), Error> {
connect_async(url).await.map(|resp| resp.0.split())
pub async fn connect_with_server(
url: &str,
headers: Vec<(String, String)>,
) -> Result<(MessageWriter, MessageReader), Error> {
let mut request = url.into_client_request()?;
let req_headers = request.headers_mut();

headers.into_iter().for_each(|(key, val)| {
let header_value = HeaderValue::from_str(&val).unwrap();
let header_name = HeaderName::from_str(&key).unwrap();
req_headers.insert(header_name, header_value);
});

connect_async(request).await.map(|resp| resp.0.split())
}

/// Optionally spawn a hearbeat task to periodically ping the server.
Expand Down Expand Up @@ -188,13 +228,15 @@ impl WebSocketClientInner {
/// Make a new connection with server. Use the new read and write halves
/// to update self writer and read and heartbeat tasks.
pub async fn reconnect(&mut self) -> Result<(), Error> {
let (new_writer, reader) = Self::connect_with_server(&self.url).await?;
let (new_writer, reader) =
Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
let mut guard = self.writer.lock().await;
*guard = new_writer;
drop(guard);

self.read_task = Self::spawn_read_task(reader, self.handler.clone());
self.heartbeat_task = Self::spawn_heartbeat_task(self.heartbeat, self.writer.clone());
self.read_task = Self::spawn_read_task(reader, self.config.handler.clone());
self.heartbeat_task =
Self::spawn_heartbeat_task(self.config.heartbeat, self.writer.clone());

Ok(())
}
Expand Down Expand Up @@ -242,14 +284,12 @@ impl WebSocketClient {
/// Creates an inner client and controller task to reconnect or disconnect
/// the client. Also assumes ownership of writer from inner client
pub async fn connect(
url: &str,
handler: PyObject,
heartbeat: Option<u64>,
config: WebSocketConfig,
post_connection: Option<PyObject>,
post_reconnection: Option<PyObject>,
post_disconnection: Option<PyObject>,
) -> Result<Self, Error> {
let inner = WebSocketClientInner::connect_url(url, handler, heartbeat).await?;
let inner = WebSocketClientInner::connect_url(config).await?;
let writer = inner.writer.clone();
let disconnect_mode = Arc::new(Mutex::new(false));
let controller_task = Self::spawn_controller_task(
Expand All @@ -264,7 +304,7 @@ impl WebSocketClient {
Ok(_) => debug!("Called post_connection handler"),
Err(e) => error!("Error calling post_connection handler: {e}"),
});
}
};

Ok(Self {
writer,
Expand Down Expand Up @@ -364,19 +404,15 @@ impl WebSocketClient {
#[staticmethod]
#[pyo3(name = "connect")]
fn py_connect(
url: String,
handler: PyObject,
heartbeat: Option<u64>,
config: WebSocketConfig,
post_connection: Option<PyObject>,
post_reconnection: Option<PyObject>,
post_disconnection: Option<PyObject>,
py: Python<'_>,
) -> PyResult<&PyAny> {
pyo3_asyncio::tokio::future_into_py(py, async move {
Self::connect(
&url,
handler,
heartbeat,
config,
post_connection,
post_reconnection,
post_disconnection,
Expand Down Expand Up @@ -472,28 +508,69 @@ mod tests {
task::{self, JoinHandle},
time::{sleep, Duration},
};
use tokio_tungstenite::accept_async;
use tokio_tungstenite::{
accept_hdr_async,
tungstenite::{
handshake::server::{self, Callback},
http::HeaderValue,
},
};
use tracing::debug;
use tracing_test::traced_test;

use crate::websocket::WebSocketClient;
use crate::websocket::{WebSocketClient, WebSocketConfig};

struct TestServer {
task: JoinHandle<()>,
port: u16,
}

#[derive(Debug, Clone)]
struct TestCallback {
key: String,
value: HeaderValue,
}

impl Callback for TestCallback {
fn on_request(
self,
request: &server::Request,
response: server::Response,
) -> Result<server::Response, server::ErrorResponse> {
let _ = response;
let value = request.headers().get(&self.key);
assert!(value.is_some());

match request.headers().get(&self.key) {
Some(value) => {
assert_eq!(value, self.value);
()
}
_ => (),
}

Ok(response)
}
}

impl TestServer {
async fn setup() -> Self {
async fn setup(key: String, value: String) -> Self {
let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = TcpListener::local_addr(&server).unwrap().port();

let test_call_back = TestCallback {
key,
value: HeaderValue::from_str(&value).unwrap(),
};

// Setup test server
let task = task::spawn(async move {
// keep accepting connections
loop {
let (conn, _) = server.accept().await.unwrap();
let mut websocket = accept_async(conn).await.unwrap();
let mut websocket = accept_hdr_async(conn, test_call_back.clone())
.await
.unwrap();

task::spawn(async move {
loop {
Expand Down Expand Up @@ -529,9 +606,11 @@ mod tests {

const N: usize = 10;
let mut success_count = 0;
let header_key = "hello-custom-key".to_string();
let header_value = "hello-custom-value".to_string();

// Initialize test server
let server = TestServer::setup().await;
let server = TestServer::setup(header_key.clone(), header_value.clone()).await;

// Create counter class and handler that increments it
let (counter, handler) = Python::with_gil(|py| {
Expand Down Expand Up @@ -561,16 +640,15 @@ counter = Counter()",
(counter, handler)
});

let client = WebSocketClient::connect(
&format!("ws://127.0.0.1:{}", server.port),
let config = WebSocketConfig::new(
format!("ws://127.0.0.1:{}", server.port),
handler.clone(),
vec![(header_key, header_value)],
None,
None,
None,
None,
)
.await
.unwrap();
);
let client = WebSocketClient::connect(config, None, None, None)
.await
.unwrap();

// Send messages that increment the count
for _ in 0..N {
Expand Down
15 changes: 12 additions & 3 deletions nautilus_trader/core/nautilus_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ from nautilus_trader.core.data import Data


# Python Interface typing:
# We will eventually separate these into separate .pyi files per module, for now this at least
# provides import resolution as well as docstring.
# We will eventually separate these into a .pyi file per module, for now this at least
# provides import resolution as well as docstrings.

###################################################################################################
# Core
Expand Down Expand Up @@ -639,12 +639,21 @@ class Quota:
@classmethod
def rate_per_hour(cls, max_burst: int) -> Quota: ...

class WebSocketConfig:
def __init__(
self,
url: str,
handler: Callable[..., Any],
headers: list[tuple[str, str]],
heartbeat: int | None = None,
) -> None: ...

class WebSocketClient:
@classmethod
def connect(
cls,
url: str,
handler: Callable[[Any], Any],
handler: Callable[..., Any],
heartbeat: int | None = None,
post_connection: Callable[..., None] | None = None,
post_reconnection: Callable[..., None] | None = None,
Expand Down
Loading

0 comments on commit 773b9d7

Please sign in to comment.