From 133fd3a495fad2c6740876b8b1350c30c7e51fc3 Mon Sep 17 00:00:00 2001 From: Nasr Date: Thu, 12 Sep 2024 15:10:26 -0400 Subject: [PATCH] fix: member clause deep fields --- crates/torii/grpc/src/server/mod.rs | 330 +++++++++++++--------------- 1 file changed, 153 insertions(+), 177 deletions(-) diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index c737858197..015414ef45 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -258,8 +258,7 @@ impl DojoWorld { "# ); // total count of rows without limit and offset - let total_count: u32 = - sqlx::query_scalar(&count_query).fetch_optional(&self.pool).await?.unwrap_or(0); + let total_count: u32 = sqlx::query_scalar(&count_query).fetch_one(&self.pool).await?; if total_count == 0 { return Ok((Vec::new(), 0)); @@ -377,11 +376,8 @@ impl DojoWorld { } ); - let total_count = sqlx::query_scalar(&count_query) - .bind(&keys_pattern) - .fetch_optional(&self.pool) - .await? - .unwrap_or(0); + let total_count = + sqlx::query_scalar(&count_query).bind(&keys_pattern).fetch_one(&self.pool).await?; if total_count == 0 { return Ok((Vec::new(), 0)); @@ -502,15 +498,53 @@ impl DojoWorld { limit: Option, offset: Option, ) -> Result<(Vec, u32), Error> { - let (where_clause, join_clause, having_clause, comparison_value, model_id) = - build_member_clause(table).await?; + let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) + .expect("invalid comparison operator"); + + let primitive: Primitive = + member_clause.value.ok_or(QueryError::MissingParam("value".into()))?.try_into()?; - let schemas = self.fetch_schemas(table, model_relation_table, model_id).await?; + let comparison_value = primitive.to_sql_value()?; + + let (namespace, model) = member_clause + .model + .split_once('-') + .ok_or(QueryError::InvalidNamespacedModel(member_clause.model.clone()))?; + + let models_query = format!( + r#" + SELECT group_concat({model_relation_table}.model_id) as model_ids + FROM {table} + JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id + GROUP BY {table}.id + HAVING INSTR(model_ids, '{:#x}') > 0 + LIMIT 1 + "#, + compute_selector_from_names(namespace, model) + ); + let (models_str,): (String,) = sqlx::query_as(&models_query).fetch_one(&self.pool).await?; + + let model_ids = models_str + .split(',') + .map(Felt::from_str) + .collect::, _>>() + .map_err(ParseError::FromStr)?; + let schemas = + self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); + + let model = member_clause.model.clone(); + let parts: Vec<&str> = member_clause.member.split('.').collect(); + let (table_name, column_name) = if parts.len() > 1 { + let nested_table = parts[..parts.len() - 1].join("$"); + (format!("{model}${nested_table}"), format!("external_{}", parts.last().unwrap())) + } else { + (format!("{model}"), format!("external_{}", member_clause.member)) + }; let (entity_query, arrays_queries, count_query) = build_sql_query( &schemas, table, entity_relation_column, - Some(&where_clause), + Some(&format!("[{table_name}].{column_name} {comparison_operator} ?")), None, limit, offset, @@ -518,19 +552,19 @@ impl DojoWorld { let total_count = sqlx::query_scalar(&count_query) .bind(comparison_value.clone()) - .fetch_optional(&self.pool) - .await? - .unwrap_or(0); + .fetch_one(&self.pool) + .await?; let db_entities = sqlx::query(&entity_query) - .bind(&comparison_value) + .bind(comparison_value.clone()) .bind(limit) .bind(offset) .fetch_all(&self.pool) .await?; let mut arrays_rows = HashMap::new(); for (name, query) in arrays_queries { - let rows = sqlx::query(&query).bind(&comparison_value).fetch_all(&self.pool).await?; + let rows = + sqlx::query(&query).bind(comparison_value.clone()).fetch_all(&self.pool).await?; arrays_rows.insert(name, rows); } @@ -550,21 +584,9 @@ impl DojoWorld { limit: Option, offset: Option, ) -> Result<(Vec, u32), Error> { - let (where_clause, having_clause, join_clause, bind_values, model_ids) = + let (where_clause, having_clause, join_clause, bind_values) = self.build_composite_clause(table, model_relation_table, &composite)?; - let schemas = self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); - - let (entity_query, arrays_queries, count_query) = build_sql_query( - &schemas, - table, - entity_relation_column, - Some(&where_clause), - Some(&having_clause), - limit, - offset, - )?; - let count_query = format!( r#" SELECT COUNT(DISTINCT [{table}].id) @@ -581,7 +603,7 @@ impl DojoWorld { count_query = count_query.bind(value); } - let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0); + let total_count = count_query.fetch_one(&self.pool).await?; if total_count == 0 { return Ok((Vec::new(), 0)); @@ -642,6 +664,107 @@ impl DojoWorld { Ok((entities, total_count)) } + fn build_composite_clause( + &self, + table: &str, + model_relation_table: &str, + composite: &proto::types::CompositeClause, + ) -> Result<(String, String, String, Vec), Error> { + let is_or = composite.operator == LogicalOperator::Or as i32; + let mut where_clauses = Vec::new(); + let mut join_clauses = Vec::new(); + let mut having_clauses = Vec::new(); + let mut bind_values = Vec::new(); + + for clause in &composite.clauses { + match clause.clause_type.as_ref().unwrap() { + ClauseType::HashedKeys(hashed_keys) => { + let ids = hashed_keys + .hashed_keys + .iter() + .map(|id| { + bind_values.push(Felt::from_bytes_be_slice(id).to_string()); + "?".to_string() + }) + .collect::>() + .join(", "); + where_clauses.push(format!("{table}.id IN ({})", ids)); + } + ClauseType::Keys(keys) => { + let keys_pattern = build_keys_pattern(keys)?; + bind_values.push(keys_pattern); + where_clauses.push(format!("{table}.keys REGEXP ?")); + } + ClauseType::Member(member) => { + let comparison_operator = + ComparisonOperator::from_repr(member.operator as usize) + .expect("invalid comparison operator"); + let value: Primitive = member.value.as_ref().unwrap().clone().try_into()?; + let comparison_value = value.to_sql_value()?; + bind_values.push(comparison_value); + + let model = member.model.clone(); + let parts: Vec<&str> = member.member.split('.').collect(); + let (table_name, column_name) = if parts.len() > 1 { + let nested_table = parts[..parts.len() - 1].join("$"); + ( + format!("[{model}${nested_table}]"), + format!("external_{}", parts.last().unwrap()), + ) + } else { + (format!("[{model}]"), format!("external_{}", member.member)) + }; + + let (namespace, model) = member + .model + .split_once('-') + .ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?; + let model_id = compute_selector_from_names(namespace, model); + join_clauses.push(format!( + "LEFT JOIN {table_name} ON [{table}].id = {table_name}.entity_id" + )); + where_clauses + .push(format!("{table_name}.{column_name} {comparison_operator} ?")); + having_clauses.push(format!( + "INSTR(group_concat({model_relation_table}.model_id), '{:#x}') > 0", + model_id + )); + } + ClauseType::Composite(nested_composite) => { + let (nested_where, nested_having, nested_join, nested_values) = + self.build_composite_clause(table, model_relation_table, nested_composite)?; + where_clauses.push(format!("({})", nested_where.trim_start_matches("WHERE "))); + if !nested_having.is_empty() { + having_clauses + .push(nested_having.trim_start_matches("HAVING ").to_string()); + } + join_clauses.extend( + nested_join + .split_whitespace() + .filter(|&s| s.starts_with("LEFT")) + .map(String::from), + ); + bind_values.extend(nested_values); + } + _ => return Err(QueryError::UnsupportedQuery.into()), + } + } + + let join_clause = join_clauses.join(" "); + let where_clause = if !where_clauses.is_empty() { + format!("WHERE {}", where_clauses.join(if is_or { " OR " } else { " AND " })) + } else { + String::new() + }; + let having_clause = if !having_clauses.is_empty() { + format!("HAVING {}", having_clauses.join(if is_or { " OR " } else { " AND " })) + } else { + String::new() + }; + + Ok((where_clause, having_clause, join_clause, bind_values)) + } + pub async fn model_metadata( &self, namespace: &str, @@ -865,39 +988,6 @@ impl DojoWorld { ) -> Result>, Error> { self.event_manager.add_subscriber(clause.into()).await } - - async fn fetch_schemas( - &self, - table: &str, - model_relation_table: &str, - model_id: Felt, - ) -> Result, Error> { - let models_query = format!( - r#" - SELECT group_concat({model_relation_table}.model_id) as model_ids - FROM {table} - JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id - GROUP BY {table}.id - HAVING INSTR(model_ids, '{:#x}') > 0 - LIMIT 1 - "#, - model_id - ); - let (models_str,): (String,) = - sqlx::query_as(&models_query).fetch_optional(&self.pool).await?; - if models_str.is_none() { - return Ok(vec![]); - } - - let models_str = models_str.unwrap(); - let model_ids = models_str - .split(',') - .map(Felt::from_str) - .collect::, _>>() - .map_err(ParseError::FromStr)?; - - Ok(self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect()) - } } fn process_event_field(data: &str) -> Result>, Error> { @@ -956,120 +1046,6 @@ fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result Result<(String, String, String, String, Felt), Error> { - let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) - .expect("invalid comparison operator"); - - let primitive: Primitive = member_clause - .value - .as_ref() - .ok_or(QueryError::MissingParam("value".into()))? - .clone() - .try_into()?; - - let comparison_value = primitive.to_sql_value()?; - - let (namespace, model) = member_clause - .model - .split_once('-') - .ok_or(QueryError::InvalidNamespacedModel(member_clause.model.clone()))?; - - let model_id = compute_selector_from_names(namespace, model); - - let table_name = &member_clause.model; - let parts: Vec<&str> = member_clause.member.split('.').collect(); - let (join_table_name, column_name) = if parts.len() > 1 { - let nested_table = parts[..parts.len() - 1].join("$"); - (format!("[{table_name}${nested_table}]"), format!("external_{}", parts.last().unwrap())) - } else { - (format!("[{table_name}]"), format!("external_{}", member_clause.member)) - }; - - let where_clause = format!("{join_table_name}.{column_name} {comparison_operator} ?"); - let join_clause = - format!("LEFT JOIN {join_table_name} ON [{{table}}].id = {join_table_name}.entity_id"); - let having_clause = - format!("INSTR(group_concat({{model_relation_table}}.model_id), '{:#x}') > 0", model_id); - - Ok((where_clause, join_clause, having_clause, comparison_value, model_id)) -} - -fn build_composite_clause( - &self, - composite: &proto::types::CompositeClause, -) -> Result<(String, String, String, Vec, Vec), Error> { - let is_or = composite.operator == LogicalOperator::Or as i32; - let mut where_clauses = Vec::new(); - let mut join_clauses = Vec::new(); - let mut having_clauses = Vec::new(); - let mut bind_values = Vec::new(); - let mut model_ids = Vec::new(); - - for clause in &composite.clauses { - match clause.clause_type.as_ref().unwrap() { - ClauseType::HashedKeys(hashed_keys) => { - let ids = hashed_keys - .hashed_keys - .iter() - .map(|id| { - bind_values.push(Felt::from_bytes_be_slice(id).to_string()); - "?".to_string() - }) - .collect::>() - .join(", "); - where_clauses.push(format!("{{table}}.id IN ({})", ids)); - } - ClauseType::Keys(keys) => { - let keys_pattern = build_keys_pattern(keys)?; - bind_values.push(keys_pattern); - where_clauses.push(format!("{{table}}.keys REGEXP ?")); - } - ClauseType::Member(member) => { - let (member_where, member_join, member_having, member_value, member_model_id) = - self.build_member_clause(member)?; - where_clauses.push(member_where); - join_clauses.push(member_join); - having_clauses.push(member_having); - bind_values.push(member_value); - model_ids.push(member_model_id); - } - ClauseType::Composite(nested_composite) => { - let (nested_where, nested_having, nested_join, nested_values, nested_model_ids) = - self.build_composite_clause(nested_composite)?; - where_clauses.push(format!("({})", nested_where.trim_start_matches("WHERE "))); - if !nested_having.is_empty() { - having_clauses.push(nested_having.trim_start_matches("HAVING ").to_string()); - } - join_clauses.extend( - nested_join - .split_whitespace() - .filter(|&s| s.starts_with("LEFT")) - .map(String::from), - ); - bind_values.extend(nested_values); - model_ids.extend(nested_model_ids); - } - _ => return Err(QueryError::UnsupportedQuery.into()), - } - } - - let join_clause = join_clauses.join(" "); - let where_clause = if !where_clauses.is_empty() { - format!("WHERE {}", where_clauses.join(if is_or { " OR " } else { " AND " })) - } else { - String::new() - }; - let having_clause = if !having_clauses.is_empty() { - format!("HAVING {}", having_clauses.join(if is_or { " OR " } else { " AND " })) - } else { - String::new() - }; - - Ok((where_clause, having_clause, join_clause, bind_values, model_ids)) -} - type ServiceResult = Result, Status>; type SubscribeModelsResponseStream = Pin> + Send>>;