Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the TaskDefinition by changing encoding execution plan to the decoded one #817

Merged
merged 3 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 135 additions & 59 deletions ballista/core/src/serde/scheduler/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
// under the License.

use chrono::{TimeZone, Utc};
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use datafusion::physical_plan::metrics::{
Count, Gauge, MetricValue, MetricsSet, Time, Timestamp,
};
use datafusion::physical_plan::Metric;
use datafusion::physical_plan::{ExecutionPlan, Metric};
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use std::collections::HashMap;
use std::convert::TryInto;
use std::sync::Arc;
Expand All @@ -28,10 +33,10 @@ use std::time::Duration;
use crate::error::BallistaError;
use crate::serde::scheduler::{
Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
PartitionLocation, PartitionStats, TaskDefinition,
PartitionLocation, PartitionStats, SimpleFunctionRegistry, TaskDefinition,
};

use crate::serde::protobuf;
use crate::serde::{protobuf, BallistaCodec};
use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime};

impl TryInto<Action> for protobuf::Action {
Expand Down Expand Up @@ -269,67 +274,138 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
}
}

impl TryInto<(TaskDefinition, Vec<u8>)> for protobuf::TaskDefinition {
type Error = BallistaError;

fn try_into(self) -> Result<(TaskDefinition, Vec<u8>), Self::Error> {
let mut props = HashMap::new();
for kv_pair in self.props {
props.insert(kv_pair.key, kv_pair.value);
}
pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
task: protobuf::TaskDefinition,
runtime: Arc<RuntimeEnv>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
codec: BallistaCodec<T, U>,
) -> Result<TaskDefinition, BallistaError> {
let mut props = HashMap::new();
for kv_pair in task.props {
props.insert(kv_pair.key, kv_pair.value);
}
let props = Arc::new(props);

Ok((
TaskDefinition {
task_id: self.task_id as usize,
task_attempt_num: self.task_attempt_num as usize,
job_id: self.job_id,
stage_id: self.stage_id as usize,
stage_attempt_num: self.stage_attempt_num as usize,
partition_id: self.partition_id as usize,
plan: vec![],
session_id: self.session_id,
launch_time: self.launch_time,
props,
},
self.plan,
))
let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
// TODO combine the functions from Executor's functions and TaskDefinition's function resources
for scalar_func in scalar_functions {
task_scalar_functions.insert(scalar_func.0, scalar_func.1);
}
}
for agg_func in aggregate_functions {
task_aggregate_functions.insert(agg_func.0, agg_func.1);
}
let function_registry = Arc::new(SimpleFunctionRegistry {
scalar_functions: task_scalar_functions,
aggregate_functions: task_aggregate_functions,
});

impl TryInto<(Vec<TaskDefinition>, Vec<u8>)> for protobuf::MultiTaskDefinition {
type Error = BallistaError;
let encoded_plan = task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
function_registry.as_ref(),
runtime.as_ref(),
codec.physical_extension_codec(),
)
})?;

fn try_into(self) -> Result<(Vec<TaskDefinition>, Vec<u8>), Self::Error> {
let mut props = HashMap::new();
for kv_pair in self.props {
props.insert(kv_pair.key, kv_pair.value);
}
let job_id = task.job_id;
let stage_id = task.stage_id as usize;
let partition_id = task.partition_id as usize;
let task_attempt_num = task.task_attempt_num as usize;
let stage_attempt_num = task.stage_attempt_num as usize;
let launch_time = task.launch_time;
let task_id = task.task_id as usize;
let session_id = task.session_id;

let plan = self.plan;
let session_id = self.session_id;
let job_id = self.job_id;
let stage_id = self.stage_id as usize;
let stage_attempt_num = self.stage_attempt_num as usize;
let launch_time = self.launch_time;
let task_ids = self.task_ids;
Ok(TaskDefinition {
task_id,
task_attempt_num,
job_id,
stage_id,
stage_attempt_num,
partition_id,
plan,
launch_time,
session_id,
props,
function_registry,
})
}

Ok((
task_ids
.iter()
.map(|task_id| TaskDefinition {
task_id: task_id.task_id as usize,
task_attempt_num: task_id.task_attempt_num as usize,
job_id: job_id.clone(),
stage_id,
stage_attempt_num,
partition_id: task_id.partition_id as usize,
plan: vec![],
session_id: session_id.clone(),
launch_time,
props: props.clone(),
})
.collect(),
plan,
))
pub fn get_task_definition_vec<
T: 'static + AsLogicalPlan,
U: 'static + AsExecutionPlan,
>(
multi_task: protobuf::MultiTaskDefinition,
runtime: Arc<RuntimeEnv>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
codec: BallistaCodec<T, U>,
) -> Result<Vec<TaskDefinition>, BallistaError> {
let mut props = HashMap::new();
for kv_pair in multi_task.props {
props.insert(kv_pair.key, kv_pair.value);
}
let props = Arc::new(props);

let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
// TODO combine the functions from Executor's functions and TaskDefinition's function resources
for scalar_func in scalar_functions {
task_scalar_functions.insert(scalar_func.0, scalar_func.1);
}
for agg_func in aggregate_functions {
task_aggregate_functions.insert(agg_func.0, agg_func.1);
}
let function_registry = Arc::new(SimpleFunctionRegistry {
scalar_functions: task_scalar_functions,
aggregate_functions: task_aggregate_functions,
});

let encoded_plan = multi_task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
function_registry.as_ref(),
runtime.as_ref(),
codec.physical_extension_codec(),
)
})?;

let job_id = multi_task.job_id;
let stage_id = multi_task.stage_id as usize;
let stage_attempt_num = multi_task.stage_attempt_num as usize;
let launch_time = multi_task.launch_time;
let task_ids = multi_task.task_ids;
let session_id = multi_task.session_id;

task_ids
.iter()
.map(|task_id| {
Ok(TaskDefinition {
task_id: task_id.task_id as usize,
task_attempt_num: task_id.task_attempt_num as usize,
job_id: job_id.clone(),
stage_id,
stage_attempt_num,
partition_id: task_id.partition_id as usize,
plan: reset_metrics_for_execution_plan(plan.clone())?,
launch_time,
session_id: session_id.clone(),
props: props.clone(),
function_registry: function_registry.clone(),
})
})
.collect()
}

fn reset_metrics_for_execution_plan(
plan: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
plan.transform(&|plan| {
let children = plan.children().clone();
plan.with_new_children(children).map(Transformed::Yes)
})
.map_err(BallistaError::DataFusionError)
}
46 changes: 42 additions & 4 deletions ballista/core/src/serde/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashSet;
use std::fmt::Debug;
use std::{collections::HashMap, fmt, sync::Arc};

use datafusion::arrow::array::{
ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder,
};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::DataFusionError;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
use serde::Serialize;
Expand Down Expand Up @@ -271,16 +276,49 @@ impl ExecutePartitionResult {
}
}

#[derive(Debug, Clone)]
#[derive(Clone, Debug)]
pub struct TaskDefinition {
pub task_id: usize,
pub task_attempt_num: usize,
pub job_id: String,
pub stage_id: usize,
pub stage_attempt_num: usize,
pub partition_id: usize,
pub plan: Vec<u8>,
pub session_id: String,
pub plan: Arc<dyn ExecutionPlan>,
pub launch_time: u64,
pub props: HashMap<String, String>,
pub session_id: String,
pub props: Arc<HashMap<String, String>>,
pub function_registry: Arc<SimpleFunctionRegistry>,
}

#[derive(Debug)]
pub struct SimpleFunctionRegistry {
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
}

impl FunctionRegistry for SimpleFunctionRegistry {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}

fn udf(&self, name: &str) -> datafusion::common::Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDF named \"{name}\" in the TaskContext"
))
})
}

fn udaf(&self, name: &str) -> datafusion::common::Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDAF named \"{name}\" in the TaskContext"
))
})
}
}
33 changes: 2 additions & 31 deletions ballista/core/src/serde/scheduler/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ use datafusion_proto::protobuf as datafusion_protobuf;

use crate::serde::scheduler::{
Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
PartitionLocation, PartitionStats, TaskDefinition,
PartitionLocation, PartitionStats,
};
use datafusion::physical_plan::Partitioning;
use protobuf::{
action::ActionType, operator_metric, KeyValuePair, NamedCount, NamedGauge, NamedTime,
};
use protobuf::{action::ActionType, operator_metric, NamedCount, NamedGauge, NamedTime};

impl TryInto<protobuf::Action> for Action {
type Error = BallistaError;
Expand Down Expand Up @@ -242,30 +240,3 @@ impl Into<protobuf::ExecutorData> for ExecutorData {
}
}
}

#[allow(clippy::from_over_into)]
impl Into<protobuf::TaskDefinition> for TaskDefinition {
fn into(self) -> protobuf::TaskDefinition {
let props = self
.props
.iter()
.map(|(k, v)| KeyValuePair {
key: k.to_owned(),
value: v.to_owned(),
})
.collect::<Vec<_>>();

protobuf::TaskDefinition {
task_id: self.task_id as u32,
task_attempt_num: self.task_attempt_num as u32,
job_id: self.job_id,
stage_id: self.stage_id as u32,
stage_attempt_num: self.stage_attempt_num as u32,
partition_id: self.partition_id as u32,
plan: self.plan,
session_id: self.session_id,
launch_time: self.launch_time,
props,
}
}
}
7 changes: 0 additions & 7 deletions ballista/executor/src/execution_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use ballista_core::execution_plans::ShuffleWriterExec;
use ballista_core::serde::protobuf::ShuffleWritePartition;
Expand Down Expand Up @@ -52,8 +51,6 @@ pub trait QueryStageExecutor: Sync + Send + Debug {
) -> Result<Vec<ShuffleWritePartition>>;

fn collect_plan_metrics(&self) -> Vec<MetricsSet>;

fn schema(&self) -> SchemaRef;
}

pub struct DefaultExecutionEngine {}
Expand Down Expand Up @@ -111,10 +108,6 @@ impl QueryStageExecutor for DefaultQueryStageExec {
.await
}

fn schema(&self) -> SchemaRef {
self.shuffle_writer.schema()
}

fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
utils::collect_plan_metrics(&self.shuffle_writer)
}
Expand Down
Loading