diff --git a/Cargo.toml b/Cargo.toml index fc5dffa02..a10d74ae6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ tracing = { version = "0.1", default-features = false, features = ["attributes", rust_decimal = { version = "1", default-features = false, optional = true } bigdecimal = { version = "0.3", default-features = false, optional = true } sea-orm-macros = { version = "0.12.0-rc.3", path = "sea-orm-macros", default-features = false, features = ["strum"] } -sea-query = { version = "0.29.0-rc.2", features = ["thread-safe"] } +sea-query = { version = "0.29.0-rc.2", features = ["thread-safe", "hashable-value"] } sea-query-binder = { version = "0.4.0-rc.2", default-features = false, optional = true } strum = { version = "0.24", default-features = false } serde = { version = "1.0", default-features = false } diff --git a/src/executor/select.rs b/src/executor/select.rs index 73815f1e2..42c994dcc 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -4,7 +4,8 @@ use crate::{ SelectB, SelectTwo, SelectTwoMany, Statement, StreamTrait, TryGetableMany, }; use futures::{Stream, TryStreamExt}; -use sea_query::SelectStatement; +use sea_query::{SelectStatement, Value}; +use std::collections::HashMap; use std::marker::PhantomData; use std::pin::Pin; @@ -991,37 +992,39 @@ where } fn consolidate_query_result( - rows: Vec<(L::Model, Option)>, + mut rows: Vec<(L::Model, Option)>, ) -> Vec<(L::Model, Vec)> where L: EntityTrait, R: EntityTrait, { - let mut acc: Vec<(L::Model, Vec)> = Vec::new(); - for (l, r) in rows { - if let Some((last_l, last_r)) = acc.last_mut() { - let mut same_l = true; - for pk_col in ::iter() { - let col = pk_col.into_column(); - let val = l.get(col); - let last_val = last_l.get(col); - if !val.eq(&last_val) { - same_l = false; - break; - } - } - if same_l { - if let Some(r) = r { - last_r.push(r); - continue; - } + //todo: could take not iter + let pkcol = ::iter() + .next() + .expect("should have primary key") + .into_column(); + + let mut hashmap: HashMap> = rows.iter_mut().fold( + HashMap::>::new(), + |mut acc: HashMap>, row: &mut (L::Model, Option)| { + let key = row.0.get(pkcol); + let value = row.1.take().expect("should have a linked entity"); + let vec: Option<&mut Vec> = acc.get_mut(&key); + if let Some(vec) = vec { + vec.push(value) + } else { + acc.insert(key, vec![value]); } - } - let rows = match r { - Some(r) => vec![r], - None => vec![], - }; - acc.push((l, rows)); - } - acc + + acc + }, + ); + + rows.into_iter() + .filter_map(|(l_model, _)| { + let l_pk = l_model.get(pkcol); + let r_models = hashmap.remove(&l_pk); + r_models.map(|r_models| (l_model, r_models)) + }) + .collect() } diff --git a/src/query/combine.rs b/src/query/combine.rs index 9ffb8e6ed..61d957f74 100644 --- a/src/query/combine.rs +++ b/src/query/combine.rs @@ -119,12 +119,16 @@ where F: EntityTrait, { pub(crate) fn new(query: SelectStatement) -> Self { + Self::new_without_prepare(query) + .prepare_select() + .prepare_order_by() + } + + pub(crate) fn new_without_prepare(query: SelectStatement) -> Self { Self { query, entity: PhantomData, } - .prepare_select() - .prepare_order_by() } fn prepare_select(mut self) -> Self { diff --git a/src/query/join.rs b/src/query/join.rs index 09f2e3857..6ed56de0c 100644 --- a/src/query/join.rs +++ b/src/query/join.rs @@ -107,6 +107,52 @@ where } select_two } + + /// Left Join with a Linked Entity and select Entity as a `Vec`. + pub fn find_with_linked(self, l: L) -> SelectTwoMany + where + L: Linked, + T: EntityTrait, + { + let mut slf = self; + for (i, mut rel) in l.link().into_iter().enumerate() { + let to_tbl = Alias::new(format!("r{i}")).into_iden(); + let from_tbl = if i > 0 { + Alias::new(format!("r{}", i - 1)).into_iden() + } else { + unpack_table_ref(&rel.from_tbl) + }; + let table_ref = rel.to_tbl; + + let mut condition = Condition::all().add(join_tbl_on_condition( + SeaRc::clone(&from_tbl), + SeaRc::clone(&to_tbl), + rel.from_col, + rel.to_col, + )); + if let Some(f) = rel.on_condition.take() { + condition = condition.add(f(SeaRc::clone(&from_tbl), SeaRc::clone(&to_tbl))); + } + + slf.query() + .join_as(JoinType::LeftJoin, table_ref, to_tbl, condition); + } + slf = slf.apply_alias(SelectA.as_str()); + let mut select_two_many = SelectTwoMany::new_without_prepare(slf.query); + for col in ::iter() { + let alias = format!("{}{}", SelectB.as_str(), col.as_str()); + let expr = Expr::col(( + Alias::new(format!("r{}", l.link().len() - 1)).into_iden(), + col.into_iden(), + )); + select_two_many.query().expr(SelectExpr { + expr: col.select_as(expr), + alias: Some(SeaRc::new(Alias::new(alias))), + window: None, + }); + } + select_two_many + } } #[cfg(test)] diff --git a/tests/relational_tests.rs b/tests/relational_tests.rs index 959f05577..6b0c6950b 100644 --- a/tests/relational_tests.rs +++ b/tests/relational_tests.rs @@ -2,6 +2,7 @@ pub mod common; pub use chrono::offset::Utc; pub use common::{bakery_chain::*, setup::*, TestContext}; +use pretty_assertions::assert_eq; pub use rust_decimal::prelude::*; pub use rust_decimal_macros::dec; use sea_orm::{entity::*, query::*, DbErr, DerivePartialModel, FromQueryResult}; @@ -747,6 +748,85 @@ pub async fn linked() -> Result<(), DbErr> { }] ); + let select_baker_with_customer = Baker::find() + .find_with_linked(baker::BakedForCustomer) + .order_by_asc(baker::Column::Id) + .order_by_asc(Expr::col((Alias::new("r4"), customer::Column::Id))); + + assert_eq!( + select_baker_with_customer + .build(sea_orm::DatabaseBackend::MySql) + .to_string(), + [ + // FIXME: This might be faulty! + "SELECT `baker`.`id` AS `A_id`,", + "`baker`.`name` AS `A_name`,", + "`baker`.`contact_details` AS `A_contact_details`,", + "`baker`.`bakery_id` AS `A_bakery_id`,", + "`r4`.`id` AS `B_id`,", + "`r4`.`name` AS `B_name`,", + "`r4`.`notes` AS `B_notes`", + "FROM `baker`", + "LEFT JOIN `cakes_bakers` AS `r0` ON `baker`.`id` = `r0`.`baker_id`", + "LEFT JOIN `cake` AS `r1` ON `r0`.`cake_id` = `r1`.`id`", + "LEFT JOIN `lineitem` AS `r2` ON `r1`.`id` = `r2`.`cake_id`", + "LEFT JOIN `order` AS `r3` ON `r2`.`order_id` = `r3`.`id`", + "LEFT JOIN `customer` AS `r4` ON `r3`.`customer_id` = `r4`.`id`", + "ORDER BY `baker`.`id` ASC, `r4`.`id` ASC" + ] + .join(" ") + ); + + assert_eq!( + select_baker_with_customer.all(&ctx.db).await?, + [ + ( + baker::Model { + id: 1, + name: "Baker Bob".into(), + contact_details: serde_json::json!({ + "mobile": "+61424000000", + "home": "0395555555", + "address": "12 Test St, Testville, Vic, Australia", + }), + bakery_id: Some(1), + }, + vec![customer::Model { + id: 2, + name: "Kara".into(), + notes: Some("Loves all cakes".into()), + }] + ), + ( + baker::Model { + id: 2, + name: "Baker Bobby".into(), + contact_details: serde_json::json!({ + "mobile": "+85212345678", + }), + bakery_id: Some(1), + }, + vec![ + customer::Model { + id: 1, + name: "Kate".into(), + notes: Some("Loves cheese cake".into()), + }, + customer::Model { + id: 1, + name: "Kate".into(), + notes: Some("Loves cheese cake".into()), + }, + customer::Model { + id: 2, + name: "Kara".into(), + notes: Some("Loves all cakes".into()), + }, + ] + ), + ] + ); + ctx.delete().await; Ok(())