Skip to content

Commit

Permalink
Remove TaskProperties / KV structure (#4382)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb authored Nov 29, 2022
1 parent 02da32e commit 1438bc4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 70 deletions.
110 changes: 41 additions & 69 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1810,22 +1810,14 @@ impl FunctionRegistry for SessionState {
}
}

/// Task Context Properties
pub enum TaskProperties {
///SessionConfig
SessionConfig(SessionConfig),
/// Name-value pairs of task properties
KVPairs(HashMap<String, String>),
}

/// Task Execution Context
pub struct TaskContext {
/// Session Id
session_id: String,
/// Optional Task Identify
task_id: Option<String>,
/// Task properties
properties: TaskProperties,
/// Session configuration
session_config: SessionConfig,
/// Scalar functions associated with this task context
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions associated with this task context
Expand All @@ -1844,55 +1836,52 @@ impl TaskContext {
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
runtime: Arc<RuntimeEnv>,
) -> Self {
let session_config = if task_props.is_empty() {
SessionConfig::new()
} else {
SessionConfig::new()
.with_batch_size(task_props.get(OPT_BATCH_SIZE).unwrap().parse().unwrap())
.with_target_partitions(
task_props.get(TARGET_PARTITIONS).unwrap().parse().unwrap(),
)
.with_repartition_joins(
task_props.get(REPARTITION_JOINS).unwrap().parse().unwrap(),
)
.with_repartition_aggregations(
task_props
.get(REPARTITION_AGGREGATIONS)
.unwrap()
.parse()
.unwrap(),
)
.with_repartition_windows(
task_props
.get(REPARTITION_WINDOWS)
.unwrap()
.parse()
.unwrap(),
)
.with_parquet_pruning(
task_props.get(PARQUET_PRUNING).unwrap().parse().unwrap(),
)
.with_collect_statistics(
task_props.get(COLLECT_STATISTICS).unwrap().parse().unwrap(),
)
};

Self {
task_id: Some(task_id),
session_id,
properties: TaskProperties::KVPairs(task_props),
session_config,
scalar_functions,
aggregate_functions,
runtime,
}
}

/// Return the SessionConfig associated with the Task
pub fn session_config(&self) -> SessionConfig {
let task_props = &self.properties;
match task_props {
TaskProperties::KVPairs(props) => {
let session_config = SessionConfig::new();
if props.is_empty() {
session_config
} else {
session_config
.with_batch_size(
props.get(OPT_BATCH_SIZE).unwrap().parse().unwrap(),
)
.with_target_partitions(
props.get(TARGET_PARTITIONS).unwrap().parse().unwrap(),
)
.with_repartition_joins(
props.get(REPARTITION_JOINS).unwrap().parse().unwrap(),
)
.with_repartition_aggregations(
props
.get(REPARTITION_AGGREGATIONS)
.unwrap()
.parse()
.unwrap(),
)
.with_repartition_windows(
props.get(REPARTITION_WINDOWS).unwrap().parse().unwrap(),
)
.with_parquet_pruning(
props.get(PARQUET_PRUNING).unwrap().parse().unwrap(),
)
.with_collect_statistics(
props.get(COLLECT_STATISTICS).unwrap().parse().unwrap(),
)
}
}
TaskProperties::SessionConfig(session_config) => session_config.clone(),
}
pub fn session_config(&self) -> &SessionConfig {
&self.session_config
}

/// Return the session_id of this [TaskContext]
Expand All @@ -1914,39 +1903,22 @@ impl TaskContext {
/// Create a new task context instance from SessionContext
impl From<&SessionContext> for TaskContext {
fn from(session: &SessionContext) -> Self {
let session_id = session.session_id.clone();
let (config, scalar_functions, aggregate_functions) = {
let session_state = session.state.read();
(
session_state.config.clone(),
session_state.scalar_functions.clone(),
session_state.aggregate_functions.clone(),
)
};
let runtime = session.runtime_env();
Self {
task_id: None,
session_id,
properties: TaskProperties::SessionConfig(config),
scalar_functions,
aggregate_functions,
runtime,
}
TaskContext::from(&*session.state.read())
}
}

/// Create a new task context instance from SessionState
impl From<&SessionState> for TaskContext {
fn from(state: &SessionState) -> Self {
let session_id = state.session_id.clone();
let config = state.config.clone();
let session_config = state.config.clone();
let scalar_functions = state.scalar_functions.clone();
let aggregate_functions = state.aggregate_functions.clone();
let runtime = state.runtime_env.clone();
Self {
task_id: None,
session_id,
properties: TaskProperties::SessionConfig(config),
session_config,
scalar_functions,
aggregate_functions,
runtime,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ async fn do_sort(
schema.clone(),
expr,
metrics_set,
Arc::new(context.session_config()),
Arc::new(context.session_config().clone()),
context.runtime_env(),
fetch,
);
Expand Down

0 comments on commit 1438bc4

Please sign in to comment.