Skip to content

Commit

Permalink
Refactor the TaskDefinition by changing encoding execution plan to th…
Browse files Browse the repository at this point in the history
…e decoded one (#817)

* Revert "Only decode plan in `LaunchMultiTaskParams`  once (#743)"

This reverts commit 4e4842c.

* Refactor the TaskDefinition by changing encoding execution plan to the decoded one

* Refine the error handling of run_task in the executor_server

---------

Co-authored-by: yangzhong <[email protected]>
  • Loading branch information
yahoNanJing and kyotoYaho authored Jun 28, 2023
1 parent 553b9a7 commit d7a808c
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 280 deletions.
194 changes: 135 additions & 59 deletions ballista/core/src/serde/scheduler/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
// under the License.

use chrono::{TimeZone, Utc};
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use datafusion::physical_plan::metrics::{
Count, Gauge, MetricValue, MetricsSet, Time, Timestamp,
};
use datafusion::physical_plan::Metric;
use datafusion::physical_plan::{ExecutionPlan, Metric};
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use std::collections::HashMap;
use std::convert::TryInto;
use std::sync::Arc;
Expand All @@ -28,10 +33,10 @@ use std::time::Duration;
use crate::error::BallistaError;
use crate::serde::scheduler::{
Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
PartitionLocation, PartitionStats, TaskDefinition,
PartitionLocation, PartitionStats, SimpleFunctionRegistry, TaskDefinition,
};

use crate::serde::protobuf;
use crate::serde::{protobuf, BallistaCodec};
use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime};

impl TryInto<Action> for protobuf::Action {
Expand Down Expand Up @@ -269,67 +274,138 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
}
}

impl TryInto<(TaskDefinition, Vec<u8>)> for protobuf::TaskDefinition {
type Error = BallistaError;

fn try_into(self) -> Result<(TaskDefinition, Vec<u8>), Self::Error> {
let mut props = HashMap::new();
for kv_pair in self.props {
props.insert(kv_pair.key, kv_pair.value);
}
pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
task: protobuf::TaskDefinition,
runtime: Arc<RuntimeEnv>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
codec: BallistaCodec<T, U>,
) -> Result<TaskDefinition, BallistaError> {
let mut props = HashMap::new();
for kv_pair in task.props {
props.insert(kv_pair.key, kv_pair.value);
}
let props = Arc::new(props);

Ok((
TaskDefinition {
task_id: self.task_id as usize,
task_attempt_num: self.task_attempt_num as usize,
job_id: self.job_id,
stage_id: self.stage_id as usize,
stage_attempt_num: self.stage_attempt_num as usize,
partition_id: self.partition_id as usize,
plan: vec![],
session_id: self.session_id,
launch_time: self.launch_time,
props,
},
self.plan,
))
let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
// TODO combine the functions from Executor's functions and TaskDefinition's function resources
for scalar_func in scalar_functions {
task_scalar_functions.insert(scalar_func.0, scalar_func.1);
}
}
for agg_func in aggregate_functions {
task_aggregate_functions.insert(agg_func.0, agg_func.1);
}
let function_registry = Arc::new(SimpleFunctionRegistry {
scalar_functions: task_scalar_functions,
aggregate_functions: task_aggregate_functions,
});

impl TryInto<(Vec<TaskDefinition>, Vec<u8>)> for protobuf::MultiTaskDefinition {
type Error = BallistaError;
let encoded_plan = task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
function_registry.as_ref(),
runtime.as_ref(),
codec.physical_extension_codec(),
)
})?;

fn try_into(self) -> Result<(Vec<TaskDefinition>, Vec<u8>), Self::Error> {
let mut props = HashMap::new();
for kv_pair in self.props {
props.insert(kv_pair.key, kv_pair.value);
}
let job_id = task.job_id;
let stage_id = task.stage_id as usize;
let partition_id = task.partition_id as usize;
let task_attempt_num = task.task_attempt_num as usize;
let stage_attempt_num = task.stage_attempt_num as usize;
let launch_time = task.launch_time;
let task_id = task.task_id as usize;
let session_id = task.session_id;

let plan = self.plan;
let session_id = self.session_id;
let job_id = self.job_id;
let stage_id = self.stage_id as usize;
let stage_attempt_num = self.stage_attempt_num as usize;
let launch_time = self.launch_time;
let task_ids = self.task_ids;
Ok(TaskDefinition {
task_id,
task_attempt_num,
job_id,
stage_id,
stage_attempt_num,
partition_id,
plan,
launch_time,
session_id,
props,
function_registry,
})
}

Ok((
task_ids
.iter()
.map(|task_id| TaskDefinition {
task_id: task_id.task_id as usize,
task_attempt_num: task_id.task_attempt_num as usize,
job_id: job_id.clone(),
stage_id,
stage_attempt_num,
partition_id: task_id.partition_id as usize,
plan: vec![],
session_id: session_id.clone(),
launch_time,
props: props.clone(),
})
.collect(),
plan,
))
pub fn get_task_definition_vec<
T: 'static + AsLogicalPlan,
U: 'static + AsExecutionPlan,
>(
multi_task: protobuf::MultiTaskDefinition,
runtime: Arc<RuntimeEnv>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
codec: BallistaCodec<T, U>,
) -> Result<Vec<TaskDefinition>, BallistaError> {
let mut props = HashMap::new();
for kv_pair in multi_task.props {
props.insert(kv_pair.key, kv_pair.value);
}
let props = Arc::new(props);

let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
// TODO combine the functions from Executor's functions and TaskDefinition's function resources
for scalar_func in scalar_functions {
task_scalar_functions.insert(scalar_func.0, scalar_func.1);
}
for agg_func in aggregate_functions {
task_aggregate_functions.insert(agg_func.0, agg_func.1);
}
let function_registry = Arc::new(SimpleFunctionRegistry {
scalar_functions: task_scalar_functions,
aggregate_functions: task_aggregate_functions,
});

let encoded_plan = multi_task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
function_registry.as_ref(),
runtime.as_ref(),
codec.physical_extension_codec(),
)
})?;

let job_id = multi_task.job_id;
let stage_id = multi_task.stage_id as usize;
let stage_attempt_num = multi_task.stage_attempt_num as usize;
let launch_time = multi_task.launch_time;
let task_ids = multi_task.task_ids;
let session_id = multi_task.session_id;

task_ids
.iter()
.map(|task_id| {
Ok(TaskDefinition {
task_id: task_id.task_id as usize,
task_attempt_num: task_id.task_attempt_num as usize,
job_id: job_id.clone(),
stage_id,
stage_attempt_num,
partition_id: task_id.partition_id as usize,
plan: reset_metrics_for_execution_plan(plan.clone())?,
launch_time,
session_id: session_id.clone(),
props: props.clone(),
function_registry: function_registry.clone(),
})
})
.collect()
}

fn reset_metrics_for_execution_plan(
plan: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
plan.transform(&|plan| {
let children = plan.children().clone();
plan.with_new_children(children).map(Transformed::Yes)
})
.map_err(BallistaError::DataFusionError)
}
46 changes: 42 additions & 4 deletions ballista/core/src/serde/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashSet;
use std::fmt::Debug;
use std::{collections::HashMap, fmt, sync::Arc};

use datafusion::arrow::array::{
ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder,
};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::DataFusionError;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
use serde::Serialize;
Expand Down Expand Up @@ -271,16 +276,49 @@ impl ExecutePartitionResult {
}
}

#[derive(Debug, Clone)]
#[derive(Clone, Debug)]
pub struct TaskDefinition {
pub task_id: usize,
pub task_attempt_num: usize,
pub job_id: String,
pub stage_id: usize,
pub stage_attempt_num: usize,
pub partition_id: usize,
pub plan: Vec<u8>,
pub session_id: String,
pub plan: Arc<dyn ExecutionPlan>,
pub launch_time: u64,
pub props: HashMap<String, String>,
pub session_id: String,
pub props: Arc<HashMap<String, String>>,
pub function_registry: Arc<SimpleFunctionRegistry>,
}

#[derive(Debug)]
pub struct SimpleFunctionRegistry {
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
}

impl FunctionRegistry for SimpleFunctionRegistry {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}

fn udf(&self, name: &str) -> datafusion::common::Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDF named \"{name}\" in the TaskContext"
))
})
}

fn udaf(&self, name: &str) -> datafusion::common::Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);

result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDAF named \"{name}\" in the TaskContext"
))
})
}
}
33 changes: 2 additions & 31 deletions ballista/core/src/serde/scheduler/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ use datafusion_proto::protobuf as datafusion_protobuf;

use crate::serde::scheduler::{
Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
PartitionLocation, PartitionStats, TaskDefinition,
PartitionLocation, PartitionStats,
};
use datafusion::physical_plan::Partitioning;
use protobuf::{
action::ActionType, operator_metric, KeyValuePair, NamedCount, NamedGauge, NamedTime,
};
use protobuf::{action::ActionType, operator_metric, NamedCount, NamedGauge, NamedTime};

impl TryInto<protobuf::Action> for Action {
type Error = BallistaError;
Expand Down Expand Up @@ -242,30 +240,3 @@ impl Into<protobuf::ExecutorData> for ExecutorData {
}
}
}

#[allow(clippy::from_over_into)]
impl Into<protobuf::TaskDefinition> for TaskDefinition {
fn into(self) -> protobuf::TaskDefinition {
let props = self
.props
.iter()
.map(|(k, v)| KeyValuePair {
key: k.to_owned(),
value: v.to_owned(),
})
.collect::<Vec<_>>();

protobuf::TaskDefinition {
task_id: self.task_id as u32,
task_attempt_num: self.task_attempt_num as u32,
job_id: self.job_id,
stage_id: self.stage_id as u32,
stage_attempt_num: self.stage_attempt_num as u32,
partition_id: self.partition_id as u32,
plan: self.plan,
session_id: self.session_id,
launch_time: self.launch_time,
props,
}
}
}
7 changes: 0 additions & 7 deletions ballista/executor/src/execution_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use ballista_core::execution_plans::ShuffleWriterExec;
use ballista_core::serde::protobuf::ShuffleWritePartition;
Expand Down Expand Up @@ -52,8 +51,6 @@ pub trait QueryStageExecutor: Sync + Send + Debug {
) -> Result<Vec<ShuffleWritePartition>>;

fn collect_plan_metrics(&self) -> Vec<MetricsSet>;

fn schema(&self) -> SchemaRef;
}

pub struct DefaultExecutionEngine {}
Expand Down Expand Up @@ -111,10 +108,6 @@ impl QueryStageExecutor for DefaultQueryStageExec {
.await
}

fn schema(&self) -> SchemaRef {
self.shuffle_writer.schema()
}

fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
utils::collect_plan_metrics(&self.shuffle_writer)
}
Expand Down
Loading

0 comments on commit d7a808c

Please sign in to comment.