Skip to content

Commit

Permalink
Continue core MessageBus with FFI
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Nov 13, 2023
1 parent 9789a99 commit 626b29b
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 233 deletions.
313 changes: 90 additions & 223 deletions nautilus_core/common/src/ffi/msgbus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,44 @@ pub unsafe extern "C" fn msgbus_has_subscribers(
bus.has_subscribers(pattern.as_str()) as u8
}

/// # Safety
///
/// - Assumes `topic_ptr` is a valid C string pointer.
/// - Assumes `handler_id_ptr` is a valid C string pointer.
/// - Assumes `py_callable_ptr` points to a valid Python callable.
#[no_mangle]
pub unsafe extern "C" fn msgbus_is_subscribed(
bus: MessageBus_API,
topic_ptr: *const c_char,
handler_id_ptr: *const c_char,
py_callable_ptr: *mut ffi::PyObject,
) -> u8 {
let topic = cstr_to_ustr(topic_ptr);
let handler_id = cstr_to_ustr(handler_id_ptr);
let py_callable = PyCallableWrapper {
ptr: py_callable_ptr,
};
let handler = MessageHandler::new(handler_id, Some(py_callable), None);
bus.is_subscribed(topic.as_str(), handler) as u8
}

/// # Safety
///
/// - Assumes `endpoint_ptr` is a valid C string pointer.
#[no_mangle]
pub unsafe extern "C" fn msgbus_is_regsitered(
bus: MessageBus_API,
endpoint_ptr: *const c_char,
) -> u8 {
let endpoint = cstr_to_string(endpoint_ptr);
bus.is_registered(&endpoint) as u8
}

#[no_mangle]
pub extern "C" fn msgbus_is_pending_request(bus: MessageBus_API, request_id: &UUID4) -> u8 {
bus.is_pending_response(request_id) as u8
}

/// # Safety
///
/// - Assumes `handler_id_ptr` is a valid C string pointer.
Expand All @@ -138,6 +176,58 @@ pub unsafe extern "C" fn msgbus_subscribe(
bus.subscribe(&topic, handler, Some(priority));
}

/// # Safety
///
/// - Assumes `handler_id_ptr` is a valid C string pointer.
#[no_mangle]
pub unsafe extern "C" fn msgbus_unsubscribe(
mut bus: MessageBus_API,
topic_ptr: *const c_char,
handler_id_ptr: *const c_char,
py_callable_ptr: *mut ffi::PyObject,
) {
let topic = cstr_to_ustr(topic_ptr);
let handler_id = cstr_to_ustr(handler_id_ptr);
let py_callable = PyCallableWrapper {
ptr: py_callable_ptr,
};
let handler = MessageHandler::new(handler_id, Some(py_callable), None);

bus.unsubscribe(&topic, handler);
}

/// # Safety
///
/// - Assumes `endpoint_ptr` is a valid C string pointer.
/// - Assumes `handler_id_ptr` is a valid C string pointer.
/// - Assumes `py_callable_ptr` points to a valid Python callable.
#[no_mangle]
pub unsafe extern "C" fn msgbus_register(
mut bus: MessageBus_API,
endpoint_ptr: *const c_char,
handler_id_ptr: *const c_char,
py_callable_ptr: *mut ffi::PyObject,
) {
let endpoint = cstr_to_string(endpoint_ptr);
let handler_id = cstr_to_ustr(handler_id_ptr);
let wrapper = PyCallableWrapper {
ptr: py_callable_ptr,
};
let handler = MessageHandler::new(handler_id, Some(wrapper), None);
bus.register(&endpoint, handler)
}

/// # Safety
///
/// - Assumes `endpoint_ptr` is a valid C string pointer.
/// - Assumes `handler_id_ptr` is a valid C string pointer.
/// - Assumes `py_callable_ptr` points to a valid Python callable.
#[no_mangle]
pub unsafe extern "C" fn msgbus_deregister(mut bus: MessageBus_API, endpoint_ptr: *const c_char) {
let endpoint = cstr_to_string(endpoint_ptr);
bus.deregister(&endpoint)
}

/// # Safety
///
/// - Assumes `endpoint_ptr` is a valid C string pointer.
Expand Down Expand Up @@ -230,226 +320,3 @@ pub unsafe extern "C" fn msgbus_is_matching(

is_matching(&topic, &pattern) as u8
}

////////////////////////////////////////////////////////////////////////////////
// Tests
////////////////////////////////////////////////////////////////////////////////
#[cfg(test)]
mod ffi_tests {
use std::{ffi::CString, ptr, rc::Rc};

use nautilus_core::message::Message;
use pyo3::FromPyPointer;
use rstest::*;
use ustr::Ustr;

use super::*;
use crate::handlers::MessageHandler;
// Helper function to create a MessageHandler with a PyCallableWrapper for testing
fn create_handler() -> MessageHandler {
let py_callable_ptr = ptr::null_mut(); // Replace with an actual PyObject pointer if needed
let handler_id = Ustr::from("test_handler");
MessageHandler::new(
handler_id,
Some(PyCallableWrapper {
ptr: py_callable_ptr,
}),
None,
)
}

#[rstest]
fn test_subscribe_rust_handler() {
let trader_id = TraderId::from("trader-001");
let topic = "my-topic".to_string();

// TODO: Create a Python list and pass the message in a closure to the `append` method
let callback = Rc::new(|_m: Message| Python::with_gil(|_| {}));
let handler_id = Ustr::from("id_of_method");
let handler = MessageHandler::new(handler_id, None, Some(callback));

let mut msgbus = MessageBus::new(trader_id, None);
msgbus.subscribe(&topic, handler, None);

assert!(msgbus.has_subscribers(&topic));
assert_eq!(msgbus.topics(), vec![topic]);
}

#[ignore]
#[rstest]
fn test_msgbus_new() {
let trader_id = TraderId::from_str("trader-001").unwrap();
let name = CString::new("Test MessageBus").unwrap();

// Create a new MessageBus using FFI
let bus = unsafe { msgbus_new(trader_id.to_string().as_ptr() as *const i8, name.as_ptr()) };

// Verify that the trader ID and name are set correctly
assert_eq!(bus.trader_id.to_string(), "trader-001");
assert_eq!(bus.name, "Test MessageBus");
}

#[ignore]
#[rstest]
fn test_msgbus_endpoints() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let endpoint1 = "endpoint1";
let endpoint2 = "endpoint2";

// Register endpoints
bus.register(endpoint1, create_handler());
bus.register(endpoint2, create_handler());

// Call msgbus_endpoints to get endpoints as a Python list
let py_list = msgbus_endpoints(MessageBus_API(Box::new(bus)));

// Convert the Python list to a Vec of strings
let endpoints: Vec<String> = Python::with_gil(|py| {
let py_list = unsafe { PyList::from_owned_ptr(py, py_list) };
py_list
.into_iter()
.map(|item| item.extract::<String>().unwrap())
.collect()
});

// Verify that the endpoints are correctly retrieved
assert_eq!(endpoints.len(), 2);
assert!(endpoints.contains(&endpoint1.to_string()));
assert!(endpoints.contains(&endpoint2.to_string()));
}

#[ignore]
#[rstest]
fn test_msgbus_topics() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let topic1 = "topic1";
let topic2 = "topic2";

// Subscribe to topics
bus.subscribe(topic1, create_handler(), None);
bus.subscribe(topic2, create_handler(), None);

// Call msgbus_topics to get topics as a Python list
let py_list = msgbus_topics(MessageBus_API(Box::new(bus)));

// Convert the Python list to a Vec of strings
let topics: Vec<String> = Python::with_gil(|py| {
let py_list = unsafe { PyList::from_owned_ptr(py, py_list) };
py_list
.into_iter()
.map(|item| item.extract::<String>().unwrap())
.collect()
});

// Verify that the topics are correctly retrieved
assert_eq!(topics.len(), 2);
assert!(topics.contains(&topic1.to_string()));
assert!(topics.contains(&topic2.to_string()));
}

#[ignore]
#[rstest]
fn test_msgbus_subscribe() {
let bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let topic = "test-topic";

// Subscribe using FFI
unsafe {
let topic_ptr = CString::new(topic).unwrap().clone().as_ptr();
let handler_id_ptr = CString::new("handler-001").unwrap().clone().as_ptr();
msgbus_subscribe(
MessageBus_API(Box::new(bus.clone())),
topic_ptr,
handler_id_ptr,
ptr::null_mut(),
1,
);

// Verify that the subscription is added
assert!(msgbus_has_subscribers(MessageBus_API(Box::new(bus)), topic_ptr) != 0);
}
}

#[ignore]
#[rstest]
fn test_msgbus_get_endpoint() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let endpoint = "test-endpoint";
let handler = create_handler();

// Register an endpoint
bus.register(endpoint, handler.clone());

// Call msgbus_get_endpoint to get the handler as a PyObject
let py_callable = unsafe {
let endpoint_ptr = CString::new(endpoint).unwrap().clone().as_ptr();
msgbus_get_endpoint(MessageBus_API(Box::new(bus)), endpoint_ptr)
};

// Verify that the PyObject pointer matches the registered handler's PyObject pointer
assert_eq!(py_callable, handler.py_callback.unwrap().ptr);
}

#[ignore]
#[rstest]
fn test_msgbus_request_handler() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let endpoint = "test-endpoint";
let request_id = UUID4::new();

// Register an endpoint
bus.register(endpoint, create_handler());

// Call msgbus_request_handler to get the handler as a PyObject
let py_callable = unsafe {
let endpoint_ptr = CString::new(endpoint).unwrap().clone().as_ptr();
msgbus_request_handler(
MessageBus_API(Box::new(bus.clone())),
endpoint_ptr,
request_id,
)
};

// Verify that the PyObject pointer matches the registered handler's PyObject pointer
assert_eq!(
py_callable,
bus.endpoints[&Ustr::from(endpoint)]
.py_callback
.unwrap()
.ptr
);
}

#[ignore]
#[rstest]
fn test_msgbus_response_handler() {
let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None);
let correlation_id = UUID4::new();

// Register a response handler
let handler = create_handler();
bus.correlation_index
.insert(correlation_id.clone(), handler.clone());

// Call msgbus_response_handler to get the handler as a PyObject
let py_callable =
unsafe { msgbus_response_handler(MessageBus_API(Box::new(bus)), &correlation_id) };

assert_eq!(py_callable, handler.py_callback.unwrap().ptr);
}

#[ignore]
#[rstest]
fn test_msgbus_is_matching() {
let topic = "data.quotes.BINANCE";
let pattern = "data.*.BINANCE";

let result = unsafe {
let topic_ptr = CString::new(topic).unwrap().clone().as_ptr();
let pattern_ptr = CString::new(pattern).unwrap().clone().as_ptr();
msgbus_is_matching(topic_ptr, pattern_ptr)
};

assert_eq!(result, 1);
}
}
Loading

0 comments on commit 626b29b

Please sign in to comment.