From 82956e1125226c022780b4b683c25a2b789e911d Mon Sep 17 00:00:00 2001 From: Chris Sellers Date: Mon, 13 Nov 2023 18:25:31 +1100 Subject: [PATCH] Continue core MessageBus with FFI --- nautilus_core/common/src/ffi/msgbus.rs | 258 ++++++++++++++++++++++++- nautilus_core/common/src/handlers.rs | 19 +- nautilus_core/common/src/msgbus.rs | 148 +++++++------- nautilus_trader/core/includes/common.h | 29 ++- nautilus_trader/core/rust/common.pxd | 23 ++- 5 files changed, 392 insertions(+), 85 deletions(-) diff --git a/nautilus_core/common/src/ffi/msgbus.rs b/nautilus_core/common/src/ffi/msgbus.rs index b60361256a70..3edb3f80333c 100644 --- a/nautilus_core/common/src/ffi/msgbus.rs +++ b/nautilus_core/common/src/ffi/msgbus.rs @@ -19,9 +19,12 @@ use std::{ str::FromStr, }; -use nautilus_core::ffi::{ - cvec::CVec, - string::{cstr_to_string, cstr_to_ustr, optional_cstr_to_string}, +use nautilus_core::{ + ffi::{ + cvec::CVec, + string::{cstr_to_string, cstr_to_ustr, optional_cstr_to_string}, + }, + uuid::UUID4, }; use nautilus_model::identifiers::trader_id::TraderId; use pyo3::{ @@ -77,7 +80,7 @@ pub unsafe extern "C" fn msgbus_new( } #[no_mangle] -pub extern "C" fn msgbus_endpoints(bus: MessageBus_API) -> *const ffi::PyObject { +pub extern "C" fn msgbus_endpoints(bus: MessageBus_API) -> *mut ffi::PyObject { Python::with_gil(|py| -> Py { let endpoints: Vec> = bus .endpoints() @@ -90,7 +93,7 @@ pub extern "C" fn msgbus_endpoints(bus: MessageBus_API) -> *const ffi::PyObject } #[no_mangle] -pub extern "C" fn msgbus_topics(bus: MessageBus_API) -> *const ffi::PyObject { +pub extern "C" fn msgbus_topics(bus: MessageBus_API) -> *mut ffi::PyObject { Python::with_gil(|py| -> Py { let topics: Vec> = bus .endpoints() @@ -102,6 +105,18 @@ pub extern "C" fn msgbus_topics(bus: MessageBus_API) -> *const ffi::PyObject { .as_ptr() } +/// # Safety +/// +/// - Assumes `pattern_ptr` is a valid C string pointer. +#[no_mangle] +pub unsafe extern "C" fn msgbus_has_subscribers( + bus: MessageBus_API, + pattern_ptr: *const c_char, +) -> u8 { + let pattern = cstr_to_ustr(pattern_ptr); + bus.has_subscribers(pattern.as_str()) as u8 +} + /// # Safety /// /// - Assumes `handler_id_ptr` is a valid C string pointer. @@ -130,7 +145,7 @@ pub unsafe extern "C" fn msgbus_subscribe( pub unsafe extern "C" fn msgbus_get_endpoint( bus: MessageBus_API, endpoint_ptr: *const c_char, -) -> *const ffi::PyObject { +) -> *mut ffi::PyObject { let endpoint = cstr_to_ustr(endpoint_ptr); match bus.get_endpoint(&endpoint) { @@ -148,7 +163,7 @@ pub unsafe extern "C" fn msgbus_get_matching_callables( pattern_ptr: *const c_char, ) -> CVec { let pattern = cstr_to_ustr(pattern_ptr); - let subs: Vec<&Subscription> = bus.get_matching_subscriptions(&pattern); + let subs: Vec<&Subscription> = bus.matching_subscriptions(&pattern); subs.iter() .map(|s| s.handler.py_callback.unwrap()) @@ -165,6 +180,42 @@ pub extern "C" fn vec_pycallable_drop(v: CVec) { drop(data); // Memory freed here } +/// # Safety +/// +/// - Assumes `pattern_ptr` is a valid C string pointer. +#[no_mangle] +pub unsafe extern "C" fn msgbus_request_handler( + mut bus: MessageBus_API, + endpoint_ptr: *const c_char, + request_id: UUID4, +) -> *mut ffi::PyObject { + let endpoint = cstr_to_ustr(endpoint_ptr); + let handler = bus.request_handler(&endpoint, request_id); + + if let Some(handler) = handler { + handler.py_callback.unwrap().ptr + } else { + ffi::Py_None() + } +} + +/// # Safety +/// +/// - Assumes `pattern_ptr` is a valid C string pointer. +#[no_mangle] +pub unsafe extern "C" fn msgbus_response_handler( + mut bus: MessageBus_API, + correlation_id: &UUID4, +) -> *mut ffi::PyObject { + let handler = bus.response_handler(correlation_id); + + if let Some(handler) = handler { + handler.py_callback.unwrap().ptr + } else { + ffi::Py_None() + } +} + /// # Safety /// /// - Assumes `topic_ptr` is a valid C string pointer. @@ -184,15 +235,28 @@ pub unsafe extern "C" fn msgbus_is_matching( // Tests //////////////////////////////////////////////////////////////////////////////// #[cfg(test)] -mod tests { - use std::rc::Rc; +mod ffi_tests { + use std::{ffi::CString, ptr, rc::Rc}; use nautilus_core::message::Message; + use pyo3::FromPyPointer; use rstest::*; use ustr::Ustr; use super::*; use crate::handlers::MessageHandler; + // Helper function to create a MessageHandler with a PyCallableWrapper for testing + fn create_handler() -> MessageHandler { + let py_callable_ptr = ptr::null_mut(); // Replace with an actual PyObject pointer if needed + let handler_id = Ustr::from("test_handler"); + MessageHandler::new( + handler_id, + Some(PyCallableWrapper { + ptr: py_callable_ptr, + }), + None, + ) + } #[rstest] fn test_subscribe_rust_handler() { @@ -210,4 +274,180 @@ mod tests { assert!(msgbus.has_subscribers(&topic)); assert_eq!(msgbus.topics(), vec![topic]); } + + #[test] + fn test_msgbus_new() { + let trader_id = TraderId::from_str("trader-001").unwrap(); + let name = CString::new("Test MessageBus").unwrap(); + + // Create a new MessageBus using FFI + let bus = unsafe { msgbus_new(trader_id.to_string().as_ptr() as *const i8, name.as_ptr()) }; + + // Verify that the trader ID and name are set correctly + assert_eq!(bus.trader_id.to_string(), "trader-001"); + assert_eq!(bus.name, "Test MessageBus"); + } + + #[ignore] + #[test] + fn test_msgbus_endpoints() { + let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); + let endpoint1 = "endpoint1"; + let endpoint2 = "endpoint2"; + + // Register endpoints + bus.register(endpoint1, create_handler()); + bus.register(endpoint2, create_handler()); + + // Call msgbus_endpoints to get endpoints as a Python list + let py_list = msgbus_endpoints(MessageBus_API(Box::new(bus))); + + // Convert the Python list to a Vec of strings + let endpoints: Vec = Python::with_gil(|py| { + let py_list = unsafe { PyList::from_owned_ptr(py, py_list) }; + py_list + .into_iter() + .map(|item| item.extract::().unwrap()) + .collect() + }); + + // Verify that the endpoints are correctly retrieved + assert_eq!(endpoints.len(), 2); + assert!(endpoints.contains(&endpoint1.to_string())); + assert!(endpoints.contains(&endpoint2.to_string())); + } + + #[ignore] + #[test] + fn test_msgbus_topics() { + let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); + let topic1 = "topic1"; + let topic2 = "topic2"; + + // Subscribe to topics + bus.subscribe(topic1, create_handler(), None); + bus.subscribe(topic2, create_handler(), None); + + // Call msgbus_topics to get topics as a Python list + let py_list = msgbus_topics(MessageBus_API(Box::new(bus))); + + // Convert the Python list to a Vec of strings + let topics: Vec = Python::with_gil(|py| { + let py_list = unsafe { PyList::from_owned_ptr(py, py_list) }; + py_list + .into_iter() + .map(|item| item.extract::().unwrap()) + .collect() + }); + + // Verify that the topics are correctly retrieved + assert_eq!(topics.len(), 2); + assert!(topics.contains(&topic1.to_string())); + assert!(topics.contains(&topic2.to_string())); + } + + #[ignore] + #[test] + fn test_msgbus_subscribe() { + let bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); + let topic = "test-topic"; + + // Subscribe using FFI + unsafe { + let topic_ptr = CString::new(topic).unwrap().clone().as_ptr(); + let handler_id_ptr = CString::new("handler-001").unwrap().clone().as_ptr(); + msgbus_subscribe( + MessageBus_API(Box::new(bus.clone())), + topic_ptr, + handler_id_ptr, + ptr::null_mut(), + 1, + ); + + // Verify that the subscription is added + assert!(msgbus_has_subscribers(MessageBus_API(Box::new(bus)), topic_ptr) != 0); + } + } + + #[ignore] + #[test] + fn test_msgbus_get_endpoint() { + let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); + let endpoint = "test-endpoint"; + let handler = create_handler(); + + // Register an endpoint + bus.register(endpoint, handler.clone()); + + // Call msgbus_get_endpoint to get the handler as a PyObject + let py_callable = unsafe { + let endpoint_ptr = CString::new(endpoint).unwrap().clone().as_ptr(); + msgbus_get_endpoint(MessageBus_API(Box::new(bus)), endpoint_ptr) + }; + + // Verify that the PyObject pointer matches the registered handler's PyObject pointer + assert_eq!(py_callable, handler.py_callback.unwrap().ptr); + } + + #[ignore] + #[test] + fn test_msgbus_request_handler() { + let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); + let endpoint = "test-endpoint"; + let request_id = UUID4::new(); + + // Register an endpoint + bus.register(endpoint, create_handler()); + + // Call msgbus_request_handler to get the handler as a PyObject + let py_callable = unsafe { + let endpoint_ptr = CString::new(endpoint).unwrap().clone().as_ptr(); + msgbus_request_handler( + MessageBus_API(Box::new(bus.clone())), + endpoint_ptr, + request_id, + ) + }; + + // Verify that the PyObject pointer matches the registered handler's PyObject pointer + assert_eq!( + py_callable, + bus.endpoints[&Ustr::from(endpoint)] + .py_callback + .unwrap() + .ptr + ); + } + + #[ignore] + #[test] + fn test_msgbus_response_handler() { + let mut bus = MessageBus::new(TraderId::from_str("trader-001").unwrap(), None); + let correlation_id = UUID4::new(); + + // Register a response handler + let handler = create_handler(); + bus.correlation_index + .insert(correlation_id.clone(), handler.clone()); + + // Call msgbus_response_handler to get the handler as a PyObject + let py_callable = + unsafe { msgbus_response_handler(MessageBus_API(Box::new(bus)), &correlation_id) }; + + assert_eq!(py_callable, handler.py_callback.unwrap().ptr); + } + + #[test] + fn test_msgbus_is_matching() { + let topic = "data.quotes.BINANCE"; + let pattern = "data.*.BINANCE"; + + let result = unsafe { + let topic_ptr = CString::new(topic).unwrap().clone().as_ptr(); + let pattern_ptr = CString::new(pattern).unwrap().clone().as_ptr(); + msgbus_is_matching(topic_ptr, pattern_ptr) + }; + + assert_eq!(result, 1); + } } diff --git a/nautilus_core/common/src/handlers.rs b/nautilus_core/common/src/handlers.rs index 20bbd0550a0c..0404b92a50d0 100644 --- a/nautilus_core/common/src/handlers.rs +++ b/nautilus_core/common/src/handlers.rs @@ -13,7 +13,7 @@ // limitations under the License. // ------------------------------------------------------------------------------------------------- -use std::rc::Rc; +use std::{fmt, rc::Rc}; use nautilus_core::message::Message; use pyo3::{ffi, prelude::*, AsPyPointer}; @@ -22,7 +22,7 @@ use ustr::Ustr; use crate::timer::TimeEvent; #[repr(C)] -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub struct PyCallableWrapper { pub ptr: *mut ffi::PyObject, } @@ -55,6 +55,21 @@ impl MessageHandler { } } +impl PartialEq for MessageHandler { + fn eq(&self, other: &Self) -> bool { + self.handler_id == other.handler_id + } +} + +impl fmt::Debug for MessageHandler { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(stringify!(MessageHandler)) + .field("handler_id", &self.handler_id) + .field("py_callback", &self.py_callback) + .finish() + } +} + // TODO: Make this more generic #[derive(Clone)] pub struct EventHandler { diff --git a/nautilus_core/common/src/msgbus.rs b/nautilus_core/common/src/msgbus.rs index 8508a2cbbb7b..6dad66053202 100644 --- a/nautilus_core/common/src/msgbus.rs +++ b/nautilus_core/common/src/msgbus.rs @@ -93,6 +93,7 @@ impl Hash for Subscription { /// `camp` and `comp`. The question mark can also be used more than once. /// For example, `c??p` would match both of the above examples and `coop`. #[allow(dead_code)] +#[derive(Clone)] pub struct MessageBus { /// The trader ID for the message bus. pub trader_id: TraderId, @@ -107,11 +108,11 @@ pub struct MessageBus { /// this is updated whenever a new subscription is created. patterns: HashMap>, /// handles a message or a request destined for a specific endpoint. - endpoints: HashMap, + pub endpoints: HashMap, /// Relates a request with a response /// a request maps it's id to a handler so that a response /// with the same id can later be handled. - correlation_index: HashMap, + pub correlation_index: HashMap, } #[allow(dead_code)] @@ -199,46 +200,49 @@ impl MessageBus { .is_some() } - // #[allow(unused_variables)] - // fn request(&mut self, endpoint: &String, request: &Message, callback: T) { - // match request { - // Message::Request { id, ts_init } => { - // if self.correlation_index.contains_key(id) { - // todo!() - // } else { - // self.correlation_index.insert(*id, callback); - // if let Some(handler) = self.endpoints.get(endpoint) { - // handler(request); - // } else { - // // TODO: log error - // } - // } - // } - // _ => unreachable!( - // "message bus request should only be called with Message::Request variant" - // ), - // } - // } - // - // #[allow(unused_variables)] - // fn response(&mut self, response: &Message) { - // match response { - // Message::Response { - // id, - // ts_init, - // correlation_id, - // } => { - // if let Some(callback) = self.correlation_index.get(correlation_id) { - // callback(response); - // } else { - // // TODO: log error - // } - // } - // _ => unreachable!( - // "message bus response should only be called with Message::Response variant" - // ), - // } - // } + #[must_use] + pub fn request_handler( + &mut self, + endpoint: &Ustr, + request_id: UUID4, + ) -> Option<&MessageHandler> { + if let Some(handler) = self.endpoints.get(endpoint) { + self.correlation_index.insert(request_id, handler.clone()); + Some(handler) + } else { + None + } + } + + #[must_use] + pub fn response_handler(&mut self, correlation_id: &UUID4) -> Option { + self.correlation_index.remove(correlation_id) + } + + #[must_use] + pub fn matching_subscriptions<'a>(&'a mut self, pattern: &'a Ustr) -> Vec<&'a Subscription> { + let mut unique_subs = std::collections::HashSet::new(); + + // Collect matching subscriptions from direct subscriptions + unique_subs.extend(self.subscriptions.iter().filter_map(|(sub, _)| { + if is_matching(&sub.topic, pattern) { + Some(sub) + } else { + None + } + })); + + // Collect matching subscriptions from pattern-based subscriptions + for subs in self.patterns.values() { + unique_subs.extend(subs.iter().filter(|sub| is_matching(&sub.topic, pattern))); + } + + // Sort into priority order + let mut matching_subs = unique_subs.into_iter().collect::>(); + matching_subs.sort(); + + matching_subs + } fn matching_handlers<'a>( &'a self, @@ -252,32 +256,6 @@ impl MessageBus { } }) } - - pub fn get_matching_subscriptions<'a>( - &'a mut self, - pattern: &'a Ustr, - ) -> Vec<&'a Subscription> { - let mut matching_subs = self - .subscriptions - .iter() - .filter_map(|(sub, _)| { - if is_matching(&sub.topic, pattern) { - Some(sub) - } else { - None - } - }) - .collect::>(); - - for (p, subs) in &self.patterns { - if is_matching(p, pattern) { - matching_subs.extend(subs.iter()); - } - } - - matching_subs.sort(); - matching_subs - } } /// Match a topic and a string pattern @@ -419,6 +397,40 @@ mod tests { assert!(msgbus.topics().is_empty()); } + #[rstest] + fn test_request_handler() { + let mut msgbus = stub_msgbus(); + let endpoint = "MyEndpoint"; + let request_id = UUID4::new(); + + let callback = stub_rust_callback(); + let handler_id = Ustr::from("1"); + let handler = MessageHandler::new(handler_id.clone(), None, Some(callback)); + + msgbus.register(&endpoint, handler.clone()); + + assert_eq!( + msgbus.request_handler(&Ustr::from(endpoint), request_id.clone()), + Some(&handler) + ); + } + + #[rstest] + fn test_response_handler() { + let mut msgbus = stub_msgbus(); + let correlation_id = UUID4::new(); + + let callback = stub_rust_callback(); + let handler_id = Ustr::from("1"); + let handler = MessageHandler::new(handler_id.clone(), None, Some(callback)); + + msgbus + .correlation_index + .insert(correlation_id.clone(), handler.clone()); + + assert_eq!(msgbus.response_handler(&correlation_id), Some(handler)); + } + #[rstest] #[case("*", "*", true)] #[case("a", "*", true)] diff --git a/nautilus_trader/core/includes/common.h b/nautilus_trader/core/includes/common.h index e840d4fd0b73..517e8414ccf3 100644 --- a/nautilus_trader/core/includes/common.h +++ b/nautilus_trader/core/includes/common.h @@ -509,9 +509,16 @@ void logger_log(struct Logger_API *logger, */ struct MessageBus_API msgbus_new(const char *trader_id_ptr, const char *name_ptr); -const PyObject *msgbus_endpoints(struct MessageBus_API bus); +PyObject *msgbus_endpoints(struct MessageBus_API bus); -const PyObject *msgbus_topics(struct MessageBus_API bus); +PyObject *msgbus_topics(struct MessageBus_API bus); + +/** + * # Safety + * + * - Assumes `pattern_ptr` is a valid C string pointer. + */ +uint8_t msgbus_has_subscribers(struct MessageBus_API bus, const char *pattern_ptr); /** * # Safety @@ -529,7 +536,7 @@ void msgbus_subscribe(struct MessageBus_API bus, * * - Assumes `endpoint_ptr` is a valid C string pointer. */ -const PyObject *msgbus_get_endpoint(struct MessageBus_API bus, const char *endpoint_ptr); +PyObject *msgbus_get_endpoint(struct MessageBus_API bus, const char *endpoint_ptr); /** * # Safety @@ -540,6 +547,22 @@ CVec msgbus_get_matching_callables(struct MessageBus_API bus, const char *patter void vec_pycallable_drop(CVec v); +/** + * # Safety + * + * - Assumes `pattern_ptr` is a valid C string pointer. + */ +PyObject *msgbus_request_handler(struct MessageBus_API bus, + const char *endpoint_ptr, + UUID4_t request_id); + +/** + * # Safety + * + * - Assumes `pattern_ptr` is a valid C string pointer. + */ +PyObject *msgbus_response_handler(struct MessageBus_API bus, const UUID4_t *correlation_id); + /** * # Safety * diff --git a/nautilus_trader/core/rust/common.pxd b/nautilus_trader/core/rust/common.pxd index b7226a108730..215fef0eb855 100644 --- a/nautilus_trader/core/rust/common.pxd +++ b/nautilus_trader/core/rust/common.pxd @@ -360,9 +360,14 @@ cdef extern from "../includes/common.h": # - Assumes `name_ptr` is a valid C string pointer. MessageBus_API msgbus_new(const char *trader_id_ptr, const char *name_ptr); - const PyObject *msgbus_endpoints(MessageBus_API bus); + PyObject *msgbus_endpoints(MessageBus_API bus); - const PyObject *msgbus_topics(MessageBus_API bus); + PyObject *msgbus_topics(MessageBus_API bus); + + # # Safety + # + # - Assumes `pattern_ptr` is a valid C string pointer. + uint8_t msgbus_has_subscribers(MessageBus_API bus, const char *pattern_ptr); # # Safety # @@ -376,7 +381,7 @@ cdef extern from "../includes/common.h": # # Safety # # - Assumes `endpoint_ptr` is a valid C string pointer. - const PyObject *msgbus_get_endpoint(MessageBus_API bus, const char *endpoint_ptr); + PyObject *msgbus_get_endpoint(MessageBus_API bus, const char *endpoint_ptr); # # Safety # @@ -385,6 +390,18 @@ cdef extern from "../includes/common.h": void vec_pycallable_drop(CVec v); + # # Safety + # + # - Assumes `pattern_ptr` is a valid C string pointer. + PyObject *msgbus_request_handler(MessageBus_API bus, + const char *endpoint_ptr, + UUID4_t request_id); + + # # Safety + # + # - Assumes `pattern_ptr` is a valid C string pointer. + PyObject *msgbus_response_handler(MessageBus_API bus, const UUID4_t *correlation_id); + # # Safety # # - Assumes `topic_ptr` is a valid C string pointer.