Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/logical_plan/extension.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! This module defines the interface for logical nodes
19
use crate::{Expr, LogicalPlan};
20
use datafusion_common::{DFSchema, DFSchemaRef, Result};
21
use std::cmp::Ordering;
22
use std::hash::{Hash, Hasher};
23
use std::{any::Any, collections::HashSet, fmt, sync::Arc};
24
25
/// This defines the interface for [`LogicalPlan`] nodes that can be
26
/// used to extend DataFusion with custom relational operators.
27
///
28
/// The [`UserDefinedLogicalNodeCore`] trait is *the recommended way to implement*
29
/// this trait and avoids having implementing some required boiler plate code.
30
pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
31
    /// Return a reference to self as Any, to support dynamic downcasting
32
    ///
33
    /// Typically this will look like:
34
    ///
35
    /// ```
36
    /// # use std::any::Any;
37
    /// # struct Dummy { }
38
    ///
39
    /// # impl Dummy {
40
    ///   // canonical boiler plate
41
    ///   fn as_any(&self) -> &dyn Any {
42
    ///      self
43
    ///   }
44
    /// # }
45
    /// ```
46
    fn as_any(&self) -> &dyn Any;
47
48
    /// Return the plan's name.
49
    fn name(&self) -> &str;
50
51
    /// Return the logical plan's inputs.
52
    fn inputs(&self) -> Vec<&LogicalPlan>;
53
54
    /// Return the output schema of this logical plan node.
55
    fn schema(&self) -> &DFSchemaRef;
56
57
    /// Returns all expressions in the current logical plan node. This should
58
    /// not include expressions of any inputs (aka non-recursively).
59
    ///
60
    /// These expressions are used for optimizer
61
    /// passes and rewrites. See [`LogicalPlan::expressions`] for more details.
62
    fn expressions(&self) -> Vec<Expr>;
63
64
    /// A list of output columns (e.g. the names of columns in
65
    /// self.schema()) for which predicates can not be pushed below
66
    /// this node without changing the output.
67
    ///
68
    /// By default, this returns all columns and thus prevents any
69
    /// predicates from being pushed below this node.
70
0
    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
71
0
        // default (safe) is all columns in the schema.
72
0
        get_all_columns_from_schema(self.schema())
73
0
    }
74
75
    /// Write a single line, human readable string to `f` for use in explain plan.
76
    ///
77
    /// For example: `TopK: k=10`
78
    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
79
80
    #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")]
81
    #[allow(clippy::wrong_self_convention)]
82
0
    fn from_template(
83
0
        &self,
84
0
        exprs: &[Expr],
85
0
        inputs: &[LogicalPlan],
86
0
    ) -> Arc<dyn UserDefinedLogicalNode> {
87
0
        self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec())
88
0
            .unwrap()
89
0
    }
90
91
    /// Create a new `UserDefinedLogicalNode` with the specified children
92
    /// and expressions. This function is used during optimization
93
    /// when the plan is being rewritten and a new instance of the
94
    /// `UserDefinedLogicalNode` must be created.
95
    ///
96
    /// Note that exprs and inputs are in the same order as the result
97
    /// of self.inputs and self.exprs.
98
    ///
99
    /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
100
    fn with_exprs_and_inputs(
101
        &self,
102
        exprs: Vec<Expr>,
103
        inputs: Vec<LogicalPlan>,
104
    ) -> Result<Arc<dyn UserDefinedLogicalNode>>;
105
106
    /// Returns the necessary input columns for this node required to compute
107
    /// the columns in the output schema
108
    ///
109
    /// This is used for projection push-down when DataFusion has determined that
110
    /// only a subset of the output columns of this node are needed by its parents.
111
    /// This API is used to tell DataFusion which, if any, of the input columns are no longer
112
    /// needed.
113
    ///
114
    /// Return `None`, the default, if this information can not be determined.
115
    /// Returns `Some(_)` with the column indices for each child of this node that are
116
    /// needed to compute `output_columns`
117
0
    fn necessary_children_exprs(
118
0
        &self,
119
0
        _output_columns: &[usize],
120
0
    ) -> Option<Vec<Vec<usize>>> {
121
0
        None
122
0
    }
123
124
    /// Update the hash `state` with this node requirements from
125
    /// [`Hash`].
126
    ///
127
    /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
128
    /// [`UserDefinedLogicalNode`] directly.
129
    ///
130
    /// This method is required to support hashing [`LogicalPlan`]s.  To
131
    /// implement it, typically the type implementing
132
    /// [`UserDefinedLogicalNode`] typically implements [`Hash`] and
133
    /// then the following boiler plate is used:
134
    ///
135
    /// # Example:
136
    /// ```
137
    /// // User defined node that derives Hash
138
    /// #[derive(Hash, Debug, PartialEq, Eq)]
139
    /// struct MyNode {
140
    ///   val: u64
141
    /// }
142
    ///
143
    /// // impl UserDefinedLogicalNode {
144
    /// // ...
145
    /// # impl MyNode {
146
    ///   // Boiler plate to call the derived Hash impl
147
    ///   fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
148
    ///     use std::hash::Hash;
149
    ///     let mut s = state;
150
    ///     self.hash(&mut s);
151
    ///   }
152
    /// // }
153
    /// # }
154
    /// ```
155
    /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Hash`]
156
    /// directly because it must remain object safe.
157
    fn dyn_hash(&self, state: &mut dyn Hasher);
158
159
    /// Compare `other`, respecting requirements from [std::cmp::Eq].
160
    ///
161
    /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
162
    /// [`UserDefinedLogicalNode`] directly.
163
    ///
164
    /// When `other` has an another type than `self`, then the values
165
    /// are *not* equal.
166
    ///
167
    /// This method is required to support Eq on [`LogicalPlan`]s.  To
168
    /// implement it, typically the type implementing
169
    /// [`UserDefinedLogicalNode`] typically implements [`Eq`] and
170
    /// then the following boiler plate is used:
171
    ///
172
    /// # Example:
173
    /// ```
174
    /// # use datafusion_expr::UserDefinedLogicalNode;
175
    /// // User defined node that derives Eq
176
    /// #[derive(Hash, Debug, PartialEq, Eq)]
177
    /// struct MyNode {
178
    ///   val: u64
179
    /// }
180
    ///
181
    /// // impl UserDefinedLogicalNode {
182
    /// // ...
183
    /// # impl MyNode {
184
    ///   // Boiler plate to call the derived Eq impl
185
    ///   fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
186
    ///     match other.as_any().downcast_ref::<Self>() {
187
    ///       Some(o) => self == o,
188
    ///       None => false,
189
    ///     }
190
    ///   }
191
    /// // }
192
    /// # }
193
    /// ```
194
    /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Eq`]
195
    /// directly because it must remain object safe.
196
    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool;
197
    fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering>;
198
199
    /// Returns `true` if a limit can be safely pushed down through this
200
    /// `UserDefinedLogicalNode` node.
201
    ///
202
    /// If this method returns `true`, and the query plan contains a limit at
203
    /// the output of this node, DataFusion will push the limit to the input
204
    /// of this node.
205
0
    fn supports_limit_pushdown(&self) -> bool {
206
0
        false
207
0
    }
208
}
209
210
impl Hash for dyn UserDefinedLogicalNode {
211
0
    fn hash<H: Hasher>(&self, state: &mut H) {
212
0
        self.dyn_hash(state);
213
0
    }
214
}
215
216
impl PartialEq for dyn UserDefinedLogicalNode {
217
0
    fn eq(&self, other: &Self) -> bool {
218
0
        self.dyn_eq(other)
219
0
    }
220
}
221
222
impl PartialOrd for dyn UserDefinedLogicalNode {
223
0
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
224
0
        self.dyn_ord(other)
225
0
    }
226
}
227
228
impl Eq for dyn UserDefinedLogicalNode {}
229
230
/// This trait facilitates implementation of the [`UserDefinedLogicalNode`].
231
///
232
/// See the example in
233
/// [user_defined_plan.rs](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/user_defined/user_defined_plan.rs)
234
/// file for an example of how to use this extension API.
235
pub trait UserDefinedLogicalNodeCore:
236
    fmt::Debug + Eq + PartialOrd + Hash + Sized + Send + Sync + 'static
237
{
238
    /// Return the plan's name.
239
    fn name(&self) -> &str;
240
241
    /// Return the logical plan's inputs.
242
    fn inputs(&self) -> Vec<&LogicalPlan>;
243
244
    /// Return the output schema of this logical plan node.
245
    fn schema(&self) -> &DFSchemaRef;
246
247
    /// Returns all expressions in the current logical plan node. This
248
    /// should not include expressions of any inputs (aka
249
    /// non-recursively). These expressions are used for optimizer
250
    /// passes and rewrites.
251
    fn expressions(&self) -> Vec<Expr>;
252
253
    /// A list of output columns (e.g. the names of columns in
254
    /// self.schema()) for which predicates can not be pushed below
255
    /// this node without changing the output.
256
    ///
257
    /// By default, this returns all columns and thus prevents any
258
    /// predicates from being pushed below this node.
259
0
    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
260
0
        // default (safe) is all columns in the schema.
261
0
        get_all_columns_from_schema(self.schema())
262
0
    }
263
264
    /// Write a single line, human readable string to `f` for use in explain plan.
265
    ///
266
    /// For example: `TopK: k=10`
267
    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
268
269
    #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")]
270
    #[allow(clippy::wrong_self_convention)]
271
0
    fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
272
0
        self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec())
273
0
            .unwrap()
274
0
    }
275
276
    /// Create a new `UserDefinedLogicalNode` with the specified children
277
    /// and expressions. This function is used during optimization
278
    /// when the plan is being rewritten and a new instance of the
279
    /// `UserDefinedLogicalNode` must be created.
280
    ///
281
    /// Note that exprs and inputs are in the same order as the result
282
    /// of self.inputs and self.exprs.
283
    ///
284
    /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
285
    fn with_exprs_and_inputs(
286
        &self,
287
        exprs: Vec<Expr>,
288
        inputs: Vec<LogicalPlan>,
289
    ) -> Result<Self>;
290
291
    /// Returns the necessary input columns for this node required to compute
292
    /// the columns in the output schema
293
    ///
294
    /// This is used for projection push-down when DataFusion has determined that
295
    /// only a subset of the output columns of this node are needed by its parents.
296
    /// This API is used to tell DataFusion which, if any, of the input columns are no longer
297
    /// needed.
298
    ///
299
    /// Return `None`, the default, if this information can not be determined.
300
    /// Returns `Some(_)` with the column indices for each child of this node that are
301
    /// needed to compute `output_columns`
302
0
    fn necessary_children_exprs(
303
0
        &self,
304
0
        _output_columns: &[usize],
305
0
    ) -> Option<Vec<Vec<usize>>> {
306
0
        None
307
0
    }
308
309
    /// Returns `true` if a limit can be safely pushed down through this
310
    /// `UserDefinedLogicalNode` node.
311
    ///
312
    /// If this method returns `true`, and the query plan contains a limit at
313
    /// the output of this node, DataFusion will push the limit to the input
314
    /// of this node.
315
0
    fn supports_limit_pushdown(&self) -> bool {
316
0
        false // Disallow limit push-down by default
317
0
    }
318
}
319
320
/// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode`
321
/// to avoid boiler plate for implementing `as_any`, `Hash` and `PartialEq`
322
impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
323
0
    fn as_any(&self) -> &dyn Any {
324
0
        self
325
0
    }
326
327
0
    fn name(&self) -> &str {
328
0
        self.name()
329
0
    }
330
331
0
    fn inputs(&self) -> Vec<&LogicalPlan> {
332
0
        self.inputs()
333
0
    }
334
335
0
    fn schema(&self) -> &DFSchemaRef {
336
0
        self.schema()
337
0
    }
338
339
0
    fn expressions(&self) -> Vec<Expr> {
340
0
        self.expressions()
341
0
    }
342
343
0
    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
344
0
        self.prevent_predicate_push_down_columns()
345
0
    }
346
347
0
    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
348
0
        self.fmt_for_explain(f)
349
0
    }
350
351
0
    fn with_exprs_and_inputs(
352
0
        &self,
353
0
        exprs: Vec<Expr>,
354
0
        inputs: Vec<LogicalPlan>,
355
0
    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
356
0
        Ok(Arc::new(self.with_exprs_and_inputs(exprs, inputs)?))
357
0
    }
358
359
0
    fn necessary_children_exprs(
360
0
        &self,
361
0
        output_columns: &[usize],
362
0
    ) -> Option<Vec<Vec<usize>>> {
363
0
        self.necessary_children_exprs(output_columns)
364
0
    }
365
366
0
    fn dyn_hash(&self, state: &mut dyn Hasher) {
367
0
        let mut s = state;
368
0
        self.hash(&mut s);
369
0
    }
370
371
0
    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
372
0
        match other.as_any().downcast_ref::<Self>() {
373
0
            Some(o) => self == o,
374
0
            None => false,
375
        }
376
0
    }
377
378
0
    fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering> {
379
0
        other
380
0
            .as_any()
381
0
            .downcast_ref::<Self>()
382
0
            .and_then(|other| self.partial_cmp(other))
383
0
    }
384
385
0
    fn supports_limit_pushdown(&self) -> bool {
386
0
        self.supports_limit_pushdown()
387
0
    }
388
}
389
390
0
fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet<String> {
391
0
    schema.fields().iter().map(|f| f.name().clone()).collect()
392
0
}