From e13263aad6f1d802f02c20b7854c9dabca1941e2 Mon Sep 17 00:00:00 2001 From: Min Shao Date: Tue, 20 Feb 2024 12:18:02 -0800 Subject: [PATCH] Avoid exposing model indicator map directly --- CHANGELOG.md | 8 ++ Cargo.toml | 2 +- src/lib.rs | 14 +-- src/migration.rs | 2 +- src/tables.rs | 8 ++ src/tables/model_indicator.rs | 206 ++++++++++++++++++++++++++++++++++ src/types.rs | 124 +------------------- 7 files changed, 235 insertions(+), 129 deletions(-) create mode 100644 src/tables/model_indicator.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index aaa7bda..d666e0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,14 @@ Versioning](https://semver.org/spec/v2.0.0.html). enhance security by preventing direct exposure of `Map`. - Modified `Filter` struct to include the `username` property, representing the associated username for the specific `Filter`. +- Changed the return type of `Store::model_indicator_map` to `Table` + to enhance security by preventing direct exposure of `Map`. +- Moved member functions of `ModelIndicator` that are related to database operations + under `Table` to facilitate insert, remove, update, get and + list operations, ensuring a more controlled and secure model indicator management + and improved code organization. +- Modified `ModelIndicator` struct to include the `name` property, representing the + associated name for the specific `ModelIndicator`. ### Deprecated diff --git a/Cargo.toml b/Cargo.toml index b17cc37..58ff1b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "review-database" -version = "0.25.0-alpha.7" +version = "0.25.0-alpha.8" edition = "2021" [dependencies] diff --git a/src/lib.rs b/src/lib.rs index 5e5ce51..5184bd5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,7 +55,9 @@ pub use self::migration::{migrate_backend, migrate_data_dir}; pub use self::model::{Digest as ModelDigest, Model}; pub use self::outlier::*; use self::tables::StateDb; -pub use self::tables::{AccessToken, Filter, IndexedTable, Iterable, Table, UniqueKey}; +pub use self::tables::{ + AccessToken, Filter, IndexedTable, Iterable, ModelIndicator, Table, UniqueKey, +}; pub use self::ti::{Tidb, TidbKind, TidbRule}; pub use self::time_series::*; pub use self::time_series::{ColumnTimeSeries, TimeCount, TimeSeriesResult}; @@ -67,8 +69,8 @@ pub use self::top_n::{ pub use self::traffic_filter::{ProtocolPorts, TrafficFilter}; pub use self::types::{ AttrCmpKind, Confidence, Customer, CustomerNetwork, DataSource, DataType, EventCategory, - HostNetworkGroup, ModelIndicator, PacketAttr, Qualifier, Response, ResponseKind, Status, Ti, - TiCmpKind, TriagePolicy, ValueKind, + HostNetworkGroup, PacketAttr, Qualifier, Response, ResponseKind, Status, Ti, TiCmpKind, + TriagePolicy, ValueKind, }; use anyhow::{anyhow, Result}; use backends::Value; @@ -221,10 +223,8 @@ impl Store { #[must_use] #[allow(clippy::missing_panics_doc)] - pub fn model_indicator_map(&self) -> Map { - self.states - .map(tables::MODEL_INDICATORS) - .expect("always available") + pub fn model_indicator_map(&self) -> Table { + self.states.model_indicators() } #[must_use] diff --git a/src/migration.rs b/src/migration.rs index ed96e2a..f888842 100644 --- a/src/migration.rs +++ b/src/migration.rs @@ -32,7 +32,7 @@ use tracing::info; /// // the database format won't be changed in the future alpha or beta versions. /// const COMPATIBLE_VERSION: &str = ">=0.5.0-alpha.2,<=0.5.0-alpha.4"; /// ``` -const COMPATIBLE_VERSION_REQ: &str = ">=0.24.0,<=0.25.0-alpha.7"; +const COMPATIBLE_VERSION_REQ: &str = ">=0.24.0,<=0.25.0-alpha.8"; /// Migrates data exists in `PostgresQL` to Rocksdb if necessary. /// diff --git a/src/tables.rs b/src/tables.rs index de06385..34631a1 100644 --- a/src/tables.rs +++ b/src/tables.rs @@ -4,6 +4,7 @@ mod batch_info; mod category; mod csv_column_extra; mod filter; +mod model_indicator; mod qualifier; mod scores; mod status; @@ -27,6 +28,7 @@ use std::{ pub use self::access_token::AccessToken; pub use self::filter::Filter; +pub use self::model_indicator::ModelIndicator; // Key-value map names in `Database`. pub(super) const ACCESS_TOKENS: &str = "access_tokens"; @@ -134,6 +136,12 @@ impl StateDb { Table::::open(inner).expect("{FILTERS} table must be present") } + #[must_use] + pub(crate) fn model_indicators(&self) -> Table { + let inner = self.inner.as_ref().expect("database must be open"); + Table::::open(inner).expect("{MODEL_INDICATORS} table must be present") + } + #[must_use] pub(crate) fn scores(&self) -> Table { let inner = self.inner.as_ref().expect("database must be open"); diff --git a/src/tables/model_indicator.rs b/src/tables/model_indicator.rs new file mode 100644 index 0000000..f375196 --- /dev/null +++ b/src/tables/model_indicator.rs @@ -0,0 +1,206 @@ +//! The `model_indicator` map. + +use std::{ + collections::HashSet, + io::{BufReader, Read}, +}; + +use anyhow::Result; +use chrono::{serde::ts_seconds, DateTime, Utc}; +use data_encoding::BASE64; +use flate2::read::GzDecoder; +use rocksdb::OptimisticTransactionDB; +use serde::{Deserialize, Serialize}; + +use crate::{types::FromKeyValue, Map, Table}; + +#[derive(Default)] +pub struct ModelIndicator { + pub name: String, + pub description: String, + pub model_id: i32, + pub tokens: HashSet>, + pub last_modification_time: DateTime, +} + +impl ModelIndicator { + /// Creates a new `ModelIndicator` from the given data. + /// + /// # Errors + /// + /// Returns an error if the given data is invalid. + pub fn new(name: &str, data: &str) -> Result { + let data = BASE64.decode(data.as_bytes())?; + let decoder = GzDecoder::new(&data[..]); + let mut buf = Vec::new(); + let mut reader = BufReader::new(decoder); + reader.read_to_end(&mut buf)?; + + Self::from_key_value(name.as_bytes(), &buf) + } + + fn into_key_value(self) -> Result<(Vec, Vec)> { + let key = self.name.as_bytes().to_owned(); + let value = Value { + description: self.description, + model_id: self.model_id, + tokens: self.tokens, + last_modification_time: self.last_modification_time, + }; + Ok((key, super::serialize(&value)?)) + } +} + +#[derive(Deserialize, Serialize)] +struct Value { + description: String, + model_id: i32, + tokens: HashSet>, + #[serde(with = "ts_seconds")] + last_modification_time: DateTime, +} + +impl FromKeyValue for ModelIndicator { + fn from_key_value(key: &[u8], value: &[u8]) -> Result { + let name = std::str::from_utf8(key)?.to_string(); + let value: Value = super::deserialize(value)?; + Ok(Self { + name, + description: value.description, + model_id: value.model_id, + tokens: value.tokens, + last_modification_time: value.last_modification_time, + }) + } +} + +/// Functions for the `model_indicator` map. +impl<'d> Table<'d, ModelIndicator> { + /// Opens the `model_indicator` map in the database. + /// + /// Returns `None` if the map does not exist. + pub(super) fn open(db: &'d OptimisticTransactionDB) -> Option { + Map::open(db, super::MODEL_INDICATORS).map(Table::new) + } + + /// Returns the `ModelIndicator` with the given name. + /// + /// # Errors + /// + /// Returns an error if the database query fails. + pub fn get(&self, name: &str) -> Result> { + self.map + .get(name.as_bytes())? + .map(|v| ModelIndicator::from_key_value(name.as_bytes(), v.as_ref())) + .transpose() + } + + /// Inserts the `ModelIndicator` into the database. + /// + /// # Errors + /// + /// Returns an error if the serialization fails or the database operation fails. + pub fn insert(&self, indicator: ModelIndicator) -> Result<()> { + let (key, value) = indicator.into_key_value()?; + self.map.put(&key, &value) + } + + /// Removes the `ModelIndicator`s with the given names. The removed names are returned. + /// + /// # Errors + /// + /// Returns an error if the database operation fails. + pub fn remove<'a>(&self, names: impl Iterator) -> Result> { + let mut removed = vec![]; + for name in names { + self.map.delete(name.as_bytes())?; + removed.push(name.to_string()); + } + Ok(removed) + } + + /// Updates the `ModelIndicator` in the database. + /// + /// # Errors + /// + /// Returns an error if the serialization fails or the database operation fails. + pub fn update(&self, indicator: ModelIndicator) -> Result<()> { + self.remove(std::iter::once(indicator.name.as_str()))?; + self.insert(indicator) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ModelIndicator, Store}; + + #[test] + fn serde() { + use chrono::Utc; + use data_encoding::BASE64; + use flate2::{bufread::GzEncoder, Compression}; + use std::collections::HashSet; + use std::io::{Cursor, Read}; + + let name = "mi_1"; + let value = super::Value { + description: "test".to_owned(), + model_id: 123, + tokens: HashSet::new(), + last_modification_time: Utc::now(), + }; + let serialized = crate::tables::serialize(&value).unwrap(); + let cursor = Cursor::new(serialized); + + let mut gz = GzEncoder::new(cursor, Compression::fast()); + let mut zipped = Vec::new(); + gz.read_to_end(&mut zipped).unwrap(); + let encoded = BASE64.encode(&zipped); + let res = super::ModelIndicator::new(name, &encoded); + + assert!(res.is_ok()); + let indicator = res.unwrap(); + assert_eq!(indicator.name, "mi_1"); + assert_eq!(indicator.description, "test"); + } + + #[test] + fn operations() { + use crate::Iterable; + let db_dir = tempfile::tempdir().unwrap(); + let backup_dir = tempfile::tempdir().unwrap(); + let store = Arc::new(Store::new(db_dir.path(), backup_dir.path()).unwrap()); + let table = store.model_indicator_map(); + + let tester = &["1", "2", "3"]; + for name in tester { + let mut mi = ModelIndicator::default(); + mi.name = name.to_string(); + + assert!(table.insert(mi).is_ok()); + } + + for name in tester { + let res = table.get(name).unwrap().map(|mi| mi.name); + assert_eq!(Some(name.to_string()), res); + } + + let res: anyhow::Result> = table + .iter(crate::Direction::Forward, None) + .map(|r| r.map(|mi| mi.name)) + .collect(); + assert!(res.is_ok()); + let list = res.unwrap(); + assert_eq!( + tester.to_vec(), + list.iter().map(String::as_str).collect::>() + ); + + let res = table.remove(list.iter().map(String::as_str)); + assert!(res.is_ok()); + let removed = res.unwrap(); + assert_eq!(removed, list); + } +} diff --git a/src/types.rs b/src/types.rs index 5b12e18..154e9f6 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,24 +1,12 @@ -use super::{Indexable, IterableMap, NetworkType, Store, TrafficDirection}; +use super::{Indexable, NetworkType, TrafficDirection}; pub use crate::account::{Account, Role}; -use anyhow::{bail, Context, Result}; +use anyhow::Result; use bincode::Options; -use chrono::{ - naive::serde::ts_nanoseconds_option, serde::ts_seconds, DateTime, NaiveDateTime, Utc, -}; -use data_encoding::BASE64; -use flate2::read::GzDecoder; +use chrono::{naive::serde::ts_nanoseconds_option, DateTime, NaiveDateTime, Utc}; use ipnet::IpNet; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; -use std::{ - borrow::Cow, - cmp::Ordering, - collections::HashSet, - convert::TryFrom, - io::{BufReader, Read}, - net::IpAddr, - ops::RangeInclusive, -}; +use std::{borrow::Cow, cmp::Ordering, convert::TryFrom, net::IpAddr, ops::RangeInclusive}; use strum_macros::Display; pub trait FromKeyValue: Sized { @@ -286,110 +274,6 @@ impl HostNetworkGroup { } } -#[derive(Deserialize, Serialize)] -pub struct ModelIndicator { - pub description: String, - pub model_id: i32, - pub tokens: HashSet>, - #[serde(with = "ts_seconds")] - pub last_modification_time: DateTime, -} - -impl ModelIndicator { - /// Creates a new `ModelIndicator` from the given data. - /// - /// # Errors - /// - /// Returns an error if the given data is invalid. - pub fn new(data: &str) -> Result { - let data = BASE64.decode(data.as_bytes())?; - let decoder = GzDecoder::new(&data[..]); - let mut buf = Vec::new(); - let mut reader = BufReader::new(decoder); - reader.read_to_end(&mut buf)?; - let indicator = match bincode::deserialize::(&buf) { - Ok(v) => v, - Err(e) => bail!("failed to deserialize. {:?}", e), - }; - Ok(indicator) - } - - /// Gets the `ModelIndicator` with the given name. - /// - /// # Errors - /// - /// Returns an error if the database operation fails or the value in the database is invalid. - pub fn get(store: &Store, name: &str) -> Result> { - let map = store.model_indicator_map(); - Ok(match map.get(name.as_bytes())? { - Some(v) => Some( - bincode::DefaultOptions::new() - .deserialize::(v.as_ref()) - .context("invalid value in database")?, - ), - None => None, - }) - } - - /// Gets the list of all `ModelIndicator`s, sorted by name. - /// - /// # Errors - /// - /// Returns an error if the database operation fails or the value in the database is invalid. - pub fn get_list(store: &Store) -> Result> { - let map = store.model_indicator_map(); - let mut indicators = Vec::new(); - for (name, value) in map.iter_forward()? { - let indicator = bincode::DefaultOptions::new() - .deserialize::(value.as_ref()) - .context("invalid value in database")?; - indicators.push((String::from_utf8_lossy(&name).to_string(), indicator)); - } - indicators.sort_unstable_by(|a, b| a.0.cmp(&b.0)); - Ok(indicators) - } - - /// Removes the `ModelIndicator`s with the given names. The removed names are returned. - /// - /// # Errors - /// - /// Returns an error if the database operation fails. - pub fn remove(store: &Store, names: &[String]) -> Result> { - let map = store.model_indicator_map(); - let mut removed = Vec::with_capacity(names.len()); - for name in names { - map.delete(name.as_bytes())?; - removed.push(name.to_string()); - } - Ok(removed) - } - - /// Inserts the `ModelIndicator` into the database. - /// - /// # Errors - /// - /// Returns an error if the serialization fails or the database operation fails. - pub fn insert(&self, store: &Store, name: &str) -> Result { - let map = store.model_indicator_map(); - let value = bincode::DefaultOptions::new().serialize(self)?; - map.put(name.as_bytes(), &value)?; - Ok(name.to_string()) - } - - /// Updates the `ModelIndicator` in the database. - /// - /// # Errors - /// - /// Returns an error if the serialization fails or the database operation fails. - pub fn update(&self, store: &Store, name: &str) -> Result { - let map = store.model_indicator_map(); - map.delete(name.as_bytes())?; - let value = bincode::DefaultOptions::new().serialize(&self)?; - map.put(name.as_bytes(), &value)?; - Ok(name.to_string()) - } -} - #[derive(Deserialize)] pub struct Outlier { pub id: i32,