Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move TaskContext to datafusion-execution #5677

Merged
merged 4 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,11 @@ pub trait ExtensionOptions: Send + Sync + std::fmt::Debug + 'static {
pub struct Extensions(BTreeMap<&'static str, ExtensionBox>);

impl Extensions {
/// Create a new, empty [`Extensions`]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was part of #5676 that I didn't want to lose

pub fn new() -> Self {
Self(BTreeMap::new())
}

/// Registers a [`ConfigExtension`] with this [`ConfigOptions`]
pub fn insert<T: ConfigExtension>(&mut self, extension: T) {
assert_ne!(T::PREFIX, "datafusion");
Expand Down
161 changes: 11 additions & 150 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ use crate::physical_plan::PhysicalPlanner;
use crate::variable::{VarProvider, VarType};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_common::{config::Extensions, OwnedTableReference};
use datafusion_common::OwnedTableReference;
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
Expand All @@ -94,7 +94,6 @@ use url::Url;
use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA};
use crate::catalog::listing_schema::ListingSchemaProvider;
use crate::datasource::object_store::ObjectStoreUrl;
use crate::execution::memory_pool::MemoryPool;
use crate::physical_optimizer::global_sort_selection::GlobalSortSelection;
use crate::physical_optimizer::pipeline_checker::PipelineChecker;
use crate::physical_optimizer::pipeline_fixer::PipelineFixer;
Expand All @@ -105,6 +104,7 @@ use uuid::Uuid;

// backwards compatibility
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;

use super::options::{
AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, ReadOptions,
Expand Down Expand Up @@ -1802,74 +1802,6 @@ impl OptimizerConfig for SessionState {
}
}

/// Task Execution Context
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR simply moves this code into datafusion_execution

pub struct TaskContext {
/// Session Id
session_id: String,
/// Optional Task Identify
task_id: Option<String>,
/// Session configuration
session_config: SessionConfig,
/// Scalar functions associated with this task context
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions associated with this task context
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
/// Runtime environment associated with this task context
runtime: Arc<RuntimeEnv>,
}

impl TaskContext {
/// Create a new task context instance
pub fn try_new(
task_id: String,
session_id: String,
task_props: HashMap<String, String>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
runtime: Arc<RuntimeEnv>,
extensions: Extensions,
) -> Result<Self> {
let mut config = ConfigOptions::new().with_extensions(extensions);
for (k, v) in task_props {
config.set(&k, &v)?;
}

Ok(Self {
task_id: Some(task_id),
session_id,
session_config: config.into(),
scalar_functions,
aggregate_functions,
runtime,
})
}

/// Return the SessionConfig associated with the Task
pub fn session_config(&self) -> &SessionConfig {
&self.session_config
}

/// Return the `session_id` of this [TaskContext]
pub fn session_id(&self) -> String {
self.session_id.clone()
}

/// Return the `task_id` of this [TaskContext]
pub fn task_id(&self) -> Option<String> {
self.task_id.clone()
}

/// Return the [`MemoryPool`] associated with this [TaskContext]
pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
&self.runtime.memory_pool
}

/// Return the [RuntimeEnv] associated with this [TaskContext]
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
self.runtime.clone()
}
}

/// Create a new task context instance from SessionContext
impl From<&SessionContext> for TaskContext {
fn from(session: &SessionContext) -> Self {
Expand All @@ -1880,45 +1812,15 @@ impl From<&SessionContext> for TaskContext {
/// Create a new task context instance from SessionState
impl From<&SessionState> for TaskContext {
fn from(state: &SessionState) -> Self {
let session_id = state.session_id.clone();
let session_config = state.config.clone();
let scalar_functions = state.scalar_functions.clone();
let aggregate_functions = state.aggregate_functions.clone();
let runtime = state.runtime_env.clone();
Self {
task_id: None,
session_id,
session_config,
scalar_functions,
aggregate_functions,
runtime,
}
}
}

impl FunctionRegistry for TaskContext {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}

fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDF named \"{name}\" in the TaskContext"
))
})
}

fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDAF named \"{name}\" in the TaskContext"
))
})
let task_id = None;
TaskContext::new(
task_id,
state.session_id.clone(),
state.config.clone(),
state.scalar_functions.clone(),
state.aggregate_functions.clone(),
state.runtime_env.clone(),
)
}
}

Expand All @@ -1936,8 +1838,6 @@ mod tests {
use arrow::array::ArrayRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_common::config::ConfigExtension;
use datafusion_common::extensions_options;
use datafusion_expr::{create_udaf, create_udf, Expr, Volatility};
use datafusion_physical_expr::functions::make_scalar_function;
use std::fs::File;
Expand Down Expand Up @@ -2605,43 +2505,4 @@ mod tests {
.unwrap()
}
}

extensions_options! {
struct TestExtension {
value: usize, default = 42
}
}

impl ConfigExtension for TestExtension {
const PREFIX: &'static str = "test";
}

#[test]
fn task_context_extensions() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let task_props = HashMap::from([("test.value".to_string(), "24".to_string())]);
let mut extensions = Extensions::default();
extensions.insert(TestExtension::default());

let task_context = TaskContext::try_new(
"task_id".to_string(),
"session_id".to_string(),
task_props,
HashMap::default(),
HashMap::default(),
runtime,
extensions,
)?;

let test = task_context
.session_config()
.options()
.extensions
.get::<TestExtension>();
assert!(test.is_some());

assert_eq!(test.unwrap().value, 24);

Ok(())
}
}
3 changes: 3 additions & 0 deletions datafusion/execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ pub mod memory_pool;
pub mod object_store;
pub mod registry;
pub mod runtime_env;
mod task;

pub use task::TaskContext;
Loading