Skip to content

Commit

Permalink
Implement MessageBus buffered pipelining
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Dec 5, 2023
1 parent 0af0770 commit 1200460
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 102 deletions.
2 changes: 2 additions & 0 deletions examples/live/binance_spot_market_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@
),
# cache_database=CacheDatabaseConfig(
# type="redis",
# # encoding="json",
# buffer_interval_ms=100,
# ),
# message_bus=MessageBusConfig(
# database=DatabaseConfig(),
# encoding="json",
# buffer_interval_ms=100,
# stream="quoters",
# use_instance_id=False,
# timestamps_as_iso8601=True,
Expand Down
110 changes: 96 additions & 14 deletions nautilus_core/common/src/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
// limitations under the License.
// -------------------------------------------------------------------------------------------------

use std::{collections::HashMap, sync::mpsc::Receiver, time::Duration};
use std::{
collections::{HashMap, VecDeque},
sync::mpsc::{Receiver, TryRecvError},
thread,
time::{Duration, Instant},
};

use nautilus_core::{time::duration_since_unix_epoch, uuid::UUID4};
use nautilus_model::identifiers::trader_id::TraderId;
Expand All @@ -34,27 +39,76 @@ pub fn handle_messages_with_redis(
) {
let redis_url = get_redis_url(&config);
let client = redis::Client::open(redis_url).unwrap();
let mut conn = client.get_connection().unwrap();
let stream_name = get_stream_name(trader_id, instance_id, &config);

// Autotrimming
let autotrim_mins = config
.get("autotrim_mins")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let autotrim_duration = Duration::from_secs(autotrim_mins as u64 * 60);
let autotrim_duration = if autotrim_mins > 0 {
Some(Duration::from_secs(autotrim_mins as u64 * 60))
} else {
None
};
let mut last_trim_index: HashMap<String, usize> = HashMap::new();
let mut conn = client.get_connection().unwrap();

// Continue to receive and handle bus messages until channel is hung up
while let Ok(msg) = rx.recv() {
// Buffering
let mut buffer: VecDeque<BusMessage> = VecDeque::new();
let mut last_drain = Instant::now();
let recv_interval = Duration::from_millis(1);
let buffer_interval = get_buffer_interval(&config);

loop {
if last_drain.elapsed() >= buffer_interval && !buffer.is_empty() {
drain_buffer(
&mut conn,
&stream_name,
autotrim_duration,
&mut last_trim_index,
&mut buffer,
);
last_drain = Instant::now();
} else {
// Continue to receive and handle messages until channel is hung up
match rx.try_recv() {
Ok(msg) => buffer.push_back(msg),
Err(TryRecvError::Empty) => thread::sleep(recv_interval),
Err(TryRecvError::Disconnected) => break, // Channel hung up
}
}
}

// Drain any remaining messages
if !buffer.is_empty() {
drain_buffer(
&mut conn,
&stream_name,
autotrim_duration,
&mut last_trim_index,
&mut buffer,
);
}
}

fn drain_buffer(
conn: &mut Connection,
stream_name: &str,
autotrim_duration: Option<Duration>,
last_trim_index: &mut HashMap<String, usize>,
buffer: &mut VecDeque<BusMessage>,
) {
let mut pipe = redis::pipe();
pipe.atomic();

for msg in buffer.drain(..) {
let key = format!("{stream_name}{}", &msg.topic);
let items: Vec<(&str, &Vec<u8>)> = vec![("payload", &msg.payload)];
let result: Result<(), redis::RedisError> = conn.xadd(&key, "*", &items);
pipe.xadd(&key, "*", &items);

if let Err(e) = result {
eprintln!("Error publishing message: {e}");
}

if autotrim_mins == 0 {
return; // Nothing else to do
if autotrim_duration.is_none() {
continue; // Nothing else to do
}

// Autotrim stream
Expand All @@ -63,12 +117,13 @@ pub fn handle_messages_with_redis(

// Improve efficiency of this by batching
if *last_trim_ms < (unix_duration_now - Duration::from_secs(60)).as_millis() as usize {
let min_timestamp_ms = (unix_duration_now - autotrim_duration).as_millis() as usize;
let min_timestamp_ms =
(unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
let result: Result<(), redis::RedisError> = redis::cmd(XTRIM)
.arg(&key)
.arg(MINID)
.arg(min_timestamp_ms)
.query(&mut conn);
.query(conn);

if let Err(e) = result {
eprintln!("Error trimming stream '{key}': {e}");
Expand All @@ -77,6 +132,10 @@ pub fn handle_messages_with_redis(
}
}
}

if let Err(e) = pipe.query::<()>(conn) {
eprintln!("{e}");
}
}

pub fn get_redis_url(config: &HashMap<String, Value>) -> String {
Expand Down Expand Up @@ -109,6 +168,13 @@ pub fn get_redis_url(config: &HashMap<String, Value>) -> String {
)
}

pub fn get_buffer_interval(config: &HashMap<String, Value>) -> Duration {
let buffer_interval_ms = config
.get("buffer_interval_ms")
.map(|v| v.as_u64().unwrap_or(0));
Duration::from_millis(buffer_interval_ms.unwrap_or(0))
}

fn get_stream_name(
trader_id: TraderId,
instance_id: UUID4,
Expand Down Expand Up @@ -156,4 +222,20 @@ mod tests {
assert!(key.starts_with("quoters:tester-123:"));
assert!(key.ends_with(&expected_suffix));
}

#[rstest]
fn test_get_buffer_interval_default() {
let config = HashMap::new();
let buffer_interval = get_buffer_interval(&config);
assert_eq!(buffer_interval, Duration::from_millis(0));
}

#[rstest]
fn test_get_buffer_interval() {
let mut config = HashMap::new();
config.insert("buffer_interval_ms".to_string(), json!(100));

let buffer_interval = get_buffer_interval(&config);
assert_eq!(buffer_interval, Duration::from_millis(100));
}
}
33 changes: 8 additions & 25 deletions nautilus_core/infrastructure/src/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::{
};

use anyhow::{anyhow, bail, Result};
use nautilus_common::redis::get_buffer_interval;
use nautilus_core::uuid::UUID4;
use nautilus_model::identifiers::trader_id::TraderId;
use pyo3::prelude::*;
Expand Down Expand Up @@ -168,7 +169,7 @@ impl CacheDatabase for RedisCacheDatabase {
let client = redis::Client::open(redis_url).unwrap();
let mut conn = client.get_connection().unwrap();

// Buffering machinery
// Buffering
let mut buffer: VecDeque<DatabaseCommand> = VecDeque::new();
let mut last_drain = Instant::now();
let recv_interval = Duration::from_millis(1);
Expand All @@ -183,10 +184,15 @@ impl CacheDatabase for RedisCacheDatabase {
match rx.try_recv() {
Ok(msg) => buffer.push_back(msg),
Err(TryRecvError::Empty) => thread::sleep(recv_interval),
Err(TryRecvError::Disconnected) => return, // Channel hung up
Err(TryRecvError::Disconnected) => break, // Channel hung up
}
}
}

// Drain any remaining messages
if !buffer.is_empty() {
drain_buffer(&mut conn, &trader_key, &mut buffer);
}
}
}

Expand Down Expand Up @@ -521,13 +527,6 @@ fn get_redis_url(config: &HashMap<String, Value>) -> String {
)
}

fn get_buffer_interval(config: &HashMap<String, Value>) -> Duration {
let buffer_interval_ms = config
.get("buffer_interval_ms")
.map(|v| v.as_u64().unwrap_or(0));
Duration::from_millis(buffer_interval_ms.unwrap_or(0))
}

fn get_trader_key(
trader_id: TraderId,
instance_id: UUID4,
Expand Down Expand Up @@ -617,22 +616,6 @@ mod tests {
assert!(key.ends_with(&instance_id.to_string()));
}

#[rstest]
fn test_get_buffer_interval_default() {
let config = HashMap::new();
let buffer_interval = get_buffer_interval(&config);
assert_eq!(buffer_interval, Duration::from_millis(0));
}

#[rstest]
fn test_get_buffer_interval() {
let mut config = HashMap::new();
config.insert("buffer_interval_ms".to_string(), json!(100));

let buffer_interval = get_buffer_interval(&config);
assert_eq!(buffer_interval, Duration::from_millis(100));
}

#[rstest]
fn test_get_collection_key_valid() {
let key = "collection:123";
Expand Down
4 changes: 2 additions & 2 deletions nautilus_trader/adapters/binance/futures/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _parse_instrument(
taker_fee=taker_fee,
ts_event=ts_event,
ts_init=ts_init,
info=self._decoder.decode(self._encoder.encode(symbol_info)),
info=msgspec.structs.asdict(symbol_info),
)
self.add_currency(currency=instrument.base_currency)
elif contract_type in (
Expand Down Expand Up @@ -358,7 +358,7 @@ def _parse_instrument(
taker_fee=taker_fee,
ts_event=ts_event,
ts_init=ts_init,
info=self._decoder.decode(self._encoder.encode(symbol_info)),
info=msgspec.structs.asdict(symbol_info),
)
self.add_currency(currency=instrument.underlying)
else:
Expand Down
2 changes: 1 addition & 1 deletion nautilus_trader/adapters/binance/spot/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _parse_instrument(
taker_fee=taker_fee,
ts_event=min(ts_event, ts_init),
ts_init=ts_init,
info=self._decoder.decode(self._encoder.encode(symbol_info)),
info=msgspec.structs.asdict(symbol_info),
)
self.add_currency(currency=instrument.base_currency)
self.add_currency(currency=instrument.quote_currency)
Expand Down
14 changes: 12 additions & 2 deletions nautilus_trader/cache/database.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,22 @@ cdef class CacheDatabaseAdapter(CacheDatabaseFacade):
Condition.type(config, CacheDatabaseConfig, "config")
super().__init__(logger, config)

if config.buffer_interval_ms and config.buffer_interval_ms >= 1000:
# Validate configuration
if config.buffer_interval_ms and config.buffer_interval_ms > 1000:
self._log.warning(
f"High `buffer_interval_ms` at {config.buffer_interval_ms}, "
"recommended range is [10, 100] milliseconds.",
"recommended range is [10, 1000] milliseconds.",
)

# Configuration
self._log.info(f"{config.type=}", LogColor.BLUE)
self._log.info(f"{config.encoding=}", LogColor.BLUE)
self._log.info(f"{config.buffer_interval_ms=}", LogColor.BLUE)
self._log.info(f"{config.flush_on_start=}", LogColor.BLUE)
self._log.info(f"{config.use_trader_prefix=}", LogColor.BLUE)
self._log.info(f"{config.use_instance_id=}", LogColor.BLUE)
self._log.info(f"{config.timestamps_as_iso8601=}", LogColor.BLUE)

# Database keys
self._key_trader = f"{_TRADER}-{trader_id}" # noqa
self._key_general = f"{self._key_trader}:{_GENERAL}:" # noqa
Expand Down
24 changes: 16 additions & 8 deletions nautilus_trader/common/component.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -730,12 +730,6 @@ cdef class MessageBus:
config = MessageBusConfig()
Condition.type(config, MessageBusConfig, "config")

if (snapshot_orders or snapshot_positions) and not config.stream:
raise InvalidConfiguration(
"Invalid `MessageBusConfig`: Cannot configure snapshots without providing a `stream` name. "
"This is because currently the message bus will write to the same snapshot keys as the cache.",
)

self.trader_id = trader_id
self.serializer = serializer
self.has_backing = config.database is not None
Expand All @@ -745,14 +739,28 @@ cdef class MessageBus:
self._clock = clock
self._log = LoggerAdapter(component_name=name, logger=logger)

# Validate configuration
if config.buffer_interval_ms and config.buffer_interval_ms > 1000:
self._log.warning(
f"High `buffer_interval_ms` at {config.buffer_interval_ms}, "
"recommended range is [10, 1000] milliseconds.",
)

if (snapshot_orders or snapshot_positions) and not config.stream:
raise InvalidConfiguration(
"Invalid `MessageBusConfig`: Cannot configure snapshots without providing a `stream` name. "
"This is because currently the message bus will write to the same snapshot keys as the cache.",
)

# Configuration
self._log.info(f"{config.database=}", LogColor.BLUE)
self._log.info(f"{config.encoding=}", LogColor.BLUE)
self._log.info(f"{config.buffer_interval_ms=}", LogColor.BLUE)
self._log.info(f"{config.autotrim_mins=}", LogColor.BLUE)
self._log.info(f"{config.stream=}", LogColor.BLUE)
self._log.info(f"{config.use_instance_id=}", LogColor.BLUE)
self._log.info(f"{config.encoding=}", LogColor.BLUE)
self._log.info(f"{config.timestamps_as_iso8601=}", LogColor.BLUE)
self._log.info(f"{config.types_filter=}", LogColor.BLUE)
self._log.info(f"{config.autotrim_mins=}", LogColor.BLUE)

# Copy and clear `types_filter` before passing down to the core MessageBus
cdef list types_filter = copy.copy(config.types_filter)
Expand Down
Loading

0 comments on commit 1200460

Please sign in to comment.