Skip to content

Commit

Permalink
diff
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Zhan <[email protected]>
  • Loading branch information
jayzhan-synnada committed Dec 12, 2024
1 parent 437cbf8 commit fda9e7d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 38 deletions.
6 changes: 3 additions & 3 deletions datafusion/core/src/physical_optimizer/sanity_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ mod tests {

let test2 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Unbounded),
expect_fail: true,
expect_fail: false,
};
let test3 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Bounded),
Expand Down Expand Up @@ -290,7 +290,7 @@ mod tests {
};
let test2 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Unbounded),
expect_fail: true,
expect_fail: false,
};
let test3 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Bounded),
Expand Down Expand Up @@ -668,4 +668,4 @@ mod tests {
assert_sanity_check(&smj, false);
Ok(())
}
}
}
89 changes: 54 additions & 35 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,27 @@ use datafusion_physical_expr_common::datum::compare_op_for_nested;
use futures::{ready, Stream, StreamExt, TryStreamExt};
use parking_lot::Mutex;

// THESE IMPORTS ARE ARAS ONLY
use super::utils::{build_join_watermark_schema, generate_join_watermark};
use crate::watermark::is_record_batch_a_watermark;
use crate::{
CheckpointCommon, CheckpointMode, CheckpointingState, RecoveryMode, WatermarkMode,
};

use datafusion_common::{exec_err, plan_datafusion_err};
use datafusion_execution::state::checkpoint_client::{
generate_checkpoint_message, parse_checkpoint_message,
};
use datafusion_execution::state::CheckpointLevel;
use datafusion_state_proto::state_protobuf;

use futures::FutureExt;
use prost::Message;

type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;

/// THIS STRUCT IS COMMON, MODIFIED BY ARAS
///
/// HashTable and input data for the left (build side) of a join
struct JoinLeftData {
/// The hash table with indices into `batch`
Expand All @@ -90,9 +109,11 @@ struct JoinLeftData {
/// Counter of running probe-threads, potentially
/// able to update `visited_indices_bitmap`
probe_threads_counter: AtomicUsize,
/// Memory reservation that tracks memory used by `hash_map` hash table
/// `batch`. Cleared on drop.
_reservation: MemoryReservation,
/// THIS MEMBER IS ARAS ONLY
///
/// Hash values stored in the hash table. Outer vector runs over batches,
/// inner vector runs over rows.
batches_hash_values: Vec<Vec<u64>>,
}

impl JoinLeftData {
Expand All @@ -102,14 +123,14 @@ impl JoinLeftData {
batch: RecordBatch,
visited_indices_bitmap: SharedBitmapBuilder,
probe_threads_counter: AtomicUsize,
reservation: MemoryReservation,
batches_hash_values: Vec<Vec<u64>>,
) -> Self {
Self {
hash_map,
batch,
visited_indices_bitmap,
probe_threads_counter,
_reservation: reservation,
batches_hash_values,
}
}

Expand Down Expand Up @@ -525,18 +546,10 @@ impl HashJoinExec {
};

// Determine execution mode by checking whether this join is pipeline
// breaking. This happens when the left side is unbounded, or the right
// side is unbounded with `Left`, `Full`, `LeftAnti` or `LeftSemi` join types.
let pipeline_breaking = left.execution_mode().is_unbounded()
|| (right.execution_mode().is_unbounded()
&& matches!(
join_type,
JoinType::Left
| JoinType::Full
| JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::LeftMark
));
// breaking, which happens when the left side is unbounded. If left side
// is bounded, whatever the right side is, we can generate the streaming
// result for all kinds of join types.
let pipeline_breaking = left.execution_mode().is_unbounded();

let mode = if pipeline_breaking {
ExecutionMode::PipelineBreaking
Expand Down Expand Up @@ -866,6 +879,8 @@ async fn collect_left_input(

// Updating hashmap starting from the last batch
let batches_iter = batches.iter().rev();
let mut batches_hash_values = Vec::with_capacity(batches.len());

for batch in batches_iter.clone() {
hashes_buffer.clear();
hashes_buffer.resize(batch.num_rows(), 0);
Expand All @@ -879,6 +894,7 @@ async fn collect_left_input(
0,
true,
)?;
batches_hash_values.push(hashes_buffer.clone());
offset += batch.num_rows();
}
// Merge all batches into a single batch, so we can directly index into the arrays
Expand All @@ -902,7 +918,7 @@ async fn collect_left_input(
single_batch,
Mutex::new(visited_indices_bitmap),
AtomicUsize::new(probe_threads_count),
reservation,
batches_hash_values,
);

Ok(data)
Expand Down Expand Up @@ -1043,7 +1059,10 @@ impl HashJoinStreamState {
}
}

/// THIS STRUCT IS COMMON, MODIFIED BY ARAS
///
/// Container for HashJoinStreamState::ProcessProbeBatch related data
#[derive(Debug, Clone)]
struct ProcessProbeBatchState {
/// Current probe-side batch
batch: RecordBatch,
Expand Down Expand Up @@ -1333,25 +1352,25 @@ impl HashJoinStream {
}
Some(Ok(batch)) => {
// Precalculate hash values for fetched batch
let keys_values = self
.on_right
.iter()
.map(|c| c.evaluate(&batch)?.into_array(batch.num_rows()))
.collect::<Result<Vec<_>>>()?;

self.hashes_buffer.clear();
self.hashes_buffer.resize(batch.num_rows(), 0);
let keys_values = self
.on_right
.iter()
.map(|c| c.evaluate(&batch)?.into_array(batch.num_rows()))
.collect::<Result<Vec<_>>>()?;

self.hashes_buffer.clear();
self.hashes_buffer.resize(batch.num_rows(), 0);
create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?;

self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());

self.state =
HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
batch,
offset: (0, None),
joined_probe_idx: None,
});
self.state =
HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
batch,
offset: (0, None),
joined_probe_idx: None,
});
}
Some(Err(err)) => return Poll::Ready(Err(err)),
};
Expand Down Expand Up @@ -4083,4 +4102,4 @@ mod tests {
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
}
}
}

0 comments on commit fda9e7d

Please sign in to comment.