diff --git a/nautilus_core/common/src/ffi/msgbus.rs b/nautilus_core/common/src/ffi/msgbus.rs index e1d55a997104..ca65f60c72c0 100644 --- a/nautilus_core/common/src/ffi/msgbus.rs +++ b/nautilus_core/common/src/ffi/msgbus.rs @@ -117,6 +117,44 @@ pub unsafe extern "C" fn msgbus_has_subscribers( bus.has_subscribers(pattern.as_str()) as u8 } +/// # Safety +/// +/// - Assumes `topic_ptr` is a valid C string pointer. +/// - Assumes `handler_id_ptr` is a valid C string pointer. +/// - Assumes `py_callable_ptr` points to a valid Python callable. +#[no_mangle] +pub unsafe extern "C" fn msgbus_is_subscribed( + bus: MessageBus_API, + topic_ptr: *const c_char, + handler_id_ptr: *const c_char, + py_callable_ptr: *mut ffi::PyObject, +) -> u8 { + let topic = cstr_to_ustr(topic_ptr); + let handler_id = cstr_to_ustr(handler_id_ptr); + let py_callable = PyCallableWrapper { + ptr: py_callable_ptr, + }; + let handler = MessageHandler::new(handler_id, Some(py_callable), None); + bus.is_subscribed(topic.as_str(), handler) as u8 +} + +/// # Safety +/// +/// - Assumes `endpoint_ptr` is a valid C string pointer. +#[no_mangle] +pub unsafe extern "C" fn msgbus_is_regsitered( + bus: MessageBus_API, + endpoint_ptr: *const c_char, +) -> u8 { + let endpoint = cstr_to_string(endpoint_ptr); + bus.is_registered(&endpoint) as u8 +} + +#[no_mangle] +pub extern "C" fn msgbus_is_pending_request(bus: MessageBus_API, request_id: &UUID4) -> u8 { + bus.is_pending_response(request_id) as u8 +} + /// # Safety /// /// - Assumes `handler_id_ptr` is a valid C string pointer. @@ -138,6 +176,58 @@ pub unsafe extern "C" fn msgbus_subscribe( bus.subscribe(&topic, handler, Some(priority)); } +/// # Safety +/// +/// - Assumes `handler_id_ptr` is a valid C string pointer. +#[no_mangle] +pub unsafe extern "C" fn msgbus_unsubscribe( + mut bus: MessageBus_API, + topic_ptr: *const c_char, + handler_id_ptr: *const c_char, + py_callable_ptr: *mut ffi::PyObject, +) { + let topic = cstr_to_ustr(topic_ptr); + let handler_id = cstr_to_ustr(handler_id_ptr); + let py_callable = PyCallableWrapper { + ptr: py_callable_ptr, + }; + let handler = MessageHandler::new(handler_id, Some(py_callable), None); + + bus.unsubscribe(&topic, handler); +} + +/// # Safety +/// +/// - Assumes `endpoint_ptr` is a valid C string pointer. +/// - Assumes `handler_id_ptr` is a valid C string pointer. +/// - Assumes `py_callable_ptr` points to a valid Python callable. +#[no_mangle] +pub unsafe extern "C" fn msgbus_register( + mut bus: MessageBus_API, + endpoint_ptr: *const c_char, + handler_id_ptr: *const c_char, + py_callable_ptr: *mut ffi::PyObject, +) { + let endpoint = cstr_to_string(endpoint_ptr); + let handler_id = cstr_to_ustr(handler_id_ptr); + let wrapper = PyCallableWrapper { + ptr: py_callable_ptr, + }; + let handler = MessageHandler::new(handler_id, Some(wrapper), None); + bus.register(&endpoint, handler) +} + +/// # Safety +/// +/// - Assumes `endpoint_ptr` is a valid C string pointer. +/// - Assumes `handler_id_ptr` is a valid C string pointer. +/// - Assumes `py_callable_ptr` points to a valid Python callable. +#[no_mangle] +pub unsafe extern "C" fn msgbus_deregister(mut bus: MessageBus_API, endpoint_ptr: *const c_char) { + let endpoint = cstr_to_string(endpoint_ptr); + bus.deregister(&endpoint) +} + /// # Safety /// /// - Assumes `endpoint_ptr` is a valid C string pointer. @@ -230,226 +320,3 @@ pub unsafe extern "C" fn msgbus_is_matching( is_matching(&topic, &pattern) as u8 } - -//////////////////////////////////////////////////////////////////////////////// -// Tests -//////////////////////////////////////////////////////////////////////////////// -#[cfg(test)] -mod ffi_tests { - use std::{ffi::CString, ptr, rc::Rc}; - - use nautilus_core::message::Message; - use pyo3::FromPyPointer; - use rstest::*; - use ustr::Ustr; - - use super::*; - use crate::handlers::MessageHandler; - // Helper function to create a MessageHandler with a PyCallableWrapper for testing - fn create_handler() -> MessageHandler { - let py_callable_ptr = ptr::null_mut(); // Replace with an actual PyObject pointer if needed - let handler_id = Ustr::from("test_handler"); - MessageHandler::new( - handler_id, - Some(PyCallableWrapper { - ptr: py_callable_ptr, - }), - None, - ) - } - - #[rstest] - fn test_subscribe_rust_handler() { - let trader_id = TraderId::from("trader-001"); - let topic = "my-topic".to_string(); - - // TODO: Create a Python list and pass the message in a closure to the `append` method - let callback = Rc::new(|_m: Message| Python::with_gil(|_| {})); - let handler_id = Ustr::from("id_of_method"); - let handler = MessageHandler::new(handler_id, None, Some(callback)); - - let mut msgbus = MessageBus::new(trader_id, None); - msgbus.subscribe(&topic, handler, None); - - assert!(msgbus.has_subscribers(&topic)); - assert_eq!(msgbus.topics(), vec![topic]); - } - - #[ignore] - #[rstest] - fn test_msgbus_new() { - let trader_id = TraderId::from_str("trader-001").unwrap(); - let name = CString::new("Test MessageBus").unwrap(); - - // Create a new MessageBus using FFI - let bus = unsafe { msgbus_new(trader_id.to_string().as_ptr() as *const i8, name.as_ptr()) }; - - // Verify that the trader ID and name are set correctly - assert_eq!(bus.trader_id.to_string(), "trader-001"); - assert_eq!(bus.name, "Test MessageBus"); - } - - #[ignore] - #[rstest] - fn test_msgbus_endpoints() { - let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); - let endpoint1 = "endpoint1"; - let endpoint2 = "endpoint2"; - - // Register endpoints - bus.register(endpoint1, create_handler()); - bus.register(endpoint2, create_handler()); - - // Call msgbus_endpoints to get endpoints as a Python list - let py_list = msgbus_endpoints(MessageBus_API(Box::new(bus))); - - // Convert the Python list to a Vec of strings - let endpoints: Vec = Python::with_gil(|py| { - let py_list = unsafe { PyList::from_owned_ptr(py, py_list) }; - py_list - .into_iter() - .map(|item| item.extract::().unwrap()) - .collect() - }); - - // Verify that the endpoints are correctly retrieved - assert_eq!(endpoints.len(), 2); - assert!(endpoints.contains(&endpoint1.to_string())); - assert!(endpoints.contains(&endpoint2.to_string())); - } - - #[ignore] - #[rstest] - fn test_msgbus_topics() { - let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); - let topic1 = "topic1"; - let topic2 = "topic2"; - - // Subscribe to topics - bus.subscribe(topic1, create_handler(), None); - bus.subscribe(topic2, create_handler(), None); - - // Call msgbus_topics to get topics as a Python list - let py_list = msgbus_topics(MessageBus_API(Box::new(bus))); - - // Convert the Python list to a Vec of strings - let topics: Vec = Python::with_gil(|py| { - let py_list = unsafe { PyList::from_owned_ptr(py, py_list) }; - py_list - .into_iter() - .map(|item| item.extract::().unwrap()) - .collect() - }); - - // Verify that the topics are correctly retrieved - assert_eq!(topics.len(), 2); - assert!(topics.contains(&topic1.to_string())); - assert!(topics.contains(&topic2.to_string())); - } - - #[ignore] - #[rstest] - fn test_msgbus_subscribe() { - let bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); - let topic = "test-topic"; - - // Subscribe using FFI - unsafe { - let topic_ptr = CString::new(topic).unwrap().clone().as_ptr(); - let handler_id_ptr = CString::new("handler-001").unwrap().clone().as_ptr(); - msgbus_subscribe( - MessageBus_API(Box::new(bus.clone())), - topic_ptr, - handler_id_ptr, - ptr::null_mut(), - 1, - ); - - // Verify that the subscription is added - assert!(msgbus_has_subscribers(MessageBus_API(Box::new(bus)), topic_ptr) != 0); - } - } - - #[ignore] - #[rstest] - fn test_msgbus_get_endpoint() { - let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); - let endpoint = "test-endpoint"; - let handler = create_handler(); - - // Register an endpoint - bus.register(endpoint, handler.clone()); - - // Call msgbus_get_endpoint to get the handler as a PyObject - let py_callable = unsafe { - let endpoint_ptr = CString::new(endpoint).unwrap().clone().as_ptr(); - msgbus_get_endpoint(MessageBus_API(Box::new(bus)), endpoint_ptr) - }; - - // Verify that the PyObject pointer matches the registered handler's PyObject pointer - assert_eq!(py_callable, handler.py_callback.unwrap().ptr); - } - - #[ignore] - #[rstest] - fn test_msgbus_request_handler() { - let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); - let endpoint = "test-endpoint"; - let request_id = UUID4::new(); - - // Register an endpoint - bus.register(endpoint, create_handler()); - - // Call msgbus_request_handler to get the handler as a PyObject - let py_callable = unsafe { - let endpoint_ptr = CString::new(endpoint).unwrap().clone().as_ptr(); - msgbus_request_handler( - MessageBus_API(Box::new(bus.clone())), - endpoint_ptr, - request_id, - ) - }; - - // Verify that the PyObject pointer matches the registered handler's PyObject pointer - assert_eq!( - py_callable, - bus.endpoints[&Ustr::from(endpoint)] - .py_callback - .unwrap() - .ptr - ); - } - - #[ignore] - #[rstest] - fn test_msgbus_response_handler() { - let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); - let correlation_id = UUID4::new(); - - // Register a response handler - let handler = create_handler(); - bus.correlation_index - .insert(correlation_id.clone(), handler.clone()); - - // Call msgbus_response_handler to get the handler as a PyObject - let py_callable = - unsafe { msgbus_response_handler(MessageBus_API(Box::new(bus)), &correlation_id) }; - - assert_eq!(py_callable, handler.py_callback.unwrap().ptr); - } - - #[ignore] - #[rstest] - fn test_msgbus_is_matching() { - let topic = "data.quotes.BINANCE"; - let pattern = "data.*.BINANCE"; - - let result = unsafe { - let topic_ptr = CString::new(topic).unwrap().clone().as_ptr(); - let pattern_ptr = CString::new(pattern).unwrap().clone().as_ptr(); - msgbus_is_matching(topic_ptr, pattern_ptr) - }; - - assert_eq!(result, 1); - } -} diff --git a/nautilus_core/common/src/msgbus.rs b/nautilus_core/common/src/msgbus.rs index 6dad66053202..a69e23715ba1 100644 --- a/nautilus_core/common/src/msgbus.rs +++ b/nautilus_core/common/src/msgbus.rs @@ -145,6 +145,33 @@ impl MessageBus { .collect() } + /// Returns whether there are subscribers for the given `pattern`. + #[must_use] + pub fn has_subscribers(&self, pattern: &str) -> bool { + self.matching_handlers(&Ustr::from(pattern)) + .next() + .is_some() + } + + /// Returns whether there are subscribers for the given `pattern`. + #[must_use] + pub fn is_subscribed(&self, topic: &str, handler: MessageHandler) -> bool { + let sub = Subscription::new(Ustr::from(topic), handler, None); + self.subscriptions.contains_key(&sub) + } + + /// Returns whether there is a pending request for the given `request_id`. + #[must_use] + pub fn is_pending_response(&self, request_id: &UUID4) -> bool { + self.correlation_index.contains_key(request_id) + } + + /// Returns whether there are subscribers for the given `pattern`. + #[must_use] + pub fn is_registered(&self, endpoint: &str) -> bool { + self.endpoints.contains_key(&Ustr::from(endpoint)) + } + /// Registers the given `handler` for the `endpoint` address. pub fn register(&mut self, endpoint: &str, handler: MessageHandler) { // Updates value if key already exists @@ -153,7 +180,7 @@ impl MessageBus { /// Deregisters the given `handler` for the `endpoint` address. pub fn deregister(&mut self, endpoint: &str) { - // removes entry if it exists for endpoint + // Removes entry if it exists for endpoint self.endpoints.remove(&Ustr::from(endpoint)); } @@ -182,7 +209,6 @@ impl MessageBus { /// Unsubscribes the given `handler` from the `topic`. pub fn unsubscribe(&mut self, topic: &str, handler: MessageHandler) { let sub = Subscription::new(Ustr::from(topic), handler, None); - self.subscriptions.remove(&sub); } @@ -192,14 +218,8 @@ impl MessageBus { self.endpoints.get(&Ustr::from(endpoint)) } - /// Returns whether there are subscribers for the given `pattern`. - #[must_use] - pub fn has_subscribers(&self, pattern: &str) -> bool { - self.matching_handlers(&Ustr::from(pattern)) - .next() - .is_some() - } - + /// Returns the handler for the request `endpoint` and adds the request ID to the internal + /// correlation index to match with the expected response. #[must_use] pub fn request_handler( &mut self, @@ -214,6 +234,8 @@ impl MessageBus { } } + /// Returns the handler for the matching response `endpoint` based on the internal correlation + /// index. #[must_use] pub fn response_handler(&mut self, correlation_id: &UUID4) -> Option { self.correlation_index.remove(correlation_id) @@ -336,6 +358,31 @@ mod tests { assert!(!msgbus.has_subscribers("my-topic")); } + #[rstest] + fn test_is_subscribed_when_no_subscriptions() { + let msgbus = stub_msgbus(); + + let callback = stub_rust_callback(); + let handler_id = Ustr::from("1"); + let handler = MessageHandler::new(handler_id, None, Some(callback)); + + assert!(!msgbus.is_subscribed("my-topic", handler)); + } + + #[rstest] + fn test_is_registered_when_no_registrations() { + let msgbus = stub_msgbus(); + + assert!(!msgbus.is_registered("MyEndpoint")); + } + + #[rstest] + fn test_is_pending_response_when_no_requests() { + let msgbus = stub_msgbus(); + + assert!(!msgbus.is_pending_response(&UUID4::default())); + } + #[rstest] fn test_regsiter_endpoint() { let mut msgbus = stub_msgbus(); diff --git a/nautilus_trader/core/includes/common.h b/nautilus_trader/core/includes/common.h index 517e8414ccf3..193b9c7c1c0c 100644 --- a/nautilus_trader/core/includes/common.h +++ b/nautilus_trader/core/includes/common.h @@ -520,6 +520,27 @@ PyObject *msgbus_topics(struct MessageBus_API bus); */ uint8_t msgbus_has_subscribers(struct MessageBus_API bus, const char *pattern_ptr); +/** + * # Safety + * + * - Assumes `topic_ptr` is a valid C string pointer. + * - Assumes `handler_id_ptr` is a valid C string pointer. + * - Assumes `py_callable_ptr` points to a valid Python callable. + */ +uint8_t msgbus_is_subscribed(struct MessageBus_API bus, + const char *topic_ptr, + const char *handler_id_ptr, + PyObject *py_callable_ptr); + +/** + * # Safety + * + * - Assumes `endpoint_ptr` is a valid C string pointer. + */ +uint8_t msgbus_is_regsitered(struct MessageBus_API bus, const char *endpoint_ptr); + +uint8_t msgbus_is_pending_request(struct MessageBus_API bus, const UUID4_t *request_id); + /** * # Safety * @@ -531,6 +552,37 @@ void msgbus_subscribe(struct MessageBus_API bus, PyObject *py_callable_ptr, uint8_t priority); +/** + * # Safety + * + * - Assumes `handler_id_ptr` is a valid C string pointer. + */ +void msgbus_unsubscribe(struct MessageBus_API bus, + const char *topic_ptr, + const char *handler_id_ptr, + PyObject *py_callable_ptr); + +/** + * # Safety + * + * - Assumes `endpoint_ptr` is a valid C string pointer. + * - Assumes `handler_id_ptr` is a valid C string pointer. + * - Assumes `py_callable_ptr` points to a valid Python callable. + */ +void msgbus_register(struct MessageBus_API bus, + const char *endpoint_ptr, + const char *handler_id_ptr, + PyObject *py_callable_ptr); + +/** + * # Safety + * + * - Assumes `endpoint_ptr` is a valid C string pointer. + * - Assumes `handler_id_ptr` is a valid C string pointer. + * - Assumes `py_callable_ptr` points to a valid Python callable. + */ +void msgbus_deregister(struct MessageBus_API bus, const char *endpoint_ptr); + /** * # Safety * diff --git a/nautilus_trader/core/rust/common.pxd b/nautilus_trader/core/rust/common.pxd index 215fef0eb855..e743536f44e5 100644 --- a/nautilus_trader/core/rust/common.pxd +++ b/nautilus_trader/core/rust/common.pxd @@ -369,6 +369,23 @@ cdef extern from "../includes/common.h": # - Assumes `pattern_ptr` is a valid C string pointer. uint8_t msgbus_has_subscribers(MessageBus_API bus, const char *pattern_ptr); + # # Safety + # + # - Assumes `topic_ptr` is a valid C string pointer. + # - Assumes `handler_id_ptr` is a valid C string pointer. + # - Assumes `py_callable_ptr` points to a valid Python callable. + uint8_t msgbus_is_subscribed(MessageBus_API bus, + const char *topic_ptr, + const char *handler_id_ptr, + PyObject *py_callable_ptr); + + # # Safety + # + # - Assumes `endpoint_ptr` is a valid C string pointer. + uint8_t msgbus_is_regsitered(MessageBus_API bus, const char *endpoint_ptr); + + uint8_t msgbus_is_pending_request(MessageBus_API bus, const UUID4_t *request_id); + # # Safety # # - Assumes `handler_id_ptr` is a valid C string pointer. @@ -378,6 +395,31 @@ cdef extern from "../includes/common.h": PyObject *py_callable_ptr, uint8_t priority); + # # Safety + # + # - Assumes `handler_id_ptr` is a valid C string pointer. + void msgbus_unsubscribe(MessageBus_API bus, + const char *topic_ptr, + const char *handler_id_ptr, + PyObject *py_callable_ptr); + + # # Safety + # + # - Assumes `endpoint_ptr` is a valid C string pointer. + # - Assumes `handler_id_ptr` is a valid C string pointer. + # - Assumes `py_callable_ptr` points to a valid Python callable. + void msgbus_register(MessageBus_API bus, + const char *endpoint_ptr, + const char *handler_id_ptr, + PyObject *py_callable_ptr); + + # # Safety + # + # - Assumes `endpoint_ptr` is a valid C string pointer. + # - Assumes `handler_id_ptr` is a valid C string pointer. + # - Assumes `py_callable_ptr` points to a valid Python callable. + void msgbus_deregister(MessageBus_API bus, const char *endpoint_ptr); + # # Safety # # - Assumes `endpoint_ptr` is a valid C string pointer.