Skip to content

Commit

Permalink
Avoid exposing model indicator map directly
Browse files Browse the repository at this point in the history
  • Loading branch information
minshao authored and msk committed Feb 20, 2024
1 parent 41ebf11 commit e13263a
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 129 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ Versioning](https://semver.org/spec/v2.0.0.html).
enhance security by preventing direct exposure of `Map`.
- Modified `Filter` struct to include the `username` property, representing the
associated username for the specific `Filter`.
- Changed the return type of `Store::model_indicator_map` to `Table<ModelIndicator>`
to enhance security by preventing direct exposure of `Map`.
- Moved member functions of `ModelIndicator` that are related to database operations
under `Table<ModelIndicator>` to facilitate insert, remove, update, get and
list operations, ensuring a more controlled and secure model indicator management
and improved code organization.
- Modified `ModelIndicator` struct to include the `name` property, representing the
associated name for the specific `ModelIndicator`.

### Deprecated

Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "review-database"
version = "0.25.0-alpha.7"
version = "0.25.0-alpha.8"
edition = "2021"

[dependencies]
Expand Down
14 changes: 7 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ pub use self::migration::{migrate_backend, migrate_data_dir};
pub use self::model::{Digest as ModelDigest, Model};
pub use self::outlier::*;
use self::tables::StateDb;
pub use self::tables::{AccessToken, Filter, IndexedTable, Iterable, Table, UniqueKey};
pub use self::tables::{
AccessToken, Filter, IndexedTable, Iterable, ModelIndicator, Table, UniqueKey,
};
pub use self::ti::{Tidb, TidbKind, TidbRule};
pub use self::time_series::*;
pub use self::time_series::{ColumnTimeSeries, TimeCount, TimeSeriesResult};
Expand All @@ -67,8 +69,8 @@ pub use self::top_n::{
pub use self::traffic_filter::{ProtocolPorts, TrafficFilter};
pub use self::types::{
AttrCmpKind, Confidence, Customer, CustomerNetwork, DataSource, DataType, EventCategory,
HostNetworkGroup, ModelIndicator, PacketAttr, Qualifier, Response, ResponseKind, Status, Ti,
TiCmpKind, TriagePolicy, ValueKind,
HostNetworkGroup, PacketAttr, Qualifier, Response, ResponseKind, Status, Ti, TiCmpKind,
TriagePolicy, ValueKind,
};
use anyhow::{anyhow, Result};
use backends::Value;
Expand Down Expand Up @@ -221,10 +223,8 @@ impl Store {

#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn model_indicator_map(&self) -> Map {
self.states
.map(tables::MODEL_INDICATORS)
.expect("always available")
pub fn model_indicator_map(&self) -> Table<ModelIndicator> {
self.states.model_indicators()
}

#[must_use]
Expand Down
2 changes: 1 addition & 1 deletion src/migration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use tracing::info;
/// // the database format won't be changed in the future alpha or beta versions.
/// const COMPATIBLE_VERSION: &str = ">=0.5.0-alpha.2,<=0.5.0-alpha.4";
/// ```
const COMPATIBLE_VERSION_REQ: &str = ">=0.24.0,<=0.25.0-alpha.7";
const COMPATIBLE_VERSION_REQ: &str = ">=0.24.0,<=0.25.0-alpha.8";

/// Migrates data exists in `PostgresQL` to Rocksdb if necessary.
///
Expand Down
8 changes: 8 additions & 0 deletions src/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod batch_info;
mod category;
mod csv_column_extra;
mod filter;
mod model_indicator;
mod qualifier;
mod scores;
mod status;
Expand All @@ -27,6 +28,7 @@ use std::{

pub use self::access_token::AccessToken;
pub use self::filter::Filter;
pub use self::model_indicator::ModelIndicator;

// Key-value map names in `Database`.
pub(super) const ACCESS_TOKENS: &str = "access_tokens";
Expand Down Expand Up @@ -134,6 +136,12 @@ impl StateDb {
Table::<Filter>::open(inner).expect("{FILTERS} table must be present")
}

#[must_use]
pub(crate) fn model_indicators(&self) -> Table<ModelIndicator> {
let inner = self.inner.as_ref().expect("database must be open");
Table::<ModelIndicator>::open(inner).expect("{MODEL_INDICATORS} table must be present")
}

#[must_use]
pub(crate) fn scores(&self) -> Table<Scores> {
let inner = self.inner.as_ref().expect("database must be open");
Expand Down
206 changes: 206 additions & 0 deletions src/tables/model_indicator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
//! The `model_indicator` map.
use std::{
collections::HashSet,
io::{BufReader, Read},
};

use anyhow::Result;
use chrono::{serde::ts_seconds, DateTime, Utc};
use data_encoding::BASE64;
use flate2::read::GzDecoder;
use rocksdb::OptimisticTransactionDB;
use serde::{Deserialize, Serialize};

use crate::{types::FromKeyValue, Map, Table};

#[derive(Default)]
pub struct ModelIndicator {
pub name: String,
pub description: String,
pub model_id: i32,
pub tokens: HashSet<Vec<String>>,
pub last_modification_time: DateTime<Utc>,
}

impl ModelIndicator {
/// Creates a new `ModelIndicator` from the given data.
///
/// # Errors
///
/// Returns an error if the given data is invalid.
pub fn new(name: &str, data: &str) -> Result<Self> {
let data = BASE64.decode(data.as_bytes())?;
let decoder = GzDecoder::new(&data[..]);
let mut buf = Vec::new();
let mut reader = BufReader::new(decoder);
reader.read_to_end(&mut buf)?;

Self::from_key_value(name.as_bytes(), &buf)
}

fn into_key_value(self) -> Result<(Vec<u8>, Vec<u8>)> {
let key = self.name.as_bytes().to_owned();
let value = Value {
description: self.description,
model_id: self.model_id,
tokens: self.tokens,
last_modification_time: self.last_modification_time,
};
Ok((key, super::serialize(&value)?))
}
}

#[derive(Deserialize, Serialize)]
struct Value {
description: String,
model_id: i32,
tokens: HashSet<Vec<String>>,
#[serde(with = "ts_seconds")]
last_modification_time: DateTime<Utc>,
}

impl FromKeyValue for ModelIndicator {
fn from_key_value(key: &[u8], value: &[u8]) -> Result<Self> {
let name = std::str::from_utf8(key)?.to_string();
let value: Value = super::deserialize(value)?;
Ok(Self {
name,
description: value.description,
model_id: value.model_id,
tokens: value.tokens,
last_modification_time: value.last_modification_time,
})
}
}

/// Functions for the `model_indicator` map.
impl<'d> Table<'d, ModelIndicator> {
/// Opens the `model_indicator` map in the database.
///
/// Returns `None` if the map does not exist.
pub(super) fn open(db: &'d OptimisticTransactionDB) -> Option<Self> {
Map::open(db, super::MODEL_INDICATORS).map(Table::new)
}

/// Returns the `ModelIndicator` with the given name.
///
/// # Errors
///
/// Returns an error if the database query fails.
pub fn get(&self, name: &str) -> Result<Option<ModelIndicator>> {
self.map
.get(name.as_bytes())?
.map(|v| ModelIndicator::from_key_value(name.as_bytes(), v.as_ref()))
.transpose()
}

/// Inserts the `ModelIndicator` into the database.
///
/// # Errors
///
/// Returns an error if the serialization fails or the database operation fails.
pub fn insert(&self, indicator: ModelIndicator) -> Result<()> {
let (key, value) = indicator.into_key_value()?;
self.map.put(&key, &value)
}

/// Removes the `ModelIndicator`s with the given names. The removed names are returned.
///
/// # Errors
///
/// Returns an error if the database operation fails.
pub fn remove<'a>(&self, names: impl Iterator<Item = &'a str>) -> Result<Vec<String>> {
let mut removed = vec![];
for name in names {
self.map.delete(name.as_bytes())?;
removed.push(name.to_string());
}
Ok(removed)
}

/// Updates the `ModelIndicator` in the database.
///
/// # Errors
///
/// Returns an error if the serialization fails or the database operation fails.
pub fn update(&self, indicator: ModelIndicator) -> Result<()> {
self.remove(std::iter::once(indicator.name.as_str()))?;
self.insert(indicator)
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::{ModelIndicator, Store};

#[test]
fn serde() {
use chrono::Utc;
use data_encoding::BASE64;
use flate2::{bufread::GzEncoder, Compression};
use std::collections::HashSet;
use std::io::{Cursor, Read};

let name = "mi_1";
let value = super::Value {
description: "test".to_owned(),
model_id: 123,
tokens: HashSet::new(),
last_modification_time: Utc::now(),
};
let serialized = crate::tables::serialize(&value).unwrap();
let cursor = Cursor::new(serialized);

let mut gz = GzEncoder::new(cursor, Compression::fast());
let mut zipped = Vec::new();
gz.read_to_end(&mut zipped).unwrap();
let encoded = BASE64.encode(&zipped);
let res = super::ModelIndicator::new(name, &encoded);

assert!(res.is_ok());
let indicator = res.unwrap();
assert_eq!(indicator.name, "mi_1");
assert_eq!(indicator.description, "test");
}

#[test]
fn operations() {
use crate::Iterable;
let db_dir = tempfile::tempdir().unwrap();
let backup_dir = tempfile::tempdir().unwrap();
let store = Arc::new(Store::new(db_dir.path(), backup_dir.path()).unwrap());
let table = store.model_indicator_map();

let tester = &["1", "2", "3"];
for name in tester {
let mut mi = ModelIndicator::default();
mi.name = name.to_string();

assert!(table.insert(mi).is_ok());
}

for name in tester {
let res = table.get(name).unwrap().map(|mi| mi.name);
assert_eq!(Some(name.to_string()), res);
}

let res: anyhow::Result<Vec<_>> = table
.iter(crate::Direction::Forward, None)
.map(|r| r.map(|mi| mi.name))
.collect();
assert!(res.is_ok());
let list = res.unwrap();
assert_eq!(
tester.to_vec(),
list.iter().map(String::as_str).collect::<Vec<_>>()
);

let res = table.remove(list.iter().map(String::as_str));
assert!(res.is_ok());
let removed = res.unwrap();
assert_eq!(removed, list);
}
}
Loading

0 comments on commit e13263a

Please sign in to comment.