diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5e53f546cadd..8c0dae0bb979 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,6 +22,25 @@ jobs: env: BUILD_MODE: debug RUST_BACKTRACE: 1 + services: + redis: + image: redis + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + postgres: + image: postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: pass + POSTGRES_DB: nautilus + ports: + - 5432:5432 + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - name: Free disk space (Ubuntu) @@ -105,10 +124,16 @@ jobs: # pre-commit run --hook-stage manual gitlint-ci pre-commit run --all-files - - name: Install Redis (Linux) + - name: Install Nautilus CLI and run init postgres run: | - sudo apt-get install redis-server - redis-server --daemonize yes + make install-cli + nautilus database init --schema ${{ github.workspace }}/schema + env: + POSTGRES_HOST: localhost + POSTGRES_PORT: 5432 + POSTGRES_USERNAME: postgres + POSTGRES_PASSWORD: pass + POSTGRES_DATABASE: nautilus - name: Run nautilus_core cargo tests (Linux) run: | @@ -224,6 +249,25 @@ jobs: env: BUILD_MODE: debug RUST_BACKTRACE: 1 + services: + redis: + image: redis + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + postgres: + image: postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: pass + POSTGRES_DB: nautilus + ports: + - 5432:5432 + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - name: Checkout repository @@ -290,10 +334,16 @@ jobs: # pre-commit run --hook-stage manual gitlint-ci pre-commit run --all-files - - name: Install Redis (macOS) + - name: Install Nautilus CLI and run init postgres run: | - brew install redis - redis-server --daemonize yes + make install-cli + nautilus database init --schema ${{ github.workspace }}/schema + env: + POSTGRES_HOST: localhost + POSTGRES_PORT: 5432 + POSTGRES_USERNAME: postgres + POSTGRES_PASSWORD: pass + POSTGRES_DATABASE: nautilus - name: Run nautilus_core cargo tests (macOS) run: | diff --git a/nautilus_core/cli/Cargo.toml b/nautilus_core/cli/Cargo.toml index 941f013ea46e..163f78e96bb8 100644 --- a/nautilus_core/cli/Cargo.toml +++ b/nautilus_core/cli/Cargo.toml @@ -14,7 +14,7 @@ path = "src/bin/cli.rs" nautilus-common = { path = "../common"} nautilus-model = { path = "../model" } nautilus-core = { path = "../core" } -nautilus-infrastructure = { path = "../infrastructure" , features = ['sql']} +nautilus-infrastructure = { path = "../infrastructure" , features = ['postgres']} anyhow = { workspace = true } tokio = {workspace = true} log = { workspace = true } diff --git a/nautilus_core/cli/src/database/postgres.rs b/nautilus_core/cli/src/database/postgres.rs index 5536276b5499..39f374881df0 100644 --- a/nautilus_core/cli/src/database/postgres.rs +++ b/nautilus_core/cli/src/database/postgres.rs @@ -13,182 +13,12 @@ // limitations under the License. // ------------------------------------------------------------------------------------------------- -use log::{error, info}; -use nautilus_infrastructure::sql::pg::{connect_pg, get_postgres_connect_options}; -use sqlx::PgPool; +use nautilus_infrastructure::sql::pg::{ + connect_pg, drop_postgres, get_postgres_connect_options, init_postgres, +}; use crate::opt::{DatabaseCommand, DatabaseOpt}; -/// Scans current path with keyword `nautilus_trader` and build schema dir -fn get_schema_dir() -> anyhow::Result { - std::env::var("SCHEMA_DIR").or_else(|_| { - let nautilus_git_repo_name = "nautilus_trader"; - let binding = std::env::current_dir().unwrap(); - let current_dir = binding.to_str().unwrap(); - match current_dir.find(nautilus_git_repo_name){ - Some(index) => { - let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema"; - Ok(schema_path) - } - None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable") - } - }) -} - -pub async fn init_postgres(pg: &PgPool, database: String, password: String) -> anyhow::Result<()> { - info!("Initializing Postgres database with target permissions and schema"); - // create public schema - match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;") - .execute(pg) - .await - { - Ok(_) => info!("Schema public created successfully"), - Err(err) => error!("Error creating schema public: {:?}", err), - } - // create role if not exists - match sqlx::query(format!("CREATE ROLE {database} PASSWORD '{password}' LOGIN;").as_str()) - .execute(pg) - .await - { - Ok(_) => info!("Role {} created successfully", database), - Err(err) => { - if err.to_string().contains("already exists") { - info!("Role {} already exists", database); - } else { - error!("Error creating role {}: {:?}", database, err); - } - } - } - // execute all the sql files in schema dir - let schema_dir = get_schema_dir()?; - let mut sql_files = - std::fs::read_dir(schema_dir)?.collect::, std::io::Error>>()?; - for file in &mut sql_files { - let file_name = file.file_name(); - info!("Executing schema file: {:?}", file_name); - let file_path = file.path(); - let sql_content = std::fs::read_to_string(file_path.clone())?; - for sql_statement in sql_content.split(';').filter(|s| !s.trim().is_empty()) { - sqlx::query(sql_statement).execute(pg).await?; - } - } - // grant connect - match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str()) - .execute(pg) - .await - { - Ok(_) => info!("Connect privileges granted to role {}", database), - Err(err) => error!( - "Error granting connect privileges to role {}: {:?}", - database, err - ), - } - // grant all schema privileges to the role - match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str()) - .execute(pg) - .await - { - Ok(_) => info!("All schema privileges granted to role {}", database), - Err(err) => error!( - "Error granting all privileges to role {}: {:?}", - database, err - ), - } - // grant all table privileges to the role - match sqlx::query( - format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(), - ) - .execute(pg) - .await - { - Ok(_) => info!("All tables privileges granted to role {}", database), - Err(err) => error!( - "Error granting all privileges to role {}: {:?}", - database, err - ), - } - // grant all sequence privileges to the role - match sqlx::query( - format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(), - ) - .execute(pg) - .await - { - Ok(_) => info!("All sequences privileges granted to role {}", database), - Err(err) => error!( - "Error granting all privileges to role {}: {:?}", - database, err - ), - } - // grant all function privileges to the role - match sqlx::query( - format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(), - ) - .execute(pg) - .await - { - Ok(_) => info!("All functions privileges granted to role {}", database), - Err(err) => error!( - "Error granting all privileges to role {}: {:?}", - database, err - ), - } - - Ok(()) -} - -pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> { - // execute drop owned - match sqlx::query(format!("DROP OWNED BY {database}").as_str()) - .execute(pg) - .await - { - Ok(_) => info!("Dropped owned objects by role {}", database), - Err(err) => error!("Error dropping owned by role {}: {:?}", database, err), - } - // revoke connect - match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str()) - .execute(pg) - .await - { - Ok(_) => info!("Revoked connect privileges from role {}", database), - Err(err) => error!( - "Error revoking connect privileges from role {}: {:?}", - database, err - ), - } - // revoke privileges - match sqlx::query( - format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(), - ) - .execute(pg) - .await - { - Ok(_) => info!("Revoked all privileges from role {}", database), - Err(err) => error!( - "Error revoking all privileges from role {}: {:?}", - database, err - ), - } - // execute drop schema - match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE") - .execute(pg) - .await - { - Ok(_) => info!("Dropped schema public"), - Err(err) => error!("Error dropping schema public: {:?}", err), - } - // drop role - match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str()) - .execute(pg) - .await - { - Ok(_) => info!("Dropped role {}", database), - Err(err) => error!("Error dropping role {}: {:?}", database, err), - } - Ok(()) -} - pub async fn run_database_command(opt: DatabaseOpt) -> anyhow::Result<()> { let command = opt.command.clone(); @@ -207,6 +37,7 @@ pub async fn run_database_command(opt: DatabaseOpt) -> anyhow::Result<()> { &pg, pg_connect_options.database, pg_connect_options.password, + config.schema, ) .await?; } diff --git a/nautilus_core/cli/src/opt.rs b/nautilus_core/cli/src/opt.rs index 8ef5b9c2b1d2..1df35a0e78f2 100644 --- a/nautilus_core/cli/src/opt.rs +++ b/nautilus_core/cli/src/opt.rs @@ -51,6 +51,9 @@ pub struct DatabaseConfig { /// Password for connecting to the database #[arg(long)] pub password: Option, + /// Directory path to the schema files + #[arg(long)] + pub schema: Option, } #[derive(Parser, Debug, Clone)] diff --git a/nautilus_core/infrastructure/Cargo.toml b/nautilus_core/infrastructure/Cargo.toml index 95a2ad9eff91..594ca1448b37 100644 --- a/nautilus_core/infrastructure/Cargo.toml +++ b/nautilus_core/infrastructure/Cargo.toml @@ -50,4 +50,4 @@ extension-module = [ ] python = ["pyo3"] redis = ["dep:redis"] -sql = ["dep:sqlx"] +postgres = ["dep:sqlx"] diff --git a/nautilus_core/infrastructure/src/lib.rs b/nautilus_core/infrastructure/src/lib.rs index 6f87ec9dc619..e34994a90fa7 100644 --- a/nautilus_core/infrastructure/src/lib.rs +++ b/nautilus_core/infrastructure/src/lib.rs @@ -34,5 +34,5 @@ pub mod python; #[cfg(feature = "redis")] pub mod redis; -#[cfg(feature = "sql")] +#[cfg(feature = "postgres")] pub mod sql; diff --git a/nautilus_core/infrastructure/src/python/mod.rs b/nautilus_core/infrastructure/src/python/mod.rs index 6c07e5f0dbfc..6bf67c7c8340 100644 --- a/nautilus_core/infrastructure/src/python/mod.rs +++ b/nautilus_core/infrastructure/src/python/mod.rs @@ -18,6 +18,9 @@ #[cfg(feature = "redis")] pub mod redis; +#[cfg(feature = "postgres")] +pub mod sql; + use pyo3::{prelude::*, pymodule}; #[pymodule] @@ -26,5 +29,7 @@ pub fn infrastructure(_: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; #[cfg(feature = "redis")] m.add_class::()?; + #[cfg(feature = "postgres")] + m.add_class::()?; Ok(()) } diff --git a/nautilus_core/infrastructure/src/python/sql/cache_database.rs b/nautilus_core/infrastructure/src/python/sql/cache_database.rs new file mode 100644 index 000000000000..f1c7b181eb33 --- /dev/null +++ b/nautilus_core/infrastructure/src/python/sql/cache_database.rs @@ -0,0 +1,90 @@ +// ------------------------------------------------------------------------------------------------- +// Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +// https://nautechsystems.io +// +// Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------------------------------- + +use std::collections::HashMap; + +use nautilus_common::runtime::get_runtime; +use nautilus_core::python::to_pyruntime_err; +use nautilus_model::types::currency::Currency; +use pyo3::prelude::*; + +use crate::sql::{ + cache_database::PostgresCacheDatabase, pg::delete_nautilus_postgres_tables, + queries::DatabaseQueries, +}; + +#[pymethods] +impl PostgresCacheDatabase { + #[staticmethod] + #[pyo3(name = "connect")] + fn py_connect( + host: Option, + port: Option, + username: Option, + password: Option, + database: Option, + ) -> PyResult { + let result = get_runtime().block_on(async { + PostgresCacheDatabase::connect(host, port, username, password, database).await + }); + result.map_err(to_pyruntime_err) + } + + #[pyo3(name = "load")] + fn py_load(slf: PyRef<'_, Self>) -> PyResult>> { + let result = get_runtime().block_on(async { slf.load().await }); + result.map_err(to_pyruntime_err) + } + + #[pyo3(name = "load_currency")] + fn py_load_currency(slf: PyRef<'_, Self>, code: &str) -> PyResult> { + let result = + get_runtime().block_on(async { DatabaseQueries::load_currency(&slf.pool, code).await }); + result.map_err(to_pyruntime_err) + } + + #[pyo3(name = "load_currencies")] + fn py_load_currencies(slf: PyRef<'_, Self>) -> PyResult> { + let result = + get_runtime().block_on(async { DatabaseQueries::load_currencies(&slf.pool).await }); + result.map_err(to_pyruntime_err) + } + + #[pyo3(name = "add")] + fn py_add(slf: PyRef<'_, Self>, key: String, value: Vec) -> PyResult<()> { + let result = get_runtime().block_on(async { slf.add(key, value).await }); + result.map_err(to_pyruntime_err) + } + + #[pyo3(name = "add_currency")] + fn py_add_currency(slf: PyRef<'_, Self>, currency: Currency) -> PyResult<()> { + let result = get_runtime().block_on(async { slf.add_currency(currency).await }); + result.map_err(to_pyruntime_err) + } + + #[pyo3(name = "flush_db")] + fn py_drop_schema(slf: PyRef<'_, Self>) -> PyResult<()> { + let result = + get_runtime().block_on(async { delete_nautilus_postgres_tables(&slf.pool).await }); + result.map_err(to_pyruntime_err) + } + + #[pyo3(name = "truncate")] + fn py_truncate(slf: PyRef<'_, Self>, table: String) -> PyResult<()> { + let result = + get_runtime().block_on(async { DatabaseQueries::truncate(&slf.pool, table).await }); + result.map_err(to_pyruntime_err) + } +} diff --git a/nautilus_core/infrastructure/src/python/sql/mod.rs b/nautilus_core/infrastructure/src/python/sql/mod.rs new file mode 100644 index 000000000000..454f4be6bd37 --- /dev/null +++ b/nautilus_core/infrastructure/src/python/sql/mod.rs @@ -0,0 +1,16 @@ +// ------------------------------------------------------------------------------------------------- +// Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +// https://nautechsystems.io +// +// Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------------------------------- + +pub mod cache_database; diff --git a/nautilus_core/infrastructure/src/sql/cache.rs b/nautilus_core/infrastructure/src/sql/cache.rs deleted file mode 100644 index fecf1072d482..000000000000 --- a/nautilus_core/infrastructure/src/sql/cache.rs +++ /dev/null @@ -1,100 +0,0 @@ -// ------------------------------------------------------------------------------------------------- -// Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. -// https://nautechsystems.io -// -// Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); -// You may not use this file except in compliance with the License. -// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ------------------------------------------------------------------------------------------------ - -use nautilus_model::identifiers::trader_id::TraderId; -use sqlx::Error; - -use crate::sql::{database::Database, schema::GeneralItem}; - -pub struct SqlCacheDatabase { - trader_id: TraderId, - db: Database, -} - -impl SqlCacheDatabase { - #[must_use] - pub fn new(trader_id: TraderId, database: Database) -> Self { - Self { - trader_id, - db: database, - } - } - #[must_use] - pub fn key_trader(&self) -> String { - format!("trader-{}", self.trader_id) - } - - #[must_use] - pub fn key_general(&self) -> String { - format!("{}:general:", self.key_trader()) - } - - pub async fn add(&self, key: String, value: String) -> Result { - let query = format!( - "INSERT INTO general (key, value) VALUES ('{key}', '{value}') ON CONFLICT (key) DO NOTHING;" - ); - self.db.execute(query.as_str()).await - } - - pub async fn get(&self, key: String) -> Vec { - let query = format!("SELECT * FROM general WHERE key = '{key}'"); - self.db - .fetch_all::(query.as_str()) - .await - .unwrap() - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Tests -//////////////////////////////////////////////////////////////////////////////// -#[cfg(test)] -mod tests { - use nautilus_model::identifiers::stubs::trader_id; - - use super::SqlCacheDatabase; - use crate::sql::database::{init_db_schema, setup_test_database}; - - async fn setup_sql_cache_database() -> SqlCacheDatabase { - let db = setup_test_database().await; - let schema_dir = "../../schema"; - init_db_schema(&db, schema_dir) - .await - .expect("Failed to init db schema"); - let trader = trader_id(); - SqlCacheDatabase::new(trader, db) - } - - #[tokio::test] - async fn test_keys() { - let cache = setup_sql_cache_database().await; - assert_eq!(cache.key_trader(), "trader-TRADER-001"); - assert_eq!(cache.key_general(), "trader-TRADER-001:general:"); - } - - #[tokio::test] - async fn test_add_get_general() { - let cache = setup_sql_cache_database().await; - cache - .add(String::from("key1"), String::from("value1")) - .await - .expect("Failed to add key"); - let value = cache.get(String::from("key1")).await; - assert_eq!(value.len(), 1); - let item = value.first().unwrap(); - assert_eq!(item.key, "key1"); - assert_eq!(item.value, "value1"); - } -} diff --git a/nautilus_core/infrastructure/src/sql/cache_database.rs b/nautilus_core/infrastructure/src/sql/cache_database.rs new file mode 100644 index 000000000000..ed6ab08a9e48 --- /dev/null +++ b/nautilus_core/infrastructure/src/sql/cache_database.rs @@ -0,0 +1,144 @@ +// ------------------------------------------------------------------------------------------------- +// Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +// https://nautechsystems.io +// +// Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------------------------------- + +use std::{ + collections::{HashMap, VecDeque}, + time::{Duration, Instant}, +}; + +use nautilus_model::types::currency::Currency; +use sqlx::{postgres::PgConnectOptions, PgPool}; +use tokio::{ + sync::mpsc::{channel, error::TryRecvError, Receiver, Sender}, + time::sleep, +}; + +use crate::sql::{ + models::general::GeneralRow, + pg::{connect_pg, get_postgres_connect_options}, + queries::DatabaseQueries, +}; + +#[derive(Debug)] +#[cfg_attr( + feature = "python", + pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence") +)] +pub struct PostgresCacheDatabase { + pub pool: PgPool, + tx: Sender, +} + +#[derive(Debug, Clone)] +pub enum DatabaseQuery { + Add(String, Vec), + AddCurrency(Currency), +} + +fn get_buffer_interval() -> Duration { + Duration::from_millis(0) +} + +async fn drain_buffer(pool: &PgPool, buffer: &mut VecDeque) { + for cmd in buffer.drain(..) { + match cmd { + DatabaseQuery::Add(key, value) => { + DatabaseQueries::add(pool, key, value).await.unwrap(); + } + DatabaseQuery::AddCurrency(currency) => { + DatabaseQueries::add_currency(pool, currency).await.unwrap(); + } + } + } +} + +impl PostgresCacheDatabase { + pub async fn connect( + host: Option, + port: Option, + username: Option, + password: Option, + database: Option, + ) -> Result { + let pg_connect_options = + get_postgres_connect_options(host, port, username, password, database).unwrap(); + let pool = connect_pg(pg_connect_options.clone().into()).await.unwrap(); + let (tx, rx) = channel::(1000); + // spawn a thread to handle messages + let _join_handle = tokio::spawn(async move { + PostgresCacheDatabase::handle_message(rx, pg_connect_options.clone().into()).await; + }); + Ok(PostgresCacheDatabase { pool, tx }) + } + + async fn handle_message(mut rx: Receiver, pg_connect_options: PgConnectOptions) { + let pool = connect_pg(pg_connect_options).await.unwrap(); + // Buffering + let mut buffer: VecDeque = VecDeque::new(); + let mut last_drain = Instant::now(); + let buffer_interval = get_buffer_interval(); + let recv_interval = Duration::from_millis(1); + + loop { + if last_drain.elapsed() >= buffer_interval && !buffer.is_empty() { + // drain buffer + drain_buffer(&pool, &mut buffer).await; + last_drain = Instant::now(); + } else { + // Continue to receive and handle messages until channel is hung up + match rx.try_recv() { + Ok(msg) => buffer.push_back(msg), + Err(TryRecvError::Empty) => sleep(recv_interval).await, + Err(TryRecvError::Disconnected) => break, + } + } + } + // rain any remaining message + if !buffer.is_empty() { + drain_buffer(&pool, &mut buffer).await; + } + } + + pub async fn load(&self) -> Result>, sqlx::Error> { + let query = sqlx::query_as::<_, GeneralRow>("SELECT * FROM general"); + let result = query.fetch_all(&self.pool).await; + match result { + Ok(rows) => { + let mut cache: HashMap> = HashMap::new(); + for row in rows { + cache.insert(row.key, row.value); + } + Ok(cache) + } + Err(err) => { + panic!("Failed to load general table: {err}") + } + } + } + + pub async fn add(&self, key: String, value: Vec) -> anyhow::Result<()> { + let query = DatabaseQuery::Add(key, value); + self.tx.send(query).await.map_err(|err| { + anyhow::anyhow!("Failed to send query to database message handler: {err}") + }) + } + + pub async fn add_currency(&self, currency: Currency) -> anyhow::Result<()> { + let query = DatabaseQuery::AddCurrency(currency); + self.tx.send(query).await.map_err(|err| { + anyhow::anyhow!("Failed to query add_currency to database message handler: {err}") + }) + } +} diff --git a/nautilus_core/infrastructure/src/sql/database.rs b/nautilus_core/infrastructure/src/sql/database.rs deleted file mode 100644 index f757025b99ae..000000000000 --- a/nautilus_core/infrastructure/src/sql/database.rs +++ /dev/null @@ -1,226 +0,0 @@ -// ------------------------------------------------------------------------------------------------- -// Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. -// https://nautechsystems.io -// -// Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); -// You may not use this file except in compliance with the License. -// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ------------------------------------------------------------------------------------------------- - -use std::{path::Path, str::FromStr}; - -use sqlx::{ - any::{install_default_drivers, AnyConnectOptions}, - sqlite::SqliteConnectOptions, - Error, Pool, SqlitePool, -}; - -#[derive(Clone)] -pub struct Database { - pub pool: Pool, -} - -pub enum DatabaseEngine { - POSTGRES, - SQLITE, -} - -fn str_to_database_engine(engine_str: &str) -> DatabaseEngine { - match engine_str { - "POSTGRES" | "postgres" => DatabaseEngine::POSTGRES, - "SQLITE" | "sqlite" => DatabaseEngine::SQLITE, - _ => panic!("Invalid database engine: {engine_str}"), - } -} - -impl Database { - pub async fn new(engine: Option, conn_string: Option<&str>) -> Self { - install_default_drivers(); - let db_options = Self::get_db_options(engine, conn_string); - let db = sqlx::pool::PoolOptions::new() - .max_connections(20) - .connect_with(db_options) - .await; - match db { - Ok(pool) => Self { pool }, - Err(err) => { - panic!("Failed to connect to database: {err}") - } - } - } - - #[must_use] - pub fn get_db_options( - engine: Option, - conn_string: Option<&str>, - ) -> AnyConnectOptions { - let connection_string = match conn_string { - Some(conn_string) => Ok(conn_string.to_string()), - None => std::env::var("DATABASE_URL"), - }; - let database_engine: DatabaseEngine = match engine { - Some(engine) => engine, - None => str_to_database_engine( - std::env::var("DATABASE_ENGINE") - .unwrap_or("SQLITE".to_string()) - .as_str(), - ), - }; - match connection_string { - Ok(connection_string) => match database_engine { - DatabaseEngine::POSTGRES => AnyConnectOptions::from_str(connection_string.as_str()) - .expect("Invalid PostgresSQL connection string"), - DatabaseEngine::SQLITE => AnyConnectOptions::from_str(connection_string.as_str()) - .expect("Invalid SQLITE connection string"), - }, - Err(err) => { - panic!("Failed to connect to database: {err}") - } - } - } - - pub async fn execute(&self, query_str: &str) -> Result { - let result = sqlx::query(query_str).execute(&self.pool).await?; - - Ok(result.rows_affected()) - } - - pub async fn fetch_all(&self, query_str: &str) -> Result, Error> - where - T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Unpin, - { - let rows = sqlx::query(query_str).fetch_all(&self.pool).await?; - - let mut objects = Vec::new(); - for row in rows { - let obj = T::from_row(&row)?; - objects.push(obj); - } - - Ok(objects) - } -} - -pub async fn init_db_schema(db: &Database, schema_dir: &str) -> anyhow::Result<()> { - // scan all the files in the current directory - let mut sql_files = - std::fs::read_dir(schema_dir)?.collect::, std::io::Error>>()?; - - for file in &mut sql_files { - let file_name = file.file_name(); - println!("Executing SQL file: {file_name:?}"); - let file_path = file.path(); - let sql_content = std::fs::read_to_string(file_path.clone())?; - for sql_statement in sql_content.split(';').filter(|s| !s.trim().is_empty()) { - db.execute(sql_statement).await.unwrap_or_else(|e| { - panic!( - "Failed to execute SQL statement: {} with reason {}", - file_path.display(), - e - ) - }); - } - } - Ok(()) -} - -pub async fn setup_test_database() -> Database { - // check if test_db.sqlite exists,if not, create it - let db_path = std::env::var("TEST_DB_PATH").unwrap_or("test_db.sqlite".to_string()); - let db_file_path = Path::new(db_path.as_str()); - let exists = db_file_path.exists(); - if !exists { - SqlitePool::connect_with( - SqliteConnectOptions::new() - .filename(db_file_path) - .create_if_missing(true), - ) - .await - .expect("Failed to create test_db.sqlite"); - } - Database::new(Some(DatabaseEngine::SQLITE), Some("sqlite:test_db.sqlite")).await -} - -//////////////////////////////////////////////////////////////////////////////// -// Tests -//////////////////////////////////////////////////////////////////////////////// -#[cfg(test)] -mod tests { - - use sqlx::{FromRow, Row}; - - use crate::sql::database::{setup_test_database, Database}; - - async fn init_item_table(database: &Database) { - database - .execute("CREATE TABLE IF NOT EXISTS items (key TEXT PRIMARY KEY, value TEXT)") - .await - .expect("Failed to create table item"); - } - - async fn drop_table(database: &Database) { - database - .execute("DROP TABLE items") - .await - .expect("Failed to drop table items"); - } - - #[tokio::test] - async fn test_database() { - let db = setup_test_database().await; - let rows_affected = db.execute("SELECT 1").await.unwrap(); - // it will not fail and give 0 rows affected - assert_eq!(rows_affected, 0); - } - - #[tokio::test] - async fn test_database_fetch_all() { - let db = setup_test_database().await; - struct SimpleValue { - value: i32, - } - impl FromRow<'_, sqlx::any::AnyRow> for SimpleValue { - fn from_row(row: &sqlx::any::AnyRow) -> Result { - Ok(Self { - value: row.try_get(0)?, - }) - } - } - let result = db.fetch_all::("SELECT 3").await.unwrap(); - assert_eq!(result[0].value, 3); - } - - #[tokio::test] - async fn test_insert_and_select() { - let db = setup_test_database().await; - init_item_table(&db).await; - // insert some value - db.execute("INSERT INTO items (key, value) VALUES ('key1', 'value1')") - .await - .unwrap(); - // fetch item, impl Data struct - struct Item { - key: String, - value: String, - } - impl FromRow<'_, sqlx::any::AnyRow> for Item { - fn from_row(row: &sqlx::any::AnyRow) -> Result { - Ok(Self { - key: row.try_get(0)?, - value: row.try_get(1)?, - }) - } - } - let result = db.fetch_all::("SELECT * FROM items").await.unwrap(); - assert_eq!(result.len(), 1); - assert_eq!(result[0].key, "key1"); - assert_eq!(result[0].value, "value1"); - drop_table(&db).await; - } -} diff --git a/nautilus_core/infrastructure/src/sql/mod.rs b/nautilus_core/infrastructure/src/sql/mod.rs index f098c3cf9d49..6e14caa333d3 100644 --- a/nautilus_core/infrastructure/src/sql/mod.rs +++ b/nautilus_core/infrastructure/src/sql/mod.rs @@ -13,12 +13,10 @@ // limitations under the License. // ------------------------------------------------------------------------------------------------- -/// Be careful about ordering and foreign key constraints when deleting data. -/// We can use this list for manual truncation of tables. -pub const NAUTILUS_TABLES: [&str; 5] = - ["general", "instrument", "currency", "order", "order_event"]; +// Be careful about ordering and foreign key constraints when deleting data. +pub const NAUTILUS_TABLES: [&str; 2] = ["general", "currency"]; -pub mod database; +pub mod cache_database; pub mod models; pub mod pg; -pub mod schema; +pub mod queries; diff --git a/nautilus_core/infrastructure/src/sql/schema.rs b/nautilus_core/infrastructure/src/sql/models/general.rs similarity index 91% rename from nautilus_core/infrastructure/src/sql/schema.rs rename to nautilus_core/infrastructure/src/sql/models/general.rs index 2a551437f16c..824714a2c0d3 100644 --- a/nautilus_core/infrastructure/src/sql/schema.rs +++ b/nautilus_core/infrastructure/src/sql/models/general.rs @@ -13,8 +13,8 @@ // limitations under the License. // ------------------------------------------------------------------------------------------------- -#[derive(sqlx::FromRow)] -pub struct GeneralItem { +#[derive(Debug, sqlx::FromRow)] +pub struct GeneralRow { pub key: String, - pub value: String, + pub value: Vec, } diff --git a/nautilus_core/infrastructure/src/sql/models/mod.rs b/nautilus_core/infrastructure/src/sql/models/mod.rs index 02d78ede908a..4fe4acea056d 100644 --- a/nautilus_core/infrastructure/src/sql/models/mod.rs +++ b/nautilus_core/infrastructure/src/sql/models/mod.rs @@ -13,4 +13,6 @@ // limitations under the License. // ------------------------------------------------------------------------------------------------- +pub mod general; pub mod instruments; +pub mod types; diff --git a/nautilus_core/infrastructure/src/sql/models/types.rs b/nautilus_core/infrastructure/src/sql/models/types.rs new file mode 100644 index 000000000000..a2d98170f723 --- /dev/null +++ b/nautilus_core/infrastructure/src/sql/models/types.rs @@ -0,0 +1,43 @@ +// ------------------------------------------------------------------------------------------------- +// Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +// https://nautechsystems.io +// +// Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------------------------------- + +use std::str::FromStr; + +use nautilus_model::{enums::CurrencyType, types::currency::Currency}; +use sqlx::{postgres::PgRow, FromRow, Row}; + +pub struct CurrencyModel(pub Currency); + +impl<'r> FromRow<'r, PgRow> for CurrencyModel { + fn from_row(row: &'r PgRow) -> Result { + let code = row.try_get::("code")?; + let precision = row.try_get::("precision")?; + let iso4217 = row.try_get::("iso4217")?; + let name = row.try_get::("name")?; + let currency_type = row + .try_get::("currency_type") + .map(|res| CurrencyType::from_str(res.as_str()).unwrap())?; + + let currency = Currency::new( + code.as_str(), + precision as u8, + iso4217 as u16, + name.as_str(), + currency_type, + ) + .unwrap(); + Ok(CurrencyModel(currency)) + } +} diff --git a/nautilus_core/infrastructure/src/sql/pg.rs b/nautilus_core/infrastructure/src/sql/pg.rs index c86a5245e0d3..fd31fea4c917 100644 --- a/nautilus_core/infrastructure/src/sql/pg.rs +++ b/nautilus_core/infrastructure/src/sql/pg.rs @@ -14,6 +14,7 @@ // ------------------------------------------------------------------------------------------------- use sqlx::{postgres::PgConnectOptions, query, ConnectOptions, PgPool}; +use tracing::log::{error, info}; use crate::sql::NAUTILUS_TABLES; @@ -112,3 +113,201 @@ pub async fn delete_nautilus_postgres_tables(db: &PgPool) -> anyhow::Result<()> pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result { Ok(PgPool::connect_with(options).await.unwrap()) } + +/// Scans current path with keyword nautilus_trader and build schema dir +fn get_schema_dir() -> anyhow::Result { + std::env::var("SCHEMA_DIR").or_else(|_| { + let nautilus_git_repo_name = "nautilus_trader"; + let binding = std::env::current_dir().unwrap(); + let current_dir = binding.to_str().unwrap(); + match current_dir.find(nautilus_git_repo_name){ + Some(index) => { + let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema"; + Ok(schema_path) + } + None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable") + } + }) +} + +pub async fn init_postgres( + pg: &PgPool, + database: String, + password: String, + schema_dir: Option, +) -> anyhow::Result<()> { + info!("Initializing Postgres database with target permissions and schema"); + // create public schema + match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;") + .execute(pg) + .await + { + Ok(_) => info!("Schema public created successfully"), + Err(err) => error!("Error creating schema public: {:?}", err), + } + // create role if not exists + match sqlx::query(format!("CREATE ROLE {} PASSWORD '{}' LOGIN;", database, password).as_str()) + .execute(pg) + .await + { + Ok(_) => info!("Role {} created successfully", database), + Err(err) => { + if err.to_string().contains("already exists") { + info!("Role {} already exists", database); + } else { + error!("Error creating role {}: {:?}", database, err); + } + } + } + // execute all the sql files in schema dir + let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap()); + let mut sql_files = + std::fs::read_dir(schema_dir)?.collect::, std::io::Error>>()?; + for file in &mut sql_files { + let file_name = file.file_name(); + info!("Executing schema file: {:?}", file_name); + let file_path = file.path(); + let sql_content = std::fs::read_to_string(file_path.clone())?; + for sql_statement in sql_content.split(';').filter(|s| !s.trim().is_empty()) { + sqlx::query(sql_statement) + .execute(pg) + .await + .map_err(|err| { + if err.to_string().contains("already exists") { + info!("Already exists error on statement, skipping"); + } else { + panic!( + "Error executing statement {} with error: {:?}", + sql_statement, err + ) + } + }) + .unwrap(); + } + } + // grant connect + match sqlx::query(format!("GRANT CONNECT ON DATABASE {0} TO {0};", database).as_str()) + .execute(pg) + .await + { + Ok(_) => info!("Connect privileges granted to role {}", database), + Err(err) => error!( + "Error granting connect privileges to role {}: {:?}", + database, err + ), + } + // grant all schema privileges to the role + match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {};", database).as_str()) + .execute(pg) + .await + { + Ok(_) => info!("All schema privileges granted to role {}", database), + Err(err) => error!( + "Error granting all privileges to role {}: {:?}", + database, err + ), + } + // grant all table privileges to the role + match sqlx::query( + format!( + "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {};", + database + ) + .as_str(), + ) + .execute(pg) + .await + { + Ok(_) => info!("All tables privileges granted to role {}", database), + Err(err) => error!( + "Error granting all privileges to role {}: {:?}", + database, err + ), + } + // grant all sequence privileges to the role + match sqlx::query( + format!( + "GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {};", + database + ) + .as_str(), + ) + .execute(pg) + .await + { + Ok(_) => info!("All sequences privileges granted to role {}", database), + Err(err) => error!( + "Error granting all privileges to role {}: {:?}", + database, err + ), + } + // grant all function privileges to the role + match sqlx::query( + format!( + "GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {};", + database + ) + .as_str(), + ) + .execute(pg) + .await + { + Ok(_) => info!("All functions privileges granted to role {}", database), + Err(err) => error!( + "Error granting all privileges to role {}: {:?}", + database, err + ), + } + + Ok(()) +} + +pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> { + // execute drop owned + match sqlx::query(format!("DROP OWNED BY {}", database).as_str()) + .execute(pg) + .await + { + Ok(_) => info!("Dropped owned objects by role {}", database), + Err(err) => error!("Error dropping owned by role {}: {:?}", database, err), + } + // revoke connect + match sqlx::query(format!("REVOKE CONNECT ON DATABASE {0} FROM {0};", database).as_str()) + .execute(pg) + .await + { + Ok(_) => info!("Revoked connect privileges from role {}", database), + Err(err) => error!( + "Error revoking connect privileges from role {}: {:?}", + database, err + ), + } + // revoke privileges + match sqlx::query(format!("REVOKE ALL PRIVILEGES ON DATABASE {0} FROM {0};", database).as_str()) + .execute(pg) + .await + { + Ok(_) => info!("Revoked all privileges from role {}", database), + Err(err) => error!( + "Error revoking all privileges from role {}: {:?}", + database, err + ), + } + // execute drop schema + match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE") + .execute(pg) + .await + { + Ok(_) => info!("Dropped schema public"), + Err(err) => error!("Error dropping schema public: {:?}", err), + } + // drop role + match sqlx::query(format!("DROP ROLE IF EXISTS {};", database).as_str()) + .execute(pg) + .await + { + Ok(_) => info!("Dropped role {}", database), + Err(err) => error!("Error dropping role {}: {:?}", database, err), + } + Ok(()) +} diff --git a/nautilus_core/infrastructure/src/sql/queries.rs b/nautilus_core/infrastructure/src/sql/queries.rs new file mode 100644 index 000000000000..c0db3edc4bc3 --- /dev/null +++ b/nautilus_core/infrastructure/src/sql/queries.rs @@ -0,0 +1,89 @@ +// ------------------------------------------------------------------------------------------------- +// Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +// https://nautechsystems.io +// +// Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------------------------------- + +use std::collections::HashMap; + +use nautilus_model::types::currency::Currency; +use sqlx::PgPool; + +use crate::sql::models::{general::GeneralRow, types::CurrencyModel}; + +pub struct DatabaseQueries; + +impl DatabaseQueries { + pub async fn add(pool: &PgPool, key: String, value: Vec) -> anyhow::Result<()> { + sqlx::query("INSERT INTO general (key, value) VALUES ($1, $2)") + .bind(key) + .bind(value) + .execute(pool) + .await + .map(|_| ()) + .map_err(|err| anyhow::anyhow!("Failed to insert into general table: {err}")) + } + + pub async fn load(pool: &PgPool) -> anyhow::Result>> { + sqlx::query_as::<_, GeneralRow>("SELECT * FROM general") + .fetch_all(pool) + .await + .map(|rows| { + let mut cache: HashMap> = HashMap::new(); + for row in rows { + cache.insert(row.key, row.value); + } + cache + }) + .map_err(|err| anyhow::anyhow!("Failed to load general table: {err}")) + } + + pub async fn add_currency(pool: &PgPool, currency: Currency) -> anyhow::Result<()> { + sqlx::query( + "INSERT INTO currency (code, precision, iso4217, name, currency_type) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (code) DO NOTHING" + ) + .bind(currency.code.as_str()) + .bind(currency.precision as i32) + .bind(currency.iso4217 as i32) + .bind(currency.name.as_str()) + .bind(currency.currency_type.to_string()) + .execute(pool) + .await + .map(|_| ()) + .map_err(|err| anyhow::anyhow!("Failed to insert into currency table: {err}")) + } + + pub async fn load_currencies(pool: &PgPool) -> anyhow::Result> { + sqlx::query_as::<_, CurrencyModel>("SELECT * FROM currency ORDER BY code ASC") + .fetch_all(pool) + .await + .map(|rows| rows.into_iter().map(|row| row.0).collect()) + .map_err(|err| anyhow::anyhow!("Failed to load currencies: {err}")) + } + + pub async fn load_currency(pool: &PgPool, code: &str) -> anyhow::Result> { + sqlx::query_as::<_, CurrencyModel>("SELECT * FROM currency WHERE code = $1") + .bind(code) + .fetch_optional(pool) + .await + .map(|currency| currency.map(|row| row.0)) + .map_err(|err| anyhow::anyhow!("Failed to load currency: {err}")) + } + + pub async fn truncate(pool: &PgPool, table: String) -> anyhow::Result<()> { + sqlx::query(format!("TRUNCATE TABLE {} CASCADE", table).as_str()) + .execute(pool) + .await + .map(|_| ()) + .map_err(|err| anyhow::anyhow!("Failed to truncate table: {err}")) + } +} diff --git a/nautilus_core/infrastructure/tests/test_cache_database_postgres.rs b/nautilus_core/infrastructure/tests/test_cache_database_postgres.rs new file mode 100644 index 000000000000..05d3cc39ff1a --- /dev/null +++ b/nautilus_core/infrastructure/tests/test_cache_database_postgres.rs @@ -0,0 +1,91 @@ +// ------------------------------------------------------------------------------------------------- +// Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +// https://nautechsystems.io +// +// Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------------------------------- + +use nautilus_infrastructure::sql::{ + cache_database::PostgresCacheDatabase, + pg::{connect_pg, delete_nautilus_postgres_tables, PostgresConnectOptions}, +}; +use sqlx::PgPool; + +pub fn get_test_pg_connect_options(username: &str) -> PostgresConnectOptions { + PostgresConnectOptions::new( + "localhost".to_string(), + 5432, + username.to_string(), + "pass".to_string(), + "nautilus".to_string(), + ) +} +pub async fn get_pg(username: &str) -> PgPool { + let pg_connect_options = get_test_pg_connect_options(username); + connect_pg(pg_connect_options.into()).await.unwrap() +} + +pub async fn initialize() -> anyhow::Result<()> { + // get pg pool with root postgres user to drop & create schema + let pg_pool = get_pg("postgres").await; + delete_nautilus_postgres_tables(&pg_pool).await.unwrap(); + Ok(()) +} + +pub async fn get_pg_cache_database() -> anyhow::Result { + initialize().await.unwrap(); + // run tests as nautilus user + let connect_options = get_test_pg_connect_options("nautilus"); + Ok(PostgresCacheDatabase::connect( + Some(connect_options.host), + Some(connect_options.port), + Some(connect_options.username), + Some(connect_options.password), + Some(connect_options.database), + ) + .await + .unwrap()) +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use crate::get_pg_cache_database; + + #[tokio::test] + async fn test_load_general_objects_when_nothing_in_cache_returns_empty_hashmap() { + let pg_cache = get_pg_cache_database().await.unwrap(); + let result = pg_cache.load().await.unwrap(); + println!("1: {:?}", result); + assert_eq!(result.len(), 0); + } + + #[tokio::test] + async fn test_add_general_object_adds_to_cache() { + let pg_cache = get_pg_cache_database().await.unwrap(); + let test_id_value = String::from("test_value").into_bytes(); + pg_cache + .add(String::from("test_id"), test_id_value.clone()) + .await + .unwrap(); + // sleep with tokio + tokio::time::sleep(Duration::from_secs(1)).await; + let result = pg_cache.load().await.unwrap(); + println!("2: {:?}", result); + assert_eq!(result.keys().len(), 1); + assert_eq!( + result.keys().cloned().collect::>(), + vec![String::from("test_id")] + ); // assert_eq!(result.get(&test_id_key).unwrap().to_owned(),&test_id_value.clone()); + assert_eq!(result.get("test_id").unwrap().to_owned(), test_id_value); + } +} diff --git a/nautilus_core/persistence/Cargo.toml b/nautilus_core/persistence/Cargo.toml index 59e60ca1c617..d95388b88c95 100644 --- a/nautilus_core/persistence/Cargo.toml +++ b/nautilus_core/persistence/Cargo.toml @@ -12,7 +12,7 @@ crate-type = ["rlib", "staticlib", "cdylib"] [dependencies] nautilus-core = { path = "../core" } -nautilus-model = { path = "../model" } +nautilus-model = { path = "../model", features = ["stubs"] } anyhow = { workspace = true } futures = { workspace = true } pyo3 = { workspace = true, optional = true } @@ -35,8 +35,8 @@ procfs = "0.16.0" [features] default = ["ffi", "python"] extension-module = [ - "pyo3/extension-module", - "nautilus-core/extension-module", + "pyo3/extension-module", + "nautilus-core/extension-module", "nautilus-model/extension-module", ] ffi = ["nautilus-core/ffi", "nautilus-model/ffi"] diff --git a/nautilus_trader/cache/postgres/__init__.py b/nautilus_trader/cache/postgres/__init__.py new file mode 100644 index 000000000000..3d34cab4588e --- /dev/null +++ b/nautilus_trader/cache/postgres/__init__.py @@ -0,0 +1,14 @@ +# ------------------------------------------------------------------------------------------------- +# Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +# https://nautechsystems.io +# +# Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------- diff --git a/nautilus_trader/cache/postgres/adapter.py b/nautilus_trader/cache/postgres/adapter.py new file mode 100644 index 000000000000..fe67593bc242 --- /dev/null +++ b/nautilus_trader/cache/postgres/adapter.py @@ -0,0 +1,57 @@ +# ------------------------------------------------------------------------------------------------- +# Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +# https://nautechsystems.io +# +# Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------- + +from nautilus_trader.cache.config import CacheConfig +from nautilus_trader.cache.facade import CacheDatabaseFacade +from nautilus_trader.cache.postgres.transformers import transform_currency_from_pyo3 +from nautilus_trader.cache.postgres.transformers import transform_currency_to_pyo3 +from nautilus_trader.core.nautilus_pyo3 import PostgresCacheDatabase +from nautilus_trader.model.objects import Currency + + +class CachePostgresAdapter(CacheDatabaseFacade): + + def __init__( + self, + config: CacheConfig | None = None, + ): + if config: + config = CacheConfig() + super().__init__(config) + self._backing: PostgresCacheDatabase = PostgresCacheDatabase.connect() + + def flush(self): + self._backing.flush_db() + + def load(self): + data = self._backing.load() + return {key: bytes(value) for key, value in data.items()} + + def add(self, key: str, value: bytes): + self._backing.add(key, value) + + def add_currency(self, currency: Currency): + currency_pyo3 = transform_currency_to_pyo3(currency) + self._backing.add_currency(currency_pyo3) + + def load_currencies(self) -> dict[str, Currency]: + currencies = self._backing.load_currencies() + return {currency.code: transform_currency_from_pyo3(currency) for currency in currencies} + + def load_currency(self, code: str) -> Currency | None: + currency_pyo3 = self._backing.load_currency(code) + if currency_pyo3: + return transform_currency_from_pyo3(currency_pyo3) + return None diff --git a/nautilus_trader/cache/postgres/transformers.py b/nautilus_trader/cache/postgres/transformers.py new file mode 100644 index 000000000000..65542490015d --- /dev/null +++ b/nautilus_trader/cache/postgres/transformers.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------------------------------- +# Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +# https://nautechsystems.io +# +# Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------- + +from nautilus_trader.core import nautilus_pyo3 +from nautilus_trader.model.enums import CurrencyType +from nautilus_trader.model.objects import Currency + + +################################################################################ +# Currency +################################################################################ +def transform_currency_from_pyo3(currency: nautilus_pyo3.Currency) -> Currency: + return Currency( + code=currency.code, + precision=currency.precision, + iso4217=currency.iso4217, + name=currency.name, + currency_type=CurrencyType(currency.currency_type.value), + ) + + +def transform_currency_to_pyo3(currency: Currency) -> nautilus_pyo3.Currency: + return nautilus_pyo3.Currency( + code=currency.code, + precision=currency.precision, + iso4217=currency.iso4217, + name=currency.name, + currency_type=nautilus_pyo3.CurrencyType.from_str(currency.currency_type.name), + ) diff --git a/nautilus_trader/core/nautilus_pyo3.pyi b/nautilus_trader/core/nautilus_pyo3.pyi index ea41259eb7cf..d027c8213435 100644 --- a/nautilus_trader/core/nautilus_pyo3.pyi +++ b/nautilus_trader/core/nautilus_pyo3.pyi @@ -682,6 +682,8 @@ class CurrencyType(Enum): CRYPTO = "CRYPTO" FIAT = "FIAT" COMMODITY_BACKED = "COMMODITY_BACKED" + @classmethod + def from_str(cls, value: str) -> CurrencyType: ... class InstrumentCloseType(Enum): END_OF_SESSION = "END_OF_SESSION" @@ -2250,6 +2252,24 @@ class RedisCacheDatabase: config: dict[str, Any], ) -> None: ... +class PostgresCacheDatabase: + @classmethod + def connect( + cls, + host: str | None = None, + port: str | None = None, + username: str | None = None, + password: str | None = None, + database: str | None = None, + )-> PostgresCacheDatabase: ... + def load(self) -> dict[str,str]: ... + def add(self, key: str, value: bytes) -> None: ... + def add_currency(self,currency: Currency) -> None: ... + def load_currency(self, code: str) -> Currency | None: ... + def load_currencies(self) -> list[Currency]: ... + def flush_db(self) -> None: ... + def truncate(self, table: str) -> None: ... + ################################################################################################### # Network ################################################################################################### diff --git a/tests/integration_tests/infrastructure/test_cache_database_postgres.py b/tests/integration_tests/infrastructure/test_cache_database_postgres.py new file mode 100644 index 000000000000..759a3de3fa75 --- /dev/null +++ b/tests/integration_tests/infrastructure/test_cache_database_postgres.py @@ -0,0 +1,125 @@ +# ------------------------------------------------------------------------------------------------- +# Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved. +# https://nautechsystems.io +# +# Licensed under the GNU Lesser General Public License Version 3.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------- + +import os + +import pytest + +from nautilus_trader.cache.postgres.adapter import CachePostgresAdapter +from nautilus_trader.common.component import MessageBus +from nautilus_trader.common.component import TestClock +from nautilus_trader.model.enums import CurrencyType +from nautilus_trader.model.objects import Currency +from nautilus_trader.portfolio.portfolio import Portfolio +from nautilus_trader.test_kit.functions import eventually +from nautilus_trader.test_kit.providers import TestInstrumentProvider +from nautilus_trader.test_kit.stubs.component import TestComponentStubs +from nautilus_trader.test_kit.stubs.data import TestDataStubs +from nautilus_trader.test_kit.stubs.identifiers import TestIdStubs +from nautilus_trader.trading.strategy import Strategy + + +AUDUSD_SIM = TestInstrumentProvider.default_fx_ccy("AUD/USD") + + +class TestCachePostgresAdapter: + def setup(self): + # set envs + os.environ["POSTGRES_HOST"] = "localhost" + os.environ["POSTGRES_PORT"] = "5432" + os.environ["POSTGRES_USERNAME"] = "nautilus" + os.environ["POSTGRES_PASSWORD"] = "pass" + os.environ["POSTGRES_DATABASE"] = "nautilus" + self.database: CachePostgresAdapter = CachePostgresAdapter() + # reset database + self.database.flush() + self.clock = TestClock() + + self.trader_id = TestIdStubs.trader_id() + + self.msgbus = MessageBus( + trader_id=self.trader_id, + clock=self.clock, + ) + + self.cache = TestComponentStubs.cache() + + self.portfolio = Portfolio( + msgbus=self.msgbus, + cache=self.cache, + clock=self.clock, + ) + + # Init strategy + self.strategy = Strategy() + self.strategy.register( + trader_id=self.trader_id, + portfolio=self.portfolio, + msgbus=self.msgbus, + cache=self.cache, + clock=self.clock, + ) + + def teardown(self): + self.database.flush() + + @pytest.mark.asyncio + async def test_load_general_objects_when_nothing_in_cache_returns_empty_dict(self): + # Arrange, Act + result = self.database.load() + + # Assert + assert result == {} + + @pytest.mark.asyncio + async def test_add_general_object_adds_to_cache(self): + # Arrange + bar = TestDataStubs.bar_5decimal() + key = str(bar.bar_type) + "-" + str(bar.ts_event) + + # Act + self.database.add(key, str(bar).encode()) + + # Allow MPSC thread to insert + await eventually(lambda: self.database.load()) + + # Assert + assert self.database.load() == {key: str(bar).encode()} + + ################################################################################ + # Currency + ################################################################################ + @pytest.mark.asyncio + async def test_add_currency(self): + # Arrange + currency = Currency( + code="BTC", + precision=8, + iso4217=0, + name="BTC", + currency_type=CurrencyType.CRYPTO, + ) + + # Act + self.database.add_currency(currency) + + # Allow MPSC thread to insert + await eventually(lambda: self.database.load_currency(currency.code)) + + # Assert + assert self.database.load_currency(currency.code) == currency + + currencies = self.database.load_currencies() + assert list(currencies.keys()) == ["BTC"]