Skip to content

Commit

Permalink
Refine CacheDatabaseAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Dec 3, 2023
1 parent 1ff2784 commit 8d62d07
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 56 deletions.
92 changes: 58 additions & 34 deletions nautilus_core/common/src/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,14 @@ use std::{collections::HashMap, sync::mpsc::Receiver, time::Duration};
use nautilus_core::{time::duration_since_unix_epoch, uuid::UUID4};
use nautilus_model::identifiers::trader_id::TraderId;
use redis::*;
use serde_json::Value;
use serde_json::{json, Value};

use crate::msgbus::BusMessage;

const DELIMITER: char = ':';
const XTRIM: &str = "XTRIM";
const MINID: &str = "MINID";

pub fn get_redis_url(config: &HashMap<String, Value>) -> String {
let empty = Value::Object(serde_json::Map::new());
let database = config.get("database").unwrap_or(&empty);

let host = database
.get("host")
.map(|v| v.as_str().unwrap_or("127.0.0.1"));
let port = database.get("port").map(|v| v.as_str().unwrap_or("6379"));
let username = database
.get("username")
.map(|v| v.as_str().unwrap_or_default());
let password = database
.get("password")
.map(|v| v.as_str().unwrap_or_default());
let use_ssl = database.get("ssl").unwrap_or(&Value::Bool(false));

format!(
"redis{}://{}:{}@{}:{}",
if use_ssl.as_bool().unwrap_or(false) {
"s"
} else {
""
},
username.unwrap_or(""),
password.unwrap_or(""),
host.unwrap(),
port.unwrap(),
)
}

pub fn handle_messages_with_redis(
rx: Receiver<BusMessage>,
trader_id: TraderId,
Expand All @@ -64,7 +34,7 @@ pub fn handle_messages_with_redis(
) {
let redis_url = get_redis_url(&config);
let client = redis::Client::open(redis_url).unwrap();
let stream_name = get_stream_name(&config, trader_id, instance_id);
let stream_name = get_stream_name(trader_id, instance_id, &config);
let autotrim_mins = config
.get("autotrim_mins")
.and_then(|v| v.as_u64())
Expand Down Expand Up @@ -109,10 +79,40 @@ pub fn handle_messages_with_redis(
}
}

pub fn get_redis_url(config: &HashMap<String, Value>) -> String {
let empty = Value::Object(serde_json::Map::new());
let database = config.get("database").unwrap_or(&empty);

let host = database
.get("host")
.map(|v| v.as_str().unwrap_or("127.0.0.1"));
let port = database.get("port").map(|v| v.as_str().unwrap_or("6379"));
let username = database
.get("username")
.map(|v| v.as_str().unwrap_or_default());
let password = database
.get("password")
.map(|v| v.as_str().unwrap_or_default());
let use_ssl = database.get("ssl").unwrap_or(&json!(false));

format!(
"redis{}://{}:{}@{}:{}",
if use_ssl.as_bool().unwrap_or(false) {
"s"
} else {
""
},
username.unwrap_or(""),
password.unwrap_or(""),
host.unwrap(),
port.unwrap(),
)
}

fn get_stream_name(
config: &HashMap<String, Value>,
trader_id: TraderId,
instance_id: UUID4,
config: &HashMap<String, Value>,
) -> String {
let mut stream_name = String::new();

Expand All @@ -126,10 +126,34 @@ fn get_stream_name(
stream_name.push_str(trader_id.value.as_str());
stream_name.push(DELIMITER);

if let Some(Value::Bool(true)) = config.get("use_instance_id") {
if let Some(json!(true)) = config.get("use_instance_id") {
stream_name.push_str(&format!("{instance_id}"));
stream_name.push(DELIMITER);
}

stream_name
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use rstest::rstest;
use serde_json::json;

use super::*;

#[rstest]
fn test_get_stream_name_with_stream_prefix_and_instance_id() {
let trader_id = TraderId::from("tester-123");
let instance_id = UUID4::new();
let mut config = HashMap::new();
config.insert("stream".to_string(), json!("quoters"));
config.insert("use_instance_id".to_string(), json!(true));

let key = get_stream_name(trader_id, instance_id, &config);
let expected_suffix = format!("{instance_id}:");
assert!(key.starts_with("quoters:tester-123:"));
assert!(key.ends_with(&expected_suffix));
}
}
8 changes: 8 additions & 0 deletions nautilus_core/infrastructure/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@ use nautilus_core::uuid::UUID4;
use nautilus_model::identifiers::trader_id::TraderId;
use serde_json::Value;

/// A type of database operation.
#[derive(Clone, Debug)]
pub enum DatabaseOperation {
Insert,
Update,
Delete,
}

/// Represents a database command to be performed which may be executed 'remotely' across a thread.
#[derive(Clone, Debug)]
pub struct DatabaseCommand {
/// The database operation type.
pub op_type: DatabaseOperation,
/// The primary key for the operation.
pub key: String,
/// The data payload for the operation.
pub payload: Option<Vec<Vec<u8>>>,
}

Expand All @@ -46,6 +51,9 @@ impl DatabaseCommand {

/// Provides a generic cache database facade.
///
/// The main operations take a consistent `key` and `payload` which should provide enough
/// information to implement the cache database in many different technologies.
///
/// Delete operations may need a `payload` to target specific values.
pub trait CacheDatabase {
type DatabaseType;
Expand Down
112 changes: 90 additions & 22 deletions nautilus_core/infrastructure/src/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use nautilus_core::uuid::UUID4;
use nautilus_model::identifiers::trader_id::TraderId;
use pyo3::prelude::*;
use redis::{Commands, Connection};
use serde_json::Value;
use serde_json::{json, Value};

use crate::cache::{CacheDatabase, DatabaseCommand, DatabaseOperation};

Expand Down Expand Up @@ -86,7 +86,6 @@ impl CacheDatabase for RedisCacheDatabase {
let conn = client.get_connection().unwrap();

let (tx, rx) = channel::<DatabaseCommand>();
let _encoding = get_encoding(&config);
let trader_key = get_trader_key(trader_id, instance_id, &config);
let trader_key_clone = trader_key.clone();

Expand Down Expand Up @@ -117,7 +116,7 @@ impl CacheDatabase for RedisCacheDatabase {
}

fn read(&mut self, key: &str) -> Result<Vec<Vec<u8>>> {
let collection = get_collection_key(key);
let collection = get_collection_key(key)?;
let key = format!("{}{DELIMITER}{}", self.trader_key, key);

match collection {
Expand Down Expand Up @@ -170,23 +169,30 @@ impl CacheDatabase for RedisCacheDatabase {

// Continue to receive and handle bus messages until channel is hung up
while let Ok(msg) = rx.recv() {
let collection = get_collection_key(&msg.key);
let collection = match get_collection_key(&msg.key) {
Ok(collection) => collection,
Err(e) => {
eprintln!("{e}");
continue; // Continue to next message
}
};

let key = format!("{trader_key}{DELIMITER}{}", msg.key);

match msg.op_type {
DatabaseOperation::Insert => {
if let Err(e) = insert(&mut conn, collection, &key, msg.payload) {
eprintln!("{}", e);
eprintln!("{e}");
}
}
DatabaseOperation::Update => {
if let Err(e) = update(&mut conn, collection, &key, msg.payload) {
eprintln!("{}", e);
eprintln!("{e}");
}
}
DatabaseOperation::Delete => {
if let Err(e) = delete(&mut conn, collection, &key, msg.payload) {
eprintln!("{}", e);
eprintln!("{e}");
}
}
}
Expand Down Expand Up @@ -296,7 +302,7 @@ fn insert_set(conn: &mut Connection, key: &str, value: &Vec<u8>) -> Result<()> {

fn insert_hset(conn: &mut Connection, key: &str, name: &Vec<u8>, value: &Vec<u8>) -> Result<()> {
conn.hset(key, name, value)
.map_err(|e| anyhow!("Failed to sadd '{key}' in Redis: {e}"))
.map_err(|e| anyhow!("Failed to hset '{key}' in Redis: {e}"))
}

fn insert_list(conn: &mut Connection, key: &str, value: &Vec<u8>) -> Result<()> {
Expand Down Expand Up @@ -338,7 +344,7 @@ fn delete(
INDEX => remove_index(conn, key, value),
ACTORS => delete_string(conn, key),
STRATEGIES => delete_string(conn, key),
_ => bail!("Collection '{collection}' not recognized for `delete`"),
_ => bail!("Unsupported operation: `delete` for collection '{collection}'"),
}
}

Expand All @@ -364,7 +370,7 @@ fn remove_from_set(conn: &mut Connection, key: &str, member: &Vec<u8>) -> Result

fn delete_string(conn: &mut Connection, key: &str) -> Result<()> {
conn.del(key)
.map_err(|e| anyhow!("Failed to delete '{key}' in Redis: {e}"))
.map_err(|e| anyhow!("Failed to del '{key}' in Redis: {e}"))
}

fn get_redis_url(config: &HashMap<String, Value>) -> String {
Expand All @@ -378,7 +384,7 @@ fn get_redis_url(config: &HashMap<String, Value>) -> String {
let password = config
.get("password")
.map(|v| v.as_str().unwrap_or_default());
let use_ssl = config.get("ssl").unwrap_or(&Value::Bool(false));
let use_ssl = config.get("ssl").unwrap_or(&json!(false));

format!(
"redis{}://{}:{}@{}:{}",
Expand All @@ -401,32 +407,34 @@ fn get_trader_key(
) -> String {
let mut key = String::new();

if let Some(Value::Bool(true)) = config.get("use_trader_prefix") {
if let Some(json!(true)) = config.get("use_trader_prefix") {
key.push_str("trader-");
}

key.push_str(trader_id.value.as_str());

if let Some(Value::Bool(true)) = config.get("use_instance_id") {
if let Some(json!(true)) = config.get("use_instance_id") {
key.push(DELIMITER);
key.push_str(&format!("{instance_id}"));
}

key
}

fn get_collection_key(key: &str) -> &str {
fn get_collection_key(key: &str) -> Result<&str> {
key.split_once(DELIMITER)
.unwrap_or_else(|| panic!("Invalid `key` '{}'", key))
.0
.map(|(collection, _)| collection)
.ok_or_else(|| anyhow!("Invalid `key`, missing a '{DELIMITER}' delimiter, was {key}"))
}

fn get_index_key(key: &str) -> Result<&str> {
key.split_once(':')
key.split_once(DELIMITER)
.map(|(_, index_key)| index_key)
.ok_or_else(|| anyhow!("Invalid key, missing ':' delimiter, was {}", key))
.ok_or_else(|| anyhow!("Invalid `key`, missing a '{DELIMITER}' delimiter, was {key}"))
}

// This function can be used when we handle cache serialization in Rust
#[allow(dead_code)]
fn get_encoding(config: &HashMap<String, Value>) -> String {
config
.get("encoding")
Expand All @@ -435,13 +443,73 @@ fn get_encoding(config: &HashMap<String, Value>) -> String {
.to_string()
}

// This function can be used when we handle cache serialization in Rust
#[allow(dead_code)]
fn deserialize_payload(encoding: &str, payload: &[u8]) -> Result<HashMap<String, Value>, String> {
fn deserialize_payload(encoding: &str, payload: &[u8]) -> Result<HashMap<String, Value>> {
match encoding {
"msgpack" => rmp_serde::from_slice(payload)
.map_err(|e| format!("Failed to deserialize msgpack `payload`: {e}")),
.map_err(|e| anyhow!("Failed to deserialize msgpack `payload`: {e}")),
"json" => serde_json::from_slice(payload)
.map_err(|e| format!("Failed to deserialize json `payload`: {e}")),
_ => Err(format!("Unsupported encoding: {encoding}")),
.map_err(|e| anyhow!("Failed to deserialize json `payload`: {e}")),
_ => Err(anyhow!("Unsupported encoding: {encoding}")),
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use rstest::rstest;
use serde_json::json;

use super::*;

#[rstest]
fn test_get_redis_url() {
let mut config = HashMap::new();
config.insert("host".to_string(), json!("localhost"));
config.insert("port".to_string(), json!("1234"));
config.insert("username".to_string(), json!("user"));
config.insert("password".to_string(), json!("pass"));
config.insert("ssl".to_string(), json!(true));

assert_eq!(get_redis_url(&config), "rediss://user:pass@localhost:1234");
}

#[rstest]
fn test_get_trader_key_with_prefix_and_instance_id() {
let trader_id = TraderId::from("tester-123");
let instance_id = UUID4::new();
let mut config = HashMap::new();
config.insert("use_trader_prefix".to_string(), json!(true));
config.insert("use_instance_id".to_string(), json!(true));

let key = get_trader_key(trader_id, instance_id, &config);
assert!(key.starts_with("trader-tester-123:"));
assert!(key.ends_with(&instance_id.to_string()));
}

#[rstest]
fn test_get_collection_key_valid() {
let key = "collection:123";
assert_eq!(get_collection_key(key).unwrap(), "collection");
}

#[rstest]
fn test_get_collection_key_invalid() {
let key = "no_delimiter";
assert!(get_collection_key(key).is_err());
}

#[rstest]
fn test_get_index_key_valid() {
let key = "index:123";
assert_eq!(get_index_key(key).unwrap(), "123");
}

#[rstest]
fn test_get_index_key_invalid() {
let key = "no_delimiter";
assert!(get_index_key(key).is_err());
}
}

0 comments on commit 8d62d07

Please sign in to comment.