From c7def77c0d0758156be2add76dd64ddab43ced73 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 19 Nov 2024 18:27:05 +0100 Subject: [PATCH] almost there... just need to select key columns and hook it up --- crates/polars-core/src/datatypes/field.rs | 6 + crates/polars-core/src/frame/mod.rs | 9 ++ .../polars-expr/src/chunked_idx_table/mod.rs | 2 +- .../src/chunked_idx_table/row_encoded.rs | 10 +- .../src/nodes/joins/equi_join.rs | 126 ++++++++++++++---- crates/polars-stream/src/physical_plan/fmt.rs | 8 +- crates/polars-stream/src/physical_plan/mod.rs | 10 +- .../src/physical_plan/to_graph.rs | 25 ++++ .../polars-utils/src/idx_map/bytes_idx_map.rs | 8 ++ 9 files changed, 175 insertions(+), 29 deletions(-) diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index b85caeec0a2e..7ff81d7277ea 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -96,6 +96,12 @@ impl Field { pub fn set_name(&mut self, name: PlSmallStr) { self.name = name; } + + /// Returns this `Field`, renamed. + pub fn with_name(mut self, name: PlSmallStr) -> Self { + self.name = name; + self + } /// Converts the `Field` to an `arrow::datatypes::Field`. /// diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index a09350242ae6..eb14812fc376 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -380,6 +380,15 @@ impl DataFrame { unsafe { DataFrame::new_no_checks(0, cols) } } + /// Create a new `DataFrame` with the given schema, only containing nulls. + pub fn full_null(schema: &Schema, height: usize) -> Self { + let columns = schema + .iter_fields() + .map(|f| Column::full_null(f.name.clone(), height, f.dtype())) + .collect(); + DataFrame { height, columns } + } + /// Removes the last `Series` from the `DataFrame` and returns it, or [`None`] if it is empty. /// /// # Example diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs index 14c7e036bb31..141072083d8a 100644 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -55,7 +55,7 @@ pub trait ChunkedIdxTable: Any + Send + Sync { ) -> IdxSize; /// Get the ChunkIds for each key which was never marked during probing. - fn unmarked_keys(&self, out: &mut Vec>); + fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize); } pub fn new_chunked_idx_table(key_schema: Arc) -> Box { diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs index 950e6a03844a..d82bb7d4425c 100644 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -252,8 +252,10 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { } } - fn unmarked_keys(&self, out: &mut Vec>) { - for chunk_ids in self.idx_map.iter_values() { + fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize) { + out.clear(); + + while let Some((_, _, chunk_ids)) = self.idx_map.get_index(offset) { let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; let first_chunk_val = first_chunk_id.load(Ordering::Acquire); if first_chunk_val >> 63 == 0 { @@ -263,6 +265,10 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { out.push(chunk_id); } } + + if out.len() >= limit as usize { + break; + } } } } diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 73af8e73cd8b..4087b4796856 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use polars_core::prelude::{PlHashSet, PlRandomState}; -use polars_core::schema::Schema; +use polars_core::schema::{Schema, SchemaExt}; use polars_core::series::IsSorted; use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_expr::chunked_idx_table::{new_chunked_idx_table, ChunkedIdxTable}; @@ -15,7 +15,8 @@ use polars_utils::{format_pl_smallstr, IdxSize}; use rayon::prelude::*; use crate::async_primitives::connector::{Receiver, Sender}; -use crate::morsel::get_ideal_morsel_size; +use crate::async_primitives::wait_group::WaitGroup; +use crate::morsel::{get_ideal_morsel_size, SourceToken}; use crate::nodes::compute_node_prelude::*; /// A payload selector contains for each column whether that column should be @@ -56,6 +57,13 @@ fn compute_payload_selector( .collect() } +fn select_schema(schema: &Schema, selector: &[Option]) -> Schema { + schema.iter_fields() + .zip(selector) + .filter_map(|(f, name)| Some(f.with_name(name.clone()?))) + .collect() +} + fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { // Maintain height of zero-width dataframes. if df.width() == 0 { @@ -248,43 +256,77 @@ impl ProbeState { Ok(()) } +} + +struct EmitUnmatchedState { + partitions: Vec, + active_partition_idx: usize, + offset_in_active_p: usize, +} +impl EmitUnmatchedState { async fn emit_unmatched( + &mut self, mut send: Sender, - partitions: &[ProbeTable], params: &EquiJoinParams, + num_pipelines: usize, ) -> PolarsResult<()> { + let total_len: usize = self.partitions.iter().map(|p| p.table.num_keys() as usize).sum(); + let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1); + let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); + let morsel_size = total_len.div_ceil(morsel_count).max(1); + + let mut morsel_seq = MorselSeq::default(); + let wait_group = WaitGroup::default(); let source_token = SourceToken::new(); let mut unmarked_idxs = Vec::new(); - unsafe { - for p in partitions { - p.table.unmarked_keys(&mut unmarked_idxs); - let build_df = p.df.take_chunked_unchecked(&table_match, IsSorted::Not); + while let Some(p) = self.partitions.get(self.active_partition_idx) { + loop { + p.table.unmarked_keys(&mut unmarked_idxs, self.offset_in_active_p as IdxSize, morsel_size as IdxSize); + self.offset_in_active_p += unmarked_idxs.len(); + if unmarked_idxs.is_empty() { + break; + } - let out_df = if params.left_is_build { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df + let out_df = unsafe { + let mut build_df = p.df.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not); + let len = build_df.height(); + if params.left_is_build { + let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, len); + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + } }; + let mut morsel = Morsel::new(out_df, morsel_seq, source_token.clone()); + morsel_seq = morsel_seq.successor(); + morsel.set_consume_token(wait_group.token()); + if send.send(morsel).await.is_err() { + return Ok(()); + } - - let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1); - let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); - self.morsel_size = len.div_ceil(morsel_count).max(1); - - + wait_group.wait().await; + if source_token.stop_requested() { + return Ok(()); + } } + + self.active_partition_idx += 1; + self.offset_in_active_p = 0; } + + Ok(()) } } enum EquiJoinState { Build(BuildState), Probe(ProbeState), - EmitUnmatchedBuild(ProbeState), + EmitUnmatchedBuild(EmitUnmatchedState), Done, } @@ -292,6 +334,8 @@ struct EquiJoinParams { left_is_build: bool, left_payload_select: Vec>, right_payload_select: Vec>, + left_payload_schema: Schema, + right_payload_schema: Schema, args: JoinArgs, random_state: PlRandomState, } @@ -341,6 +385,9 @@ impl EquiJoinNode { compute_payload_selector(&left_input_schema, &right_input_schema, true, &args); let right_payload_select = compute_payload_selector(&right_input_schema, &left_input_schema, false, &args); + + let left_payload_schema = select_schema(&left_input_schema, &left_payload_select); + let right_payload_schema = select_schema(&right_input_schema, &right_payload_select); Self { state: EquiJoinState::Build(BuildState { partitions_per_worker: Vec::new(), @@ -350,6 +397,8 @@ impl EquiJoinNode { left_is_build, left_payload_select, right_payload_select, + left_payload_schema, + right_payload_schema, args, random_state: PlRandomState::new(), }, @@ -373,9 +422,8 @@ impl ComputeNode for EquiJoinNode { let build_idx = if self.params.left_is_build { 0 } else { 1 }; let probe_idx = 1 - build_idx; - // If the output doesn't want any more data, or the probe side is done, - // transition to being done. - if send[0] == PortState::Done || recv[probe_idx] == PortState::Done { + // If the output doesn't want any more data, transition to being done. + if send[0] == PortState::Done { self.state = EquiJoinState::Done; } @@ -385,6 +433,29 @@ impl ComputeNode for EquiJoinNode { self.state = EquiJoinState::Probe(build_state.finalize(&*self.table)); } } + + // If we are probing and the probe input is done, emit unmatched if + // necessary, otherwise we're done. + if let EquiJoinState::Probe(probe_state) = &mut self.state { + if recv[probe_idx] == PortState::Done { + if self.params.emit_unmatched_build() { + self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState { + partitions: core::mem::take(&mut probe_state.table_per_partition), + active_partition_idx: 0, + offset_in_active_p: 0, + }); + } else { + self.state = EquiJoinState::Done; + } + } + } + + // Finally, check if we are done emitting unmatched keys. + if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state { + if emit_state.active_partition_idx >= emit_state.partitions.len() { + self.state = EquiJoinState::Done; + } + } match &mut self.state { EquiJoinState::Build(_) => { @@ -471,6 +542,15 @@ impl ComputeNode for EquiJoinNode { )); } }, + EquiJoinState::EmitUnmatchedBuild(emit_state) => { + assert!(recv_ports[build_idx].is_none()); + assert!(recv_ports[probe_idx].is_none()); + let send = send_ports[0].take().unwrap().serial(); + join_handles.push(scope.spawn_task( + TaskPriority::Low, + emit_state.emit_unmatched(send, &self.params, self.num_pipelines) + )); + }, EquiJoinState::Done => unreachable!(), } } diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 7ef74d5b0ad9..57ae8119db11 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -214,8 +214,12 @@ fn visualize_plan_rec( left_on, right_on, args, - } => { - let mut label = "in-memory-join".to_string(); + } | PhysNodeKind::EquiJoin { input_left, input_right, left_on, right_on, args } => { + let mut label = if matches!(phys_sm[node_key].kind, PhysNodeKind::EquiJoin { .. }) { + "equi-join".to_string() + } else { + "in-memory-join".to_string() + }; write!(label, r"\nleft_on:\n{}", fmt_exprs(left_on, expr_arena)).unwrap(); write!(label, r"\nright_on:\n{}", fmt_exprs(right_on, expr_arena)).unwrap(); write!( diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index 707c2a53dec2..aa821e6b0a38 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -153,6 +153,14 @@ pub enum PhysNodeKind { key: Vec, aggs: Vec, }, + + EquiJoin { + input_left: PhysNodeKey, + input_right: PhysNodeKey, + left_on: Vec, + right_on: Vec, + args: JoinArgs, + }, /// Generic fallback for (as-of-yet) unsupported streaming joins. /// Fully sinks all data to in-memory data frames and uses the in-memory @@ -213,7 +221,7 @@ fn insert_multiplexers( insert_multiplexers(*input, phys_sm, referenced); }, - PhysNodeKind::InMemoryJoin { + PhysNodeKind::InMemoryJoin { input_left, input_right, .. } | PhysNodeKind::EquiJoin { input_left, input_right, .. diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 66bb1f4180a8..f5dd8d02b94c 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -23,6 +23,7 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind}; use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; use crate::nodes; +use crate::nodes::joins::equi_join::EquiJoinNode; use crate::physical_plan::lower_expr::compute_output_schema; use crate::utils::late_materialized_df::LateMaterializedDataFrame; @@ -503,6 +504,30 @@ fn to_graph_rec<'a>( [left_input_key, right_input_key], ) }, + + EquiJoin { + input_left, + input_right, + left_on, + right_on, + args, + } => { + let args = args.clone(); + let left_input_key = to_graph_rec(*input_left, ctx)?; + let right_input_key = to_graph_rec(*input_right, ctx)?; + let left_input_schema = ctx.phys_sm[*input_left].output_schema.clone(); + let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone(); + + todo!() + // ctx.graph.add_node( + // nodes::joins::equi_join::EquiJoinNode::new( + // left_input_schema, + // right_input_schema, + // args, + // ), + // [left_input_key, right_input_key], + // ) + }, }; ctx.phys_to_graph.insert(phys_node_key, graph_key); diff --git a/crates/polars-utils/src/idx_map/bytes_idx_map.rs b/crates/polars-utils/src/idx_map/bytes_idx_map.rs index 61848af1d8df..c362361e2620 100644 --- a/crates/polars-utils/src/idx_map/bytes_idx_map.rs +++ b/crates/polars-utils/src/idx_map/bytes_idx_map.rs @@ -102,10 +102,18 @@ impl BytesIndexMap { } } + /// Gets the hash, key and value at the given index by insertion order. + #[inline(always)] + pub fn get_index(&self, idx: IdxSize) -> Option<(u64, &[u8], &V)> { + let t = self.tuples.get(idx as usize)?; + Some((t.0.key_hash, unsafe { t.0.get(&self.key_data) }, &t.1)) + } + /// Gets the hash, key and value at the given index by insertion order. /// /// # Safety /// The index must be less than len(). + #[inline(always)] pub unsafe fn get_index_unchecked(&self, idx: IdxSize) -> (u64, &[u8], &V) { let t = self.tuples.get_unchecked(idx as usize); (t.0.key_hash, t.0.get(&self.key_data), &t.1)