diff --git a/ballista/core/src/serde/scheduler/from_proto.rs b/ballista/core/src/serde/scheduler/from_proto.rs index 17875e2b1..545896d8b 100644 --- a/ballista/core/src/serde/scheduler/from_proto.rs +++ b/ballista/core/src/serde/scheduler/from_proto.rs @@ -16,10 +16,15 @@ // under the License. use chrono::{TimeZone, Utc}; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; use datafusion::physical_plan::metrics::{ Count, Gauge, MetricValue, MetricsSet, Time, Timestamp, }; -use datafusion::physical_plan::Metric; +use datafusion::physical_plan::{ExecutionPlan, Metric}; +use datafusion_proto::logical_plan::AsLogicalPlan; +use datafusion_proto::physical_plan::AsExecutionPlan; use std::collections::HashMap; use std::convert::TryInto; use std::sync::Arc; @@ -28,10 +33,10 @@ use std::time::Duration; use crate::error::BallistaError; use crate::serde::scheduler::{ Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId, - PartitionLocation, PartitionStats, TaskDefinition, + PartitionLocation, PartitionStats, SimpleFunctionRegistry, TaskDefinition, }; -use crate::serde::protobuf; +use crate::serde::{protobuf, BallistaCodec}; use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime}; impl TryInto for protobuf::Action { @@ -269,67 +274,138 @@ impl Into for protobuf::ExecutorData { } } -impl TryInto<(TaskDefinition, Vec)> for protobuf::TaskDefinition { - type Error = BallistaError; - - fn try_into(self) -> Result<(TaskDefinition, Vec), Self::Error> { - let mut props = HashMap::new(); - for kv_pair in self.props { - props.insert(kv_pair.key, kv_pair.value); - } +pub fn get_task_definition( + task: protobuf::TaskDefinition, + runtime: Arc, + scalar_functions: HashMap>, + aggregate_functions: HashMap>, + codec: BallistaCodec, +) -> Result { + let mut props = HashMap::new(); + for kv_pair in task.props { + props.insert(kv_pair.key, kv_pair.value); + } + let props = Arc::new(props); - Ok(( - TaskDefinition { - task_id: self.task_id as usize, - task_attempt_num: self.task_attempt_num as usize, - job_id: self.job_id, - stage_id: self.stage_id as usize, - stage_attempt_num: self.stage_attempt_num as usize, - partition_id: self.partition_id as usize, - plan: vec![], - session_id: self.session_id, - launch_time: self.launch_time, - props, - }, - self.plan, - )) + let mut task_scalar_functions = HashMap::new(); + let mut task_aggregate_functions = HashMap::new(); + // TODO combine the functions from Executor's functions and TaskDefinition's function resources + for scalar_func in scalar_functions { + task_scalar_functions.insert(scalar_func.0, scalar_func.1); } -} + for agg_func in aggregate_functions { + task_aggregate_functions.insert(agg_func.0, agg_func.1); + } + let function_registry = Arc::new(SimpleFunctionRegistry { + scalar_functions: task_scalar_functions, + aggregate_functions: task_aggregate_functions, + }); -impl TryInto<(Vec, Vec)> for protobuf::MultiTaskDefinition { - type Error = BallistaError; + let encoded_plan = task.plan.as_slice(); + let plan: Arc = U::try_decode(encoded_plan).and_then(|proto| { + proto.try_into_physical_plan( + function_registry.as_ref(), + runtime.as_ref(), + codec.physical_extension_codec(), + ) + })?; - fn try_into(self) -> Result<(Vec, Vec), Self::Error> { - let mut props = HashMap::new(); - for kv_pair in self.props { - props.insert(kv_pair.key, kv_pair.value); - } + let job_id = task.job_id; + let stage_id = task.stage_id as usize; + let partition_id = task.partition_id as usize; + let task_attempt_num = task.task_attempt_num as usize; + let stage_attempt_num = task.stage_attempt_num as usize; + let launch_time = task.launch_time; + let task_id = task.task_id as usize; + let session_id = task.session_id; - let plan = self.plan; - let session_id = self.session_id; - let job_id = self.job_id; - let stage_id = self.stage_id as usize; - let stage_attempt_num = self.stage_attempt_num as usize; - let launch_time = self.launch_time; - let task_ids = self.task_ids; + Ok(TaskDefinition { + task_id, + task_attempt_num, + job_id, + stage_id, + stage_attempt_num, + partition_id, + plan, + launch_time, + session_id, + props, + function_registry, + }) +} - Ok(( - task_ids - .iter() - .map(|task_id| TaskDefinition { - task_id: task_id.task_id as usize, - task_attempt_num: task_id.task_attempt_num as usize, - job_id: job_id.clone(), - stage_id, - stage_attempt_num, - partition_id: task_id.partition_id as usize, - plan: vec![], - session_id: session_id.clone(), - launch_time, - props: props.clone(), - }) - .collect(), - plan, - )) +pub fn get_task_definition_vec< + T: 'static + AsLogicalPlan, + U: 'static + AsExecutionPlan, +>( + multi_task: protobuf::MultiTaskDefinition, + runtime: Arc, + scalar_functions: HashMap>, + aggregate_functions: HashMap>, + codec: BallistaCodec, +) -> Result, BallistaError> { + let mut props = HashMap::new(); + for kv_pair in multi_task.props { + props.insert(kv_pair.key, kv_pair.value); } + let props = Arc::new(props); + + let mut task_scalar_functions = HashMap::new(); + let mut task_aggregate_functions = HashMap::new(); + // TODO combine the functions from Executor's functions and TaskDefinition's function resources + for scalar_func in scalar_functions { + task_scalar_functions.insert(scalar_func.0, scalar_func.1); + } + for agg_func in aggregate_functions { + task_aggregate_functions.insert(agg_func.0, agg_func.1); + } + let function_registry = Arc::new(SimpleFunctionRegistry { + scalar_functions: task_scalar_functions, + aggregate_functions: task_aggregate_functions, + }); + + let encoded_plan = multi_task.plan.as_slice(); + let plan: Arc = U::try_decode(encoded_plan).and_then(|proto| { + proto.try_into_physical_plan( + function_registry.as_ref(), + runtime.as_ref(), + codec.physical_extension_codec(), + ) + })?; + + let job_id = multi_task.job_id; + let stage_id = multi_task.stage_id as usize; + let stage_attempt_num = multi_task.stage_attempt_num as usize; + let launch_time = multi_task.launch_time; + let task_ids = multi_task.task_ids; + let session_id = multi_task.session_id; + + task_ids + .iter() + .map(|task_id| { + Ok(TaskDefinition { + task_id: task_id.task_id as usize, + task_attempt_num: task_id.task_attempt_num as usize, + job_id: job_id.clone(), + stage_id, + stage_attempt_num, + partition_id: task_id.partition_id as usize, + plan: reset_metrics_for_execution_plan(plan.clone())?, + launch_time, + session_id: session_id.clone(), + props: props.clone(), + function_registry: function_registry.clone(), + }) + }) + .collect() +} + +fn reset_metrics_for_execution_plan( + plan: Arc, +) -> Result, BallistaError> { + plan.transform(&|plan| { + let children = plan.children().clone(); + plan.with_new_children(children).map(Transformed::Yes) + }) + .map_err(BallistaError::DataFusionError) } diff --git a/ballista/core/src/serde/scheduler/mod.rs b/ballista/core/src/serde/scheduler/mod.rs index 6e9440a36..96c4e0fa7 100644 --- a/ballista/core/src/serde/scheduler/mod.rs +++ b/ballista/core/src/serde/scheduler/mod.rs @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; +use std::fmt::Debug; use std::{collections::HashMap, fmt, sync::Arc}; use datafusion::arrow::array::{ ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder, }; use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::DataFusionError; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::Partitioning; use serde::Serialize; @@ -271,7 +276,7 @@ impl ExecutePartitionResult { } } -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct TaskDefinition { pub task_id: usize, pub task_attempt_num: usize, @@ -279,8 +284,41 @@ pub struct TaskDefinition { pub stage_id: usize, pub stage_attempt_num: usize, pub partition_id: usize, - pub plan: Vec, - pub session_id: String, + pub plan: Arc, pub launch_time: u64, - pub props: HashMap, + pub session_id: String, + pub props: Arc>, + pub function_registry: Arc, +} + +#[derive(Debug)] +pub struct SimpleFunctionRegistry { + pub scalar_functions: HashMap>, + pub aggregate_functions: HashMap>, +} + +impl FunctionRegistry for SimpleFunctionRegistry { + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } + + fn udf(&self, name: &str) -> datafusion::common::Result> { + let result = self.scalar_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDF named \"{name}\" in the TaskContext" + )) + }) + } + + fn udaf(&self, name: &str) -> datafusion::common::Result> { + let result = self.aggregate_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDAF named \"{name}\" in the TaskContext" + )) + }) + } } diff --git a/ballista/core/src/serde/scheduler/to_proto.rs b/ballista/core/src/serde/scheduler/to_proto.rs index ccb5ec427..6ceb1dd6e 100644 --- a/ballista/core/src/serde/scheduler/to_proto.rs +++ b/ballista/core/src/serde/scheduler/to_proto.rs @@ -26,12 +26,10 @@ use datafusion_proto::protobuf as datafusion_protobuf; use crate::serde::scheduler::{ Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId, - PartitionLocation, PartitionStats, TaskDefinition, + PartitionLocation, PartitionStats, }; use datafusion::physical_plan::Partitioning; -use protobuf::{ - action::ActionType, operator_metric, KeyValuePair, NamedCount, NamedGauge, NamedTime, -}; +use protobuf::{action::ActionType, operator_metric, NamedCount, NamedGauge, NamedTime}; impl TryInto for Action { type Error = BallistaError; @@ -242,30 +240,3 @@ impl Into for ExecutorData { } } } - -#[allow(clippy::from_over_into)] -impl Into for TaskDefinition { - fn into(self) -> protobuf::TaskDefinition { - let props = self - .props - .iter() - .map(|(k, v)| KeyValuePair { - key: k.to_owned(), - value: v.to_owned(), - }) - .collect::>(); - - protobuf::TaskDefinition { - task_id: self.task_id as u32, - task_attempt_num: self.task_attempt_num as u32, - job_id: self.job_id, - stage_id: self.stage_id as u32, - stage_attempt_num: self.stage_attempt_num as u32, - partition_id: self.partition_id as u32, - plan: self.plan, - session_id: self.session_id, - launch_time: self.launch_time, - props, - } - } -} diff --git a/ballista/executor/src/execution_engine.rs b/ballista/executor/src/execution_engine.rs index 965153298..5121f016b 100644 --- a/ballista/executor/src/execution_engine.rs +++ b/ballista/executor/src/execution_engine.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::SchemaRef; use async_trait::async_trait; use ballista_core::execution_plans::ShuffleWriterExec; use ballista_core::serde::protobuf::ShuffleWritePartition; @@ -52,8 +51,6 @@ pub trait QueryStageExecutor: Sync + Send + Debug { ) -> Result>; fn collect_plan_metrics(&self) -> Vec; - - fn schema(&self) -> SchemaRef; } pub struct DefaultExecutionEngine {} @@ -111,10 +108,6 @@ impl QueryStageExecutor for DefaultQueryStageExec { .await } - fn schema(&self) -> SchemaRef { - self.shuffle_writer.schema() - } - fn collect_plan_metrics(&self) -> Vec { utils::collect_plan_metrics(&self.shuffle_writer) } diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs index 9102923f6..2892cb0b1 100644 --- a/ballista/executor/src/executor_server.rs +++ b/ballista/executor/src/executor_server.rs @@ -16,11 +16,8 @@ // under the License. use ballista_core::BALLISTA_VERSION; -use datafusion::config::ConfigOptions; -use datafusion::prelude::SessionConfig; use std::collections::HashMap; use std::convert::TryInto; -use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -41,18 +38,22 @@ use ballista_core::serde::protobuf::{ LaunchTaskResult, RegisterExecutorParams, RemoveJobDataParams, RemoveJobDataResult, StopExecutorParams, StopExecutorResult, TaskStatus, UpdateTaskStatusParams, }; +use ballista_core::serde::scheduler::from_proto::{ + get_task_definition, get_task_definition_vec, +}; use ballista_core::serde::scheduler::PartitionId; use ballista_core::serde::scheduler::TaskDefinition; use ballista_core::serde::BallistaCodec; use ballista_core::utils::{create_grpc_client_connection, create_grpc_server}; use dashmap::DashMap; -use datafusion::execution::context::TaskContext; +use datafusion::config::ConfigOptions; +use datafusion::execution::TaskContext; +use datafusion::prelude::SessionConfig; use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan}; use tokio::sync::mpsc::error::TryRecvError; use tokio::task::JoinHandle; use crate::cpu_bound_executor::DedicatedExecutor; -use crate::execution_engine::QueryStageExecutor; use crate::executor::Executor; use crate::executor_process::ExecutorProcessConfig; use crate::shutdown::ShutdownNotifier; @@ -65,8 +66,7 @@ type SchedulerClients = Arc>>; #[derive(Debug)] struct CuratorTaskDefinition { scheduler_id: String, - plan: Vec, - tasks: Vec, + task: TaskDefinition, } /// Wrap TaskStatus with its curator scheduler id for task update to its specific curator scheduler later @@ -298,100 +298,22 @@ impl ExecutorServer Result, BallistaError> { - let task = curator_task; - let task_identity = task_identity(&task); - let task_props = task.props; - let mut config = ConfigOptions::new(); - for (k, v) in task_props { - config.set(&k, &v)?; - } - let session_config = SessionConfig::from(config); - - let mut task_scalar_functions = HashMap::new(); - let mut task_aggregate_functions = HashMap::new(); - for scalar_func in self.executor.scalar_functions.clone() { - task_scalar_functions.insert(scalar_func.0, scalar_func.1); - } - for agg_func in self.executor.aggregate_functions.clone() { - task_aggregate_functions.insert(agg_func.0, agg_func.1); - } - - let task_context = Arc::new(TaskContext::new( - Some(task_identity), - task.session_id.clone(), - session_config, - task_scalar_functions, - task_aggregate_functions, - self.executor.runtime.clone(), - )); - - let plan = U::try_decode(plan).and_then(|proto| { - proto.try_into_physical_plan( - task_context.deref(), - &self.executor.runtime, - self.codec.physical_extension_codec(), - ) - })?; - - Ok(self.executor.execution_engine.create_query_stage_exec( - task.job_id, - task.stage_id, - plan, - &self.executor.work_dir, - )?) - } - - async fn run_task( - &self, - task_identity: &str, - scheduler_id: String, - curator_task: TaskDefinition, - query_stage_exec: Arc, - ) -> Result<(), BallistaError> { + /// This method should not return Err. If task fails, a failure task status should be sent + /// to the channel to notify the scheduler. + async fn run_task(&self, task_identity: String, curator_task: CuratorTaskDefinition) { let start_exec_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_millis() as u64; info!("Start to run task {}", task_identity); - let task = curator_task; - let task_props = task.props; - let mut config = ConfigOptions::new(); - for (k, v) in task_props { - config.set(&k, &v)?; - } - let session_config = SessionConfig::from(config); - - let mut task_scalar_functions = HashMap::new(); - let mut task_aggregate_functions = HashMap::new(); - // TODO combine the functions from Executor's functions and TaskDefintion's function resources - for scalar_func in self.executor.scalar_functions.clone() { - task_scalar_functions.insert(scalar_func.0, scalar_func.1); - } - for agg_func in self.executor.aggregate_functions.clone() { - task_aggregate_functions.insert(agg_func.0, agg_func.1); - } - - let session_id = task.session_id; - let runtime = self.executor.runtime.clone(); - let task_context = Arc::new(TaskContext::new( - Some(task_identity.to_string()), - session_id, - session_config, - task_scalar_functions, - task_aggregate_functions, - runtime.clone(), - )); + let task = curator_task.task; let task_id = task.task_id; let job_id = task.job_id; let stage_id = task.stage_id; let stage_attempt_num = task.stage_attempt_num; let partition_id = task.partition_id; + let plan = task.plan; let part = PartitionId { job_id: job_id.clone(), @@ -399,6 +321,40 @@ impl ExecutorServer ExecutorServer, BallistaError>>()?; + .collect::, BallistaError>>() + { + Ok(metrics) => Some(metrics), + Err(_) => None, + }; let executor_id = &self.executor.metadata.id; let end_exec_time = SystemTime::now() @@ -436,10 +396,11 @@ impl ExecutorServer ExecutorServer executor_server: Arc>, } -fn task_identity(task: &TaskDefinition) -> String { - format!( - "TID {} {}/{}.{}/{}.{}", - &task.task_id, - &task.job_id, - &task.stage_id, - &task.stage_attempt_num, - &task.partition_id, - &task.task_attempt_num, - ) -} - impl TaskRunnerPool { fn new(executor_server: Arc>) -> Self { Self { executor_server } @@ -638,64 +586,22 @@ impl TaskRunnerPool, BallistaError>, - > = dedicated_executor.spawn(async move { - server.decode_task(curator_task, &plan).await + dedicated_executor.spawn(async move { + server.run_task(task_identity.clone(), curator_task).await; }); - - let plan = out.await; - - let plan = match plan { - Ok(Ok(plan)) => plan, - Ok(Err(e)) => { - error!( - "Failed to decode the plan of task {:?} due to {:?}", - task_identity(&task.tasks[0]), - e - ); - return; - } - Err(e) => { - error!( - "Failed to receive error plan of task {:?} due to {:?}", - task_identity(&task.tasks[0]), - e - ); - return; - } - }; - let scheduler_id = task.scheduler_id.clone(); - - for curator_task in task.tasks { - let plan = plan.clone(); - let scheduler_id = scheduler_id.clone(); - - let task_identity = task_identity(&curator_task); - info!("Received task {:?}", &task_identity); - - let server = executor_server.clone(); - dedicated_executor.spawn(async move { - server - .run_task( - &task_identity, - scheduler_id, - curator_task, - plan, - ) - .await - .unwrap_or_else(|e| { - error!( - "Fail to run the task {:?} due to {:?}", - task_identity, e - ); - }); - }); - } } else { info!("Channel is closed and will exit the task receive loop"); drop(task_runner_complete); @@ -720,15 +626,17 @@ impl ExecutorGrpc } = request.into_inner(); let task_sender = self.executor_env.tx_task.clone(); for task in tasks { - let (task_def, plan) = task - .try_into() - .map_err(|e| Status::invalid_argument(format!("{e}")))?; - task_sender .send(CuratorTaskDefinition { scheduler_id: scheduler_id.clone(), - plan, - tasks: vec![task_def], + task: get_task_definition( + task, + self.executor.runtime.clone(), + self.executor.scalar_functions.clone(), + self.executor.aggregate_functions.clone(), + self.codec.clone(), + ) + .map_err(|e| Status::invalid_argument(format!("{e}")))?, }) .await .unwrap(); @@ -748,17 +656,23 @@ impl ExecutorGrpc } = request.into_inner(); let task_sender = self.executor_env.tx_task.clone(); for multi_task in multi_tasks { - let (multi_task, plan): (Vec, Vec) = multi_task - .try_into() - .map_err(|e| Status::invalid_argument(format!("{e}")))?; - task_sender - .send(CuratorTaskDefinition { - scheduler_id: scheduler_id.clone(), - plan, - tasks: multi_task, - }) - .await - .unwrap(); + let multi_task: Vec = get_task_definition_vec( + multi_task, + self.executor.runtime.clone(), + self.executor.scalar_functions.clone(), + self.executor.aggregate_functions.clone(), + self.codec.clone(), + ) + .map_err(|e| Status::invalid_argument(format!("{e}")))?; + for task in multi_task { + task_sender + .send(CuratorTaskDefinition { + scheduler_id: scheduler_id.clone(), + task, + }) + .await + .unwrap(); + } } Ok(Response::new(LaunchMultiTaskResult { success: true })) }