Skip to content

Commit

Permalink
Refine raw socket error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Dec 18, 2023
1 parent b3a8fc5 commit 88a3628
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 34 deletions.
53 changes: 27 additions & 26 deletions nautilus_core/network/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
// -------------------------------------------------------------------------------------------------

use std::{io, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};

use nautilus_core::python::to_pyruntime_err;
use pyo3::{prelude::*, PyObject, Python};
Expand All @@ -26,7 +26,7 @@ use tokio::{
};
use tokio_tungstenite::{
tls::tcp_tls,
tungstenite::{client::IntoClientRequest, stream::Mode},
tungstenite::{client::IntoClientRequest, stream::Mode, Error},
MaybeTlsStream,
};
use tracing::{debug, error};
Expand Down Expand Up @@ -102,15 +102,15 @@ struct SocketClientInner {
}

impl SocketClientInner {
pub async fn connect_url(config: SocketConfig) -> io::Result<Self> {
pub async fn connect_url(config: SocketConfig) -> Result<Self, Error> {
let SocketConfig {
url,
mode,
heartbeat,
suffix,
handler,
} = &config;
let (reader, writer) = Self::tls_connect_with_server(url, *mode).await;
let (reader, writer) = Self::tls_connect_with_server(url, *mode).await?;
let shared_writer = Arc::new(Mutex::new(writer));

// Keep receiving messages from socket pass them as arguments to handler
Expand All @@ -128,16 +128,15 @@ impl SocketClientInner {
})
}

// TODO: handle unwraps properly
pub async fn tls_connect_with_server(url: &str, mode: Mode) -> (TcpReader, TcpWriter) {
pub async fn tls_connect_with_server(
url: &str,
mode: Mode,
) -> Result<(TcpReader, TcpWriter), Error> {
debug!("Connecting to server");
let stream = TcpStream::connect(url).await.unwrap();
let stream = TcpStream::connect(url).await?;
debug!("Making TLS connection");
let request = url.into_client_request().unwrap();
tcp_tls(&request, mode, stream, None)
.await
.map(split)
.unwrap()
let request = url.into_client_request()?;
tcp_tls(&request, mode, stream, None).await.map(split)
}

#[must_use]
Expand Down Expand Up @@ -217,7 +216,7 @@ impl SocketClientInner {
/// the connection might still be alive for some time before terminating.
/// Closing the connection is an async call which cannot be done by the
/// drop method so it must be done explicitly.
pub async fn shutdown(&mut self) {
pub async fn shutdown(&mut self) -> Result<(), std::io::Error> {
debug!("Abort read task");
if !self.read_task.is_finished() {
self.read_task.abort();
Expand All @@ -233,8 +232,7 @@ impl SocketClientInner {

debug!("Shutdown writer");
let mut writer = self.writer.lock().await;
writer.shutdown().await.unwrap();
debug!("Closed connection");
writer.shutdown().await
}

/// Reconnect with server.
Expand All @@ -243,7 +241,7 @@ impl SocketClientInner {
/// to update the shared writer and the read and heartbeat tasks.
///
/// TODO: fix error type
pub async fn reconnect(&mut self) -> Result<(), String> {
pub async fn reconnect(&mut self) -> Result<(), Error> {
let SocketConfig {
url,
mode,
Expand All @@ -252,7 +250,7 @@ impl SocketClientInner {
handler,
} = &self.config;
debug!("Reconnecting client");
let (reader, new_writer) = Self::tls_connect_with_server(url, *mode).await;
let (reader, new_writer) = Self::tls_connect_with_server(url, *mode).await?;

debug!("Use new writer end");
let mut guard = self.writer.lock().await;
Expand Down Expand Up @@ -311,7 +309,7 @@ impl SocketClient {
post_connection: Option<PyObject>,
post_reconnection: Option<PyObject>,
post_disconnection: Option<PyObject>,
) -> io::Result<Self> {
) -> Result<Self, Error> {
let suffix = config.suffix.clone();
let inner = SocketClientInner::connect_url(config).await?;
let writer = inner.writer.clone();
Expand Down Expand Up @@ -346,11 +344,10 @@ impl SocketClient {
*self.disconnect_mode.lock().await = true;
}

// TODO: fix error type
pub async fn send_bytes(&self, data: &[u8]) {
pub async fn send_bytes(&self, data: &[u8]) -> Result<(), std::io::Error> {
let mut writer = self.writer.lock().await;
writer.write_all(data).await.unwrap();
writer.write_all(&self.suffix).await.unwrap();
writer.write_all(data).await?;
writer.write_all(&self.suffix).await
}

#[must_use]
Expand Down Expand Up @@ -394,7 +391,11 @@ impl SocketClient {
},
(true, true) => {
debug!("Shutting down inner client");
inner.shutdown().await;
match inner.shutdown().await {
Ok(_) => debug!("Closed connection"),
Err(e) => error!("Error on `shutdown`: {e}"),
}

if let Some(ref handler) = post_disconnection {
Python::with_gil(|py| match handler.call0(py) {
Ok(_) => debug!("Called `post_disconnection` handler"),
Expand Down Expand Up @@ -632,7 +633,7 @@ counter = Counter()",

// Send messages that increment the count
for _ in 0..N {
client.send_bytes(b"ping".as_slice()).await;
let _ = client.send_bytes(b"ping".as_slice()).await;
}

sleep(Duration::from_secs(1)).await;
Expand All @@ -655,11 +656,11 @@ counter = Counter()",

// close the connection and wait
// client should reconnect automatically
client.send_bytes(b"close".as_slice()).await;
let _ = client.send_bytes(b"close".as_slice()).await;
sleep(Duration::from_secs(2)).await;

for _ in 0..N {
client.send_bytes(b"ping".as_slice()).await;
let _ = client.send_bytes(b"ping".as_slice()).await;
}

// Check count is same as number messages sent
Expand Down
2 changes: 1 addition & 1 deletion nautilus_trader/core/nautilus_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ class SocketConfig:
self,
url: str,
ssl: bool,
suffix: list[int],
suffix: bytes,
handler: Callable[..., Any],
heartbeat: tuple[int, list[int]] | None = None,
) -> None: ...
Expand Down
13 changes: 6 additions & 7 deletions tests/integration_tests/network/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@
from nautilus_trader.test_kit.functions import eventually


pytestmark = pytest.mark.skip(reason="WIP")


def _config(socket_server, handler):
host, port = socket_server
server_url = f"{host}:{port}"
return SocketConfig(
url=server_url,
handler=handler,
ssl=False,
handler=handler,
suffix=b"\r\n",
)

Expand All @@ -46,8 +43,8 @@ async def test_connect_and_disconnect(socket_server):

# Act, Assert
await eventually(lambda: client.is_alive)
await client.disconnect()
# await eventually(lambda: not client.is_alive)
client.disconnect()
# await eventually(lambda: not client.is_alive) # Investigate why client is staying alive?


@pytest.mark.asyncio()
Expand All @@ -64,7 +61,9 @@ async def test_client_send_recv(socket_server):
for _ in range(num_messages):
await client.send(b"Hello")
await asyncio.sleep(0.1)
await client.disconnect()

client.disconnect()
# await eventually(lambda: not client.is_alive) # Investigate why client is staying alive?

# Assert
await eventually(lambda: store == [b"connected"] + [b"hello"] * 2)
Expand Down

0 comments on commit 88a3628

Please sign in to comment.