Skip to content

Commit

Permalink
??
Browse files Browse the repository at this point in the history
  • Loading branch information
wlmyng committed Oct 26, 2023
1 parent d42db04 commit 45b69af
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 60 deletions.
97 changes: 48 additions & 49 deletions crates/sui-graphql-rpc/src/context_data/db_data_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ use sui_types::{
Identifier,
};

use super::{DEFAULT_PAGE_SIZE, db_query_cost::extract_cost};
use super::{db_query_cost::extract_cost, DEFAULT_PAGE_SIZE};

use super::sui_sdk_data_provider::convert_to_validators;

Expand All @@ -120,7 +120,7 @@ pub enum DbValidationError {
#[error("Invalid owner type. Must be Address or Object")]
InvalidOwnerType,
#[error("Query cost exceeded - cost: {0}, limit: {1}")]
QueryCostExceeded(u64, u64)
QueryCostExceeded(u64, u64),
}

type BalanceQuery<'a> = BoxedSelectStatement<
Expand Down Expand Up @@ -518,55 +518,65 @@ impl PgManager {
execute_fn: EF,
) -> Result<T, Error>
where
Q: QueryDsl + RunQueryDsl<Pg> + diesel::query_builder::QueryFragment<diesel::pg::Pg> + Send + 'static,
Q: QueryDsl
+ RunQueryDsl<Pg>
+ diesel::query_builder::QueryFragment<diesel::pg::Pg>
+ Send
+ 'static,
EF: FnOnce(Q) -> F + Send + 'static,
F: FnOnce(&mut PgConnection) -> Result<T, E> + Send + 'static,
E: From<diesel::result::Error> + std::error::Error + Send + 'static,
T: Send + 'static,
{
let max_db_query_cost = self.limits.max_db_query_cost;
self.inner.spawn_blocking(move |this| {
let cost = extract_cost(&query, &this)?;
if cost > max_db_query_cost as f64 {
return Err(DbValidationError::QueryCostExceeded(cost as u64, max_db_query_cost).into());
}
let execute_closure = execute_fn(query);
this.run_query(execute_closure).map_err(|e| Error::Internal(e.to_string()))
}).await
self.inner
.spawn_blocking(move |this| {
let cost = extract_cost(&query, &this)?;
if cost > max_db_query_cost as f64 {
return Err(DbValidationError::QueryCostExceeded(
cost as u64,
max_db_query_cost,
)
.into());
}
let execute_closure = execute_fn(query);
this.run_query(execute_closure)
.map_err(|e| Error::Internal(e.to_string()))
})
.await
}
}

/// Implement methods to query db and return StoredData
impl PgManager {
async fn get_tx(&self, digest: Vec<u8>) -> Result<Option<StoredTransaction>, Error> {
self.run_query_async_with_cost(
QueryBuilder::get_tx_by_digest(digest),
|query| move |conn| query.get_result::<StoredTransaction>(conn).optional()
).await
self.run_query_async_with_cost(QueryBuilder::get_tx_by_digest(digest), |query| {
move |conn| query.get_result::<StoredTransaction>(conn).optional()
})
.await
}

async fn get_obj(
&self,
address: Vec<u8>,
version: Option<i64>,
) -> Result<Option<StoredObject>, Error> {
self.run_query_async_with_cost(
QueryBuilder::get_obj(address, version),
|query| move |conn| query.get_result::<StoredObject>(conn).optional()
)
self.run_query_async_with_cost(QueryBuilder::get_obj(address, version), |query| {
move |conn| query.get_result::<StoredObject>(conn).optional()
})
.await
}

pub async fn get_epoch(&self, epoch_id: Option<i64>) -> Result<Option<StoredEpochInfo>, Error> {
let query = match epoch_id {
Some(epoch_id) => QueryBuilder::get_epoch(epoch_id),
None => QueryBuilder::get_latest_epoch()
None => QueryBuilder::get_latest_epoch(),
};

self.run_query_async_with_cost(
query,
|query| move |conn| query.get_result::<StoredEpochInfo>(conn).optional()
).await
self.run_query_async_with_cost(query, |query| {
move |conn| query.get_result::<StoredEpochInfo>(conn).optional()
})
.await
}

async fn get_checkpoint(
Expand All @@ -585,9 +595,10 @@ impl PgManager {
_ => QueryBuilder::get_latest_checkpoint(),
};

self.run_query_async_with_cost(query,
|query| move |conn| query.get_result::<StoredCheckpoint>(conn).optional())
.await
self.run_query_async_with_cost(query, |query| {
move |conn| query.get_result::<StoredCheckpoint>(conn).optional()
})
.await
}

async fn get_chain_identifier(&self) -> Result<ChainIdentifier, Error> {
Expand Down Expand Up @@ -623,7 +634,7 @@ impl PgManager {
let result: Option<Vec<StoredObject>> = self
.run_query_async_with_cost(
QueryBuilder::multi_get_coins(cursor, descending_order, limit, address, coin_type),
|query| move |conn| query.load(conn).optional()
|query| move |conn| query.load(conn).optional(),
)
.await?;

Expand All @@ -644,10 +655,9 @@ impl PgManager {
address: Vec<u8>,
coin_type: String,
) -> Result<Option<(Option<i64>, Option<i64>, Option<String>)>, Error> {
self.run_query_async_with_cost(
QueryBuilder::get_balance(address, coin_type),
|query| move |conn| query.get_result(conn).optional()
)
self.run_query_async_with_cost(QueryBuilder::get_balance(address, coin_type), |query| {
move |conn| query.get_result(conn).optional()
})
.await
}

Expand All @@ -665,12 +675,9 @@ impl PgManager {
return Err(DbValidationError::PaginationDisabledOnBalances.into());
}

self.run_query_async_with_cost(
QueryBuilder::multi_get_balances(address),
|query| move |conn| query
.load(conn)
.optional()
)
self.run_query_async_with_cost(QueryBuilder::multi_get_balances(address), |query| {
move |conn| query.load(conn).optional()
})
.await
}

Expand Down Expand Up @@ -737,12 +744,7 @@ impl PgManager {
)?;

let result: Option<Vec<StoredTransaction>> = self
.run_query_async_with_cost(
query,
|query| move |conn| query
.load(conn)
.optional()
)
.run_query_async_with_cost(query, |query| move |conn| query.load(conn).optional())
.await?;

result
Expand Down Expand Up @@ -780,9 +782,7 @@ impl PgManager {
limit,
epoch.map(|e| e as i64),
),
|query| move |conn| query
.load(conn)
.optional()
|query| move |conn| query.load(conn).optional(),
)
.await?;

Expand Down Expand Up @@ -818,8 +818,7 @@ impl PgManager {
QueryBuilder::multi_get_objs(cursor, descending_order, limit, filter, owner_type)?;

let result: Option<Vec<StoredObject>> = self
.run_query_async_with_cost(query,
|query| move |conn| query.load(conn).optional())
.run_query_async_with_cost(query, |query| move |conn| query.load(conn).optional())
.await?;
result
.map(|mut stored_objs| {
Expand Down
10 changes: 3 additions & 7 deletions crates/sui-graphql-rpc/src/context_data/db_query_cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use diesel::{
PgConnection, RunQueryDsl,
};
use regex::Regex;
use sui_indexer::{schema_v2::query_cost, indexer_reader::IndexerReader};
use sui_indexer::{indexer_reader::IndexerReader, schema_v2::query_cost};

/// Extracts the raw sql query string from a diesel query
/// and replaces all the parameters with '0'
Expand Down Expand Up @@ -49,7 +49,6 @@ pub fn raw_sql_string_values_set(

let output = re.replace_all(&sql, "LIMIT 1").to_string();


let re = Regex::new(r"\$(\d+)")
.map_err(|e| crate::error::Error::Internal(format!("Failed create valid regex: {}", e)))?;
Ok(re.replace_all(&output, "'0'").to_string())
Expand All @@ -61,11 +60,8 @@ pub fn extract_cost(
pg_reader: &IndexerReader,
) -> Result<f64, crate::error::Error> {
let raw_sql_string = raw_sql_string_values_set(query)?;
pg_reader.run_query(|conn| {

diesel::select(query_cost(&raw_sql_string))
.get_result::<f64>(conn)
})
pg_reader
.run_query(|conn| diesel::select(query_cost(&raw_sql_string)).get_result::<f64>(conn))
.map_err(|e| {
crate::error::Error::Internal(format!(
"Unable to run query_cost function to determine query cost for {}: {}",
Expand Down
8 changes: 4 additions & 4 deletions crates/sui-graphql-rpc/tests/e2e_tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

#[cfg(feature = "pg_integration")]
// #[cfg(feature = "pg_integration")]
mod tests {
use diesel::OptionalExtension;
use diesel::RunQueryDsl;
Expand Down Expand Up @@ -122,10 +122,10 @@ mod tests {
idx_cfg.set_pool_size(20);
let reader = IndexerReader::new_with_config(connection_config.db_url(), idx_cfg).unwrap();
reader
.run_query_async(|conn| {
let cost = extract_cost(&query, conn).unwrap();
.spawn_blocking(move |this| {
let cost = extract_cost(&query, &this).unwrap();
assert!(cost > 0.0);
query.get_result::<StoredObject>(conn).optional()
this.run_query(|conn| query.get_result::<StoredObject>(conn).optional())
})
.await
.unwrap();
Expand Down

0 comments on commit 45b69af

Please sign in to comment.