Skip to content

Commit

Permalink
Refine core MessageBus subscription matching
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Nov 13, 2023
1 parent de85a33 commit cb4d4e0
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 101 deletions.
36 changes: 4 additions & 32 deletions nautilus_core/common/src/ffi/msgbus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ pub unsafe extern "C" fn msgbus_subscribe(
ptr: py_callable_ptr,
};
let handler = MessageHandler::new(handler_id, Some(py_callable), None);

bus.subscribe(&topic, handler, Some(priority));
}

Expand All @@ -131,6 +132,7 @@ pub unsafe extern "C" fn msgbus_get_endpoint(
endpoint_ptr: *const c_char,
) -> *const ffi::PyObject {
let endpoint = cstr_to_ustr(endpoint_ptr);

match bus.get_endpoint(&endpoint) {
Some(handler) => handler.py_callback.unwrap().ptr,
None => ffi::Py_None(),
Expand All @@ -141,49 +143,19 @@ pub unsafe extern "C" fn msgbus_get_endpoint(
///
/// - Assumes `pattern_ptr` is a valid C string pointer.
#[no_mangle]
pub unsafe extern "C" fn msgbus_get_matching_handlers(
pub unsafe extern "C" fn msgbus_get_matching_callables(
mut bus: MessageBus_API,
pattern_ptr: *const c_char,
) -> CVec {
let pattern = cstr_to_ustr(pattern_ptr);
let subs: Vec<&Subscription> = bus.get_matching_handlers(&pattern);
let subs: Vec<&Subscription> = bus.get_matching_subscriptions(&pattern);

subs.iter()
.map(|s| s.handler.py_callback.unwrap())
.collect::<Vec<PyCallableWrapper>>()
.into()
}

/// # Safety
///
/// - Assumes any registered handler has a Python callable.
/// - Assumes `endpoint_ptr` is a valid C string pointer.
// pub unsafe extern "C" fn msgbus_send(
// bus: MessageBus_API,
// endpoint_ptr: *const c_char,
// msg: *mut ffi::PyObject,
// ) {
// let endpoint = cstr_to_ustr(endpoint_ptr);
//
// if let Some(handler) = bus.get_endpoint(&endpoint) {
// let callable_ptr = handler.py_callback.unwrap().ptr;
// Python::with_gil(|py| {
// let callable = PyObject::from_borrowed_ptr(py, callable_ptr);
// let msg = PyObject::from_borrowed_ptr(py, msg);
// callable.call1(py, msg.into_py(py));
// });
// }
// }

#[allow(clippy::drop_non_drop)]
#[no_mangle]
pub extern "C" fn vec_msgbus_handlers_drop(v: CVec) {
let CVec { ptr, len, cap } = v;
let data: Vec<ffi::PyObject> =
unsafe { Vec::from_raw_parts(ptr.cast::<ffi::PyObject>(), len, cap) };
drop(data); // Memory freed here
}

#[allow(clippy::drop_non_drop)]
#[no_mangle]
pub extern "C" fn vec_pycallable_drop(v: CVec) {
Expand Down
2 changes: 1 addition & 1 deletion nautilus_core/common/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl MessageHandler {
}
}

pub fn as_ptr(self) -> *const ffi::PyObject {
pub fn as_ptr(self) -> *mut ffi::PyObject {
// SAFETY: Will panic if `unwrap` is called on None
self.py_callback.unwrap().ptr
}
Expand Down
89 changes: 37 additions & 52 deletions nautilus_core/common/src/msgbus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::{
hash::{Hash, Hasher},
};

use nautilus_core::{message::Message, uuid::UUID4};
use nautilus_core::uuid::UUID4;
use nautilus_model::identifiers::trader_id::TraderId;
use ustr::Ustr;

Expand Down Expand Up @@ -145,9 +145,9 @@ impl MessageBus {
}

/// Registers the given `handler` for the `endpoint` address.
pub fn register(&mut self, endpoint: String, handler: MessageHandler) {
// updates value if key already exists
self.endpoints.insert(Ustr::from(&endpoint), handler);
pub fn register(&mut self, endpoint: &str, handler: MessageHandler) {
// Updates value if key already exists
self.endpoints.insert(Ustr::from(endpoint), handler);
}

/// Deregisters the given `handler` for the `endpoint` address.
Expand Down Expand Up @@ -199,14 +199,6 @@ impl MessageBus {
.is_some()
}

// fn send(&self, endpoint: &Ustr, msg: &Message) {
// if let Some(handler) = self.endpoints.get(endpoint) {
// if let Some(py_callable) = handler.py_callback {
// Python::with_gil(|| msg)
// }
// }
// }

// #[allow(unused_variables)]
// fn request(&mut self, endpoint: &String, request: &Message, callback: T) {
// match request {
Expand Down Expand Up @@ -248,11 +240,6 @@ impl MessageBus {
// }
// }

// TODO: This is the modified version of matching_subscriptions
// Since we've separated subscription and handler we can choose to return
// one of those fields or reconstruct the subscription as a tuple and
// return that.
// Depends on on how the output of this function is meant to be used
fn matching_handlers<'a>(
&'a self,
pattern: &'a Ustr,
Expand All @@ -266,37 +253,34 @@ impl MessageBus {
})
}

// TODO: Need to improve the efficiency of this
pub fn get_matching_handlers<'a>(&'a mut self, pattern: &'a Ustr) -> Vec<&'a Subscription> {
let matching_handlers = || {
self.subscriptions
.iter()
.filter_map(|(sub, _)| {
if is_matching(&sub.topic, pattern) {
Some(sub)
} else {
None
}
})
.collect::<Vec<&'a Subscription>>()
};

matching_handlers()

// self.patterns
// .entry(*pattern)
// .or_insert_with(matching_handlers)
}

pub fn publish(&mut self, pattern: Ustr, _msg: &Message) {
let _handlers = self.get_matching_handlers(&pattern);
pub fn get_matching_subscriptions<'a>(
&'a mut self,
pattern: &'a Ustr,
) -> Vec<&'a Subscription> {
let mut matching_subs = self
.subscriptions
.iter()
.filter_map(|(sub, _)| {
if is_matching(&sub.topic, pattern) {
Some(sub)
} else {
None
}
})
.collect::<Vec<&'a Subscription>>();

for (p, subs) in &self.patterns {
if is_matching(p, pattern) {
matching_subs.extend(subs.iter());
}
}

// call matched handlers
// handlers.iter().for_each(|handler| handler(msg));
matching_subs.sort();
matching_subs
}
}

/// match a topic and a string pattern
/// Match a topic and a string pattern
/// pattern can contains -
/// '*' - match 0 or more characters after this
/// '?' - match any character once
Expand Down Expand Up @@ -334,6 +318,7 @@ pub fn is_matching(topic: &Ustr, pattern: &Ustr) -> bool {
mod tests {
use std::rc::Rc;

use nautilus_core::message::Message;
use rstest::*;

use super::*;
Expand Down Expand Up @@ -376,13 +361,13 @@ mod tests {
#[rstest]
fn test_regsiter_endpoint() {
let mut msgbus = stub_msgbus();
let endpoint = "MyEndpoint".to_string();
let endpoint = "MyEndpoint";

let callback = stub_rust_callback();
let handler_id = Ustr::from("1");
let handler = MessageHandler::new(handler_id, None, Some(callback));

msgbus.register(endpoint.clone(), handler.clone());
msgbus.register(&endpoint, handler);

assert_eq!(msgbus.endpoints(), vec!["MyEndpoint".to_string()]);
assert!(msgbus.get_endpoint(&Ustr::from(&endpoint)).is_some());
Expand All @@ -391,13 +376,13 @@ mod tests {
#[rstest]
fn test_deregsiter_endpoint() {
let mut msgbus = stub_msgbus();
let endpoint = "MyEndpoint".to_string();
let endpoint = "MyEndpoint";

let callback = stub_rust_callback();
let handler_id = Ustr::from("1");
let handler = MessageHandler::new(handler_id, None, Some(callback));

msgbus.register(endpoint.clone(), handler.clone());
msgbus.register(&endpoint, handler);
msgbus.deregister(&endpoint);

assert!(msgbus.endpoints().is_empty());
Expand All @@ -406,13 +391,13 @@ mod tests {
#[rstest]
fn test_subscribe() {
let mut msgbus = stub_msgbus();
let topic = "my-topic".to_string();
let topic = "my-topic";

let callback = stub_rust_callback();
let handler_id = Ustr::from("1");
let handler = MessageHandler::new(handler_id, None, Some(callback));

msgbus.subscribe(&topic, handler.clone(), Some(1));
msgbus.subscribe(&topic, handler, Some(1));

assert!(msgbus.has_subscribers(&topic));
assert_eq!(msgbus.topics(), vec![topic]);
Expand All @@ -421,14 +406,14 @@ mod tests {
#[rstest]
fn test_unsubscribe() {
let mut msgbus = stub_msgbus();
let topic = "my-topic".to_string();
let topic = "my-topic";

let callback = stub_rust_callback();
let handler_id = Ustr::from("1");
let handler = MessageHandler::new(handler_id, None, Some(callback));

msgbus.subscribe(&topic, handler.clone(), None);
msgbus.unsubscribe(&topic, handler.clone());
msgbus.unsubscribe(&topic, handler);

assert!(!msgbus.has_subscribers(&topic));
assert!(msgbus.topics().is_empty());
Expand Down
10 changes: 1 addition & 9 deletions nautilus_trader/core/includes/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,15 +536,7 @@ const PyObject *msgbus_get_endpoint(struct MessageBus_API bus, const char *endpo
*
* - Assumes `pattern_ptr` is a valid C string pointer.
*/
CVec msgbus_get_matching_handlers(struct MessageBus_API bus, const char *pattern_ptr);

/**
* # Safety
*
* - Assumes any registered handler has a Python callable.
* - Assumes `endpoint_ptr` is a valid C string pointer.
*/
void vec_msgbus_handlers_drop(CVec v);
CVec msgbus_get_matching_callables(struct MessageBus_API bus, const char *pattern_ptr);

void vec_pycallable_drop(CVec v);

Expand Down
8 changes: 1 addition & 7 deletions nautilus_trader/core/rust/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,7 @@ cdef extern from "../includes/common.h":
# # Safety
#
# - Assumes `pattern_ptr` is a valid C string pointer.
CVec msgbus_get_matching_handlers(MessageBus_API bus, const char *pattern_ptr);

# # Safety
#
# - Assumes any registered handler has a Python callable.
# - Assumes `endpoint_ptr` is a valid C string pointer.
void vec_msgbus_handlers_drop(CVec v);
CVec msgbus_get_matching_callables(MessageBus_API bus, const char *pattern_ptr);

void vec_pycallable_drop(CVec v);

Expand Down

0 comments on commit cb4d4e0

Please sign in to comment.