Skip to content

Commit

Permalink
[graphql] query costing (#14463)
Browse files Browse the repository at this point in the history
  • Loading branch information
wlmyng authored Oct 27, 2023
1 parent d08ac84 commit 8904957
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 65 deletions.
2 changes: 1 addition & 1 deletion crates/sui-graphql-rpc/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::functional_group::FunctionalGroup;
// TODO: calculate proper cost limits
const MAX_QUERY_DEPTH: u32 = 10;
const MAX_QUERY_NODES: u32 = 100;
const MAX_DB_QUERY_COST: u64 = 50; // Max DB query cost (normally f64) truncated
const MAX_DB_QUERY_COST: u64 = 50000; // Max DB query cost (normally f64) truncated
const MAX_QUERY_VARIABLES: u32 = 50;
const MAX_QUERY_FRAGMENTS: u32 = 50;

Expand Down
131 changes: 75 additions & 56 deletions crates/sui-graphql-rpc/src/context_data/db_data_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ use sui_types::{
Identifier,
};

use super::DEFAULT_PAGE_SIZE;
use super::{db_query_cost::extract_cost, DEFAULT_PAGE_SIZE};

use super::sui_sdk_data_provider::convert_to_validators;

Expand All @@ -119,6 +119,8 @@ pub enum DbValidationError {
PaginationDisabledOnBalances,
#[error("Invalid owner type. Must be Address or Object")]
InvalidOwnerType,
#[error("Query cost exceeded - cost: {0}, limit: {1}")]
QueryCostExceeded(u64, u64),
}

type BalanceQuery<'a> = BoxedSelectStatement<
Expand Down Expand Up @@ -479,12 +481,12 @@ impl QueryBuilder {

pub(crate) struct PgManager {
pub inner: IndexerReader,
pub _limits: Limits,
pub limits: Limits,
}

impl PgManager {
pub(crate) fn new(inner: IndexerReader, _limits: Limits) -> Self {
Self { inner, _limits }
pub(crate) fn new(inner: IndexerReader, limits: Limits) -> Self {
Self { inner, limits }
}

/// Create a new underlying reader, which is used by this type as well as other data providers.
Expand All @@ -506,15 +508,51 @@ impl PgManager {
.await
.map_err(|e| Error::Internal(e.to_string()))
}

/// Takes a query fragment and a lambda that executes the query
/// Spawns a blocking task that determines the cost of the query fragment
/// And if within limits, executes the query
async fn run_query_async_with_cost<T, Q, EF, E, F>(
&self,
query: Q,
execute_fn: EF,
) -> Result<T, Error>
where
Q: QueryDsl
+ RunQueryDsl<Pg>
+ diesel::query_builder::QueryFragment<diesel::pg::Pg>
+ Send
+ 'static,
EF: FnOnce(Q) -> F + Send + 'static,
F: FnOnce(&mut PgConnection) -> Result<T, E> + Send + 'static,
E: From<diesel::result::Error> + std::error::Error + Send + 'static,
T: Send + 'static,
{
let max_db_query_cost = self.limits.max_db_query_cost;
self.inner
.spawn_blocking(move |this| {
let cost = extract_cost(&query, &this)?;
if cost > max_db_query_cost as f64 {
return Err(DbValidationError::QueryCostExceeded(
cost as u64,
max_db_query_cost,
)
.into());
}

let execute_closure = execute_fn(query);
this.run_query(execute_closure)
.map_err(|e| Error::Internal(e.to_string()))
})
.await
}
}

/// Implement methods to query db and return StoredData
impl PgManager {
async fn get_tx(&self, digest: Vec<u8>) -> Result<Option<StoredTransaction>, Error> {
self.run_query_async(|conn| {
QueryBuilder::get_tx_by_digest(digest)
.get_result::<StoredTransaction>(conn) // Expect exactly 0 to 1 result
.optional()
self.run_query_async_with_cost(QueryBuilder::get_tx_by_digest(digest), |query| {
move |conn| query.get_result::<StoredTransaction>(conn).optional()
})
.await
}
Expand All @@ -524,32 +562,22 @@ impl PgManager {
address: Vec<u8>,
version: Option<i64>,
) -> Result<Option<StoredObject>, Error> {
self.run_query_async(move |conn| {
QueryBuilder::get_obj(address, version)
.get_result::<StoredObject>(conn)
.optional()
self.run_query_async_with_cost(QueryBuilder::get_obj(address, version), |query| {
move |conn| query.get_result::<StoredObject>(conn).optional()
})
.await
}

pub async fn get_epoch(&self, epoch_id: Option<i64>) -> Result<Option<StoredEpochInfo>, Error> {
match epoch_id {
Some(epoch_id) => {
self.run_query_async(move |conn| {
QueryBuilder::get_epoch(epoch_id)
.get_result::<StoredEpochInfo>(conn)
.optional()
})
.await
}
None => Some(
self.run_query_async(|conn| {
QueryBuilder::get_latest_epoch().first::<StoredEpochInfo>(conn)
})
.await,
)
.transpose(),
}
let query = match epoch_id {
Some(epoch_id) => QueryBuilder::get_epoch(epoch_id),
None => QueryBuilder::get_latest_epoch(),
};

self.run_query_async_with_cost(query, |query| {
move |conn| query.get_result::<StoredEpochInfo>(conn).optional()
})
.await
}

async fn get_checkpoint(
Expand All @@ -568,8 +596,10 @@ impl PgManager {
_ => QueryBuilder::get_latest_checkpoint(),
};

self.run_query_async(|conn| query.get_result::<StoredCheckpoint>(conn).optional())
.await
self.run_query_async_with_cost(query, |query| {
move |conn| query.get_result::<StoredCheckpoint>(conn).optional()
})
.await
}

async fn get_chain_identifier(&self) -> Result<ChainIdentifier, Error> {
Expand Down Expand Up @@ -603,11 +633,10 @@ impl PgManager {
let limit = first.or(last).unwrap_or(DEFAULT_PAGE_SIZE) as i64;

let result: Option<Vec<StoredObject>> = self
.run_query_async(move |conn| {
QueryBuilder::multi_get_coins(cursor, descending_order, limit, address, coin_type)
.load(conn)
.optional()
})
.run_query_async_with_cost(
QueryBuilder::multi_get_coins(cursor, descending_order, limit, address, coin_type),
|query| move |conn| query.load(conn).optional(),
)
.await?;

result
Expand All @@ -627,10 +656,8 @@ impl PgManager {
address: Vec<u8>,
coin_type: String,
) -> Result<Option<(Option<i64>, Option<i64>, Option<String>)>, Error> {
self.run_query_async(move |conn| {
QueryBuilder::get_balance(address, coin_type)
.get_result(conn)
.optional()
self.run_query_async_with_cost(QueryBuilder::get_balance(address, coin_type), |query| {
move |conn| query.get_result(conn).optional()
})
.await
}
Expand All @@ -649,10 +676,8 @@ impl PgManager {
return Err(DbValidationError::PaginationDisabledOnBalances.into());
}

self.run_query_async(move |conn| {
QueryBuilder::multi_get_balances(address)
.load(conn)
.optional()
self.run_query_async_with_cost(QueryBuilder::multi_get_balances(address), |query| {
move |conn| query.load(conn).optional()
})
.await
}
Expand Down Expand Up @@ -720,12 +745,7 @@ impl PgManager {
)?;

let result: Option<Vec<StoredTransaction>> = self
.run_query_async(move |conn| {
query
.select(transactions::all_columns)
.load(conn)
.optional()
})
.run_query_async_with_cost(query, |query| move |conn| query.load(conn).optional())
.await?;

result
Expand Down Expand Up @@ -756,16 +776,15 @@ impl PgManager {
let limit = first.or(last).unwrap_or(DEFAULT_PAGE_SIZE) as i64;

let result: Option<Vec<StoredCheckpoint>> = self
.run_query_async(move |conn| {
.run_query_async_with_cost(
QueryBuilder::multi_get_checkpoints(
cursor,
descending_order,
limit,
epoch.map(|e| e as i64),
)
.load(conn)
.optional()
})
),
|query| move |conn| query.load(conn).optional(),
)
.await?;

result
Expand Down Expand Up @@ -800,7 +819,7 @@ impl PgManager {
QueryBuilder::multi_get_objs(cursor, descending_order, limit, filter, owner_type)?;

let result: Option<Vec<StoredObject>> = self
.run_query_async(move |conn| query.load(conn).optional())
.run_query_async_with_cost(query, |query| move |conn| query.load(conn).optional())
.await?;
result
.map(|mut stored_objs| {
Expand Down
32 changes: 27 additions & 5 deletions crates/sui-graphql-rpc/src/context_data/db_query_cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use diesel::{
PgConnection, RunQueryDsl,
};
use regex::Regex;
use sui_indexer::schema_v2::query_cost;
use sui_indexer::{indexer_reader::IndexerReader, schema_v2::query_cost};

use crate::context_data::DEFAULT_PAGE_SIZE;

/// Extracts the raw sql query string from a diesel query
/// and replaces all the parameters with '0'
Expand Down Expand Up @@ -43,18 +45,38 @@ pub fn raw_sql_string_values_set(
})?;
let sql: String = query_builder.finish();

// handle limits, as '0' is invalid - set to DEFAULT_PAGE_SIZE instead
let re = Regex::new(r"(LIMIT\s+)\$(\d+)")
.map_err(|e| crate::error::Error::Internal(format!("Failed create valid regex: {}", e)))?;
let replacement_string = format!("LIMIT {}", DEFAULT_PAGE_SIZE);
let output = re
.replace_all(&sql, replacement_string.as_str())
.to_string();

// handle matching column against ANY value in input array
let re = Regex::new(r"ANY\(\$(\d+)\)")
.map_err(|e| crate::error::Error::Internal(format!("Failed create valid regex: {}", e)))?;
let nums: Vec<String> = (1..=50).map(|n| n.to_string()).collect();
let nums_str = nums.join(", ");
let replacement_string = format!("ANY ('{{{}}}')", nums_str);
let output = re
.replace_all(&output, replacement_string.as_str())
.to_string();

let re = Regex::new(r"\$(\d+)")
.map_err(|e| crate::error::Error::Internal(format!("Failed create valid regex: {}", e)))?;
Ok(re.replace_all(&sql, "'0'").to_string())

Ok(re.replace_all(&output, "'0'").to_string())
}

pub fn extract_cost(
query: &dyn QueryFragment<Pg>,
pg_connection: &mut PgConnection,
pg_reader: &IndexerReader,
) -> Result<f64, crate::error::Error> {
let raw_sql_string = raw_sql_string_values_set(query)?;
diesel::select(query_cost(&raw_sql_string))
.get_result::<f64>(pg_connection)
// Use IndexerReader.run_query so we get alerted when blocking calls are made in an async thread
pg_reader
.run_query(|conn| diesel::select(query_cost(&raw_sql_string)).get_result::<f64>(conn))
.map_err(|e| {
crate::error::Error::Internal(format!(
"Unable to run query_cost function to determine query cost for {}: {}",
Expand Down
6 changes: 3 additions & 3 deletions crates/sui-graphql-rpc/tests/e2e_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ mod tests {
idx_cfg.set_pool_size(20);
let reader = IndexerReader::new_with_config(connection_config.db_url(), idx_cfg).unwrap();
reader
.run_query_async(|conn| {
let cost = extract_cost(&query, conn).unwrap();
.spawn_blocking(move |this| {
let cost = extract_cost(&query, &this).unwrap();
assert!(cost > 0.0);
query.get_result::<StoredObject>(conn).optional()
this.run_query(|conn| query.get_result::<StoredObject>(conn).optional())
})
.await
.unwrap();
Expand Down

1 comment on commit 8904957

@vercel
Copy link

@vercel vercel bot commented on 8904957 Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.