From cb4d4e03569c0aa5b55c09720822e9dcac612345 Mon Sep 17 00:00:00 2001 From: Chris Sellers Date: Mon, 13 Nov 2023 16:49:16 +1100 Subject: [PATCH] Refine core MessageBus subscription matching --- nautilus_core/common/src/ffi/msgbus.rs | 36 ++--------- nautilus_core/common/src/handlers.rs | 2 +- nautilus_core/common/src/msgbus.rs | 89 +++++++++++--------------- nautilus_trader/core/includes/common.h | 10 +-- nautilus_trader/core/rust/common.pxd | 8 +-- 5 files changed, 44 insertions(+), 101 deletions(-) diff --git a/nautilus_core/common/src/ffi/msgbus.rs b/nautilus_core/common/src/ffi/msgbus.rs index 15d07ba6e5c5..b60361256a70 100644 --- a/nautilus_core/common/src/ffi/msgbus.rs +++ b/nautilus_core/common/src/ffi/msgbus.rs @@ -119,6 +119,7 @@ pub unsafe extern "C" fn msgbus_subscribe( ptr: py_callable_ptr, }; let handler = MessageHandler::new(handler_id, Some(py_callable), None); + bus.subscribe(&topic, handler, Some(priority)); } @@ -131,6 +132,7 @@ pub unsafe extern "C" fn msgbus_get_endpoint( endpoint_ptr: *const c_char, ) -> *const ffi::PyObject { let endpoint = cstr_to_ustr(endpoint_ptr); + match bus.get_endpoint(&endpoint) { Some(handler) => handler.py_callback.unwrap().ptr, None => ffi::Py_None(), @@ -141,12 +143,12 @@ pub unsafe extern "C" fn msgbus_get_endpoint( /// /// - Assumes `pattern_ptr` is a valid C string pointer. #[no_mangle] -pub unsafe extern "C" fn msgbus_get_matching_handlers( +pub unsafe extern "C" fn msgbus_get_matching_callables( mut bus: MessageBus_API, pattern_ptr: *const c_char, ) -> CVec { let pattern = cstr_to_ustr(pattern_ptr); - let subs: Vec<&Subscription> = bus.get_matching_handlers(&pattern); + let subs: Vec<&Subscription> = bus.get_matching_subscriptions(&pattern); subs.iter() .map(|s| s.handler.py_callback.unwrap()) @@ -154,36 +156,6 @@ pub unsafe extern "C" fn msgbus_get_matching_handlers( .into() } -/// # Safety -/// -/// - Assumes any registered handler has a Python callable. -/// - Assumes `endpoint_ptr` is a valid C string pointer. -// pub unsafe extern "C" fn msgbus_send( -// bus: MessageBus_API, -// endpoint_ptr: *const c_char, -// msg: *mut ffi::PyObject, -// ) { -// let endpoint = cstr_to_ustr(endpoint_ptr); -// -// if let Some(handler) = bus.get_endpoint(&endpoint) { -// let callable_ptr = handler.py_callback.unwrap().ptr; -// Python::with_gil(|py| { -// let callable = PyObject::from_borrowed_ptr(py, callable_ptr); -// let msg = PyObject::from_borrowed_ptr(py, msg); -// callable.call1(py, msg.into_py(py)); -// }); -// } -// } - -#[allow(clippy::drop_non_drop)] -#[no_mangle] -pub extern "C" fn vec_msgbus_handlers_drop(v: CVec) { - let CVec { ptr, len, cap } = v; - let data: Vec = - unsafe { Vec::from_raw_parts(ptr.cast::(), len, cap) }; - drop(data); // Memory freed here -} - #[allow(clippy::drop_non_drop)] #[no_mangle] pub extern "C" fn vec_pycallable_drop(v: CVec) { diff --git a/nautilus_core/common/src/handlers.rs b/nautilus_core/common/src/handlers.rs index d4e42fc06567..20bbd0550a0c 100644 --- a/nautilus_core/common/src/handlers.rs +++ b/nautilus_core/common/src/handlers.rs @@ -49,7 +49,7 @@ impl MessageHandler { } } - pub fn as_ptr(self) -> *const ffi::PyObject { + pub fn as_ptr(self) -> *mut ffi::PyObject { // SAFETY: Will panic if `unwrap` is called on None self.py_callback.unwrap().ptr } diff --git a/nautilus_core/common/src/msgbus.rs b/nautilus_core/common/src/msgbus.rs index 63d7cb46cda5..8508a2cbbb7b 100644 --- a/nautilus_core/common/src/msgbus.rs +++ b/nautilus_core/common/src/msgbus.rs @@ -18,7 +18,7 @@ use std::{ hash::{Hash, Hasher}, }; -use nautilus_core::{message::Message, uuid::UUID4}; +use nautilus_core::uuid::UUID4; use nautilus_model::identifiers::trader_id::TraderId; use ustr::Ustr; @@ -145,9 +145,9 @@ impl MessageBus { } /// Registers the given `handler` for the `endpoint` address. - pub fn register(&mut self, endpoint: String, handler: MessageHandler) { - // updates value if key already exists - self.endpoints.insert(Ustr::from(&endpoint), handler); + pub fn register(&mut self, endpoint: &str, handler: MessageHandler) { + // Updates value if key already exists + self.endpoints.insert(Ustr::from(endpoint), handler); } /// Deregisters the given `handler` for the `endpoint` address. @@ -199,14 +199,6 @@ impl MessageBus { .is_some() } - // fn send(&self, endpoint: &Ustr, msg: &Message) { - // if let Some(handler) = self.endpoints.get(endpoint) { - // if let Some(py_callable) = handler.py_callback { - // Python::with_gil(|| msg) - // } - // } - // } - // #[allow(unused_variables)] // fn request(&mut self, endpoint: &String, request: &Message, callback: T) { // match request { @@ -248,11 +240,6 @@ impl MessageBus { // } // } - // TODO: This is the modified version of matching_subscriptions - // Since we've separated subscription and handler we can choose to return - // one of those fields or reconstruct the subscription as a tuple and - // return that. - // Depends on on how the output of this function is meant to be used fn matching_handlers<'a>( &'a self, pattern: &'a Ustr, @@ -266,37 +253,34 @@ impl MessageBus { }) } - // TODO: Need to improve the efficiency of this - pub fn get_matching_handlers<'a>(&'a mut self, pattern: &'a Ustr) -> Vec<&'a Subscription> { - let matching_handlers = || { - self.subscriptions - .iter() - .filter_map(|(sub, _)| { - if is_matching(&sub.topic, pattern) { - Some(sub) - } else { - None - } - }) - .collect::>() - }; - - matching_handlers() - - // self.patterns - // .entry(*pattern) - // .or_insert_with(matching_handlers) - } - - pub fn publish(&mut self, pattern: Ustr, _msg: &Message) { - let _handlers = self.get_matching_handlers(&pattern); + pub fn get_matching_subscriptions<'a>( + &'a mut self, + pattern: &'a Ustr, + ) -> Vec<&'a Subscription> { + let mut matching_subs = self + .subscriptions + .iter() + .filter_map(|(sub, _)| { + if is_matching(&sub.topic, pattern) { + Some(sub) + } else { + None + } + }) + .collect::>(); + + for (p, subs) in &self.patterns { + if is_matching(p, pattern) { + matching_subs.extend(subs.iter()); + } + } - // call matched handlers - // handlers.iter().for_each(|handler| handler(msg)); + matching_subs.sort(); + matching_subs } } -/// match a topic and a string pattern +/// Match a topic and a string pattern /// pattern can contains - /// '*' - match 0 or more characters after this /// '?' - match any character once @@ -334,6 +318,7 @@ pub fn is_matching(topic: &Ustr, pattern: &Ustr) -> bool { mod tests { use std::rc::Rc; + use nautilus_core::message::Message; use rstest::*; use super::*; @@ -376,13 +361,13 @@ mod tests { #[rstest] fn test_regsiter_endpoint() { let mut msgbus = stub_msgbus(); - let endpoint = "MyEndpoint".to_string(); + let endpoint = "MyEndpoint"; let callback = stub_rust_callback(); let handler_id = Ustr::from("1"); let handler = MessageHandler::new(handler_id, None, Some(callback)); - msgbus.register(endpoint.clone(), handler.clone()); + msgbus.register(&endpoint, handler); assert_eq!(msgbus.endpoints(), vec!["MyEndpoint".to_string()]); assert!(msgbus.get_endpoint(&Ustr::from(&endpoint)).is_some()); @@ -391,13 +376,13 @@ mod tests { #[rstest] fn test_deregsiter_endpoint() { let mut msgbus = stub_msgbus(); - let endpoint = "MyEndpoint".to_string(); + let endpoint = "MyEndpoint"; let callback = stub_rust_callback(); let handler_id = Ustr::from("1"); let handler = MessageHandler::new(handler_id, None, Some(callback)); - msgbus.register(endpoint.clone(), handler.clone()); + msgbus.register(&endpoint, handler); msgbus.deregister(&endpoint); assert!(msgbus.endpoints().is_empty()); @@ -406,13 +391,13 @@ mod tests { #[rstest] fn test_subscribe() { let mut msgbus = stub_msgbus(); - let topic = "my-topic".to_string(); + let topic = "my-topic"; let callback = stub_rust_callback(); let handler_id = Ustr::from("1"); let handler = MessageHandler::new(handler_id, None, Some(callback)); - msgbus.subscribe(&topic, handler.clone(), Some(1)); + msgbus.subscribe(&topic, handler, Some(1)); assert!(msgbus.has_subscribers(&topic)); assert_eq!(msgbus.topics(), vec![topic]); @@ -421,14 +406,14 @@ mod tests { #[rstest] fn test_unsubscribe() { let mut msgbus = stub_msgbus(); - let topic = "my-topic".to_string(); + let topic = "my-topic"; let callback = stub_rust_callback(); let handler_id = Ustr::from("1"); let handler = MessageHandler::new(handler_id, None, Some(callback)); msgbus.subscribe(&topic, handler.clone(), None); - msgbus.unsubscribe(&topic, handler.clone()); + msgbus.unsubscribe(&topic, handler); assert!(!msgbus.has_subscribers(&topic)); assert!(msgbus.topics().is_empty()); diff --git a/nautilus_trader/core/includes/common.h b/nautilus_trader/core/includes/common.h index 149e4a731a77..e840d4fd0b73 100644 --- a/nautilus_trader/core/includes/common.h +++ b/nautilus_trader/core/includes/common.h @@ -536,15 +536,7 @@ const PyObject *msgbus_get_endpoint(struct MessageBus_API bus, const char *endpo * * - Assumes `pattern_ptr` is a valid C string pointer. */ -CVec msgbus_get_matching_handlers(struct MessageBus_API bus, const char *pattern_ptr); - -/** - * # Safety - * - * - Assumes any registered handler has a Python callable. - * - Assumes `endpoint_ptr` is a valid C string pointer. - */ -void vec_msgbus_handlers_drop(CVec v); +CVec msgbus_get_matching_callables(struct MessageBus_API bus, const char *pattern_ptr); void vec_pycallable_drop(CVec v); diff --git a/nautilus_trader/core/rust/common.pxd b/nautilus_trader/core/rust/common.pxd index 7b98642e735d..b7226a108730 100644 --- a/nautilus_trader/core/rust/common.pxd +++ b/nautilus_trader/core/rust/common.pxd @@ -381,13 +381,7 @@ cdef extern from "../includes/common.h": # # Safety # # - Assumes `pattern_ptr` is a valid C string pointer. - CVec msgbus_get_matching_handlers(MessageBus_API bus, const char *pattern_ptr); - - # # Safety - # - # - Assumes any registered handler has a Python callable. - # - Assumes `endpoint_ptr` is a valid C string pointer. - void vec_msgbus_handlers_drop(CVec v); + CVec msgbus_get_matching_callables(MessageBus_API bus, const char *pattern_ptr); void vec_pycallable_drop(CVec v);