diff --git a/crates/sui-graphql-rpc/src/context_data/db_data_provider.rs b/crates/sui-graphql-rpc/src/context_data/db_data_provider.rs index 3f41fb5c6db3b..cf5c872ff44b0 100644 --- a/crates/sui-graphql-rpc/src/context_data/db_data_provider.rs +++ b/crates/sui-graphql-rpc/src/context_data/db_data_provider.rs @@ -93,7 +93,7 @@ use sui_types::{ Identifier, }; -use super::{DEFAULT_PAGE_SIZE, db_query_cost::extract_cost}; +use super::{db_query_cost::extract_cost, DEFAULT_PAGE_SIZE}; use super::sui_sdk_data_provider::convert_to_validators; @@ -120,7 +120,7 @@ pub enum DbValidationError { #[error("Invalid owner type. Must be Address or Object")] InvalidOwnerType, #[error("Query cost exceeded - cost: {0}, limit: {1}")] - QueryCostExceeded(u64, u64) + QueryCostExceeded(u64, u64), } type BalanceQuery<'a> = BoxedSelectStatement< @@ -518,31 +518,42 @@ impl PgManager { execute_fn: EF, ) -> Result where - Q: QueryDsl + RunQueryDsl + diesel::query_builder::QueryFragment + Send + 'static, + Q: QueryDsl + + RunQueryDsl + + diesel::query_builder::QueryFragment + + Send + + 'static, EF: FnOnce(Q) -> F + Send + 'static, F: FnOnce(&mut PgConnection) -> Result + Send + 'static, E: From + std::error::Error + Send + 'static, T: Send + 'static, { let max_db_query_cost = self.limits.max_db_query_cost; - self.inner.spawn_blocking(move |this| { - let cost = extract_cost(&query, &this)?; - if cost > max_db_query_cost as f64 { - return Err(DbValidationError::QueryCostExceeded(cost as u64, max_db_query_cost).into()); - } - let execute_closure = execute_fn(query); - this.run_query(execute_closure).map_err(|e| Error::Internal(e.to_string())) - }).await + self.inner + .spawn_blocking(move |this| { + let cost = extract_cost(&query, &this)?; + if cost > max_db_query_cost as f64 { + return Err(DbValidationError::QueryCostExceeded( + cost as u64, + max_db_query_cost, + ) + .into()); + } + let execute_closure = execute_fn(query); + this.run_query(execute_closure) + .map_err(|e| Error::Internal(e.to_string())) + }) + .await } } /// Implement methods to query db and return StoredData impl PgManager { async fn get_tx(&self, digest: Vec) -> Result, Error> { - self.run_query_async_with_cost( - QueryBuilder::get_tx_by_digest(digest), - |query| move |conn| query.get_result::(conn).optional() - ).await + self.run_query_async_with_cost(QueryBuilder::get_tx_by_digest(digest), |query| { + move |conn| query.get_result::(conn).optional() + }) + .await } async fn get_obj( @@ -550,23 +561,22 @@ impl PgManager { address: Vec, version: Option, ) -> Result, Error> { - self.run_query_async_with_cost( - QueryBuilder::get_obj(address, version), - |query| move |conn| query.get_result::(conn).optional() - ) + self.run_query_async_with_cost(QueryBuilder::get_obj(address, version), |query| { + move |conn| query.get_result::(conn).optional() + }) .await } pub async fn get_epoch(&self, epoch_id: Option) -> Result, Error> { let query = match epoch_id { Some(epoch_id) => QueryBuilder::get_epoch(epoch_id), - None => QueryBuilder::get_latest_epoch() + None => QueryBuilder::get_latest_epoch(), }; - self.run_query_async_with_cost( - query, - |query| move |conn| query.get_result::(conn).optional() - ).await + self.run_query_async_with_cost(query, |query| { + move |conn| query.get_result::(conn).optional() + }) + .await } async fn get_checkpoint( @@ -585,9 +595,10 @@ impl PgManager { _ => QueryBuilder::get_latest_checkpoint(), }; - self.run_query_async_with_cost(query, - |query| move |conn| query.get_result::(conn).optional()) - .await + self.run_query_async_with_cost(query, |query| { + move |conn| query.get_result::(conn).optional() + }) + .await } async fn get_chain_identifier(&self) -> Result { @@ -623,7 +634,7 @@ impl PgManager { let result: Option> = self .run_query_async_with_cost( QueryBuilder::multi_get_coins(cursor, descending_order, limit, address, coin_type), - |query| move |conn| query.load(conn).optional() + |query| move |conn| query.load(conn).optional(), ) .await?; @@ -644,10 +655,9 @@ impl PgManager { address: Vec, coin_type: String, ) -> Result, Option, Option)>, Error> { - self.run_query_async_with_cost( - QueryBuilder::get_balance(address, coin_type), - |query| move |conn| query.get_result(conn).optional() - ) + self.run_query_async_with_cost(QueryBuilder::get_balance(address, coin_type), |query| { + move |conn| query.get_result(conn).optional() + }) .await } @@ -665,12 +675,9 @@ impl PgManager { return Err(DbValidationError::PaginationDisabledOnBalances.into()); } - self.run_query_async_with_cost( - QueryBuilder::multi_get_balances(address), - |query| move |conn| query - .load(conn) - .optional() - ) + self.run_query_async_with_cost(QueryBuilder::multi_get_balances(address), |query| { + move |conn| query.load(conn).optional() + }) .await } @@ -737,12 +744,7 @@ impl PgManager { )?; let result: Option> = self - .run_query_async_with_cost( - query, - |query| move |conn| query - .load(conn) - .optional() - ) + .run_query_async_with_cost(query, |query| move |conn| query.load(conn).optional()) .await?; result @@ -780,9 +782,7 @@ impl PgManager { limit, epoch.map(|e| e as i64), ), - |query| move |conn| query - .load(conn) - .optional() + |query| move |conn| query.load(conn).optional(), ) .await?; @@ -818,8 +818,7 @@ impl PgManager { QueryBuilder::multi_get_objs(cursor, descending_order, limit, filter, owner_type)?; let result: Option> = self - .run_query_async_with_cost(query, - |query| move |conn| query.load(conn).optional()) + .run_query_async_with_cost(query, |query| move |conn| query.load(conn).optional()) .await?; result .map(|mut stored_objs| { diff --git a/crates/sui-graphql-rpc/src/context_data/db_query_cost.rs b/crates/sui-graphql-rpc/src/context_data/db_query_cost.rs index 4020a11b8828f..d7cd4923e533a 100644 --- a/crates/sui-graphql-rpc/src/context_data/db_query_cost.rs +++ b/crates/sui-graphql-rpc/src/context_data/db_query_cost.rs @@ -7,7 +7,7 @@ use diesel::{ PgConnection, RunQueryDsl, }; use regex::Regex; -use sui_indexer::{schema_v2::query_cost, indexer_reader::IndexerReader}; +use sui_indexer::{indexer_reader::IndexerReader, schema_v2::query_cost}; /// Extracts the raw sql query string from a diesel query /// and replaces all the parameters with '0' @@ -49,7 +49,6 @@ pub fn raw_sql_string_values_set( let output = re.replace_all(&sql, "LIMIT 1").to_string(); - let re = Regex::new(r"\$(\d+)") .map_err(|e| crate::error::Error::Internal(format!("Failed create valid regex: {}", e)))?; Ok(re.replace_all(&output, "'0'").to_string()) @@ -61,11 +60,8 @@ pub fn extract_cost( pg_reader: &IndexerReader, ) -> Result { let raw_sql_string = raw_sql_string_values_set(query)?; - pg_reader.run_query(|conn| { - - diesel::select(query_cost(&raw_sql_string)) - .get_result::(conn) - }) + pg_reader + .run_query(|conn| diesel::select(query_cost(&raw_sql_string)).get_result::(conn)) .map_err(|e| { crate::error::Error::Internal(format!( "Unable to run query_cost function to determine query cost for {}: {}", diff --git a/crates/sui-graphql-rpc/tests/e2e_tests.rs b/crates/sui-graphql-rpc/tests/e2e_tests.rs index edada4a578d2f..88524a211b35c 100644 --- a/crates/sui-graphql-rpc/tests/e2e_tests.rs +++ b/crates/sui-graphql-rpc/tests/e2e_tests.rs @@ -1,7 +1,7 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -#[cfg(feature = "pg_integration")] +// #[cfg(feature = "pg_integration")] mod tests { use diesel::OptionalExtension; use diesel::RunQueryDsl; @@ -122,10 +122,10 @@ mod tests { idx_cfg.set_pool_size(20); let reader = IndexerReader::new_with_config(connection_config.db_url(), idx_cfg).unwrap(); reader - .run_query_async(|conn| { - let cost = extract_cost(&query, conn).unwrap(); + .spawn_blocking(move |this| { + let cost = extract_cost(&query, &this).unwrap(); assert!(cost > 0.0); - query.get_result::(conn).optional() + this.run_query(|conn| query.get_result::(conn).optional()) }) .await .unwrap();