Skip to content

Commit

Permalink
Cache encoded stage plan (#393)
Browse files Browse the repository at this point in the history
* Cache encoded stage plan

Co-authored-by: yangzhong <[email protected]>
  • Loading branch information
yahoNanJing and kyotoYaho authored Oct 19, 2022
1 parent 5121053 commit f888506
Showing 1 changed file with 86 additions and 35 deletions.
121 changes: 86 additions & 35 deletions ballista/scheduler/src/state/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use crate::state::execution_graph::{
use crate::state::executor_manager::{ExecutorManager, ExecutorReservation};
use crate::state::{decode_protobuf, encode_protobuf, with_lock, with_locks};
use ballista_core::config::BallistaConfig;
#[cfg(not(test))]
use ballista_core::error::BallistaError;
use ballista_core::error::Result;

Expand All @@ -47,7 +46,7 @@ use std::sync::Arc;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
type ExecutionGraphCache = Arc<DashMap<String, Arc<RwLock<ExecutionGraph>>>>;
type ActiveJobCache = Arc<DashMap<String, JobInfoCache>>;

// TODO move to configuration file
/// Default max failure attempts for task level retry
Expand All @@ -61,8 +60,25 @@ pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
session_builder: SessionBuilder,
codec: BallistaCodec<T, U>,
scheduler_id: String,
// Cache for active jobs curated by this scheduler
active_job_cache: ActiveJobCache,
}

#[derive(Clone)]
struct JobInfoCache {
// Cache for active execution graphs curated by this scheduler
active_job_cache: ExecutionGraphCache,
execution_graph: Arc<RwLock<ExecutionGraph>>,
// Cache for encoded execution stage plan to avoid duplicated encoding for multiple tasks
encoded_stage_plans: HashMap<usize, Vec<u8>>,
}

impl JobInfoCache {
fn new(graph: ExecutionGraph) -> Self {
Self {
execution_graph: Arc::new(RwLock::new(graph)),
encoded_stage_plans: HashMap::new(),
}
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -113,7 +129,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>

graph.revive();
self.active_job_cache
.insert(job_id.to_owned(), Arc::new(RwLock::new(graph)));
.insert(job_id.to_owned(), JobInfoCache::new(graph));

Ok(())
}
Expand Down Expand Up @@ -266,8 +282,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
let mut pending_tasks = 0usize;
let mut assign_tasks = 0usize;
for pairs in self.active_job_cache.iter() {
let (_job_id, graph) = pairs.pair();
let mut graph = graph.write().await;
let (_job_id, job_info) = pairs.pair();
let mut graph = job_info.execution_graph.write().await;
for reservation in free_reservations.iter().skip(assign_tasks) {
if let Some(task) = graph.pop_next_task(&reservation.executor_id)? {
assignments.push((reservation.executor_id.clone(), task));
Expand Down Expand Up @@ -476,8 +492,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
let updated_graphs: DashMap<String, ExecutionGraph> = DashMap::new();
{
for pairs in self.active_job_cache.iter() {
let (job_id, graph) = pairs.pair();
let mut graph = graph.write().await;
let (job_id, job_info) = pairs.pair();
let mut graph = job_info.execution_graph.write().await;
let reset = graph.reset_stages_on_lost_executor(executor_id)?;
if !reset.0.is_empty() {
updated_graphs.insert(job_id.to_owned(), graph.clone());
Expand Down Expand Up @@ -557,39 +573,74 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
task: TaskDescription,
) -> Result<TaskDefinition> {
debug!("Preparing task definition for {:?}", task);
let mut plan_buf: Vec<u8> = vec![];
let plan_proto =
U::try_from_physical_plan(task.plan, self.codec.physical_extension_codec())?;
plan_proto.try_encode(&mut plan_buf)?;

let output_partitioning =
hash_partitioning_to_proto(task.output_partitioning.as_ref())?;

let task_definition = TaskDefinition {
task_id: task.task_id as u32,
task_attempt_num: task.task_attempt as u32,
job_id: task.partition.job_id.clone(),
stage_id: task.partition.stage_id as u32,
stage_attempt_num: task.stage_attempt_num as u32,
partition_id: task.partition.partition_id as u32,
plan: plan_buf,
output_partitioning,
session_id: task.session_id,
launch_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
props: vec![],
};
Ok(task_definition)

let job_id = task.partition.job_id.clone();
let stage_id = task.partition.stage_id;

if let Some(mut job_info) = self.active_job_cache.get_mut(&job_id) {
let plan = if let Some(plan) = job_info.encoded_stage_plans.get(&stage_id) {
plan.clone()
} else {
let mut plan_buf: Vec<u8> = vec![];
let plan_proto = U::try_from_physical_plan(
task.plan,
self.codec.physical_extension_codec(),
)?;
plan_proto.try_encode(&mut plan_buf)?;

job_info
.encoded_stage_plans
.insert(stage_id, plan_buf.clone());

plan_buf
};

let output_partitioning =
hash_partitioning_to_proto(task.output_partitioning.as_ref())?;

let task_definition = TaskDefinition {
task_id: task.task_id as u32,
task_attempt_num: task.task_attempt as u32,
job_id,
stage_id: stage_id as u32,
stage_attempt_num: task.stage_attempt_num as u32,
partition_id: task.partition.partition_id as u32,
plan,
output_partitioning,
session_id: task.session_id,
launch_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
props: vec![],
};
Ok(task_definition)
} else {
Err(BallistaError::General(format!(
"Cannot prepare task definition for job {} which is not in active cache",
job_id
)))
}
}

/// Get the `ExecutionGraph` for the given job ID from cache
pub(crate) async fn get_active_execution_graph(
&self,
job_id: &str,
) -> Option<Arc<RwLock<ExecutionGraph>>> {
self.active_job_cache.get(job_id).map(|value| value.clone())
self.active_job_cache
.get(job_id)
.map(|value| value.execution_graph.clone())
}

/// Remove the `ExecutionGraph` for the given job ID from cache
pub(crate) async fn remove_active_execution_graph(
&self,
job_id: &str,
) -> Option<Arc<RwLock<ExecutionGraph>>> {
self.active_job_cache
.remove(job_id)
.map(|value| value.1.execution_graph)
}

/// Remove the `ExecutionGraph` for the given job ID from cache
Expand Down Expand Up @@ -658,7 +709,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>

async fn clean_up_job_data(
state: Arc<dyn StateBackendClient>,
active_job_cache: ExecutionGraphCache,
active_job_cache: ActiveJobCache,
failed: bool,
job_id: String,
) -> Result<()> {
Expand Down

0 comments on commit f888506

Please sign in to comment.