diff --git a/.github/workflows/connect-test-local.yml b/.github/workflows/connect-test-local.yml index cf56cca1..5430e2c9 100644 --- a/.github/workflows/connect-test-local.yml +++ b/.github/workflows/connect-test-local.yml @@ -48,9 +48,12 @@ jobs: working-directory: ./infra run: | ( docker-compose logs --follow & ) | grep -q "database system is ready to accept connections" + - name: Prepare db tables + run: | + cargo run --bin tables_migration - name: run cargo test run: | - cargo test + cargo test -- --test-threads=1 cargo run --bin nightly-connect-server & - name: test base local run: | diff --git a/database/migrations/000000000003_sessions.sql b/database/migrations/000000000003_sessions.sql new file mode 100644 index 00000000..72afc282 --- /dev/null +++ b/database/migrations/000000000003_sessions.sql @@ -0,0 +1,23 @@ +CREATE TABLE sessions( + session_id TEXT NOT NULL UNIQUE, + app_id TEXT NOT NULL, + app_metadata TEXT NOT NULL, + app_ip_address TEXT NOT NULL, + persistent BOOLEAN NOT NULL, + network TEXT NOT NULL, + client_id TEXT, + client_device TEXT, + client_metadata TEXT, + client_notification_endpoint TEXT, + client_connected_at BIGINT, + session_open_timestamp BIGINT NOT NULL, + session_close_timestamp BIGINT +); + +CREATE UNIQUE INDEX sessions_session_id ON sessions(session_id); + +ALTER TABLE sessions +ADD CONSTRAINT fk_sessions_registered_apps +FOREIGN KEY (app_id) +REFERENCES registered_apps (app_id) +ON DELETE CASCADE; \ No newline at end of file diff --git a/database/migrations/000000000004_request_status.sql b/database/migrations/000000000004_request_status.sql new file mode 100644 index 00000000..64b872fe --- /dev/null +++ b/database/migrations/000000000004_request_status.sql @@ -0,0 +1,8 @@ +CREATE TYPE request_status_enum AS ENUM ( + 'Pending', + 'Completed', + 'Failed', + 'Rejected', + 'TimedOut', + 'Unknown' +); \ No newline at end of file diff --git a/database/migrations/000000000005_requests.sql b/database/migrations/000000000005_requests.sql new file mode 100644 index 00000000..380f10b7 --- /dev/null +++ b/database/migrations/000000000005_requests.sql @@ -0,0 +1,16 @@ +CREATE TABLE requests( + request_id TEXT NOT NULL UNIQUE, + request_type TEXT NOT NULL, + session_id TEXT NOT NULL, + request_status request_status_enum NOT NULL, + network TEXT NOT NULL, + creation_timestamp BIGINT NOT NULL +); + +CREATE UNIQUE INDEX requests_request_id ON requests(request_id); + +ALTER TABLE requests +ADD CONSTRAINT fk_requests_sessions +FOREIGN KEY (session_id) +REFERENCES sessions (session_id) +ON DELETE CASCADE; \ No newline at end of file diff --git a/database/src/bin/tables_migration.rs b/database/src/bin/tables_migration.rs index bee88133..0b510165 100644 --- a/database/src/bin/tables_migration.rs +++ b/database/src/bin/tables_migration.rs @@ -2,6 +2,9 @@ use database::db::Db; #[tokio::main] async fn main() { + println!("Connecting to the database..."); let db = Db::connect_to_the_pool().await; + println!("Starting migration of tables..."); db.migrate_tables().await.unwrap(); + println!("Migration completed."); } diff --git a/database/src/db.rs b/database/src/db.rs index 1f5d7440..329efa14 100644 --- a/database/src/db.rs +++ b/database/src/db.rs @@ -32,10 +32,4 @@ impl Db { pub async fn migrate_tables(&self) -> Result<(), sqlx::migrate::MigrateError> { migrate!("./migrations").run(&self.connection_pool).await } - - pub async fn truncate_table(&self, table_name: &str) -> Result<(), sqlx::Error> { - let query = format!("TRUNCATE TABLE {table_name}"); - sqlx::query(&query).execute(&self.connection_pool).await?; - Ok(()) - } } diff --git a/database/src/lib.rs b/database/src/lib.rs index 94110a11..48a019cd 100644 --- a/database/src/lib.rs +++ b/database/src/lib.rs @@ -1,2 +1,3 @@ pub mod db; +pub mod structs; pub mod tables; diff --git a/database/src/structs/client_data.rs b/database/src/structs/client_data.rs new file mode 100644 index 00000000..eab5f3d1 --- /dev/null +++ b/database/src/structs/client_data.rs @@ -0,0 +1,8 @@ +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ClientData { + pub client_id: Option, + pub device: Option, + pub metadata: Option, + pub notification_endpoint: Option, + pub connected_at: u64, // Timestamp of when the client connected to the session +} diff --git a/database/src/structs/consts.rs b/database/src/structs/consts.rs new file mode 100644 index 00000000..0c5e1fb0 --- /dev/null +++ b/database/src/structs/consts.rs @@ -0,0 +1,3 @@ +pub const LAST_24_HOURS: &str = "EXTRACT(EPOCH FROM NOW() - INTERVAL '1 day')::BIGINT * 1000"; +pub const LAST_7_DAYS: &str = "EXTRACT(EPOCH FROM NOW() - INTERVAL '7 days')::BIGINT * 1000"; +pub const LAST_30_DAYS: &str = "EXTRACT(EPOCH FROM NOW() - INTERVAL '30 days')::BIGINT * 1000"; diff --git a/database/src/structs/mod.rs b/database/src/structs/mod.rs new file mode 100644 index 00000000..a9de028b --- /dev/null +++ b/database/src/structs/mod.rs @@ -0,0 +1,4 @@ +pub mod client_data; +pub mod consts; +pub mod request_status; +pub mod subscription; diff --git a/database/src/structs/request_status.rs b/database/src/structs/request_status.rs new file mode 100644 index 00000000..3a47dd14 --- /dev/null +++ b/database/src/structs/request_status.rs @@ -0,0 +1,10 @@ +use sqlx::Type; + +#[derive(Clone, Debug, Eq, PartialEq, Type)] +#[sqlx(type_name = "request_status_enum")] +pub enum RequestStatus { + Pending, + Completed, + Rejected, + TimedOut, +} diff --git a/database/src/structs/subscription.rs b/database/src/structs/subscription.rs new file mode 100644 index 00000000..5bbf4d56 --- /dev/null +++ b/database/src/structs/subscription.rs @@ -0,0 +1,9 @@ +use sqlx::Type; + +#[derive(Clone, Debug, Eq, PartialEq, Type)] +#[sqlx(type_name = "subscription")] +pub struct Subscription { + pub subscription_type: String, + pub valid_from: i64, + pub valid_till: i64, +} diff --git a/database/src/tables/mod.rs b/database/src/tables/mod.rs index 056653ba..0edc4a56 100644 --- a/database/src/tables/mod.rs +++ b/database/src/tables/mod.rs @@ -1 +1,5 @@ pub mod registered_app; +pub mod requests; +pub mod sessions; +pub mod test_utils; +pub mod utils; diff --git a/database/src/tables/registered_app/select.rs b/database/src/tables/registered_app/select.rs index 2bdcef2e..0eb56b72 100644 --- a/database/src/tables/registered_app/select.rs +++ b/database/src/tables/registered_app/select.rs @@ -1,5 +1,6 @@ use super::table_struct::{RegisteredApp, REGISTERED_APPS_TABLE_NAME}; -use crate::db::Db; +use crate::tables::requests::table_struct::REQUESTS_TABLE_NAME; +use crate::{db::Db, tables::requests::table_struct::Request}; use sqlx::query_as; impl Db { @@ -15,4 +16,41 @@ impl Db { .fetch_one(&self.connection_pool) .await; } + + pub async fn get_requests_by_app_id( + &self, + app_id: &String, + ) -> Result, sqlx::Error> { + let query = format!( + "SELECT r.* FROM {REQUESTS_TABLE_NAME} r + INNER JOIN sessions s ON r.session_id = s.session_id + WHERE s.app_id = $1 + ORDER BY r.creation_timestamp DESC" + ); + let typed_query = query_as::<_, Request>(&query); + + return typed_query + .bind(&app_id) + .fetch_all(&self.connection_pool) + .await; + } + + pub async fn get_requests_by_app_id_with_filter( + &self, + app_id: &String, + filter: &str, + ) -> Result, sqlx::Error> { + let query = format!( + "SELECT r.* FROM {REQUESTS_TABLE_NAME} r + INNER JOIN sessions s ON r.session_id = s.session_id + WHERE s.app_id = $1 AND creation_timestamp >= {filter} + ORDER BY r.creation_timestamp DESC" + ); + let typed_query = query_as::<_, Request>(&query); + + return typed_query + .bind(&app_id) + .fetch_all(&self.connection_pool) + .await; + } } diff --git a/database/src/tables/registered_app/table_struct.rs b/database/src/tables/registered_app/table_struct.rs index c0ee7190..47068130 100644 --- a/database/src/tables/registered_app/table_struct.rs +++ b/database/src/tables/registered_app/table_struct.rs @@ -1,12 +1,5 @@ -use sqlx::{postgres::PgRow, FromRow, Row, Type}; - -// TODO move later to a common place -#[derive(Clone, Debug, Eq, PartialEq, Type)] -#[sqlx(type_name = "subscription")] -pub struct Subscription { - pub email: String, - pub subscribed_at: i64, -} +use crate::structs::subscription::Subscription; +use sqlx::{postgres::PgRow, FromRow, Row}; pub const REGISTERED_APPS_TABLE_NAME: &str = "registered_apps"; pub const REGISTERED_APPS_KEYS: &str = "app_id, app_name, whitelisted_domains, subscription, ack_public_keys, email, registration_timestamp, pass_hash"; diff --git a/database/src/tables/registered_app/update.rs b/database/src/tables/registered_app/update.rs index aefc8445..99e5e56e 100644 --- a/database/src/tables/registered_app/update.rs +++ b/database/src/tables/registered_app/update.rs @@ -1,7 +1,5 @@ -use super::table_struct::{ - RegisteredApp, Subscription, REGISTERED_APPS_KEYS, REGISTERED_APPS_TABLE_NAME, -}; -use crate::db::Db; +use super::table_struct::{RegisteredApp, REGISTERED_APPS_KEYS, REGISTERED_APPS_TABLE_NAME}; +use crate::{db::Db, structs::subscription::Subscription}; use sqlx::query; impl Db { @@ -50,13 +48,21 @@ impl Db { #[cfg(test)] mod tests { - use crate::tables::registered_app::table_struct::{RegisteredApp, REGISTERED_APPS_TABLE_NAME}; + use crate::{ + structs::{ + consts::{LAST_24_HOURS, LAST_30_DAYS, LAST_7_DAYS}, + request_status::RequestStatus, + }, + tables::{ + registered_app::table_struct::RegisteredApp, requests::table_struct::Request, + sessions::table_struct::DbNcSession, utils::get_timestamp_in_milliseconds, + }, + }; #[tokio::test] async fn test_register_app() { let db = super::Db::connect_to_the_pool().await; - db.migrate_tables().await.unwrap(); - db.truncate_table(REGISTERED_APPS_TABLE_NAME).await.unwrap(); + db.truncate_all_tables().await.unwrap(); let app = RegisteredApp { app_id: "test_app_id".to_string(), @@ -74,4 +80,195 @@ mod tests { let result = db.get_registered_app_by_app_id(&app.app_id).await.unwrap(); assert_eq!(app, result); } + + #[tokio::test] + async fn test_get_requests() { + let db = super::Db::connect_to_the_pool().await; + db.truncate_all_tables().await.unwrap(); + + // "Register" an app + let app = RegisteredApp { + app_id: "test_app_id".to_string(), + app_name: "test_app_name".to_string(), + whitelisted_domains: vec!["test_domain".to_string()], + subscription: None, + ack_public_keys: vec!["test_key".to_string()], + email: None, + registration_timestamp: 0, + pass_hash: None, + }; + + db.register_new_app(&app).await.unwrap(); + + let result = db.get_registered_app_by_app_id(&app.app_id).await.unwrap(); + assert_eq!(app, result); + + // Create 2 sessions + let session = DbNcSession { + session_id: "test_session_id".to_string(), + app_id: "test_app_id".to_string(), + app_metadata: "test_app_metadata".to_string(), + app_ip_address: "test_app_ip_address".to_string(), + persistent: false, + network: "test_network".to_string(), + client: None, + session_open_timestamp: 10, + session_close_timestamp: None, + }; + + let second_session = DbNcSession { + session_id: "test_session_id_2".to_string(), + app_id: "test_app_id".to_string(), + app_metadata: "test_app_metadata".to_string(), + app_ip_address: "test_app_ip_address".to_string(), + persistent: false, + network: "test_network".to_string(), + client: None, + session_open_timestamp: 12, + session_close_timestamp: None, + }; + + db.save_new_session(&session).await.unwrap(); + db.save_new_session(&second_session).await.unwrap(); + + let result = db.get_sessions_by_app_id(&app.app_id).await.unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(second_session, result[0]); + assert_eq!(session, result[1]); + + // Create 2 requests per session + // First session + let request = Request { + request_id: "test_request_id".to_string(), + session_id: "test_session_id".to_string(), + network: "test_network".to_string(), + creation_timestamp: 10, + request_status: RequestStatus::Pending, + request_type: "test_request_type".to_string(), + }; + + let second_request = Request { + request_id: "test_request_id_2".to_string(), + session_id: "test_session_id".to_string(), + network: "test_network".to_string(), + creation_timestamp: 12, + request_status: RequestStatus::Pending, + request_type: "test_request_type".to_string(), + }; + + db.save_request(&request).await.unwrap(); + db.save_request(&second_request).await.unwrap(); + + // Second session + let third_request = Request { + request_id: "test_request_id_3".to_string(), + session_id: "test_session_id_2".to_string(), + network: "test_network".to_string(), + creation_timestamp: 14, + request_status: RequestStatus::Pending, + request_type: "test_request_type".to_string(), + }; + + let fourth_request = Request { + request_id: "test_request_id_4".to_string(), + session_id: "test_session_id_2".to_string(), + network: "test_network".to_string(), + creation_timestamp: 16, + request_status: RequestStatus::Pending, + request_type: "test_request_type".to_string(), + }; + + db.save_request(&third_request).await.unwrap(); + db.save_request(&fourth_request).await.unwrap(); + + // Get all requests by app_id + let result = db.get_requests_by_app_id(&app.app_id).await.unwrap(); + assert_eq!(result.len(), 4); + + assert_eq!(result[0], fourth_request); + assert_eq!(result[1], third_request); + assert_eq!(result[2], second_request); + assert_eq!(result[3], request); + } + + #[tokio::test] + async fn test_data_ranges() { + let db = super::Db::connect_to_the_pool().await; + db.truncate_all_tables().await.unwrap(); + + // "Register" an app + let app = RegisteredApp { + app_id: "test_app_id".to_string(), + app_name: "test_app_name".to_string(), + whitelisted_domains: vec!["test_domain".to_string()], + subscription: None, + ack_public_keys: vec!["test_key".to_string()], + email: None, + registration_timestamp: 0, + pass_hash: None, + }; + + db.register_new_app(&app).await.unwrap(); + + let result = db.get_registered_app_by_app_id(&app.app_id).await.unwrap(); + assert_eq!(app, result); + + // Create session + let session = DbNcSession { + session_id: "test_session_id".to_string(), + app_id: "test_app_id".to_string(), + app_metadata: "test_app_metadata".to_string(), + app_ip_address: "test_app_ip_address".to_string(), + persistent: false, + network: "test_network".to_string(), + client: None, + session_open_timestamp: 10, + session_close_timestamp: None, + }; + + db.save_new_session(&session).await.unwrap(); + + let result = db.get_sessions_by_app_id(&app.app_id).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(session, result[0]); + + let now = get_timestamp_in_milliseconds(); + // Create requests across last 33 days, 3 requests per day + for i in 0..33 { + for j in 0..3 { + let request = Request { + request_id: format!("test_request_id_{}_{}", i, j), + session_id: "test_session_id".to_string(), + network: "test_network".to_string(), + creation_timestamp: (now - (i * 24 * 60 * 60 * 1000) - ((j + 1) * 10000)) + as u64, + request_status: RequestStatus::Pending, + request_type: "test_request_type".to_string(), + }; + + db.save_request(&request).await.unwrap(); + } + } + + // Query last 30 days + let result = db + .get_requests_by_app_id_with_filter(&app.app_id, LAST_30_DAYS) + .await + .unwrap(); + assert_eq!(result.len(), 30 * 3); + + // Query last 7 days + let result = db + .get_requests_by_app_id_with_filter(&app.app_id, LAST_7_DAYS) + .await + .unwrap(); + assert_eq!(result.len(), 7 * 3); + + // Query last 24 hours + let result = db + .get_requests_by_app_id_with_filter(&app.app_id, LAST_24_HOURS) + .await + .unwrap(); + assert_eq!(result.len(), 3); + } } diff --git a/database/src/tables/requests/mod.rs b/database/src/tables/requests/mod.rs new file mode 100644 index 00000000..4b2d4aa3 --- /dev/null +++ b/database/src/tables/requests/mod.rs @@ -0,0 +1,3 @@ +pub mod select; +pub mod table_struct; +pub mod update; diff --git a/database/src/tables/requests/select.rs b/database/src/tables/requests/select.rs new file mode 100644 index 00000000..92761eac --- /dev/null +++ b/database/src/tables/requests/select.rs @@ -0,0 +1,31 @@ +use super::table_struct::{Request, REQUESTS_TABLE_NAME}; +use crate::db::Db; +use sqlx::query_as; + +impl Db { + pub async fn get_requests_by_session_id( + &self, + session_id: &String, + ) -> Result, sqlx::Error> { + let query = format!("SELECT * FROM {REQUESTS_TABLE_NAME} WHERE session_id = $1 ORDER BY creation_timestamp DESC"); + let typed_query = query_as::<_, Request>(&query); + + return typed_query + .bind(&session_id) + .fetch_all(&self.connection_pool) + .await; + } + + pub async fn get_request_by_request_id( + &self, + request_id: &String, + ) -> Result, sqlx::Error> { + let query = format!("SELECT * FROM {REQUESTS_TABLE_NAME} WHERE request_id = $1"); + let typed_query = query_as::<_, Request>(&query); + + return typed_query + .bind(&request_id) + .fetch_optional(&self.connection_pool) + .await; + } +} diff --git a/database/src/tables/requests/table_struct.rs b/database/src/tables/requests/table_struct.rs new file mode 100644 index 00000000..eb82e1aa --- /dev/null +++ b/database/src/tables/requests/table_struct.rs @@ -0,0 +1,31 @@ +use sqlx::{postgres::PgRow, FromRow, Row}; + +use crate::structs::request_status::RequestStatus; + +pub const REQUESTS_TABLE_NAME: &str = "requests"; +pub const REQUESTS_KEYS: &str = + "request_id, request_type, session_id, request_status, network, creation_timestamp"; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Request { + pub request_id: String, + pub request_type: String, + pub session_id: String, + pub request_status: RequestStatus, + pub network: String, + pub creation_timestamp: u64, +} + +impl FromRow<'_, PgRow> for Request { + fn from_row(row: &sqlx::postgres::PgRow) -> std::result::Result { + let creation_timestamp: i64 = row.get("creation_timestamp"); + Ok(Request { + request_id: row.get("request_id"), + request_type: row.get("request_type"), + session_id: row.get("session_id"), + request_status: row.get("request_status"), + network: row.get("network"), + creation_timestamp: creation_timestamp as u64, + }) + } +} diff --git a/database/src/tables/requests/update.rs b/database/src/tables/requests/update.rs new file mode 100644 index 00000000..1bfe57b1 --- /dev/null +++ b/database/src/tables/requests/update.rs @@ -0,0 +1,148 @@ +use super::table_struct::{Request, REQUESTS_KEYS, REQUESTS_TABLE_NAME}; +use crate::{db::Db, structs::request_status::RequestStatus}; +use sqlx::query; + +impl Db { + pub async fn save_request(&self, request: &Request) -> Result<(), sqlx::Error> { + let query_body = format!( + "INSERT INTO {REQUESTS_TABLE_NAME} ({}) VALUES ($1, $2, $3, $4, $5, $6)", + REQUESTS_KEYS + ); + + let query_result = query(&query_body) + .bind(&request.request_id) + .bind(&request.request_type) + .bind(&request.session_id) + .bind(&request.request_status) + .bind(&request.network) + .bind(&(request.creation_timestamp as i64)) + .execute(&self.connection_pool) + .await; + + match query_result { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + + pub async fn update_request_status( + &self, + request_id: &String, + new_status: &RequestStatus, + ) -> Result<(), sqlx::Error> { + let query_body = + format!("UPDATE {REQUESTS_TABLE_NAME} SET request_status = $1 WHERE request_id = $2"); + let query_result = query(&query_body) + .bind(new_status) + .bind(request_id) + .execute(&self.connection_pool) + .await; + + match query_result { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::{ + structs::client_data::ClientData, + tables::{ + registered_app::table_struct::RegisteredApp, sessions::table_struct::DbNcSession, + }, + }; + + #[tokio::test] + async fn test_requests() { + let db = super::Db::connect_to_the_pool().await; + db.truncate_all_tables().await.unwrap(); + + // Create basic app to satisfy foreign key constraint + let app = RegisteredApp { + app_id: "test_app_id".to_string(), + app_name: "test_app_name".to_string(), + whitelisted_domains: vec!["test_domain".to_string()], + subscription: None, + ack_public_keys: vec!["test_key".to_string()], + email: Some("test_email".to_string()), + registration_timestamp: 10, + pass_hash: Some("test_pass_hash".to_string()), + }; + db.register_new_app(&app).await.unwrap(); + + // Create basic session to satisfy foreign key constraint + let session = DbNcSession { + session_id: "test_session_id".to_string(), + app_id: "test_app_id".to_string(), + app_metadata: "test_app_metadata".to_string(), + app_ip_address: "test_app_ip_address".to_string(), + persistent: false, + network: "test_network".to_string(), + client: Some(ClientData { + client_id: Some("test_client_id".to_string()), + device: Some("test_device".to_string()), + metadata: Some("test_metadata".to_string()), + notification_endpoint: Some("test_notification_endpoint".to_string()), + connected_at: 12, + }), + session_open_timestamp: 10, + session_close_timestamp: None, + }; + + // Create a new session entry + db.save_new_session(&session).await.unwrap(); + + let request = Request { + request_id: "test_request_id".to_string(), + request_type: "test_request_type".to_string(), + session_id: "test_session_id".to_string(), + request_status: RequestStatus::Pending, + network: "test_network".to_string(), + creation_timestamp: 10, + }; + + db.save_request(&request).await.unwrap(); + + let requests = db + .get_requests_by_session_id(&request.session_id) + .await + .unwrap(); + assert_eq!(requests.len(), 1); + assert_eq!(request, requests[0]); + + let second_request = Request { + request_id: "test_request_id2".to_string(), + request_type: "test_request_type".to_string(), + session_id: "test_session_id".to_string(), + request_status: RequestStatus::Pending, + network: "test_network".to_string(), + creation_timestamp: 12, + }; + + db.save_request(&second_request).await.unwrap(); + + let requests = db + .get_requests_by_session_id(&request.session_id) + .await + .unwrap(); + assert_eq!(requests.len(), 2); + assert_eq!(second_request, requests[0]); + assert_eq!(request, requests[1]); + + db.update_request_status(&request.request_id, &RequestStatus::Completed) + .await + .unwrap(); + + let request = db + .get_request_by_request_id(&request.request_id) + .await + .unwrap() + .unwrap(); + + assert_eq!(request.request_status, RequestStatus::Completed); + } +} diff --git a/database/src/tables/sessions/mod.rs b/database/src/tables/sessions/mod.rs new file mode 100644 index 00000000..4b2d4aa3 --- /dev/null +++ b/database/src/tables/sessions/mod.rs @@ -0,0 +1,3 @@ +pub mod select; +pub mod table_struct; +pub mod update; diff --git a/database/src/tables/sessions/select.rs b/database/src/tables/sessions/select.rs new file mode 100644 index 00000000..a9772370 --- /dev/null +++ b/database/src/tables/sessions/select.rs @@ -0,0 +1,45 @@ +use super::table_struct::{DbNcSession, SESSIONS_TABLE_NAME}; +use crate::db::Db; +use crate::tables::requests::table_struct::{Request, REQUESTS_TABLE_NAME}; +use sqlx::query_as; + +impl Db { + pub async fn get_sessions_by_app_id( + &self, + app_id: &String, + ) -> Result, sqlx::Error> { + let query = format!("SELECT * FROM {SESSIONS_TABLE_NAME} WHERE app_id = $1 ORDER BY session_open_timestamp DESC"); + let typed_query = query_as::<_, DbNcSession>(&query); + + return typed_query + .bind(&app_id) + .fetch_all(&self.connection_pool) + .await; + } + + pub async fn get_session_by_session_id( + &self, + session_id: &String, + ) -> Result, sqlx::Error> { + let query = format!("SELECT * FROM {SESSIONS_TABLE_NAME} WHERE session_id = $1"); + let typed_query = query_as::<_, DbNcSession>(&query); + + return typed_query + .bind(&session_id) + .fetch_optional(&self.connection_pool) + .await; + } + + pub async fn get_session_requests( + &self, + session_id: &String, + ) -> Result, sqlx::Error> { + let query = format!("SELECT * FROM {REQUESTS_TABLE_NAME} WHERE session_id = $1"); + let typed_query = query_as::<_, Request>(&query); + + return typed_query + .bind(&session_id) + .fetch_all(&self.connection_pool) + .await; + } +} diff --git a/database/src/tables/sessions/table_struct.rs b/database/src/tables/sessions/table_struct.rs new file mode 100644 index 00000000..d1b436ca --- /dev/null +++ b/database/src/tables/sessions/table_struct.rs @@ -0,0 +1,48 @@ +use crate::structs::client_data::ClientData; +use sqlx::{postgres::PgRow, FromRow, Row}; + +pub const SESSIONS_TABLE_NAME: &str = "sessions"; +pub const SESSIONS_KEYS: &str = + "session_id, app_id, app_metadata, app_ip_address, persistent, network, client_id, client_device, client_metadata, client_notification_endpoint, client_connected_at, session_open_timestamp, session_close_timestamp"; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct DbNcSession { + pub session_id: String, + pub app_id: String, + pub app_metadata: String, + pub app_ip_address: String, + pub persistent: bool, + pub network: String, + pub client: Option, // Some if user has ever connected to the session + pub session_open_timestamp: u64, + pub session_close_timestamp: Option, +} + +impl FromRow<'_, PgRow> for DbNcSession { + fn from_row(row: &sqlx::postgres::PgRow) -> std::result::Result { + let session_open_timestamp: i64 = row.get("session_open_timestamp"); + let session_close_timestamp: Option = row.get("session_close_timestamp"); + let client_connected_at: Option = row.get("client_connected_at"); + Ok(DbNcSession { + app_id: row.get("app_id"), + app_metadata: row.get("app_metadata"), + app_ip_address: row.get("app_ip_address"), + persistent: row.get("persistent"), + network: row.get("network"), + session_id: row.get("session_id"), + // If client has ever connected to the session, return the client data + client: match client_connected_at { + Some(client_connected_at) => Some(ClientData { + client_id: row.get("client_id"), + device: row.get("client_device"), + metadata: row.get("client_metadata"), + notification_endpoint: row.get("client_notification_endpoint"), + connected_at: client_connected_at as u64, + }), + None => None, + }, + session_open_timestamp: session_open_timestamp as u64, + session_close_timestamp: session_close_timestamp.map(|x| x as u64), + }) + } +} diff --git a/database/src/tables/sessions/update.rs b/database/src/tables/sessions/update.rs new file mode 100644 index 00000000..f97ae3a3 --- /dev/null +++ b/database/src/tables/sessions/update.rs @@ -0,0 +1,172 @@ +use super::table_struct::{DbNcSession, SESSIONS_KEYS, SESSIONS_TABLE_NAME}; +use crate::db::Db; +use sqlx::query; + +impl Db { + pub async fn save_new_session(&self, session: &DbNcSession) -> Result<(), sqlx::Error> { + let query_body = format!( + "INSERT INTO {SESSIONS_TABLE_NAME} ({}) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)", + SESSIONS_KEYS + ); + + let (client_id, device, metadata, notification_endpoint, connected_at) = + match &session.client { + Some(client) => ( + &client.client_id, + &client.device, + &client.metadata, + &client.notification_endpoint, + Some(client.connected_at.clone() as i64), + ), + None => (&None, &None, &None, &None, None), + }; + + let query_result = query(&query_body) + .bind(&session.session_id) + .bind(&session.app_id) + .bind(&session.app_metadata) + .bind(&session.app_ip_address) + .bind(&session.persistent) + .bind(&session.network) + .bind(&client_id) + .bind(&device) + .bind(&metadata) + .bind(¬ification_endpoint) + .bind(&connected_at) + .bind(&(session.session_open_timestamp as i64)) + .bind(&None::) + .execute(&self.connection_pool) + .await; + + match query_result { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + + pub async fn close_session( + &self, + session_id: &String, + close_timestamp: u64, + ) -> Result<(), sqlx::Error> { + let query_body = format!( + "UPDATE {SESSIONS_TABLE_NAME} SET session_close_timestamp = $1 WHERE session_id = $2" + ); + + let query_result = query(&query_body) + .bind(close_timestamp as i64) + .bind(session_id) + .execute(&self.connection_pool) + .await; + + match query_result { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::{ + structs::{client_data::ClientData, request_status::RequestStatus}, + tables::{registered_app::table_struct::RegisteredApp, requests::table_struct::Request}, + }; + + #[tokio::test] + async fn test_sessions() { + let db = super::Db::connect_to_the_pool().await; + db.truncate_all_tables().await.unwrap(); + + // Create basic app to satisfy foreign key constraint + let app = RegisteredApp { + app_id: "test_app_id".to_string(), + app_name: "test_app_name".to_string(), + whitelisted_domains: vec!["test_domain".to_string()], + subscription: None, + ack_public_keys: vec!["test_key".to_string()], + email: Some("test_email".to_string()), + registration_timestamp: 10, + pass_hash: Some("test_pass_hash".to_string()), + }; + db.register_new_app(&app).await.unwrap(); + + let session = DbNcSession { + session_id: "test_session_id".to_string(), + app_id: "test_app_id".to_string(), + app_metadata: "test_app_metadata".to_string(), + app_ip_address: "test_app_ip_address".to_string(), + persistent: false, + network: "test_network".to_string(), + client: Some(ClientData { + client_id: Some("test_client_id".to_string()), + device: Some("test_device".to_string()), + metadata: Some("test_metadata".to_string()), + notification_endpoint: Some("test_notification_endpoint".to_string()), + connected_at: 12, + }), + session_open_timestamp: 10, + session_close_timestamp: None, + }; + + // Create a new session entry + db.save_new_session(&session).await.unwrap(); + + // Get all sessions by app_id + let sessions = db.get_sessions_by_app_id(&session.app_id).await.unwrap(); + assert_eq!(sessions.len(), 1); + assert_eq!(session, sessions[0]); + + // Get session by session_id + let session = db + .get_session_by_session_id(&session.session_id) + .await + .unwrap() + .unwrap(); + assert_eq!(session, session); + + // Change the session status to closed + db.close_session(&session.session_id, 15).await.unwrap(); + + // Get session by session_id to check if the session status is closed + let session = db + .get_session_by_session_id(&session.session_id) + .await + .unwrap() + .unwrap(); + assert_eq!(session.session_close_timestamp, Some(15)); + + // Create a few requests for the session + let request = Request { + request_id: "test_request_id".to_string(), + request_type: "test_request_type".to_string(), + session_id: "test_session_id".to_string(), + request_status: RequestStatus::Pending, + network: "test_network".to_string(), + creation_timestamp: 13, + }; + + let second_request = Request { + request_id: "test_request_id2".to_string(), + request_type: "test_request_type".to_string(), + session_id: "test_session_id".to_string(), + request_status: RequestStatus::Pending, + network: "test_network".to_string(), + creation_timestamp: 13, + }; + + db.save_request(&request).await.unwrap(); + db.save_request(&second_request).await.unwrap(); + + // Get all requests by session_id + let requests = db + .get_requests_by_session_id(&request.session_id) + .await + .unwrap(); + + assert_eq!(requests.len(), 2); + assert_eq!(request, requests[0]); + } +} diff --git a/database/src/tables/test_utils.rs b/database/src/tables/test_utils.rs new file mode 100644 index 00000000..f4c29757 --- /dev/null +++ b/database/src/tables/test_utils.rs @@ -0,0 +1,29 @@ +#[cfg(test)] +pub mod test_utils { + use crate::db::Db; + use sqlx::Row; + + impl Db { + pub async fn truncate_all_tables(&self) -> Result<(), sqlx::Error> { + let rows = sqlx::query( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'", + ) + .fetch_all(&self.connection_pool) + .await?; + + // Join all names except _sqlx_migrations into a single string and run single truncate + let tables_names = rows + .iter() + .map(|row| row.get::("table_name")) + .filter(|table_name| !table_name.starts_with("_sqlx_migrations")) + .collect::>() + .join(", "); + + let query = format!("TRUNCATE TABLE {tables_names} CASCADE"); + sqlx::query(&query).execute(&self.connection_pool).await?; + Ok(()) + } + } + + +} diff --git a/database/src/tables/utils.rs b/database/src/tables/utils.rs new file mode 100644 index 00000000..b22a0b48 --- /dev/null +++ b/database/src/tables/utils.rs @@ -0,0 +1,7 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +pub fn get_timestamp_in_milliseconds() -> u64 { + let now = SystemTime::now(); + let since_the_epoch = now.duration_since(UNIX_EPOCH).expect("Time went backwards"); + since_the_epoch.as_millis() as u64 +}