/Users/andrewlamb/Software/datafusion/datafusion/execution/src/task.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::{ |
19 | | collections::{HashMap, HashSet}, |
20 | | sync::Arc, |
21 | | }; |
22 | | |
23 | | use crate::{ |
24 | | config::SessionConfig, |
25 | | memory_pool::MemoryPool, |
26 | | registry::FunctionRegistry, |
27 | | runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, |
28 | | }; |
29 | | use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; |
30 | | use datafusion_expr::planner::ExprPlanner; |
31 | | use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; |
32 | | |
33 | | /// Task Execution Context |
34 | | /// |
35 | | /// A [`TaskContext`] contains the state required during a single query's |
36 | | /// execution. Please see the documentation on [`SessionContext`] for more |
37 | | /// information. |
38 | | /// |
39 | | /// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html |
40 | | #[derive(Debug)] |
41 | | pub struct TaskContext { |
42 | | /// Session Id |
43 | | session_id: String, |
44 | | /// Optional Task Identify |
45 | | task_id: Option<String>, |
46 | | /// Session configuration |
47 | | session_config: SessionConfig, |
48 | | /// Scalar functions associated with this task context |
49 | | scalar_functions: HashMap<String, Arc<ScalarUDF>>, |
50 | | /// Aggregate functions associated with this task context |
51 | | aggregate_functions: HashMap<String, Arc<AggregateUDF>>, |
52 | | /// Window functions associated with this task context |
53 | | window_functions: HashMap<String, Arc<WindowUDF>>, |
54 | | /// Runtime environment associated with this task context |
55 | | runtime: Arc<RuntimeEnv>, |
56 | | } |
57 | | |
58 | | impl Default for TaskContext { |
59 | 891 | fn default() -> Self { |
60 | 891 | let runtime = RuntimeEnvBuilder::new() |
61 | 891 | .build_arc() |
62 | 891 | .expect("default runtime created successfully"); |
63 | 891 | |
64 | 891 | // Create a default task context, mostly useful for testing |
65 | 891 | Self { |
66 | 891 | session_id: "DEFAULT".to_string(), |
67 | 891 | task_id: None, |
68 | 891 | session_config: SessionConfig::new(), |
69 | 891 | scalar_functions: HashMap::new(), |
70 | 891 | aggregate_functions: HashMap::new(), |
71 | 891 | window_functions: HashMap::new(), |
72 | 891 | runtime, |
73 | 891 | } |
74 | 891 | } |
75 | | } |
76 | | |
77 | | impl TaskContext { |
78 | | /// Create a new [`TaskContext`] instance. |
79 | | /// |
80 | | /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s |
81 | | /// |
82 | | /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx |
83 | 0 | pub fn new( |
84 | 0 | task_id: Option<String>, |
85 | 0 | session_id: String, |
86 | 0 | session_config: SessionConfig, |
87 | 0 | scalar_functions: HashMap<String, Arc<ScalarUDF>>, |
88 | 0 | aggregate_functions: HashMap<String, Arc<AggregateUDF>>, |
89 | 0 | window_functions: HashMap<String, Arc<WindowUDF>>, |
90 | 0 | runtime: Arc<RuntimeEnv>, |
91 | 0 | ) -> Self { |
92 | 0 | Self { |
93 | 0 | task_id, |
94 | 0 | session_id, |
95 | 0 | session_config, |
96 | 0 | scalar_functions, |
97 | 0 | aggregate_functions, |
98 | 0 | window_functions, |
99 | 0 | runtime, |
100 | 0 | } |
101 | 0 | } |
102 | | |
103 | | /// Return the SessionConfig associated with this [TaskContext] |
104 | 2.01k | pub fn session_config(&self) -> &SessionConfig { |
105 | 2.01k | &self.session_config |
106 | 2.01k | } |
107 | | |
108 | | /// Return the `session_id` of this [TaskContext] |
109 | 0 | pub fn session_id(&self) -> String { |
110 | 0 | self.session_id.clone() |
111 | 0 | } |
112 | | |
113 | | /// Return the `task_id` of this [TaskContext] |
114 | 0 | pub fn task_id(&self) -> Option<String> { |
115 | 0 | self.task_id.clone() |
116 | 0 | } |
117 | | |
118 | | /// Return the [`MemoryPool`] associated with this [TaskContext] |
119 | 8.84k | pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> { |
120 | 8.84k | &self.runtime.memory_pool |
121 | 8.84k | } |
122 | | |
123 | | /// Return the [RuntimeEnv] associated with this [TaskContext] |
124 | 203 | pub fn runtime_env(&self) -> Arc<RuntimeEnv> { |
125 | 203 | Arc::clone(&self.runtime) |
126 | 203 | } |
127 | | |
128 | | /// Update the [`SessionConfig`] |
129 | 545 | pub fn with_session_config(mut self, session_config: SessionConfig) -> Self { |
130 | 545 | self.session_config = session_config; |
131 | 545 | self |
132 | 545 | } |
133 | | |
134 | | /// Update the [`RuntimeEnv`] |
135 | 78 | pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self { |
136 | 78 | self.runtime = runtime; |
137 | 78 | self |
138 | 78 | } |
139 | | } |
140 | | |
141 | | impl FunctionRegistry for TaskContext { |
142 | 0 | fn udfs(&self) -> HashSet<String> { |
143 | 0 | self.scalar_functions.keys().cloned().collect() |
144 | 0 | } |
145 | | |
146 | 0 | fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> { |
147 | 0 | let result = self.scalar_functions.get(name); |
148 | 0 |
|
149 | 0 | result.cloned().ok_or_else(|| { |
150 | 0 | plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext") |
151 | 0 | }) |
152 | 0 | } |
153 | | |
154 | 0 | fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> { |
155 | 0 | let result = self.aggregate_functions.get(name); |
156 | 0 |
|
157 | 0 | result.cloned().ok_or_else(|| { |
158 | 0 | plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext") |
159 | 0 | }) |
160 | 0 | } |
161 | | |
162 | 0 | fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> { |
163 | 0 | let result = self.window_functions.get(name); |
164 | 0 |
|
165 | 0 | result.cloned().ok_or_else(|| { |
166 | 0 | DataFusionError::Internal(format!( |
167 | 0 | "There is no UDWF named \"{name}\" in the TaskContext" |
168 | 0 | )) |
169 | 0 | }) |
170 | 0 | } |
171 | 0 | fn register_udaf( |
172 | 0 | &mut self, |
173 | 0 | udaf: Arc<AggregateUDF>, |
174 | 0 | ) -> Result<Option<Arc<AggregateUDF>>> { |
175 | 0 | udaf.aliases().iter().for_each(|alias| { |
176 | 0 | self.aggregate_functions |
177 | 0 | .insert(alias.clone(), Arc::clone(&udaf)); |
178 | 0 | }); |
179 | 0 | Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) |
180 | 0 | } |
181 | 0 | fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> { |
182 | 0 | udwf.aliases().iter().for_each(|alias| { |
183 | 0 | self.window_functions |
184 | 0 | .insert(alias.clone(), Arc::clone(&udwf)); |
185 | 0 | }); |
186 | 0 | Ok(self.window_functions.insert(udwf.name().into(), udwf)) |
187 | 0 | } |
188 | 0 | fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> { |
189 | 0 | udf.aliases().iter().for_each(|alias| { |
190 | 0 | self.scalar_functions |
191 | 0 | .insert(alias.clone(), Arc::clone(&udf)); |
192 | 0 | }); |
193 | 0 | Ok(self.scalar_functions.insert(udf.name().into(), udf)) |
194 | 0 | } |
195 | | |
196 | 0 | fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> { |
197 | 0 | vec![] |
198 | 0 | } |
199 | | } |
200 | | |
201 | | #[cfg(test)] |
202 | | mod tests { |
203 | | use super::*; |
204 | | use datafusion_common::{ |
205 | | config::{ConfigExtension, ConfigOptions, Extensions}, |
206 | | extensions_options, |
207 | | }; |
208 | | |
209 | | extensions_options! { |
210 | | struct TestExtension { |
211 | | value: usize, default = 42 |
212 | | } |
213 | | } |
214 | | |
215 | | impl ConfigExtension for TestExtension { |
216 | | const PREFIX: &'static str = "test"; |
217 | | } |
218 | | |
219 | | #[test] |
220 | | fn task_context_extensions() -> Result<()> { |
221 | | let runtime = Arc::new(RuntimeEnv::default()); |
222 | | let mut extensions = Extensions::new(); |
223 | | extensions.insert(TestExtension::default()); |
224 | | |
225 | | let mut config = ConfigOptions::new().with_extensions(extensions); |
226 | | config.set("test.value", "24")?; |
227 | | let session_config = SessionConfig::from(config); |
228 | | |
229 | | let task_context = TaskContext::new( |
230 | | Some("task_id".to_string()), |
231 | | "session_id".to_string(), |
232 | | session_config, |
233 | | HashMap::default(), |
234 | | HashMap::default(), |
235 | | HashMap::default(), |
236 | | runtime, |
237 | | ); |
238 | | |
239 | | let test = task_context |
240 | | .session_config() |
241 | | .options() |
242 | | .extensions |
243 | | .get::<TestExtension>(); |
244 | | assert!(test.is_some()); |
245 | | |
246 | | assert_eq!(test.unwrap().value, 24); |
247 | | |
248 | | Ok(()) |
249 | | } |
250 | | } |