From 91a0fd0eecf02227528a74df4eae3e9e0458338e Mon Sep 17 00:00:00 2001 From: Larko <59736843+Larkooo@users.noreply.github.com> Date: Mon, 16 Sep 2024 21:42:37 -0400 Subject: [PATCH] feat(torii-core): parallelization (#2423) --- bin/torii/src/main.rs | 38 +++-- crates/torii/core/src/cache.rs | 5 + crates/torii/core/src/engine.rs | 133 ++++++++++++++---- crates/torii/core/src/processors/mod.rs | 11 +- crates/torii/core/src/sql.rs | 60 +++++++- crates/torii/core/src/sql_test.rs | 67 +++++---- .../torii/graphql/src/tests/entities_test.rs | 11 +- crates/torii/graphql/src/tests/mod.rs | 21 +-- crates/torii/graphql/src/tests/models_test.rs | 128 +++++++++++++---- .../grpc/src/server/tests/entities_test.rs | 4 +- crates/torii/libp2p/src/tests.rs | 1 + 11 files changed, 356 insertions(+), 123 deletions(-) diff --git a/bin/torii/src/main.rs b/bin/torii/src/main.rs index 94b6e63d70..374cce2829 100644 --- a/bin/torii/src/main.rs +++ b/bin/torii/src/main.rs @@ -10,6 +10,7 @@ //! documentation for usage details. This is **not recommended on Windows**. See [here](https://rust-lang.github.io/rfcs/1974-global-allocators.html#jemalloc) //! for more info. +use std::cmp; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; @@ -125,6 +126,10 @@ struct Args { /// Polling interval in ms #[arg(long, default_value = "500")] polling_interval: u64, + + /// Max concurrent tasks + #[arg(long, default_value = "100")] + max_concurrent_tasks: usize, } #[tokio::main] @@ -157,32 +162,34 @@ async fn main() -> anyhow::Result<()> { .connect_with(options) .await?; - if args.database == ":memory:" { - // Disable auto-vacuum - sqlx::query("PRAGMA auto_vacuum = NONE;").execute(&pool).await?; + // Disable auto-vacuum + sqlx::query("PRAGMA auto_vacuum = NONE;").execute(&pool).await?; + sqlx::query("PRAGMA journal_mode = WAL;").execute(&pool).await?; + sqlx::query("PRAGMA synchronous = NORMAL;").execute(&pool).await?; - // Switch DELETE journal mode - sqlx::query("PRAGMA journal_mode=DELETE;").execute(&pool).await?; - } + // Set the number of threads based on CPU count + let cpu_count = std::thread::available_parallelism().unwrap().get(); + let thread_count = cmp::min(cpu_count, 8); + sqlx::query(&format!("PRAGMA threads = {};", thread_count)).execute(&pool).await?; sqlx::migrate!("../../crates/torii/migrations").run(&pool).await?; let provider: Arc<_> = JsonRpcClient::new(HttpTransport::new(args.rpc)).into(); // Get world address - let world = WorldContractReader::new(args.world_address, &provider); + let world = WorldContractReader::new(args.world_address, provider.clone()); let db = Sql::new(pool.clone(), args.world_address).await?; let processors = Processors { event: generate_event_processors_map(vec![ - Box::new(RegisterModelProcessor), - Box::new(StoreSetRecordProcessor), - Box::new(MetadataUpdateProcessor), - Box::new(StoreDelRecordProcessor), - Box::new(EventMessageProcessor), - Box::new(StoreUpdateRecordProcessor), - Box::new(StoreUpdateMemberProcessor), + Arc::new(RegisterModelProcessor), + Arc::new(StoreSetRecordProcessor), + Arc::new(MetadataUpdateProcessor), + Arc::new(StoreDelRecordProcessor), + Arc::new(EventMessageProcessor), + Arc::new(StoreUpdateRecordProcessor), + Arc::new(StoreUpdateMemberProcessor), ])?, transaction: vec![Box::new(StoreTransactionProcessor)], ..Processors::default() @@ -193,9 +200,10 @@ async fn main() -> anyhow::Result<()> { let mut engine = Engine::new( world, db.clone(), - &provider, + provider.clone(), processors, EngineConfig { + max_concurrent_tasks: args.max_concurrent_tasks, start_block: args.start_block, events_chunk_size: args.events_chunk_size, index_pending: args.index_pending, diff --git a/crates/torii/core/src/cache.rs b/crates/torii/core/src/cache.rs index a77c88642f..f5afab2103 100644 --- a/crates/torii/core/src/cache.rs +++ b/crates/torii/core/src/cache.rs @@ -113,6 +113,11 @@ impl ModelCache { Ok(model) } + pub async fn set(&self, selector: Felt, model: Model) { + let mut cache = self.cache.write().await; + cache.insert(selector, model); + } + pub async fn clear(&self) { self.cache.write().await.clear(); } diff --git a/crates/torii/core/src/engine.rs b/crates/torii/core/src/engine.rs index a12420d5a0..8b2f5685e9 100644 --- a/crates/torii/core/src/engine.rs +++ b/crates/torii/core/src/engine.rs @@ -1,5 +1,7 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; use std::time::Duration; use anyhow::Result; @@ -13,6 +15,8 @@ use starknet::core::types::{ use starknet::providers::Provider; use tokio::sync::broadcast::Sender; use tokio::sync::mpsc::Sender as BoundedSender; +use tokio::sync::Semaphore; +use tokio::task::JoinSet; use tokio::time::sleep; use tracing::{debug, error, info, trace, warn}; @@ -21,14 +25,14 @@ use crate::processors::{BlockProcessor, EventProcessor, TransactionProcessor}; use crate::sql::Sql; #[allow(missing_debug_implementations)] -pub struct Processors { +pub struct Processors { pub block: Vec>>, pub transaction: Vec>>, - pub event: HashMap>>, + pub event: HashMap>>, pub catch_all_event: Box>, } -impl Default for Processors

{ +impl Default for Processors

{ fn default() -> Self { Self { block: vec![], @@ -48,6 +52,7 @@ pub struct EngineConfig { pub start_block: u64, pub events_chunk_size: u64, pub index_pending: bool, + pub max_concurrent_tasks: usize, } impl Default for EngineConfig { @@ -57,6 +62,7 @@ impl Default for EngineConfig { start_block: 0, events_chunk_size: 1024, index_pending: true, + max_concurrent_tasks: 100, } } } @@ -83,15 +89,24 @@ pub struct FetchPendingResult { pub block_number: u64, } +#[derive(Debug)] +pub struct ParallelizedEvent { + pub block_number: u64, + pub block_timestamp: u64, + pub event_id: String, + pub event: Event, +} + #[allow(missing_debug_implementations)] -pub struct Engine { - world: WorldContractReader

, +pub struct Engine { + world: Arc>, db: Sql, provider: Box

, - processors: Processors

, + processors: Arc>, config: EngineConfig, shutdown_tx: Sender<()>, block_tx: Option>, + tasks: HashMap>, } struct UnprocessedEvent { @@ -99,7 +114,7 @@ struct UnprocessedEvent { data: Vec, } -impl Engine

{ +impl Engine

{ pub fn new( world: WorldContractReader

, db: Sql, @@ -109,7 +124,16 @@ impl Engine

{ shutdown_tx: Sender<()>, block_tx: Option>, ) -> Self { - Self { world, db, provider: Box::new(provider), processors, config, shutdown_tx, block_tx } + Self { + world: Arc::new(world), + db, + provider: Box::new(provider), + processors: Arc::new(processors), + config, + shutdown_tx, + block_tx, + tasks: HashMap::new(), + } } pub async fn start(&mut self) -> Result<()> { @@ -397,11 +421,14 @@ impl Engine

{ } } + // Process parallelized events + self.process_tasks().await?; + // Set the head to the last processed pending transaction // Head block number should still be latest block number self.db.set_head(data.block_number - 1, last_pending_block_world_tx, last_pending_block_tx); - self.db.execute().await?; + Ok(()) } @@ -436,11 +463,8 @@ impl Engine

{ } } - // We return None for the pending_block_tx because our process_range - // gets only specific events from the world. so some transactions - // might get ignored and wont update the cursor. - // so once the sync range is done, we assume all of the tx of the block - // have been processed. + // Process parallelized events + self.process_tasks().await?; self.db.set_head(data.latest_block_number, None, None); self.db.execute().await?; @@ -448,6 +472,46 @@ impl Engine

{ Ok(()) } + async fn process_tasks(&mut self) -> Result<()> { + // We use a semaphore to limit the number of concurrent tasks + let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent_tasks)); + + // Run all tasks concurrently + let mut set = JoinSet::new(); + for (task_id, events) in self.tasks.drain() { + let db = self.db.clone(); + let world = self.world.clone(); + let processors = self.processors.clone(); + let semaphore = semaphore.clone(); + + set.spawn(async move { + let _permit = semaphore.acquire().await.unwrap(); + let mut local_db = db.clone(); + for ParallelizedEvent { event_id, event, block_number, block_timestamp } in events { + if let Some(processor) = processors.event.get(&event.keys[0]) { + debug!(target: LOG_TARGET, event_name = processor.event_key(), task_id = %task_id, "Processing parallelized event."); + + if let Err(e) = processor + .process(&world, &mut local_db, block_number, block_timestamp, &event_id, &event) + .await + { + error!(target: LOG_TARGET, event_name = processor.event_key(), error = %e, task_id = %task_id, "Processing parallelized event."); + } + } + } + Ok::<_, anyhow::Error>(local_db) + }); + } + + // Join all tasks + while let Some(result) = set.join_next().await { + let local_db = result??; + self.db.merge(local_db)?; + } + + Ok(()) + } + async fn get_block_timestamp(&self, block_number: u64) -> Result { match self.provider.get_block_with_tx_hashes(BlockId::Number(block_number)).await? { MaybePendingBlockWithTxHashes::Block(block) => Ok(block.timestamp), @@ -477,7 +541,7 @@ impl Engine

{ block_timestamp, &event_id, &event, - transaction_hash, + // transaction_hash, ) .await?; } @@ -527,7 +591,7 @@ impl Engine

{ block_timestamp, &event_id, event, - *transaction_hash, + // *transaction_hash, ) .await?; } @@ -587,9 +651,9 @@ impl Engine

{ block_timestamp: u64, event_id: &str, event: &Event, - transaction_hash: Felt, + // transaction_hash: Felt, ) -> Result<()> { - self.db.store_event(event_id, event, transaction_hash, block_timestamp); + // self.db.store_event(event_id, event, transaction_hash, block_timestamp); let event_key = event.keys[0]; let Some(processor) = self.processors.event.get(&event_key) else { @@ -627,14 +691,33 @@ impl Engine

{ return Ok(()); }; - // if processor.validate(event) { - if let Err(e) = processor - .process(&self.world, &mut self.db, block_number, block_timestamp, event_id, event) - .await - { - error!(target: LOG_TARGET, event_name = processor.event_key(), error = %e, "Processing event."); + let task_identifier = match processor.event_key().as_str() { + "StoreSetRecord" | "StoreUpdateRecord" | "StoreUpdateMember" | "StoreDelRecord" => { + let mut hasher = DefaultHasher::new(); + event.data[0].hash(&mut hasher); + event.data[1].hash(&mut hasher); + hasher.finish() + } + _ => 0, + }; + + // if we have a task identifier, we queue the event to be parallelized + if task_identifier != 0 { + self.tasks.entry(task_identifier).or_default().push(ParallelizedEvent { + event_id: event_id.to_string(), + event: event.clone(), + block_number, + block_timestamp, + }); + } else { + // if we dont have a task identifier, we process the event immediately + if let Err(e) = processor + .process(&self.world, &mut self.db, block_number, block_timestamp, event_id, event) + .await + { + error!(target: LOG_TARGET, event_name = processor.event_key(), error = %e, "Processing event."); + } } - // } Ok(()) } diff --git a/crates/torii/core/src/processors/mod.rs b/crates/torii/core/src/processors/mod.rs index c4a02da631..c6a8f13af5 100644 --- a/crates/torii/core/src/processors/mod.rs +++ b/crates/torii/core/src/processors/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use anyhow::{Error, Result}; use async_trait::async_trait; @@ -23,7 +24,7 @@ const ENTITY_ID_INDEX: usize = 1; const NUM_KEYS_INDEX: usize = 2; #[async_trait] -pub trait EventProcessor

+pub trait EventProcessor

: Send + Sync where P: Provider + Sync, { @@ -48,7 +49,7 @@ where } #[async_trait] -pub trait BlockProcessor { +pub trait BlockProcessor: Send + Sync { fn get_block_number(&self) -> String; async fn process( &self, @@ -60,7 +61,7 @@ pub trait BlockProcessor { } #[async_trait] -pub trait TransactionProcessor { +pub trait TransactionProcessor: Send + Sync { #[allow(clippy::too_many_arguments)] async fn process( &self, @@ -75,8 +76,8 @@ pub trait TransactionProcessor { /// Given a list of event processors, generate a map of event keys to the event processor pub fn generate_event_processors_map( - event_processor: Vec>>, -) -> Result>>> { + event_processor: Vec>>, +) -> Result>>> { let mut event_processors = HashMap::new(); for processor in event_processor { diff --git a/crates/torii/core/src/sql.rs b/crates/torii/core/src/sql.rs index e53a116889..bb518b1bc2 100644 --- a/crates/torii/core/src/sql.rs +++ b/crates/torii/core/src/sql.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use anyhow::{anyhow, Result}; use chrono::Utc; use dojo_types::primitive::Primitive; -use dojo_types::schema::{EnumOption, Member, Ty}; +use dojo_types::schema::{EnumOption, Member, Struct, Ty}; use dojo_world::contracts::abi::model::Layout; use dojo_world::contracts::naming::compute_selector_from_names; use dojo_world::metadata::WorldMetadata; @@ -13,7 +13,7 @@ use sqlx::pool::PoolConnection; use sqlx::{Pool, Sqlite}; use starknet::core::types::{Event, Felt, InvokeTransaction, Transaction}; use starknet_crypto::poseidon_hash_many; -use tracing::debug; +use tracing::{debug, warn}; use crate::cache::{Model, ModelCache}; use crate::query_queue::{Argument, BrokerMessage, DeleteEntityQuery, QueryQueue, QueryType}; @@ -32,7 +32,7 @@ pub const FELT_DELIMITER: &str = "/"; #[path = "sql_test.rs"] mod test; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Sql { world_address: Felt, pub pool: Pool, @@ -40,6 +40,17 @@ pub struct Sql { model_cache: Arc, } +impl Clone for Sql { + fn clone(&self) -> Self { + Self { + world_address: self.world_address, + pool: self.pool.clone(), + query_queue: QueryQueue::new(self.pool.clone()), + model_cache: self.model_cache.clone(), + } + } +} + impl Sql { pub async fn new(pool: Pool, world_address: Felt) -> Result { let mut query_queue = QueryQueue::new(pool.clone()); @@ -65,6 +76,22 @@ impl Sql { }) } + pub fn merge(&mut self, other: Sql) -> Result<()> { + // Merge query queue + self.query_queue.queue.extend(other.query_queue.queue); + self.query_queue.publish_queue.extend(other.query_queue.publish_queue); + + // This should never happen + if self.world_address != other.world_address { + warn!( + "Merging Sql instances with different world addresses: {} and {}", + self.world_address, other.world_address + ); + } + + Ok(()) + } + pub async fn head(&self) -> Result<(u64, Option, Option)> { let mut conn: PoolConnection = self.pool.acquire().await?; let indexer_query = @@ -123,6 +150,7 @@ impl Sql { block_timestamp: u64, ) -> Result<()> { let selector = compute_selector_from_names(namespace, &model.name()); + let namespaced_name = format!("{}-{}", namespace, model.name()); let insert_models = "INSERT INTO models (id, namespace, name, class_hash, contract_address, layout, \ @@ -149,13 +177,35 @@ impl Sql { self.build_register_queries_recursive( selector, &model, - vec![format!("{}-{}", namespace, model.name())], + vec![namespaced_name.clone()], &mut model_idx, block_timestamp, &mut 0, &mut 0, ); - self.execute().await?; + + // we set the model in the cache directly + // because entities might be using it before the query queue is processed + self.model_cache + .set( + selector, + Model { + namespace: namespace.to_string(), + name: model.name().to_string(), + selector, + class_hash, + contract_address, + packed_size, + unpacked_size, + layout, + // we need to update the name of the struct to include the namespace + schema: Ty::Struct(Struct { + name: namespaced_name, + children: model.as_struct().unwrap().children.clone(), + }), + }, + ) + .await; self.query_queue.push_publish(BrokerMessage::ModelRegistered(model_registered)); Ok(()) diff --git a/crates/torii/core/src/sql_test.rs b/crates/torii/core/src/sql_test.rs index db60d738ec..b60ea3de36 100644 --- a/crates/torii/core/src/sql_test.rs +++ b/crates/torii/core/src/sql_test.rs @@ -1,4 +1,5 @@ use std::str::FromStr; +use std::sync::Arc; use cainome::cairo_serde::ContractAddress; use camino::Utf8PathBuf; @@ -10,10 +11,11 @@ use dojo_world::contracts::world::{WorldContract, WorldContractReader}; use katana_runner::{KatanaRunner, KatanaRunnerConfig}; use scarb::compiler::Profile; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; -use starknet::accounts::{Account, ConnectedAccount}; +use starknet::accounts::Account; use starknet::core::types::{Call, Felt}; use starknet::core::utils::{get_contract_address, get_selector_from_name}; -use starknet::providers::Provider; +use starknet::providers::jsonrpc::HttpTransport; +use starknet::providers::{JsonRpcClient, Provider}; use starknet_crypto::poseidon_hash_many; use tokio::sync::broadcast; @@ -32,7 +34,7 @@ pub async fn bootstrap_engine

( provider: P, ) -> Result, Box> where - P: Provider + Send + Sync + core::fmt::Debug, + P: Provider + Send + Sync + core::fmt::Debug + Clone + 'static, { let (shutdown_tx, _) = broadcast::channel(1); let to = provider.block_hash_and_number().await?.block_number; @@ -42,11 +44,11 @@ where provider, Processors { event: generate_event_processors_map(vec![ - Box::new(RegisterModelProcessor), - Box::new(StoreSetRecordProcessor), - Box::new(StoreUpdateRecordProcessor), - Box::new(StoreUpdateMemberProcessor), - Box::new(StoreDelRecordProcessor), + Arc::new(RegisterModelProcessor), + Arc::new(StoreSetRecordProcessor), + Arc::new(StoreUpdateRecordProcessor), + Arc::new(StoreUpdateMemberProcessor), + Arc::new(StoreDelRecordProcessor), ])?, ..Processors::default() }, @@ -80,6 +82,7 @@ async fn test_load_from_remote() { let sequencer = KatanaRunner::new_with_config(seq_config).expect("Failed to start runner."); let account = sequencer.account(0); + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(sequencer.url()))); let (strat, _) = prepare_migration_with_world_and_seed( manifest_path, @@ -106,7 +109,7 @@ async fn test_load_from_remote() { .await .unwrap(); - TransactionWaiter::new(res.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(res.transaction_hash, &provider).await.unwrap(); // spawn let tx = &account @@ -119,13 +122,13 @@ async fn test_load_from_remote() { .await .unwrap(); - TransactionWaiter::new(tx.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(tx.transaction_hash, &provider).await.unwrap(); - let world_reader = WorldContractReader::new(strat.world_address, account.provider()); + let world_reader = WorldContractReader::new(strat.world_address, Arc::clone(&provider)); let mut db = Sql::new(pool.clone(), world_reader.address).await.unwrap(); - let _ = bootstrap_engine(world_reader, db.clone(), account.provider()).await.unwrap(); + let _ = bootstrap_engine(world_reader, db.clone(), provider).await.unwrap(); let _block_timestamp = 1710754478_u64; let models = sqlx::query("SELECT * FROM models").fetch_all(&pool).await.unwrap(); @@ -214,6 +217,7 @@ async fn test_load_from_remote_del() { let sequencer = KatanaRunner::new_with_config(seq_config).expect("Failed to start runner."); let account = sequencer.account(0); + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(sequencer.url()))); let (strat, _) = prepare_migration_with_world_and_seed( manifest_path, @@ -239,7 +243,7 @@ async fn test_load_from_remote_del() { .await .unwrap(); - TransactionWaiter::new(res.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(res.transaction_hash, &provider).await.unwrap(); // spawn let res = account @@ -252,7 +256,7 @@ async fn test_load_from_remote_del() { .await .unwrap(); - TransactionWaiter::new(res.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(res.transaction_hash, &provider).await.unwrap(); // Set player config. let res = account @@ -266,7 +270,7 @@ async fn test_load_from_remote_del() { .await .unwrap(); - TransactionWaiter::new(res.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(res.transaction_hash, &provider).await.unwrap(); let res = account .execute_v1(vec![Call { @@ -278,13 +282,13 @@ async fn test_load_from_remote_del() { .await .unwrap(); - TransactionWaiter::new(res.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(res.transaction_hash, &provider).await.unwrap(); - let world_reader = WorldContractReader::new(strat.world_address, account.provider()); + let world_reader = WorldContractReader::new(strat.world_address, Arc::clone(&provider)); let mut db = Sql::new(pool.clone(), world_reader.address).await.unwrap(); - let _ = bootstrap_engine(world_reader, db.clone(), account.provider()).await; + let _ = bootstrap_engine(world_reader, db.clone(), provider).await; assert_eq!(count_table("dojo_examples-PlayerConfig", &pool).await, 0); assert_eq!(count_table("dojo_examples-PlayerConfig$favorite_item", &pool).await, 0); @@ -296,16 +300,13 @@ async fn test_load_from_remote_del() { db.execute().await.unwrap(); } -// Start of Selection #[tokio::test(flavor = "multi_thread")] async fn test_update_with_set_record() { - // Initialize the SQLite in-memory database let options = SqliteConnectOptions::from_str("sqlite::memory:").unwrap().create_if_missing(true); let pool = SqlitePoolOptions::new().max_connections(5).connect_with(options).await.unwrap(); sqlx::migrate!("../migrations").run(&pool).await.unwrap(); - // Set up the compiler test environment let setup = CompilerTestSetup::from_examples("../../dojo-core", "../../../examples/"); let config = setup.build_test_config("spawn-and-move", Profile::DEV); @@ -313,14 +314,10 @@ async fn test_update_with_set_record() { let manifest_path = Utf8PathBuf::from(config.manifest_path().parent().unwrap()); let target_dir = Utf8PathBuf::from(ws.target_dir().to_string()).join("dev"); - // Configure and start the KatanaRunner let seq_config = KatanaRunnerConfig { n_accounts: 10, ..Default::default() } .with_db_dir(copy_spawn_and_move_db().as_str()); - let sequencer = KatanaRunner::new_with_config(seq_config).expect("Failed to start runner."); - let account = sequencer.account(0); - // Prepare migration with world and seed let (strat, _) = prepare_migration_with_world_and_seed( manifest_path, target_dir, @@ -338,16 +335,18 @@ async fn test_update_with_set_record() { strat.world_address, ); + let account = sequencer.account(0); + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(sequencer.url()))); + let world = WorldContract::new(strat.world_address, &account); - // Grant writer permissions let res = world .grant_writer(&compute_bytearray_hash("dojo_examples"), &ContractAddress(actions_address)) .send_with_cfg(&TxnConfig::init_wait()) .await .unwrap(); - TransactionWaiter::new(res.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(res.transaction_hash, &provider).await.unwrap(); // Send spawn transaction let spawn_res = account @@ -360,7 +359,7 @@ async fn test_update_with_set_record() { .await .unwrap(); - TransactionWaiter::new(spawn_res.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(spawn_res.transaction_hash, &provider).await.unwrap(); // Send move transaction let move_res = account @@ -373,15 +372,15 @@ async fn test_update_with_set_record() { .await .unwrap(); - TransactionWaiter::new(move_res.transaction_hash, &account.provider()).await.unwrap(); + TransactionWaiter::new(move_res.transaction_hash, &provider).await.unwrap(); - let world_reader = WorldContractReader::new(strat.world_address, account.provider()); + let world_reader = WorldContractReader::new(strat.world_address, Arc::clone(&provider)); - let db = Sql::new(pool.clone(), world_reader.address).await.unwrap(); + let mut db = Sql::new(pool.clone(), world_reader.address).await.unwrap(); + + let _ = bootstrap_engine(world_reader, db.clone(), Arc::clone(&provider)).await.unwrap(); - // Expect bootstrap_engine to error out due to the existing bug - let result = bootstrap_engine(world_reader, db.clone(), account.provider()).await; - assert!(result.is_ok(), "bootstrap_engine should not fail"); + db.execute().await.unwrap(); } /// Count the number of rows in a table. diff --git a/crates/torii/graphql/src/tests/entities_test.rs b/crates/torii/graphql/src/tests/entities_test.rs index 4722a7a26c..6138aac846 100644 --- a/crates/torii/graphql/src/tests/entities_test.rs +++ b/crates/torii/graphql/src/tests/entities_test.rs @@ -106,8 +106,15 @@ mod tests { let last_entity = connection.edges.last().unwrap(); assert_eq!(connection.edges.len(), 2); assert_eq!(connection.total_count, 2); - assert_eq!(first_entity.node.keys.clone().unwrap(), vec!["0x0", "0x1"]); - assert_eq!(last_entity.node.keys.clone().unwrap(), vec!["0x0"]); + // due to parallelization order is nondeterministic + assert!( + first_entity.node.keys.clone().unwrap() == vec!["0x0", "0x1"] + || first_entity.node.keys.clone().unwrap() == vec!["0x0"] + ); + assert!( + last_entity.node.keys.clone().unwrap() == vec!["0x0", "0x1"] + || last_entity.node.keys.clone().unwrap() == vec!["0x0"] + ); // double key param - returns all entities with `0x0` as first key and `0x1` as second key let entities = entities_query(&schema, "(keys: [\"0x0\", \"0x1\"])").await; diff --git a/crates/torii/graphql/src/tests/mod.rs b/crates/torii/graphql/src/tests/mod.rs index 26ff6870df..133b46075e 100644 --- a/crates/torii/graphql/src/tests/mod.rs +++ b/crates/torii/graphql/src/tests/mod.rs @@ -1,4 +1,5 @@ use std::str::FromStr; +use std::sync::Arc; use anyhow::Result; use async_graphql::dynamic::Schema; @@ -21,7 +22,8 @@ use sqlx::SqlitePool; use starknet::accounts::{Account, ConnectedAccount}; use starknet::core::types::{Call, Felt, InvokeTransactionResult}; use starknet::macros::selector; -use starknet::providers::Provider; +use starknet::providers::jsonrpc::HttpTransport; +use starknet::providers::{JsonRpcClient, Provider}; use tokio::sync::broadcast; use tokio_stream::StreamExt; use torii_core::engine::{Engine, EngineConfig, Processors}; @@ -268,6 +270,8 @@ pub async fn model_fixtures(db: &mut Sql) { ) .await .unwrap(); + + db.execute().await.unwrap(); } pub async fn spinup_types_test() -> Result { @@ -290,6 +294,7 @@ pub async fn spinup_types_test() -> Result { let sequencer = KatanaRunner::new_with_config(seq_config).expect("Failed to start runner."); let account = sequencer.account(0); + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(sequencer.url()))); let (strat, _) = prepare_migration_with_world_and_seed( manifest_path, @@ -328,7 +333,7 @@ pub async fn spinup_types_test() -> Result { .await .unwrap(); - TransactionWaiter::new(transaction_hash, &account.provider()).await?; + TransactionWaiter::new(transaction_hash, &provider).await?; // Execute `delete` and delete Record with id 20 let InvokeTransactionResult { transaction_hash } = account @@ -341,9 +346,9 @@ pub async fn spinup_types_test() -> Result { .await .unwrap(); - TransactionWaiter::new(transaction_hash, &account.provider()).await?; + TransactionWaiter::new(transaction_hash, &provider).await?; - let world = WorldContractReader::new(strat.world_address, account.provider()); + let world = WorldContractReader::new(strat.world_address, Arc::clone(&provider)); let db = Sql::new(pool.clone(), strat.world_address).await.unwrap(); @@ -351,12 +356,12 @@ pub async fn spinup_types_test() -> Result { let mut engine = Engine::new( world, db, - account.provider(), + Arc::clone(&provider), Processors { event: generate_event_processors_map(vec![ - Box::new(RegisterModelProcessor), - Box::new(StoreSetRecordProcessor), - Box::new(StoreDelRecordProcessor), + Arc::new(RegisterModelProcessor), + Arc::new(StoreSetRecordProcessor), + Arc::new(StoreDelRecordProcessor), ]) .unwrap(), ..Processors::default() diff --git a/crates/torii/graphql/src/tests/models_test.rs b/crates/torii/graphql/src/tests/models_test.rs index 66bda6902b..163d9afc41 100644 --- a/crates/torii/graphql/src/tests/models_test.rs +++ b/crates/torii/graphql/src/tests/models_test.rs @@ -169,8 +169,11 @@ mod tests { let pool = spinup_types_test().await?; let schema = build_schema(&pool).await.unwrap(); + // we need to order all the records because insertions are done in parallel + // which can have random order // default params, test entity relationship, test nested types - let records = records_model_query(&schema, "").await; + let records = + records_model_query(&schema, "(order: { direction: DESC, field: RECORD_ID })").await; let connection: Connection = serde_json::from_value(records).unwrap(); let record = connection.edges.last().unwrap(); let entity = record.node.entity.as_ref().unwrap(); @@ -193,25 +196,41 @@ mod tests { // *** WHERE FILTER TESTING *** // where filter EQ on record_id - let records = records_model_query(&schema, "(where: { record_id: 0 })").await; + let records = records_model_query( + &schema, + "(where: { record_id: 0 }, order: { direction: DESC, field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); let first_record = connection.edges.first().unwrap(); assert_eq!(connection.total_count, 1); assert_eq!(first_record.node.type_u8, 0); // where filter GTE on u16 - let records = records_model_query(&schema, "(where: { type_u16GTE: 5 })").await; + let records = records_model_query( + &schema, + "(where: { type_u16GTE: 5 }, order: { direction: DESC, field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); assert_eq!(connection.total_count, 5); // where filter LTE on u32 - let records = records_model_query(&schema, "(where: { type_u32LTE: 4 })").await; + let records = records_model_query( + &schema, + "(where: { type_u32LTE: 4 }, order: { direction: DESC, field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); assert_eq!(connection.total_count, 5); // where filter LT and GT - let records = - records_model_query(&schema, "(where: { type_u32GT: 2, type_u16LT: 4 })").await; + let records = records_model_query( + &schema, + "(where: { type_u32GT: 2, type_u16LT: 4 }, order: { direction: DESC, field: RECORD_ID \ + })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); let first_record = connection.edges.first().unwrap(); assert_eq!(first_record.node.type_u16, 3); @@ -224,7 +243,8 @@ mod tests { let records = records_model_query( &schema, &format!( - "(where: {{ type_class_hash: \"{}\", type_contract_address: \"{}\" }})", + "(where: {{ type_class_hash: \"{}\", type_contract_address: \"{}\" }}, order: {{ \ + direction: DESC, field: RECORD_ID }})", felt_str_0x5, felt_int_5 ), ) @@ -234,9 +254,14 @@ mod tests { assert_eq!(first_record.node.type_class_hash, "0x5"); // where filter EQ on u64 (string) - let records = - records_model_query(&schema, &format!("(where: {{ type_u64: \"{}\" }})", felt_str_0x5)) - .await; + let records = records_model_query( + &schema, + &format!( + "(where: {{ type_u64: \"{}\" }}, order: {{ direction: DESC, field: RECORD_ID }})", + felt_str_0x5 + ), + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); let first_record = connection.edges.first().unwrap(); assert_eq!(first_record.node.type_u64, "0x5"); @@ -244,7 +269,11 @@ mod tests { // where filter GTE on u128 (string) let records = records_model_query( &schema, - &format!("(where: {{ type_u128GTE: \"{}\" }})", felt_str_0x5), + &format!( + "(where: {{ type_u128GTE: \"{}\" }}, order: {{ direction: DESC, field: RECORD_ID \ + }})", + felt_str_0x5 + ), ) .await; let connection: Connection = serde_json::from_value(records).unwrap(); @@ -257,7 +286,11 @@ mod tests { // where filter LT on u256 (string) let records = records_model_query( &schema, - &format!("(where: {{ type_u256LT: \"{}\" }})", felt_int_5), + &format!( + "(where: {{ type_u256LT: \"{}\" }}, order: {{ direction: DESC, field: RECORD_ID \ + }})", + felt_int_5 + ), ) .await; let connection: Connection = serde_json::from_value(records).unwrap(); @@ -268,30 +301,42 @@ mod tests { assert_eq!(last_record.node.type_u256, "0x0"); // where filter on true bool - let records = records_model_query(&schema, "(where: { type_bool: true })").await; + let records = records_model_query( + &schema, + "(where: { type_bool: true }, order: { direction: DESC, field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); let first_record = connection.edges.first().unwrap(); assert_eq!(connection.total_count, 5); assert!(first_record.node.type_bool, "should be true"); // where filter on false bool - let records = records_model_query(&schema, "(where: { type_bool: false })").await; + let records = records_model_query( + &schema, + "(where: { type_bool: false }, order: { direction: DESC, field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); let first_record = connection.edges.first().unwrap(); assert_eq!(connection.total_count, 5); assert!(!first_record.node.type_bool, "should be false"); // where filter on In - let records = - records_model_query(&schema, "(where: { type_feltIN: [\"0x5\", \"0x6\", \"0x7\"] })") - .await; + let records = records_model_query( + &schema, + "(where: { type_feltIN: [\"0x5\", \"0x6\", \"0x7\"] }, order: { direction: DESC, \ + field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); assert_eq!(connection.total_count, 3); // where filter on NotIn let records = records_model_query( &schema, - "(where: { type_feltNOTIN: [\"0x5\", \"0x6\", \"0x7\"] })", + "(where: { type_feltNOTIN: [\"0x5\", \"0x6\", \"0x7\"] }, order: { direction: DESC, \ + field: RECORD_ID })", ) .await; let connection: Connection = serde_json::from_value(records).unwrap(); @@ -339,7 +384,11 @@ mod tests { // *** WHERE FILTER + PAGINATION TESTING *** - let records = records_model_query(&schema, "(where: { type_u8GTE: 5 })").await; + let records = records_model_query( + &schema, + "(where: { type_u8GTE: 5 }, order: { field: TYPE_U8, direction: DESC })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); let one = connection.edges.first().unwrap(); let two = connection.edges.get(1).unwrap(); @@ -348,7 +397,11 @@ mod tests { let five = connection.edges.get(4).unwrap(); // cursor based pagination - let records = records_model_query(&schema, "(where: { type_u8GTE: 5 }, first: 2)").await; + let records = records_model_query( + &schema, + "(where: { type_u8GTE: 5 }, first: 2, order: { field: TYPE_U8, direction: DESC })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); let first_record = connection.edges.first().unwrap(); let last_record = connection.edges.last().unwrap(); @@ -359,7 +412,11 @@ mod tests { let records = records_model_query( &schema, - &format!("(where: {{ type_u8GTE: 5 }}, first: 3, after: \"{}\")", last_record.cursor), + &format!( + "(where: {{ type_u8GTE: 5 }}, first: 3, after: \"{}\", order: {{ field: TYPE_U8, \ + direction: DESC }})", + last_record.cursor + ), ) .await; let connection: Connection = serde_json::from_value(records).unwrap(); @@ -371,8 +428,12 @@ mod tests { assert_eq!(second_record, five); // offset/limit base pagination - let records = - records_model_query(&schema, "(where: { type_u8GTE: 5 }, limit: 2, offset: 2)").await; + let records = records_model_query( + &schema, + "(where: { type_u8GTE: 5 }, limit: 2, offset: 2, order: { field: TYPE_U8, direction: \ + DESC })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); let first_record = connection.edges.first().unwrap(); let last_record = connection.edges.last().unwrap(); @@ -425,7 +486,8 @@ mod tests { assert_eq!(connection.total_count, 10); // *** SUBRECORD TESTING *** - let subrecord = subrecord_model_query(&schema, "").await; + let subrecord = + subrecord_model_query(&schema, "(order: { direction: DESC, field: RECORD_ID })").await; let connection: Connection = serde_json::from_value(subrecord).unwrap(); let last_record = connection.edges.first().unwrap(); assert_eq!(last_record.node.record_id, 18); @@ -433,17 +495,29 @@ mod tests { // *** DELETE TESTING *** // where filter EQ on record_id, test Record with id 20 is deleted - let records = records_model_query(&schema, "(where: { record_id: 20 })").await; + let records = records_model_query( + &schema, + "(where: { record_id: 20 }, order: { direction: DESC, field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(records).unwrap(); assert_eq!(connection.edges.len(), 0); // where filter GTE on record_id, test Sibling with id 20 is deleted - let sibling = record_sibling_query(&schema, "(where: { record_id: 20 })").await; + let sibling = record_sibling_query( + &schema, + "(where: { record_id: 20 }, order: { direction: DESC, field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(sibling).unwrap(); assert_eq!(connection.edges.len(), 0); // where filter GTE on record_id, test Subrecord with id 20 is deleted - let subrecord = subrecord_model_query(&schema, "(where: { record_id: 20 })").await; + let subrecord = subrecord_model_query( + &schema, + "(where: { record_id: 20 }, order: { direction: DESC, field: RECORD_ID })", + ) + .await; let connection: Connection = serde_json::from_value(subrecord).unwrap(); assert_eq!(connection.edges.len(), 0); diff --git a/crates/torii/grpc/src/server/tests/entities_test.rs b/crates/torii/grpc/src/server/tests/entities_test.rs index f1d60f80c8..ba35e24b8a 100644 --- a/crates/torii/grpc/src/server/tests/entities_test.rs +++ b/crates/torii/grpc/src/server/tests/entities_test.rs @@ -104,8 +104,8 @@ async fn test_entities_queries() { Arc::clone(&provider), Processors { event: generate_event_processors_map(vec![ - Box::new(RegisterModelProcessor), - Box::new(StoreSetRecordProcessor), + Arc::new(RegisterModelProcessor), + Arc::new(StoreSetRecordProcessor), ]) .unwrap(), ..Processors::default() diff --git a/crates/torii/libp2p/src/tests.rs b/crates/torii/libp2p/src/tests.rs index 552b240590..7ef1472068 100644 --- a/crates/torii/libp2p/src/tests.rs +++ b/crates/torii/libp2p/src/tests.rs @@ -588,6 +588,7 @@ mod test { ) .await .unwrap(); + db.execute().await.unwrap(); // Initialize the relay server let mut relay_server = Relay::new(db, provider, 9900, 9901, 9902, None, None)?;