Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/recursive_query.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines the recursive query plan
19
20
use std::any::Any;
21
use std::sync::Arc;
22
use std::task::{Context, Poll};
23
24
use super::{
25
    metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
26
    work_table::{ReservedBatches, WorkTable, WorkTableExec},
27
    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
28
};
29
use crate::{DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan};
30
31
use arrow::datatypes::SchemaRef;
32
use arrow::record_batch::RecordBatch;
33
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
34
use datafusion_common::{not_impl_err, DataFusionError, Result};
35
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
36
use datafusion_execution::TaskContext;
37
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
38
39
use futures::{ready, Stream, StreamExt};
40
41
/// Recursive query execution plan.
42
///
43
/// This plan has two components: a base part (the static term) and
44
/// a dynamic part (the recursive term). The execution will start from
45
/// the base, and as long as the previous iteration produced at least
46
/// a single new row (taking care of the distinction) the recursive
47
/// part will be continuously executed.
48
///
49
/// Before each execution of the dynamic part, the rows from the previous
50
/// iteration will be available in a "working table" (not a real table,
51
/// can be only accessed using a continuance operation).
52
///
53
/// Note that there won't be any limit or checks applied to detect
54
/// an infinite recursion, so it is up to the planner to ensure that
55
/// it won't happen.
56
#[derive(Debug)]
57
pub struct RecursiveQueryExec {
58
    /// Name of the query handler
59
    name: String,
60
    /// The working table of cte
61
    work_table: Arc<WorkTable>,
62
    /// The base part (static term)
63
    static_term: Arc<dyn ExecutionPlan>,
64
    /// The dynamic part (recursive term)
65
    recursive_term: Arc<dyn ExecutionPlan>,
66
    /// Distinction
67
    is_distinct: bool,
68
    /// Execution metrics
69
    metrics: ExecutionPlanMetricsSet,
70
    /// Cache holding plan properties like equivalences, output partitioning etc.
71
    cache: PlanProperties,
72
}
73
74
impl RecursiveQueryExec {
75
    /// Create a new RecursiveQueryExec
76
0
    pub fn try_new(
77
0
        name: String,
78
0
        static_term: Arc<dyn ExecutionPlan>,
79
0
        recursive_term: Arc<dyn ExecutionPlan>,
80
0
        is_distinct: bool,
81
0
    ) -> Result<Self> {
82
0
        // Each recursive query needs its own work table
83
0
        let work_table = Arc::new(WorkTable::new());
84
        // Use the same work table for both the WorkTableExec and the recursive term
85
0
        let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
86
0
        let cache = Self::compute_properties(static_term.schema());
87
0
        Ok(RecursiveQueryExec {
88
0
            name,
89
0
            static_term,
90
0
            recursive_term,
91
0
            is_distinct,
92
0
            work_table,
93
0
            metrics: ExecutionPlanMetricsSet::new(),
94
0
            cache,
95
0
        })
96
0
    }
97
98
    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
99
0
    fn compute_properties(schema: SchemaRef) -> PlanProperties {
100
0
        let eq_properties = EquivalenceProperties::new(schema);
101
0
102
0
        PlanProperties::new(
103
0
            eq_properties,
104
0
            Partitioning::UnknownPartitioning(1),
105
0
            ExecutionMode::Bounded,
106
0
        )
107
0
    }
108
}
109
110
impl ExecutionPlan for RecursiveQueryExec {
111
0
    fn name(&self) -> &'static str {
112
0
        "RecursiveQueryExec"
113
0
    }
114
115
0
    fn as_any(&self) -> &dyn Any {
116
0
        self
117
0
    }
118
119
0
    fn properties(&self) -> &PlanProperties {
120
0
        &self.cache
121
0
    }
122
123
0
    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
124
0
        vec![&self.static_term, &self.recursive_term]
125
0
    }
126
127
    // TODO: control these hints and see whether we can
128
    // infer some from the child plans (static/recurisve terms).
129
0
    fn maintains_input_order(&self) -> Vec<bool> {
130
0
        vec![false, false]
131
0
    }
132
133
0
    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
134
0
        vec![false, false]
135
0
    }
136
137
0
    fn required_input_distribution(&self) -> Vec<datafusion_physical_expr::Distribution> {
138
0
        vec![
139
0
            datafusion_physical_expr::Distribution::SinglePartition,
140
0
            datafusion_physical_expr::Distribution::SinglePartition,
141
0
        ]
142
0
    }
143
144
0
    fn with_new_children(
145
0
        self: Arc<Self>,
146
0
        children: Vec<Arc<dyn ExecutionPlan>>,
147
0
    ) -> Result<Arc<dyn ExecutionPlan>> {
148
0
        RecursiveQueryExec::try_new(
149
0
            self.name.clone(),
150
0
            Arc::clone(&children[0]),
151
0
            Arc::clone(&children[1]),
152
0
            self.is_distinct,
153
0
        )
154
0
        .map(|e| Arc::new(e) as _)
155
0
    }
156
157
0
    fn execute(
158
0
        &self,
159
0
        partition: usize,
160
0
        context: Arc<TaskContext>,
161
0
    ) -> Result<SendableRecordBatchStream> {
162
0
        // TODO: we might be able to handle multiple partitions in the future.
163
0
        if partition != 0 {
164
0
            return Err(DataFusionError::Internal(format!(
165
0
                "RecursiveQueryExec got an invalid partition {} (expected 0)",
166
0
                partition
167
0
            )));
168
0
        }
169
170
0
        let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
171
0
        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
172
0
        Ok(Box::pin(RecursiveQueryStream::new(
173
0
            context,
174
0
            Arc::clone(&self.work_table),
175
0
            Arc::clone(&self.recursive_term),
176
0
            static_stream,
177
0
            baseline_metrics,
178
0
        )))
179
0
    }
180
181
0
    fn metrics(&self) -> Option<MetricsSet> {
182
0
        Some(self.metrics.clone_inner())
183
0
    }
184
185
0
    fn statistics(&self) -> Result<Statistics> {
186
0
        Ok(Statistics::new_unknown(&self.schema()))
187
0
    }
188
}
189
190
impl DisplayAs for RecursiveQueryExec {
191
0
    fn fmt_as(
192
0
        &self,
193
0
        t: DisplayFormatType,
194
0
        f: &mut std::fmt::Formatter,
195
0
    ) -> std::fmt::Result {
196
0
        match t {
197
            DisplayFormatType::Default | DisplayFormatType::Verbose => {
198
0
                write!(
199
0
                    f,
200
0
                    "RecursiveQueryExec: name={}, is_distinct={}",
201
0
                    self.name, self.is_distinct
202
0
                )
203
0
            }
204
0
        }
205
0
    }
206
}
207
208
/// The actual logic of the recursive queries happens during the streaming
209
/// process. A simplified version of the algorithm is the following:
210
///
211
/// buffer = []
212
///
213
/// while batch := static_stream.next():
214
///    buffer.push(batch)
215
///    yield buffer
216
///
217
/// while buffer.len() > 0:
218
///    sender, receiver = Channel()
219
///    register_continuation(handle_name, receiver)
220
///    sender.send(buffer.drain())
221
///    recursive_stream = recursive_term.execute()
222
///    while batch := recursive_stream.next():
223
///        buffer.append(batch)
224
///        yield buffer
225
///
226
struct RecursiveQueryStream {
227
    /// The context to be used for managing handlers & executing new tasks
228
    task_context: Arc<TaskContext>,
229
    /// The working table state, representing the self referencing cte table
230
    work_table: Arc<WorkTable>,
231
    /// The dynamic part (recursive term) as is (without being executed)
232
    recursive_term: Arc<dyn ExecutionPlan>,
233
    /// The static part (static term) as a stream. If the processing of this
234
    /// part is completed, then it will be None.
235
    static_stream: Option<SendableRecordBatchStream>,
236
    /// The dynamic part (recursive term) as a stream. If the processing of this
237
    /// part has not started yet, or has been completed, then it will be None.
238
    recursive_stream: Option<SendableRecordBatchStream>,
239
    /// The schema of the output.
240
    schema: SchemaRef,
241
    /// In-memory buffer for storing a copy of the current results. Will be
242
    /// cleared after each iteration.
243
    buffer: Vec<RecordBatch>,
244
    /// Tracks the memory used by the buffer
245
    reservation: MemoryReservation,
246
    // /// Metrics.
247
    _baseline_metrics: BaselineMetrics,
248
}
249
250
impl RecursiveQueryStream {
251
    /// Create a new recursive query stream
252
0
    fn new(
253
0
        task_context: Arc<TaskContext>,
254
0
        work_table: Arc<WorkTable>,
255
0
        recursive_term: Arc<dyn ExecutionPlan>,
256
0
        static_stream: SendableRecordBatchStream,
257
0
        baseline_metrics: BaselineMetrics,
258
0
    ) -> Self {
259
0
        let schema = static_stream.schema();
260
0
        let reservation =
261
0
            MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
262
0
        Self {
263
0
            task_context,
264
0
            work_table,
265
0
            recursive_term,
266
0
            static_stream: Some(static_stream),
267
0
            recursive_stream: None,
268
0
            schema,
269
0
            buffer: vec![],
270
0
            reservation,
271
0
            _baseline_metrics: baseline_metrics,
272
0
        }
273
0
    }
274
275
    /// Push a clone of the given batch to the in memory buffer, and then return
276
    /// a poll with it.
277
0
    fn push_batch(
278
0
        mut self: std::pin::Pin<&mut Self>,
279
0
        batch: RecordBatch,
280
0
    ) -> Poll<Option<Result<RecordBatch>>> {
281
0
        if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
282
0
            return Poll::Ready(Some(Err(e)));
283
0
        }
284
0
285
0
        self.buffer.push(batch.clone());
286
0
        Poll::Ready(Some(Ok(batch)))
287
0
    }
288
289
    /// Start polling for the next iteration, will be called either after the static term
290
    /// is completed or another term is completed. It will follow the algorithm above on
291
    /// to check whether the recursion has ended.
292
0
    fn poll_next_iteration(
293
0
        mut self: std::pin::Pin<&mut Self>,
294
0
        cx: &mut Context<'_>,
295
0
    ) -> Poll<Option<Result<RecordBatch>>> {
296
0
        let total_length = self
297
0
            .buffer
298
0
            .iter()
299
0
            .fold(0, |acc, batch| acc + batch.num_rows());
300
0
301
0
        if total_length == 0 {
302
0
            return Poll::Ready(None);
303
0
        }
304
0
305
0
        // Update the work table with the current buffer
306
0
        let reserved_batches = ReservedBatches::new(
307
0
            std::mem::take(&mut self.buffer),
308
0
            self.reservation.take(),
309
0
        );
310
0
        self.work_table.update(reserved_batches);
311
0
312
0
        // We always execute (and re-execute iteratively) the first partition.
313
0
        // Downstream plans should not expect any partitioning.
314
0
        let partition = 0;
315
316
0
        let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
317
0
        self.recursive_stream =
318
0
            Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
319
0
        self.poll_next(cx)
320
0
    }
321
}
322
323
0
fn assign_work_table(
324
0
    plan: Arc<dyn ExecutionPlan>,
325
0
    work_table: Arc<WorkTable>,
326
0
) -> Result<Arc<dyn ExecutionPlan>> {
327
0
    let mut work_table_refs = 0;
328
0
    plan.transform_down(|plan| {
329
0
        if let Some(exec) = plan.as_any().downcast_ref::<WorkTableExec>() {
330
0
            if work_table_refs > 0 {
331
0
                not_impl_err!(
332
0
                    "Multiple recursive references to the same CTE are not supported"
333
0
                )
334
            } else {
335
0
                work_table_refs += 1;
336
0
                Ok(Transformed::yes(Arc::new(
337
0
                    exec.with_work_table(Arc::clone(&work_table)),
338
0
                )))
339
            }
340
0
        } else if plan.as_any().is::<RecursiveQueryExec>() {
341
0
            not_impl_err!("Recursive queries cannot be nested")
342
        } else {
343
0
            Ok(Transformed::no(plan))
344
        }
345
0
    })
346
0
    .data()
347
0
}
348
349
/// Some plans will change their internal states after execution, making them unable to be executed again.
350
/// This function uses `ExecutionPlan::with_new_children` to fork a new plan with initial states.
351
///
352
/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan.
353
/// However, if the data of the left table is derived from the work table, it will become outdated
354
/// as the work table changes. When the next iteration executes this plan again, we must clear the left table.
355
0
fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
356
0
    plan.transform_up(|plan| {
357
0
        // WorkTableExec's states have already been updated correctly.
358
0
        if plan.as_any().is::<WorkTableExec>() {
359
0
            Ok(Transformed::no(plan))
360
        } else {
361
0
            let new_plan = Arc::clone(&plan)
362
0
                .with_new_children(plan.children().into_iter().cloned().collect())?;
363
0
            Ok(Transformed::yes(new_plan))
364
        }
365
0
    })
366
0
    .data()
367
0
}
368
369
impl Stream for RecursiveQueryStream {
370
    type Item = Result<RecordBatch>;
371
372
0
    fn poll_next(
373
0
        mut self: std::pin::Pin<&mut Self>,
374
0
        cx: &mut Context<'_>,
375
0
    ) -> Poll<Option<Self::Item>> {
376
        // TODO: we should use this poll to record some metrics!
377
0
        if let Some(static_stream) = &mut self.static_stream {
378
            // While the static term's stream is available, we'll be forwarding the batches from it (also
379
            // saving them for the initial iteration of the recursive term).
380
0
            let batch_result = ready!(static_stream.poll_next_unpin(cx));
381
0
            match &batch_result {
382
                None => {
383
                    // Once this is done, we can start running the setup for the recursive term.
384
0
                    self.static_stream = None;
385
0
                    self.poll_next_iteration(cx)
386
                }
387
0
                Some(Ok(batch)) => self.push_batch(batch.clone()),
388
0
                _ => Poll::Ready(batch_result),
389
            }
390
0
        } else if let Some(recursive_stream) = &mut self.recursive_stream {
391
0
            let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
392
0
            match batch_result {
393
                None => {
394
0
                    self.recursive_stream = None;
395
0
                    self.poll_next_iteration(cx)
396
                }
397
0
                Some(Ok(batch)) => self.push_batch(batch),
398
0
                _ => Poll::Ready(batch_result),
399
            }
400
        } else {
401
0
            Poll::Ready(None)
402
        }
403
0
    }
404
}
405
406
impl RecordBatchStream for RecursiveQueryStream {
407
    /// Get the schema
408
0
    fn schema(&self) -> SchemaRef {
409
0
        Arc::clone(&self.schema)
410
0
    }
411
}
412
413
#[cfg(test)]
414
mod tests {}