Skip to content

Commit

Permalink
Continue core MessageBus with FFI
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Nov 13, 2023
1 parent cb4d4e0 commit 82956e1
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 85 deletions.
258 changes: 249 additions & 9 deletions nautilus_core/common/src/ffi/msgbus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ use std::{
str::FromStr,
};

use nautilus_core::ffi::{
cvec::CVec,
string::{cstr_to_string, cstr_to_ustr, optional_cstr_to_string},
use nautilus_core::{
ffi::{
cvec::CVec,
string::{cstr_to_string, cstr_to_ustr, optional_cstr_to_string},
},
uuid::UUID4,
};
use nautilus_model::identifiers::trader_id::TraderId;
use pyo3::{
Expand Down Expand Up @@ -77,7 +80,7 @@ pub unsafe extern "C" fn msgbus_new(
}

#[no_mangle]
pub extern "C" fn msgbus_endpoints(bus: MessageBus_API) -> *const ffi::PyObject {
pub extern "C" fn msgbus_endpoints(bus: MessageBus_API) -> *mut ffi::PyObject {
Python::with_gil(|py| -> Py<PyList> {
let endpoints: Vec<Py<PyString>> = bus
.endpoints()
Expand All @@ -90,7 +93,7 @@ pub extern "C" fn msgbus_endpoints(bus: MessageBus_API) -> *const ffi::PyObject
}

#[no_mangle]
pub extern "C" fn msgbus_topics(bus: MessageBus_API) -> *const ffi::PyObject {
pub extern "C" fn msgbus_topics(bus: MessageBus_API) -> *mut ffi::PyObject {
Python::with_gil(|py| -> Py<PyList> {
let topics: Vec<Py<PyString>> = bus
.endpoints()
Expand All @@ -102,6 +105,18 @@ pub extern "C" fn msgbus_topics(bus: MessageBus_API) -> *const ffi::PyObject {
.as_ptr()
}

/// # Safety
///
/// - Assumes `pattern_ptr` is a valid C string pointer.
#[no_mangle]
pub unsafe extern "C" fn msgbus_has_subscribers(
bus: MessageBus_API,
pattern_ptr: *const c_char,
) -> u8 {
let pattern = cstr_to_ustr(pattern_ptr);
bus.has_subscribers(pattern.as_str()) as u8
}

/// # Safety
///
/// - Assumes `handler_id_ptr` is a valid C string pointer.
Expand Down Expand Up @@ -130,7 +145,7 @@ pub unsafe extern "C" fn msgbus_subscribe(
pub unsafe extern "C" fn msgbus_get_endpoint(
bus: MessageBus_API,
endpoint_ptr: *const c_char,
) -> *const ffi::PyObject {
) -> *mut ffi::PyObject {
let endpoint = cstr_to_ustr(endpoint_ptr);

match bus.get_endpoint(&endpoint) {
Expand All @@ -148,7 +163,7 @@ pub unsafe extern "C" fn msgbus_get_matching_callables(
pattern_ptr: *const c_char,
) -> CVec {
let pattern = cstr_to_ustr(pattern_ptr);
let subs: Vec<&Subscription> = bus.get_matching_subscriptions(&pattern);
let subs: Vec<&Subscription> = bus.matching_subscriptions(&pattern);

subs.iter()
.map(|s| s.handler.py_callback.unwrap())
Expand All @@ -165,6 +180,42 @@ pub extern "C" fn vec_pycallable_drop(v: CVec) {
drop(data); // Memory freed here
}

/// # Safety
///
/// - Assumes `pattern_ptr` is a valid C string pointer.
#[no_mangle]
pub unsafe extern "C" fn msgbus_request_handler(
mut bus: MessageBus_API,
endpoint_ptr: *const c_char,
request_id: UUID4,
) -> *mut ffi::PyObject {
let endpoint = cstr_to_ustr(endpoint_ptr);
let handler = bus.request_handler(&endpoint, request_id);

if let Some(handler) = handler {
handler.py_callback.unwrap().ptr
} else {
ffi::Py_None()
}
}

/// # Safety
///
/// - Assumes `pattern_ptr` is a valid C string pointer.
#[no_mangle]
pub unsafe extern "C" fn msgbus_response_handler(
mut bus: MessageBus_API,
correlation_id: &UUID4,
) -> *mut ffi::PyObject {
let handler = bus.response_handler(correlation_id);

if let Some(handler) = handler {
handler.py_callback.unwrap().ptr
} else {
ffi::Py_None()
}
}

/// # Safety
///
/// - Assumes `topic_ptr` is a valid C string pointer.
Expand All @@ -184,15 +235,28 @@ pub unsafe extern "C" fn msgbus_is_matching(
// Tests
////////////////////////////////////////////////////////////////////////////////
#[cfg(test)]
mod tests {
use std::rc::Rc;
mod ffi_tests {
use std::{ffi::CString, ptr, rc::Rc};

use nautilus_core::message::Message;
use pyo3::FromPyPointer;
use rstest::*;
use ustr::Ustr;

use super::*;
use crate::handlers::MessageHandler;
// Helper function to create a MessageHandler with a PyCallableWrapper for testing
fn create_handler() -> MessageHandler {
let py_callable_ptr = ptr::null_mut(); // Replace with an actual PyObject pointer if needed
let handler_id = Ustr::from("test_handler");
MessageHandler::new(
handler_id,
Some(PyCallableWrapper {
ptr: py_callable_ptr,
}),
None,
)
}

#[rstest]
fn test_subscribe_rust_handler() {
Expand All @@ -210,4 +274,180 @@ mod tests {
assert!(msgbus.has_subscribers(&topic));
assert_eq!(msgbus.topics(), vec![topic]);
}

#[test]
fn test_msgbus_new() {
let trader_id = TraderId::from_str("trader-001").unwrap();
let name = CString::new("Test MessageBus").unwrap();

// Create a new MessageBus using FFI
let bus = unsafe { msgbus_new(trader_id.to_string().as_ptr() as *const i8, name.as_ptr()) };

// Verify that the trader ID and name are set correctly
assert_eq!(bus.trader_id.to_string(), "trader-001");
assert_eq!(bus.name, "Test MessageBus");
}

#[ignore]
#[test]
fn test_msgbus_endpoints() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let endpoint1 = "endpoint1";
let endpoint2 = "endpoint2";

// Register endpoints
bus.register(endpoint1, create_handler());
bus.register(endpoint2, create_handler());

// Call msgbus_endpoints to get endpoints as a Python list
let py_list = msgbus_endpoints(MessageBus_API(Box::new(bus)));

// Convert the Python list to a Vec of strings
let endpoints: Vec<String> = Python::with_gil(|py| {
let py_list = unsafe { PyList::from_owned_ptr(py, py_list) };
py_list
.into_iter()
.map(|item| item.extract::<String>().unwrap())
.collect()
});

// Verify that the endpoints are correctly retrieved
assert_eq!(endpoints.len(), 2);
assert!(endpoints.contains(&endpoint1.to_string()));
assert!(endpoints.contains(&endpoint2.to_string()));
}

#[ignore]
#[test]
fn test_msgbus_topics() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let topic1 = "topic1";
let topic2 = "topic2";

// Subscribe to topics
bus.subscribe(topic1, create_handler(), None);
bus.subscribe(topic2, create_handler(), None);

// Call msgbus_topics to get topics as a Python list
let py_list = msgbus_topics(MessageBus_API(Box::new(bus)));

// Convert the Python list to a Vec of strings
let topics: Vec<String> = Python::with_gil(|py| {
let py_list = unsafe { PyList::from_owned_ptr(py, py_list) };
py_list
.into_iter()
.map(|item| item.extract::<String>().unwrap())
.collect()
});

// Verify that the topics are correctly retrieved
assert_eq!(topics.len(), 2);
assert!(topics.contains(&topic1.to_string()));
assert!(topics.contains(&topic2.to_string()));
}

#[ignore]
#[test]
fn test_msgbus_subscribe() {
let bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let topic = "test-topic";

// Subscribe using FFI
unsafe {
let topic_ptr = CString::new(topic).unwrap().clone().as_ptr();
let handler_id_ptr = CString::new("handler-001").unwrap().clone().as_ptr();
msgbus_subscribe(
MessageBus_API(Box::new(bus.clone())),
topic_ptr,
handler_id_ptr,
ptr::null_mut(),
1,
);

// Verify that the subscription is added
assert!(msgbus_has_subscribers(MessageBus_API(Box::new(bus)), topic_ptr) != 0);
}
}

#[ignore]
#[test]
fn test_msgbus_get_endpoint() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let endpoint = "test-endpoint";
let handler = create_handler();

// Register an endpoint
bus.register(endpoint, handler.clone());

// Call msgbus_get_endpoint to get the handler as a PyObject
let py_callable = unsafe {
let endpoint_ptr = CString::new(endpoint).unwrap().clone().as_ptr();
msgbus_get_endpoint(MessageBus_API(Box::new(bus)), endpoint_ptr)
};

// Verify that the PyObject pointer matches the registered handler's PyObject pointer
assert_eq!(py_callable, handler.py_callback.unwrap().ptr);
}

#[ignore]
#[test]
fn test_msgbus_request_handler() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let endpoint = "test-endpoint";
let request_id = UUID4::new();

// Register an endpoint
bus.register(endpoint, create_handler());

// Call msgbus_request_handler to get the handler as a PyObject
let py_callable = unsafe {
let endpoint_ptr = CString::new(endpoint).unwrap().clone().as_ptr();
msgbus_request_handler(
MessageBus_API(Box::new(bus.clone())),
endpoint_ptr,
request_id,
)
};

// Verify that the PyObject pointer matches the registered handler's PyObject pointer
assert_eq!(
py_callable,
bus.endpoints[&Ustr::from(endpoint)]
.py_callback
.unwrap()
.ptr
);
}

#[ignore]
#[test]
fn test_msgbus_response_handler() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let correlation_id = UUID4::new();

// Register a response handler
let handler = create_handler();
bus.correlation_index
.insert(correlation_id.clone(), handler.clone());

// Call msgbus_response_handler to get the handler as a PyObject
let py_callable =
unsafe { msgbus_response_handler(MessageBus_API(Box::new(bus)), &correlation_id) };

assert_eq!(py_callable, handler.py_callback.unwrap().ptr);
}

#[test]
fn test_msgbus_is_matching() {
let topic = "data.quotes.BINANCE";
let pattern = "data.*.BINANCE";

let result = unsafe {
let topic_ptr = CString::new(topic).unwrap().clone().as_ptr();
let pattern_ptr = CString::new(pattern).unwrap().clone().as_ptr();
msgbus_is_matching(topic_ptr, pattern_ptr)
};

assert_eq!(result, 1);
}
}
19 changes: 17 additions & 2 deletions nautilus_core/common/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
// -------------------------------------------------------------------------------------------------

use std::rc::Rc;
use std::{fmt, rc::Rc};

use nautilus_core::message::Message;
use pyo3::{ffi, prelude::*, AsPyPointer};
Expand All @@ -22,7 +22,7 @@ use ustr::Ustr;
use crate::timer::TimeEvent;

#[repr(C)]
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub struct PyCallableWrapper {
pub ptr: *mut ffi::PyObject,
}
Expand Down Expand Up @@ -55,6 +55,21 @@ impl MessageHandler {
}
}

impl PartialEq for MessageHandler {
fn eq(&self, other: &Self) -> bool {
self.handler_id == other.handler_id
}
}

impl fmt::Debug for MessageHandler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(stringify!(MessageHandler))
.field("handler_id", &self.handler_id)
.field("py_callback", &self.py_callback)
.finish()
}
}

// TODO: Make this more generic
#[derive(Clone)]
pub struct EventHandler {
Expand Down
Loading

0 comments on commit 82956e1

Please sign in to comment.