/Users/andrewlamb/Software/datafusion/datafusion/expr/src/logical_plan/extension.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! This module defines the interface for logical nodes |
19 | | use crate::{Expr, LogicalPlan}; |
20 | | use datafusion_common::{DFSchema, DFSchemaRef, Result}; |
21 | | use std::cmp::Ordering; |
22 | | use std::hash::{Hash, Hasher}; |
23 | | use std::{any::Any, collections::HashSet, fmt, sync::Arc}; |
24 | | |
25 | | /// This defines the interface for [`LogicalPlan`] nodes that can be |
26 | | /// used to extend DataFusion with custom relational operators. |
27 | | /// |
28 | | /// The [`UserDefinedLogicalNodeCore`] trait is *the recommended way to implement* |
29 | | /// this trait and avoids having implementing some required boiler plate code. |
30 | | pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { |
31 | | /// Return a reference to self as Any, to support dynamic downcasting |
32 | | /// |
33 | | /// Typically this will look like: |
34 | | /// |
35 | | /// ``` |
36 | | /// # use std::any::Any; |
37 | | /// # struct Dummy { } |
38 | | /// |
39 | | /// # impl Dummy { |
40 | | /// // canonical boiler plate |
41 | | /// fn as_any(&self) -> &dyn Any { |
42 | | /// self |
43 | | /// } |
44 | | /// # } |
45 | | /// ``` |
46 | | fn as_any(&self) -> &dyn Any; |
47 | | |
48 | | /// Return the plan's name. |
49 | | fn name(&self) -> &str; |
50 | | |
51 | | /// Return the logical plan's inputs. |
52 | | fn inputs(&self) -> Vec<&LogicalPlan>; |
53 | | |
54 | | /// Return the output schema of this logical plan node. |
55 | | fn schema(&self) -> &DFSchemaRef; |
56 | | |
57 | | /// Returns all expressions in the current logical plan node. This should |
58 | | /// not include expressions of any inputs (aka non-recursively). |
59 | | /// |
60 | | /// These expressions are used for optimizer |
61 | | /// passes and rewrites. See [`LogicalPlan::expressions`] for more details. |
62 | | fn expressions(&self) -> Vec<Expr>; |
63 | | |
64 | | /// A list of output columns (e.g. the names of columns in |
65 | | /// self.schema()) for which predicates can not be pushed below |
66 | | /// this node without changing the output. |
67 | | /// |
68 | | /// By default, this returns all columns and thus prevents any |
69 | | /// predicates from being pushed below this node. |
70 | 0 | fn prevent_predicate_push_down_columns(&self) -> HashSet<String> { |
71 | 0 | // default (safe) is all columns in the schema. |
72 | 0 | get_all_columns_from_schema(self.schema()) |
73 | 0 | } |
74 | | |
75 | | /// Write a single line, human readable string to `f` for use in explain plan. |
76 | | /// |
77 | | /// For example: `TopK: k=10` |
78 | | fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result; |
79 | | |
80 | | #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")] |
81 | | #[allow(clippy::wrong_self_convention)] |
82 | 0 | fn from_template( |
83 | 0 | &self, |
84 | 0 | exprs: &[Expr], |
85 | 0 | inputs: &[LogicalPlan], |
86 | 0 | ) -> Arc<dyn UserDefinedLogicalNode> { |
87 | 0 | self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) |
88 | 0 | .unwrap() |
89 | 0 | } |
90 | | |
91 | | /// Create a new `UserDefinedLogicalNode` with the specified children |
92 | | /// and expressions. This function is used during optimization |
93 | | /// when the plan is being rewritten and a new instance of the |
94 | | /// `UserDefinedLogicalNode` must be created. |
95 | | /// |
96 | | /// Note that exprs and inputs are in the same order as the result |
97 | | /// of self.inputs and self.exprs. |
98 | | /// |
99 | | /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs |
100 | | fn with_exprs_and_inputs( |
101 | | &self, |
102 | | exprs: Vec<Expr>, |
103 | | inputs: Vec<LogicalPlan>, |
104 | | ) -> Result<Arc<dyn UserDefinedLogicalNode>>; |
105 | | |
106 | | /// Returns the necessary input columns for this node required to compute |
107 | | /// the columns in the output schema |
108 | | /// |
109 | | /// This is used for projection push-down when DataFusion has determined that |
110 | | /// only a subset of the output columns of this node are needed by its parents. |
111 | | /// This API is used to tell DataFusion which, if any, of the input columns are no longer |
112 | | /// needed. |
113 | | /// |
114 | | /// Return `None`, the default, if this information can not be determined. |
115 | | /// Returns `Some(_)` with the column indices for each child of this node that are |
116 | | /// needed to compute `output_columns` |
117 | 0 | fn necessary_children_exprs( |
118 | 0 | &self, |
119 | 0 | _output_columns: &[usize], |
120 | 0 | ) -> Option<Vec<Vec<usize>>> { |
121 | 0 | None |
122 | 0 | } |
123 | | |
124 | | /// Update the hash `state` with this node requirements from |
125 | | /// [`Hash`]. |
126 | | /// |
127 | | /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of |
128 | | /// [`UserDefinedLogicalNode`] directly. |
129 | | /// |
130 | | /// This method is required to support hashing [`LogicalPlan`]s. To |
131 | | /// implement it, typically the type implementing |
132 | | /// [`UserDefinedLogicalNode`] typically implements [`Hash`] and |
133 | | /// then the following boiler plate is used: |
134 | | /// |
135 | | /// # Example: |
136 | | /// ``` |
137 | | /// // User defined node that derives Hash |
138 | | /// #[derive(Hash, Debug, PartialEq, Eq)] |
139 | | /// struct MyNode { |
140 | | /// val: u64 |
141 | | /// } |
142 | | /// |
143 | | /// // impl UserDefinedLogicalNode { |
144 | | /// // ... |
145 | | /// # impl MyNode { |
146 | | /// // Boiler plate to call the derived Hash impl |
147 | | /// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { |
148 | | /// use std::hash::Hash; |
149 | | /// let mut s = state; |
150 | | /// self.hash(&mut s); |
151 | | /// } |
152 | | /// // } |
153 | | /// # } |
154 | | /// ``` |
155 | | /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Hash`] |
156 | | /// directly because it must remain object safe. |
157 | | fn dyn_hash(&self, state: &mut dyn Hasher); |
158 | | |
159 | | /// Compare `other`, respecting requirements from [std::cmp::Eq]. |
160 | | /// |
161 | | /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of |
162 | | /// [`UserDefinedLogicalNode`] directly. |
163 | | /// |
164 | | /// When `other` has an another type than `self`, then the values |
165 | | /// are *not* equal. |
166 | | /// |
167 | | /// This method is required to support Eq on [`LogicalPlan`]s. To |
168 | | /// implement it, typically the type implementing |
169 | | /// [`UserDefinedLogicalNode`] typically implements [`Eq`] and |
170 | | /// then the following boiler plate is used: |
171 | | /// |
172 | | /// # Example: |
173 | | /// ``` |
174 | | /// # use datafusion_expr::UserDefinedLogicalNode; |
175 | | /// // User defined node that derives Eq |
176 | | /// #[derive(Hash, Debug, PartialEq, Eq)] |
177 | | /// struct MyNode { |
178 | | /// val: u64 |
179 | | /// } |
180 | | /// |
181 | | /// // impl UserDefinedLogicalNode { |
182 | | /// // ... |
183 | | /// # impl MyNode { |
184 | | /// // Boiler plate to call the derived Eq impl |
185 | | /// fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { |
186 | | /// match other.as_any().downcast_ref::<Self>() { |
187 | | /// Some(o) => self == o, |
188 | | /// None => false, |
189 | | /// } |
190 | | /// } |
191 | | /// // } |
192 | | /// # } |
193 | | /// ``` |
194 | | /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Eq`] |
195 | | /// directly because it must remain object safe. |
196 | | fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool; |
197 | | fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering>; |
198 | | |
199 | | /// Returns `true` if a limit can be safely pushed down through this |
200 | | /// `UserDefinedLogicalNode` node. |
201 | | /// |
202 | | /// If this method returns `true`, and the query plan contains a limit at |
203 | | /// the output of this node, DataFusion will push the limit to the input |
204 | | /// of this node. |
205 | 0 | fn supports_limit_pushdown(&self) -> bool { |
206 | 0 | false |
207 | 0 | } |
208 | | } |
209 | | |
210 | | impl Hash for dyn UserDefinedLogicalNode { |
211 | 0 | fn hash<H: Hasher>(&self, state: &mut H) { |
212 | 0 | self.dyn_hash(state); |
213 | 0 | } |
214 | | } |
215 | | |
216 | | impl PartialEq for dyn UserDefinedLogicalNode { |
217 | 0 | fn eq(&self, other: &Self) -> bool { |
218 | 0 | self.dyn_eq(other) |
219 | 0 | } |
220 | | } |
221 | | |
222 | | impl PartialOrd for dyn UserDefinedLogicalNode { |
223 | 0 | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
224 | 0 | self.dyn_ord(other) |
225 | 0 | } |
226 | | } |
227 | | |
228 | | impl Eq for dyn UserDefinedLogicalNode {} |
229 | | |
230 | | /// This trait facilitates implementation of the [`UserDefinedLogicalNode`]. |
231 | | /// |
232 | | /// See the example in |
233 | | /// [user_defined_plan.rs](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/user_defined/user_defined_plan.rs) |
234 | | /// file for an example of how to use this extension API. |
235 | | pub trait UserDefinedLogicalNodeCore: |
236 | | fmt::Debug + Eq + PartialOrd + Hash + Sized + Send + Sync + 'static |
237 | | { |
238 | | /// Return the plan's name. |
239 | | fn name(&self) -> &str; |
240 | | |
241 | | /// Return the logical plan's inputs. |
242 | | fn inputs(&self) -> Vec<&LogicalPlan>; |
243 | | |
244 | | /// Return the output schema of this logical plan node. |
245 | | fn schema(&self) -> &DFSchemaRef; |
246 | | |
247 | | /// Returns all expressions in the current logical plan node. This |
248 | | /// should not include expressions of any inputs (aka |
249 | | /// non-recursively). These expressions are used for optimizer |
250 | | /// passes and rewrites. |
251 | | fn expressions(&self) -> Vec<Expr>; |
252 | | |
253 | | /// A list of output columns (e.g. the names of columns in |
254 | | /// self.schema()) for which predicates can not be pushed below |
255 | | /// this node without changing the output. |
256 | | /// |
257 | | /// By default, this returns all columns and thus prevents any |
258 | | /// predicates from being pushed below this node. |
259 | 0 | fn prevent_predicate_push_down_columns(&self) -> HashSet<String> { |
260 | 0 | // default (safe) is all columns in the schema. |
261 | 0 | get_all_columns_from_schema(self.schema()) |
262 | 0 | } |
263 | | |
264 | | /// Write a single line, human readable string to `f` for use in explain plan. |
265 | | /// |
266 | | /// For example: `TopK: k=10` |
267 | | fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result; |
268 | | |
269 | | #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")] |
270 | | #[allow(clippy::wrong_self_convention)] |
271 | 0 | fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { |
272 | 0 | self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) |
273 | 0 | .unwrap() |
274 | 0 | } |
275 | | |
276 | | /// Create a new `UserDefinedLogicalNode` with the specified children |
277 | | /// and expressions. This function is used during optimization |
278 | | /// when the plan is being rewritten and a new instance of the |
279 | | /// `UserDefinedLogicalNode` must be created. |
280 | | /// |
281 | | /// Note that exprs and inputs are in the same order as the result |
282 | | /// of self.inputs and self.exprs. |
283 | | /// |
284 | | /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs |
285 | | fn with_exprs_and_inputs( |
286 | | &self, |
287 | | exprs: Vec<Expr>, |
288 | | inputs: Vec<LogicalPlan>, |
289 | | ) -> Result<Self>; |
290 | | |
291 | | /// Returns the necessary input columns for this node required to compute |
292 | | /// the columns in the output schema |
293 | | /// |
294 | | /// This is used for projection push-down when DataFusion has determined that |
295 | | /// only a subset of the output columns of this node are needed by its parents. |
296 | | /// This API is used to tell DataFusion which, if any, of the input columns are no longer |
297 | | /// needed. |
298 | | /// |
299 | | /// Return `None`, the default, if this information can not be determined. |
300 | | /// Returns `Some(_)` with the column indices for each child of this node that are |
301 | | /// needed to compute `output_columns` |
302 | 0 | fn necessary_children_exprs( |
303 | 0 | &self, |
304 | 0 | _output_columns: &[usize], |
305 | 0 | ) -> Option<Vec<Vec<usize>>> { |
306 | 0 | None |
307 | 0 | } |
308 | | |
309 | | /// Returns `true` if a limit can be safely pushed down through this |
310 | | /// `UserDefinedLogicalNode` node. |
311 | | /// |
312 | | /// If this method returns `true`, and the query plan contains a limit at |
313 | | /// the output of this node, DataFusion will push the limit to the input |
314 | | /// of this node. |
315 | 0 | fn supports_limit_pushdown(&self) -> bool { |
316 | 0 | false // Disallow limit push-down by default |
317 | 0 | } |
318 | | } |
319 | | |
320 | | /// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode` |
321 | | /// to avoid boiler plate for implementing `as_any`, `Hash` and `PartialEq` |
322 | | impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T { |
323 | 0 | fn as_any(&self) -> &dyn Any { |
324 | 0 | self |
325 | 0 | } |
326 | | |
327 | 0 | fn name(&self) -> &str { |
328 | 0 | self.name() |
329 | 0 | } |
330 | | |
331 | 0 | fn inputs(&self) -> Vec<&LogicalPlan> { |
332 | 0 | self.inputs() |
333 | 0 | } |
334 | | |
335 | 0 | fn schema(&self) -> &DFSchemaRef { |
336 | 0 | self.schema() |
337 | 0 | } |
338 | | |
339 | 0 | fn expressions(&self) -> Vec<Expr> { |
340 | 0 | self.expressions() |
341 | 0 | } |
342 | | |
343 | 0 | fn prevent_predicate_push_down_columns(&self) -> HashSet<String> { |
344 | 0 | self.prevent_predicate_push_down_columns() |
345 | 0 | } |
346 | | |
347 | 0 | fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { |
348 | 0 | self.fmt_for_explain(f) |
349 | 0 | } |
350 | | |
351 | 0 | fn with_exprs_and_inputs( |
352 | 0 | &self, |
353 | 0 | exprs: Vec<Expr>, |
354 | 0 | inputs: Vec<LogicalPlan>, |
355 | 0 | ) -> Result<Arc<dyn UserDefinedLogicalNode>> { |
356 | 0 | Ok(Arc::new(self.with_exprs_and_inputs(exprs, inputs)?)) |
357 | 0 | } |
358 | | |
359 | 0 | fn necessary_children_exprs( |
360 | 0 | &self, |
361 | 0 | output_columns: &[usize], |
362 | 0 | ) -> Option<Vec<Vec<usize>>> { |
363 | 0 | self.necessary_children_exprs(output_columns) |
364 | 0 | } |
365 | | |
366 | 0 | fn dyn_hash(&self, state: &mut dyn Hasher) { |
367 | 0 | let mut s = state; |
368 | 0 | self.hash(&mut s); |
369 | 0 | } |
370 | | |
371 | 0 | fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { |
372 | 0 | match other.as_any().downcast_ref::<Self>() { |
373 | 0 | Some(o) => self == o, |
374 | 0 | None => false, |
375 | | } |
376 | 0 | } |
377 | | |
378 | 0 | fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering> { |
379 | 0 | other |
380 | 0 | .as_any() |
381 | 0 | .downcast_ref::<Self>() |
382 | 0 | .and_then(|other| self.partial_cmp(other)) |
383 | 0 | } |
384 | | |
385 | 0 | fn supports_limit_pushdown(&self) -> bool { |
386 | 0 | self.supports_limit_pushdown() |
387 | 0 | } |
388 | | } |
389 | | |
390 | 0 | fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet<String> { |
391 | 0 | schema.fields().iter().map(|f| f.name().clone()).collect() |
392 | 0 | } |