Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(torii-grpc): ordering and pagination #2765

Merged
merged 7 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion crates/torii/core/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub fn build_sql_query(
table_name: &str,
entity_relation_column: &str,
where_clause: Option<&str>,
order_by: Option<&str>,
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(String, String), Error> {
Expand Down Expand Up @@ -196,7 +197,12 @@ pub fn build_sql_query(
count_query += &format!(" WHERE {}", where_clause);
}

query += &format!(" ORDER BY {}.event_id DESC", table_name);
// Use custom order by if provided, otherwise default to event_id DESC
if let Some(order_clause) = order_by {
query += &format!(" ORDER BY {}", order_clause);
} else {
query += &format!(" ORDER BY {}.event_id DESC", table_name);
}
glihm marked this conversation as resolved.
Show resolved Hide resolved

if let Some(limit) = limit {
query += &format!(" LIMIT {}", limit);
Expand Down Expand Up @@ -487,6 +493,7 @@ mod tests {
None,
None,
None,
None,
)
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion crates/torii/core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,4 @@ pub struct ContractCursor {
pub contract_address: String,
pub last_pending_block_tx: Option<String>,
pub last_pending_block_contract_tx: Option<String>,
}
}
17 changes: 17 additions & 0 deletions crates/torii/grpc/proto/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,15 @@ message ModelUpdate {

message Query {
Clause clause = 1;
// Standard offset pagination
uint32 limit = 2;
uint32 offset = 3;
bool dont_include_hashed_keys = 4;
repeated OrderBy order_by = 5;

// Updated at time-based filter
uint64 updated_before = 6;
uint64 updated_after = 7;
}

message EventQuery {
Expand Down Expand Up @@ -164,4 +170,15 @@ message TokenBalance {
string account_address = 2;
string contract_address = 3;
string token_id = 4;
}

message OrderBy {
string model = 1;
string member = 2;
OrderDirection direction = 3;
}

enum OrderDirection {
ASC = 0;
DESC = 1;
}
55 changes: 48 additions & 7 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ impl DojoWorld {
entity_relation_column: &str,
entities: Vec<(String, String)>,
dont_include_hashed_keys: bool,
where_clause: Option<&str>,
order_by: Option<&str>,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential SQL Injection Vulnerability: Validate where_clause Parameter

Ohayo sensei! Passing the where_clause parameter directly into SQL query strings without proper sanitization can lead to SQL injection vulnerabilities. Please ensure that where_clause is properly validated or use parameterized queries to prevent SQL injection attacks.

Apply this diff to use parameterized queries:

-            let where_clause = where_clause.map_or(
-                "[{table}].id IN (SELECT id FROM temp_entity_ids WHERE model_group = ?)",
-                |clause| &format!("{} AND [{table}].id IN (SELECT id FROM temp_entity_ids WHERE model_group = ?)", clause),
-            );
+            let mut where_conditions = vec![];
+            let mut query_params = vec![];
+
+            // Always include the model group condition
+            where_conditions.push(format!("[{table}].id IN (SELECT id FROM temp_entity_ids WHERE model_group = ?)"));
+            query_params.push(models_str.clone());
+
+            // Validate and include the where_clause if present
+            if let Some(clause) = where_clause {
+                if is_valid_clause(clause) {
+                    where_conditions.push(clause.to_string());
+                } else {
+                    return Err(Error::InvalidParameter("where_clause".into()));
+                }
+            }
+
+            let where_clause = where_conditions.join(" AND ");

Also applies to: 298-302

) -> Result<Vec<proto::types::Entity>, Error> {
// Group entities by their model combinations
let mut model_groups: HashMap<String, Vec<String>> = HashMap::new();
Expand Down Expand Up @@ -293,6 +295,11 @@ impl DojoWorld {
}
}

let where_clause = where_clause.map_or(
"[{table}].id IN (SELECT id FROM temp_entity_ids WHERE model_group = ?)",
|clause| &format!("{} AND [{table}].id IN (SELECT id FROM temp_entity_ids WHERE model_group = ?)", clause),
);

for (models_str, _) in model_groups {
let model_ids =
models_str.split(',').map(|id| Felt::from_str(id).unwrap()).collect::<Vec<_>>();
Expand All @@ -303,9 +310,8 @@ impl DojoWorld {
&schemas,
table,
entity_relation_column,
Some(&format!(
"[{table}].id IN (SELECT id FROM temp_entity_ids WHERE model_group = ?)"
)),
Some(&where_clause),
order_by,
None,
None,
)?;
Expand Down Expand Up @@ -377,6 +383,7 @@ impl DojoWorld {
limit: Option<u32>,
offset: Option<u32>,
dont_include_hashed_keys: bool,
where_clause: Option<&str>,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential SQL Injection Vulnerability: Validate Input Parameters

Ohayo sensei! Passing user-provided where_clause, order_by, updated_before, and updated_after parameters directly into SQL queries can lead to SQL injection vulnerabilities. Please ensure that these parameters are properly sanitized or use parameterized queries.

Consider modifying the code to use parameterized queries for these parameters and validate their values before incorporating them into SQL statements.

Also applies to: 475-477

) -> Result<(Vec<proto::types::Entity>, u32), Error> {
// TODO: use prepared statement for where clause
let filter_ids = match hashed_keys {
Expand Down Expand Up @@ -465,9 +472,19 @@ impl DojoWorld {
limit: Option<u32>,
offset: Option<u32>,
dont_include_hashed_keys: bool,
updated_before: Option<u64>,
updated_after: Option<u64>,
order_by: Option<&str>,
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
let keys_pattern = build_keys_pattern(keys_clause)?;

let mut where_clause = vec![];
if let Some(updated_before) = updated_before {
where_clause.push(format!("updated_at <= {}", updated_before));
}
if let Some(updated_after) = updated_after {
where_clause.push(format!("updated_at >= {}", updated_after));
}
// total count of rows that matches keys_pattern without limit and offset
let count_query = format!(
r#"
Expand Down Expand Up @@ -628,6 +645,9 @@ impl DojoWorld {
limit: Option<u32>,
offset: Option<u32>,
dont_include_hashed_keys: bool,
updated_before: Option<u64>,
updated_after: Option<u64>,
order_by: Option<&str>,
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize)
.expect("invalid comparison operator");
Expand Down Expand Up @@ -675,14 +695,23 @@ impl DojoWorld {
self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect();

// Use the member name directly as the column name since it's already flattened
let mut where_clause = format!(
"[{}].[{}] {comparison_operator} ?",
member_clause.model, member_clause.member
);
if let Some(updated_before) = updated_before {
where_clause += &format!(" AND updated_at <= {}", updated_before);
}
if let Some(updated_after) = updated_after {
where_clause += &format!(" AND updated_at >= {}", updated_after);
}

let (entity_query, count_query) = build_sql_query(
&schemas,
table,
entity_relation_column,
Some(&format!(
"[{}].[{}] {comparison_operator} ?",
member_clause.model, member_clause.member
)),
Some(&where_clause),
order_by,
limit,
offset,
)?;
Expand Down Expand Up @@ -909,6 +938,15 @@ impl DojoWorld {
entity_relation_column: &str,
query: proto::types::Query,
) -> Result<proto::world::RetrieveEntitiesResponse, Error> {
let order_by = query
.order_by
.map(|order_by| {
format!(
"[{}] [{}] {}",
order_by.model, order_by.member, order_by.direction
)
});

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential SQL Injection Risk in retrieve_entities Method

Ohayo sensei! The order_by parameter is constructed using string formatting and injected directly into the SQL query, which can introduce SQL injection vulnerabilities. Please validate the order_by parameter against a whitelist of allowed values or use parameterized queries to prevent SQL injection attacks.

Consider the following refactor to validate order_by:

             let order_by = query
                 .order_by
                 .map(|order_by| {
+                    // Validate that model and member are valid identifiers
+                    if !is_valid_identifier(&order_by.model) || !is_valid_identifier(&order_by.member) {
+                        return Err(QueryError::InvalidOrderByParameter);
+                    }

                     format!(
                         "[{}].[{}] {}",
                         order_by.model, order_by.member, order_by.direction
                     )
                 });

Implement a helper function is_valid_identifier to ensure that only allowed column names are used.

Committable suggestion skipped: line range outside the PR's diff.

let (entities, total_count) = match query.clause {
None => {
self.entities_all(
Expand Down Expand Up @@ -963,6 +1001,9 @@ impl DojoWorld {
Some(query.limit),
Some(query.offset),
query.dont_include_hashed_keys,
query.updated_before,
query.updated_after,
order_by,
)
.await?
}
Expand Down
Loading