diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 52e46d6e8..cd6081a98 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -402,7 +402,10 @@ jobs: # # ignore ./Cargo.toml because putting workspaces in multi-line lists make it easy to read ci/scripts/rust_toml_fmt.sh - git diff --exit-code + if test -f "./Cargo.toml.bak"; then + echo "cargo tomlfmt found format violations" + exit 1 + fi env: CARGO_HOME: "/github/home/.cargo" CARGO_TARGET_DIR: "/github/home/target" diff --git a/ballista/core/proto/ballista.proto b/ballista/core/proto/ballista.proto index 0f122b850..c169791d0 100644 --- a/ballista/core/proto/ballista.proto +++ b/ballista/core/proto/ballista.proto @@ -289,6 +289,7 @@ message ExecutorMetadata { ExecutorSpecification specification = 5; } + // Used by grpc message ExecutorRegistration { string id = 1; @@ -336,6 +337,15 @@ message ExecutorResource { } } +message AvailableTaskSlots { + string executor_id = 1; + uint32 slots = 2; + } + +message ExecutorTaskSlots { + repeated AvailableTaskSlots task_slots = 1; +} + message ExecutorData { string executor_id = 1; repeated ExecutorResourcePair resources = 2; @@ -544,18 +554,33 @@ message GetJobStatusParams { message SuccessfulJob { repeated PartitionLocation partition_location = 1; + uint64 queued_at = 2; + uint64 started_at = 3; + uint64 ended_at = 4; } -message QueuedJob {} +message QueuedJob { + uint64 queued_at = 1; +} // TODO: add progress report -message RunningJob {} +message RunningJob { + uint64 queued_at = 1; + uint64 started_at = 2; + string scheduler = 3; +} message FailedJob { string error = 1; + uint64 queued_at = 2; + uint64 started_at = 3; + uint64 ended_at = 4; } message JobStatus { + string job_id = 5; + string job_name = 6; + oneof status { QueuedJob queued = 1; RunningJob running = 2; diff --git a/ballista/core/src/serde/generated/ballista.rs b/ballista/core/src/serde/generated/ballista.rs index c23b43e5a..28236ad04 100644 --- a/ballista/core/src/serde/generated/ballista.rs +++ b/ballista/core/src/serde/generated/ballista.rs @@ -581,6 +581,20 @@ pub mod executor_resource { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct AvailableTaskSlots { + #[prost(string, tag = "1")] + pub executor_id: ::prost::alloc::string::String, + #[prost(uint32, tag = "2")] + pub slots: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ExecutorTaskSlots { + #[prost(message, repeated, tag = "1")] + pub task_slots: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ExecutorData { #[prost(string, tag = "1")] pub executor_id: ::prost::alloc::string::String, @@ -933,23 +947,49 @@ pub struct GetJobStatusParams { pub struct SuccessfulJob { #[prost(message, repeated, tag = "1")] pub partition_location: ::prost::alloc::vec::Vec, + #[prost(uint64, tag = "2")] + pub queued_at: u64, + #[prost(uint64, tag = "3")] + pub started_at: u64, + #[prost(uint64, tag = "4")] + pub ended_at: u64, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct QueuedJob {} +pub struct QueuedJob { + #[prost(uint64, tag = "1")] + pub queued_at: u64, +} /// TODO: add progress report #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct RunningJob {} +pub struct RunningJob { + #[prost(uint64, tag = "1")] + pub queued_at: u64, + #[prost(uint64, tag = "2")] + pub started_at: u64, + #[prost(string, tag = "3")] + pub scheduler: ::prost::alloc::string::String, +} #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FailedJob { #[prost(string, tag = "1")] pub error: ::prost::alloc::string::String, + #[prost(uint64, tag = "2")] + pub queued_at: u64, + #[prost(uint64, tag = "3")] + pub started_at: u64, + #[prost(uint64, tag = "4")] + pub ended_at: u64, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct JobStatus { + #[prost(string, tag = "5")] + pub job_id: ::prost::alloc::string::String, + #[prost(string, tag = "6")] + pub job_name: ::prost::alloc::string::String, #[prost(oneof = "job_status::Status", tags = "1, 2, 3, 4")] pub status: ::core::option::Option, } diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 1b2770182..dd4dc162d 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -26,11 +26,13 @@ use datafusion::execution::FunctionRegistry; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; use datafusion_proto::common::proto_error; use datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning; +use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use datafusion_proto::{ convert_required, logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec}, physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}, }; + use prost::Message; use std::fmt::Debug; use std::marker::PhantomData; @@ -69,16 +71,17 @@ pub fn decode_protobuf(bytes: &[u8]) -> Result { } #[derive(Clone, Debug)] -pub struct BallistaCodec { +pub struct BallistaCodec< + T: 'static + AsLogicalPlan = LogicalPlanNode, + U: 'static + AsExecutionPlan = PhysicalPlanNode, +> { logical_extension_codec: Arc, physical_extension_codec: Arc, logical_plan_repr: PhantomData, physical_plan_repr: PhantomData, } -impl Default - for BallistaCodec -{ +impl Default for BallistaCodec { fn default() -> Self { Self { logical_extension_codec: Arc::new(DefaultLogicalExtensionCodec {}), diff --git a/ballista/scheduler/scheduler_config_spec.toml b/ballista/scheduler/scheduler_config_spec.toml index a03977711..f6575b437 100644 --- a/ballista/scheduler/scheduler_config_spec.toml +++ b/ballista/scheduler/scheduler_config_spec.toml @@ -31,17 +31,10 @@ doc = "Route for proxying flight results via scheduler. Should be of the form 'I [[param]] abbr = "b" -name = "config_backend" -type = "ballista_scheduler::state::backend::StateBackend" -doc = "The configuration backend for the scheduler, possible values: etcd, memory, sled. Default: sled" -default = "ballista_scheduler::state::backend::StateBackend::Sled" - -[[param]] -abbr = "c" name = "cluster_backend" -type = "ballista_scheduler::state::backend::StateBackend" +type = "ballista_scheduler::cluster::ClusterStorage" doc = "The configuration backend for the scheduler cluster state, possible values: etcd, memory, sled. Default: sled" -default = "ballista_scheduler::state::backend::StateBackend::Sled" +default = "ballista_scheduler::cluster::ClusterStorage::Sled" [[param]] abbr = "n" diff --git a/ballista/scheduler/src/api/handlers.rs b/ballista/scheduler/src/api/handlers.rs index 63ae2436b..94b2c90d3 100644 --- a/ballista/scheduler/src/api/handlers.rs +++ b/ballista/scheduler/src/api/handlers.rs @@ -209,6 +209,7 @@ pub(crate) async fn get_query_stages( { Ok(warp::reply::json(&QueryStagesResponse { stages: graph + .as_ref() .stages() .iter() .map(|(id, stage)| { @@ -303,7 +304,7 @@ pub(crate) async fn get_job_dot_graph( .await .map_err(|_| warp::reject())? { - ExecutionGraphDot::generate(graph).map_err(|_| warp::reject()) + ExecutionGraphDot::generate(graph.as_ref()).map_err(|_| warp::reject()) } else { Ok("Not Found".to_string()) } @@ -322,7 +323,7 @@ pub(crate) async fn get_query_stage_dot_graph Result<()> { // parse options @@ -61,25 +55,14 @@ async fn main() -> Result<()> { std::process::exit(0); } - let config_backend = init_kv_backend(&opt.config_backend, &opt).await?; - - let cluster_state = if opt.cluster_backend == opt.config_backend { - Arc::new(DefaultClusterState::new(config_backend.clone())) - } else { - let cluster_kv_store = init_kv_backend(&opt.cluster_backend, &opt).await?; - - Arc::new(DefaultClusterState::new(cluster_kv_store)) - }; - let special_mod_log_level = opt.log_level_setting; - let namespace = opt.namespace; - let external_host = opt.external_host; - let bind_host = opt.bind_host; - let port = opt.bind_port; let log_dir = opt.log_dir; let print_thread_info = opt.print_thread_info; - let log_file_name_prefix = format!("scheduler_{namespace}_{external_host}_{port}"); - let scheduler_name = format!("{external_host}:{port}"); + + let log_file_name_prefix = format!( + "scheduler_{}_{}_{}", + opt.namespace, opt.external_host, opt.bind_port + ); let rust_log = env::var(EnvFilter::DEFAULT_ENV); let log_filter = EnvFilter::new(rust_log.unwrap_or(special_mod_log_level)); @@ -117,10 +100,13 @@ async fn main() -> Result<()> { .init(); } - let addr = format!("{bind_host}:{port}"); + let addr = format!("{}:{}", opt.bind_host, opt.bind_port); let addr = addr.parse()?; let config = SchedulerConfig { + namespace: opt.namespace, + external_host: opt.external_host, + bind_port: opt.bind_port, scheduling_policy: opt.scheduler_policy, event_loop_buffer_size: opt.event_loop_buffer_size, executor_slots_policy: opt.executor_slots_policy, @@ -129,54 +115,13 @@ async fn main() -> Result<()> { finished_job_state_clean_up_interval_seconds: opt .finished_job_state_clean_up_interval_seconds, advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, + cluster_storage: ClusterStorageConfig::Memory, job_resubmit_interval_ms: (opt.job_resubmit_interval_ms > 0) .then_some(opt.job_resubmit_interval_ms), }; - start_server(scheduler_name, config_backend, cluster_state, addr, config).await?; - Ok(()) -} -async fn init_kv_backend( - backend: &StateBackend, - opt: &Config, -) -> Result> { - let cluster_backend: Arc = match backend { - #[cfg(feature = "etcd")] - StateBackend::Etcd => { - let etcd = etcd_client::Client::connect(&[opt.etcd_urls.clone()], None) - .await - .context("Could not connect to etcd")?; - Arc::new(EtcdClient::new(opt.namespace.clone(), etcd)) - } - #[cfg(not(feature = "etcd"))] - StateBackend::Etcd => { - unimplemented!( - "build the scheduler with the `etcd` feature to use the etcd config backend" - ) - } - #[cfg(feature = "sled")] - StateBackend::Sled => { - if opt.sled_dir.is_empty() { - Arc::new( - SledClient::try_new_temporary() - .context("Could not create sled config backend")?, - ) - } else { - println!("{}", opt.sled_dir); - Arc::new( - SledClient::try_new(opt.sled_dir.clone()) - .context("Could not create sled config backend")?, - ) - } - } - #[cfg(not(feature = "sled"))] - StateBackend::Sled => { - unimplemented!( - "build the scheduler with the `sled` feature to use the sled config backend" - ) - } - StateBackend::Memory => Arc::new(MemoryBackendClient::new()), - }; + let cluster = BallistaCluster::new_from_config(&config).await?; - Ok(cluster_backend) + start_server(cluster, addr, config).await?; + Ok(()) } diff --git a/ballista/scheduler/src/cluster/event/mod.rs b/ballista/scheduler/src/cluster/event/mod.rs new file mode 100644 index 000000000..88b3c2972 --- /dev/null +++ b/ballista/scheduler/src/cluster/event/mod.rs @@ -0,0 +1,318 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use futures::Stream; +use log::debug; +use parking_lot::RwLock; +use std::collections::BTreeMap; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; +use tokio::sync::broadcast; +use tokio::sync::broadcast::error::TryRecvError; + +// TODO make configurable +const EVENT_BUFFER_SIZE: usize = 256; + +static ID_GEN: AtomicUsize = AtomicUsize::new(0); + +#[derive(Default)] +struct Shared { + subscriptions: AtomicUsize, + wakers: RwLock>, +} + +impl Shared { + pub fn register(&self, subscriber_id: usize, waker: Waker) { + self.wakers.write().insert(subscriber_id, waker); + } + + pub fn deregister(&self, subscriber_id: usize) { + self.wakers.write().remove(&subscriber_id); + } + + pub fn notify(&self) { + let guard = self.wakers.read(); + for waker in guard.values() { + waker.wake_by_ref(); + } + } +} + +pub(crate) struct ClusterEventSender { + sender: broadcast::Sender, + shared: Arc, +} + +impl ClusterEventSender { + pub fn new(capacity: usize) -> Self { + let (sender, _) = broadcast::channel(capacity); + + Self { + sender, + shared: Arc::new(Shared::default()), + } + } + + pub fn send(&self, event: &T) { + if self.shared.subscriptions.load(Ordering::Acquire) > 0 { + if let Err(e) = self.sender.send(event.clone()) { + debug!("Failed to send event to channel: {}", e); + return; + } + + self.shared.notify(); + } + } + + pub fn subscribe(&self) -> EventSubscriber { + self.shared.subscriptions.fetch_add(1, Ordering::AcqRel); + let id = ID_GEN.fetch_add(1, Ordering::AcqRel); + + EventSubscriber { + id, + receiver: self.sender.subscribe(), + shared: self.shared.clone(), + registered: false, + } + } + + #[cfg(test)] + pub fn registered_wakers(&self) -> usize { + self.shared.wakers.read().len() + } +} + +impl Default for ClusterEventSender { + fn default() -> Self { + Self::new(EVENT_BUFFER_SIZE) + } +} + +pub struct EventSubscriber { + id: usize, + receiver: broadcast::Receiver, + shared: Arc, + registered: bool, +} + +impl EventSubscriber { + pub fn register(&mut self, waker: Waker) { + if !self.registered { + self.shared.register(self.id, waker); + self.registered = true; + } + } +} + +impl Drop for EventSubscriber { + fn drop(&mut self) { + self.shared.deregister(self.id); + } +} + +impl Stream for EventSubscriber { + type Item = T; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match self.receiver.try_recv() { + Ok(event) => { + self.register(cx.waker().clone()); + return Poll::Ready(Some(event)); + } + Err(TryRecvError::Closed) => return Poll::Ready(None), + Err(TryRecvError::Lagged(n)) => { + debug!("Subscriber lagged by {} message", n); + self.register(cx.waker().clone()); + continue; + } + Err(TryRecvError::Empty) => { + self.register(cx.waker().clone()); + return Poll::Pending; + } + } + } + } +} + +#[cfg(test)] +mod test { + use crate::cluster::event::{ClusterEventSender, EventSubscriber}; + use futures::stream::FuturesUnordered; + use futures::StreamExt; + + async fn collect_events(mut rx: EventSubscriber) -> Vec { + let mut events = vec![]; + while let Some(event) = rx.next().await { + events.push(event); + } + + events + } + + #[tokio::test] + async fn test_event_subscription() { + let sender = ClusterEventSender::new(100); + + let rx = vec![sender.subscribe(), sender.subscribe(), sender.subscribe()]; + + let mut tasks: FuturesUnordered<_> = rx + .into_iter() + .map(|rx| async move { collect_events(rx).await }) + .collect(); + + let handle = tokio::spawn(async move { + let mut results = vec![]; + while let Some(result) = tasks.next().await { + results.push(result) + } + results + }); + + tokio::spawn(async move { + for i in 0..100 { + sender.send(&i); + } + }); + + let expected: Vec = (0..100).into_iter().collect(); + + let results = handle.await.unwrap(); + assert_eq!(results.len(), 3); + + for res in results { + assert_eq!(res, expected); + } + } + + #[tokio::test] + async fn test_event_lagged() { + // Created sender with a buffer for only 8 events + let sender = ClusterEventSender::new(8); + + let rx = vec![sender.subscribe(), sender.subscribe(), sender.subscribe()]; + + let mut tasks: FuturesUnordered<_> = rx + .into_iter() + .map(|rx| async move { collect_events(rx).await }) + .collect(); + + let handle = tokio::spawn(async move { + let mut results = vec![]; + while let Some(result) = tasks.next().await { + results.push(result) + } + results + }); + + // Send events faster than they can be consumed by subscribers + tokio::spawn(async move { + for i in 0..100 { + sender.send(&i); + } + }); + + // When we reach capacity older events should be dropped so we only see + // the last 8 events in our subscribers + let expected: Vec = (92..100).into_iter().collect(); + + let results = handle.await.unwrap(); + assert_eq!(results.len(), 3); + + for res in results { + assert_eq!(res, expected); + } + } + + #[tokio::test] + async fn test_event_skip_unsubscribed() { + let sender = ClusterEventSender::new(100); + + // There are no subscribers yet so this event should be ignored + sender.send(&0); + + let rx = vec![sender.subscribe(), sender.subscribe(), sender.subscribe()]; + + let mut tasks: FuturesUnordered<_> = rx + .into_iter() + .map(|rx| async move { collect_events(rx).await }) + .collect(); + + let handle = tokio::spawn(async move { + let mut results = vec![]; + while let Some(result) = tasks.next().await { + results.push(result) + } + results + }); + + tokio::spawn(async move { + for i in 1..=100 { + sender.send(&i); + } + }); + + let expected: Vec = (1..=100).into_iter().collect(); + + let results = handle.await.unwrap(); + assert_eq!(results.len(), 3); + + for res in results { + assert_eq!(res, expected); + } + } + + #[tokio::test] + async fn test_event_register_wakers() { + let sender = ClusterEventSender::new(100); + + let mut rx_1 = sender.subscribe(); + let mut rx_2 = sender.subscribe(); + let mut rx_3 = sender.subscribe(); + + sender.send(&0); + + // Subscribers haven't been polled yet so expect not registered wakers + assert_eq!(sender.registered_wakers(), 0); + + let event = rx_1.next().await; + assert_eq!(event, Some(0)); + assert_eq!(sender.registered_wakers(), 1); + + let event = rx_2.next().await; + assert_eq!(event, Some(0)); + assert_eq!(sender.registered_wakers(), 2); + + let event = rx_3.next().await; + assert_eq!(event, Some(0)); + assert_eq!(sender.registered_wakers(), 3); + + drop(rx_1); + assert_eq!(sender.registered_wakers(), 2); + + drop(rx_2); + assert_eq!(sender.registered_wakers(), 1); + + drop(rx_3); + assert_eq!(sender.registered_wakers(), 0); + } +} diff --git a/ballista/scheduler/src/cluster/kv.rs b/ballista/scheduler/src/cluster/kv.rs new file mode 100644 index 000000000..e28699778 --- /dev/null +++ b/ballista/scheduler/src/cluster/kv.rs @@ -0,0 +1,781 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cluster::storage::{KeyValueStore, Keyspace, Lock, Operation, WatchEvent}; +use crate::cluster::{ + reserve_slots_bias, reserve_slots_round_robin, ClusterState, ExecutorHeartbeatStream, + JobState, JobStateEvent, JobStateEventStream, JobStatus, TaskDistribution, +}; +use crate::scheduler_server::SessionBuilder; +use crate::state::execution_graph::ExecutionGraph; +use crate::state::executor_manager::ExecutorReservation; +use crate::state::session_manager::create_datafusion_context; +use crate::state::{decode_into, decode_protobuf}; +use async_trait::async_trait; +use ballista_core::config::BallistaConfig; +use ballista_core::error::{BallistaError, Result}; +use ballista_core::serde::protobuf::job_status::Status; +use ballista_core::serde::protobuf::{ + self, AvailableTaskSlots, ExecutorHeartbeat, ExecutorTaskSlots, FailedJob, + KeyValuePair, QueuedJob, +}; +use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; +use ballista_core::serde::BallistaCodec; +use dashmap::DashMap; +use datafusion::prelude::SessionContext; +use datafusion_proto::logical_plan::AsLogicalPlan; +use datafusion_proto::physical_plan::AsExecutionPlan; +use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; +use futures::StreamExt; +use itertools::Itertools; +use log::warn; +use prost::Message; +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// State implementation based on underlying `KeyValueStore` +pub struct KeyValueState< + S: KeyValueStore, + T: 'static + AsLogicalPlan = LogicalPlanNode, + U: 'static + AsExecutionPlan = PhysicalPlanNode, +> { + /// Underlying `KeyValueStore` + store: S, + /// Codec used to serialize/deserialize execution plan + codec: BallistaCodec, + /// Name of current scheduler. Should be `{host}:{port}` + #[allow(dead_code)] + scheduler: String, + /// In-memory store of queued jobs. Map from Job ID -> (Job Name, queued_at timestamp) + queued_jobs: DashMap, + //// `SessionBuilder` for constructing `SessionContext` from stored `BallistaConfig` + session_builder: SessionBuilder, +} + +impl + KeyValueState +{ + pub fn new( + scheduler: impl Into, + store: S, + codec: BallistaCodec, + session_builder: SessionBuilder, + ) -> Self { + Self { + store, + scheduler: scheduler.into(), + codec, + queued_jobs: DashMap::new(), + session_builder, + } + } +} + +#[async_trait] +impl + ClusterState for KeyValueState +{ + async fn reserve_slots( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option>, + ) -> Result> { + let lock = self.store.lock(Keyspace::Slots, "global").await?; + + with_lock(lock, async { + let resources = self.store.get(Keyspace::Slots, "all").await?; + + let mut slots = + ExecutorTaskSlots::decode(resources.as_slice()).map_err(|err| { + BallistaError::Internal(format!( + "Unexpected value in executor slots state: {err:?}" + )) + })?; + + let mut available_slots: Vec<&mut AvailableTaskSlots> = slots + .task_slots + .iter_mut() + .filter_map(|data| { + (data.slots > 0 + && executors + .as_ref() + .map(|executors| executors.contains(&data.executor_id)) + .unwrap_or(true)) + .then_some(data) + }) + .collect(); + + available_slots.sort_by(|a, b| Ord::cmp(&b.slots, &a.slots)); + + let reservations = match distribution { + TaskDistribution::Bias => reserve_slots_bias(available_slots, num_slots), + TaskDistribution::RoundRobin => { + reserve_slots_round_robin(available_slots, num_slots) + } + }; + + if !reservations.is_empty() { + self.store + .put(Keyspace::Slots, "all".to_owned(), slots.encode_to_vec()) + .await? + } + + Ok(reservations) + }) + .await + } + + async fn reserve_slots_exact( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option>, + ) -> Result> { + let lock = self.store.lock(Keyspace::Slots, "global").await?; + + with_lock(lock, async { + let resources = self.store.get(Keyspace::Slots, "all").await?; + + let mut slots = + ExecutorTaskSlots::decode(resources.as_slice()).map_err(|err| { + BallistaError::Internal(format!( + "Unexpected value in executor slots state: {err:?}" + )) + })?; + + let mut available_slots: Vec<&mut AvailableTaskSlots> = slots + .task_slots + .iter_mut() + .filter_map(|data| { + (data.slots > 0 + && executors + .as_ref() + .map(|executors| executors.contains(&data.executor_id)) + .unwrap_or(true)) + .then_some(data) + }) + .collect(); + + available_slots.sort_by(|a, b| Ord::cmp(&b.slots, &a.slots)); + + let reservations = match distribution { + TaskDistribution::Bias => reserve_slots_bias(available_slots, num_slots), + TaskDistribution::RoundRobin => { + reserve_slots_round_robin(available_slots, num_slots) + } + }; + + if reservations.len() == num_slots as usize { + self.store + .put(Keyspace::Slots, "all".to_owned(), slots.encode_to_vec()) + .await?; + Ok(reservations) + } else { + Ok(vec![]) + } + }) + .await + } + + async fn cancel_reservations( + &self, + reservations: Vec, + ) -> Result<()> { + let lock = self.store.lock(Keyspace::Slots, "all").await?; + + with_lock(lock, async { + let resources = self.store.get(Keyspace::Slots, "all").await?; + + let mut slots = + ExecutorTaskSlots::decode(resources.as_slice()).map_err(|err| { + BallistaError::Internal(format!( + "Unexpected value in executor slots state: {err:?}" + )) + })?; + + let mut increments = HashMap::new(); + for ExecutorReservation { executor_id, .. } in reservations { + if let Some(inc) = increments.get_mut(&executor_id) { + *inc += 1; + } else { + increments.insert(executor_id, 1usize); + } + } + + for executor_slots in slots.task_slots.iter_mut() { + if let Some(slots) = increments.get(&executor_slots.executor_id) { + executor_slots.slots += *slots as u32; + } + } + + self.store + .put(Keyspace::Slots, "all".to_string(), slots.encode_to_vec()) + .await + }) + .await + } + + async fn register_executor( + &self, + metadata: ExecutorMetadata, + spec: ExecutorData, + reserve: bool, + ) -> Result> { + let executor_id = metadata.id.clone(); + + let current_ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| { + BallistaError::Internal(format!("Error getting current timestamp: {e:?}")) + })? + .as_secs(); + + //TODO this should be in a transaction + // Now that we know we can connect, save the metadata and slots + self.save_executor_metadata(metadata).await?; + self.save_executor_heartbeat(ExecutorHeartbeat { + executor_id: executor_id.clone(), + timestamp: current_ts, + metrics: vec![], + status: Some(protobuf::ExecutorStatus { + status: Some(protobuf::executor_status::Status::Active("".to_string())), + }), + }) + .await?; + + if !reserve { + let available_slots = AvailableTaskSlots { + executor_id, + slots: spec.available_task_slots, + }; + + let lock = self.store.lock(Keyspace::Slots, "all").await?; + + with_lock(lock, async { + let current_slots = self.store.get(Keyspace::Slots, "all").await?; + + let mut current_slots: ExecutorTaskSlots = + decode_protobuf(current_slots.as_slice())?; + + if let Some((idx, _)) = + current_slots.task_slots.iter().find_position(|slots| { + slots.executor_id == available_slots.executor_id + }) + { + current_slots.task_slots[idx] = available_slots; + } else { + current_slots.task_slots.push(available_slots); + } + + self.store + .put( + Keyspace::Slots, + "all".to_string(), + current_slots.encode_to_vec(), + ) + .await + }) + .await?; + + Ok(vec![]) + } else { + let num_slots = spec.available_task_slots as usize; + let mut reservations: Vec = vec![]; + for _ in 0..num_slots { + reservations.push(ExecutorReservation::new_free(executor_id.clone())); + } + + let available_slots = AvailableTaskSlots { + executor_id, + slots: 0, + }; + + let lock = self.store.lock(Keyspace::Slots, "all").await?; + + with_lock(lock, async { + let current_slots = self.store.get(Keyspace::Slots, "all").await?; + + let mut current_slots: ExecutorTaskSlots = + decode_protobuf(current_slots.as_slice())?; + + if let Some((idx, _)) = + current_slots.task_slots.iter().find_position(|slots| { + slots.executor_id == available_slots.executor_id + }) + { + current_slots.task_slots[idx] = available_slots; + } else { + current_slots.task_slots.push(available_slots); + } + + self.store + .put( + Keyspace::Slots, + "all".to_string(), + current_slots.encode_to_vec(), + ) + .await + }) + .await?; + + Ok(reservations) + } + } + + async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()> { + let executor_id = metadata.id.clone(); + let proto: protobuf::ExecutorMetadata = metadata.into(); + + self.store + .put(Keyspace::Executors, executor_id, proto.encode_to_vec()) + .await + } + + async fn get_executor_metadata(&self, executor_id: &str) -> Result { + let value = self.store.get(Keyspace::Executors, executor_id).await?; + + let decoded = + decode_into::(&value)?; + Ok(decoded) + } + + async fn save_executor_heartbeat(&self, heartbeat: ExecutorHeartbeat) -> Result<()> { + let executor_id = heartbeat.executor_id.clone(); + self.store + .put(Keyspace::Heartbeats, executor_id, heartbeat.encode_to_vec()) + .await + } + + async fn remove_executor(&self, executor_id: &str) -> Result<()> { + let current_ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| { + BallistaError::Internal(format!("Error getting current timestamp: {e:?}")) + })? + .as_secs(); + + let value = ExecutorHeartbeat { + executor_id: executor_id.to_owned(), + timestamp: current_ts, + metrics: vec![], + status: Some(protobuf::ExecutorStatus { + status: Some(protobuf::executor_status::Status::Dead("".to_string())), + }), + } + .encode_to_vec(); + + self.store + .put(Keyspace::Heartbeats, executor_id.to_owned(), value) + .await?; + + // TODO Check the Executor reservation logic for push-based scheduling + + Ok(()) + } + + async fn executor_heartbeat_stream(&self) -> Result { + let events = self + .store + .watch(Keyspace::Heartbeats, String::default()) + .await?; + + Ok(events + .filter_map(|event| { + futures::future::ready(match event { + WatchEvent::Put(_, value) => { + if let Ok(heartbeat) = + decode_protobuf::(&value) + { + Some(heartbeat) + } else { + None + } + } + WatchEvent::Delete(_) => None, + }) + }) + .boxed()) + } + + async fn executor_heartbeats(&self) -> Result> { + let heartbeats = self.store.scan(Keyspace::Heartbeats, None).await?; + + let mut heartbeat_map = HashMap::with_capacity(heartbeats.len()); + + for (_, value) in heartbeats { + let data: ExecutorHeartbeat = decode_protobuf(&value)?; + if let Some(protobuf::ExecutorStatus { + status: Some(protobuf::executor_status::Status::Active(_)), + }) = &data.status + { + heartbeat_map.insert(data.executor_id.clone(), data); + } + } + + Ok(heartbeat_map) + } +} + +#[async_trait] +impl JobState + for KeyValueState +{ + async fn accept_job( + &self, + job_id: &str, + job_name: &str, + queued_at: u64, + ) -> Result<()> { + self.queued_jobs + .insert(job_id.to_string(), (job_name.to_string(), queued_at)); + + Ok(()) + } + + async fn submit_job(&self, job_id: String, graph: &ExecutionGraph) -> Result<()> { + if self.queued_jobs.get(&job_id).is_some() { + let status = graph.status(); + let encoded_graph = + ExecutionGraph::encode_execution_graph(graph.clone(), &self.codec)?; + + self.store + .apply_txn(vec![ + ( + Operation::Put(status.encode_to_vec()), + Keyspace::JobStatus, + job_id.clone(), + ), + ( + Operation::Put(encoded_graph.encode_to_vec()), + Keyspace::ExecutionGraph, + job_id.clone(), + ), + ]) + .await?; + + self.queued_jobs.remove(&job_id); + + Ok(()) + } else { + Err(BallistaError::Internal(format!( + "Failed to submit job {job_id}, job was not in queueud jobs" + ))) + } + } + + async fn get_jobs(&self) -> Result> { + self.store.scan_keys(Keyspace::JobStatus).await + } + + async fn get_job_status(&self, job_id: &str) -> Result> { + if let Some((job_name, queued_at)) = self.queued_jobs.get(job_id).as_deref() { + Ok(Some(JobStatus { + job_id: job_id.to_string(), + job_name: job_name.clone(), + status: Some(Status::Queued(QueuedJob { + queued_at: *queued_at, + })), + })) + } else { + let value = self.store.get(Keyspace::JobStatus, job_id).await?; + + (!value.is_empty()) + .then(|| decode_protobuf(value.as_slice())) + .transpose() + } + } + + async fn get_execution_graph(&self, job_id: &str) -> Result> { + let value = self.store.get(Keyspace::ExecutionGraph, job_id).await?; + + if value.is_empty() { + return Ok(None); + } + + let proto: protobuf::ExecutionGraph = decode_protobuf(value.as_slice())?; + + let session = self.get_session(&proto.session_id).await?; + + Ok(Some( + ExecutionGraph::decode_execution_graph(proto, &self.codec, session.as_ref()) + .await?, + )) + } + + async fn save_job(&self, job_id: &str, graph: &ExecutionGraph) -> Result<()> { + let status = graph.status(); + let encoded_graph = + ExecutionGraph::encode_execution_graph(graph.clone(), &self.codec)?; + + self.store + .apply_txn(vec![ + ( + Operation::Put(status.encode_to_vec()), + Keyspace::JobStatus, + job_id.to_string(), + ), + ( + Operation::Put(encoded_graph.encode_to_vec()), + Keyspace::ExecutionGraph, + job_id.to_string(), + ), + ]) + .await + } + + async fn fail_unscheduled_job(&self, job_id: &str, reason: String) -> Result<()> { + if let Some((job_id, (job_name, queued_at))) = self.queued_jobs.remove(job_id) { + let status = JobStatus { + job_id: job_id.clone(), + job_name, + status: Some(Status::Failed(FailedJob { + error: reason, + queued_at, + started_at: 0, + ended_at: 0, + })), + }; + + self.store + .put(Keyspace::JobStatus, job_id, status.encode_to_vec()) + .await + } else { + Err(BallistaError::Internal(format!( + "Could not fail unscheduled job {job_id}, not found in queued jobs" + ))) + } + } + + async fn remove_job(&self, job_id: &str) -> Result<()> { + if self.queued_jobs.remove(job_id).is_none() { + self.store + .apply_txn(vec![ + (Operation::Delete, Keyspace::JobStatus, job_id.to_string()), + ( + Operation::Delete, + Keyspace::ExecutionGraph, + job_id.to_string(), + ), + ]) + .await + } else { + Ok(()) + } + } + + async fn try_acquire_job(&self, _job_id: &str) -> Result> { + Err(BallistaError::NotImplemented( + "Work stealing is not currently implemented".to_string(), + )) + } + + async fn job_state_events(&self) -> Result { + let watch = self + .store + .watch(Keyspace::JobStatus, String::default()) + .await?; + + let stream = watch + .filter_map(|event| { + futures::future::ready(match event { + WatchEvent::Put(key, value) => { + if let Some(job_id) = Keyspace::JobStatus.strip_prefix(&key) { + match JobStatus::decode(value.as_slice()) { + Ok(status) => Some(JobStateEvent::JobUpdated { + job_id: job_id.to_string(), + status, + }), + Err(err) => { + warn!( + "Error decoding job status from watch event: {err:?}" + ); + None + } + } + } else { + None + } + } + _ => None, + }) + }) + .boxed(); + + Ok(stream) + } + + async fn get_session(&self, session_id: &str) -> Result> { + let value = self.store.get(Keyspace::Sessions, session_id).await?; + + let settings: protobuf::SessionSettings = decode_protobuf(&value)?; + + let mut config_builder = BallistaConfig::builder(); + for kv_pair in &settings.configs { + config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); + } + let config = config_builder.build()?; + + Ok(create_datafusion_context(&config, self.session_builder)) + } + + async fn create_session( + &self, + config: &BallistaConfig, + ) -> Result> { + let mut settings: Vec = vec![]; + + for (key, value) in config.settings() { + settings.push(KeyValuePair { + key: key.clone(), + value: value.clone(), + }) + } + + let value = protobuf::SessionSettings { configs: settings }; + + let session = create_datafusion_context(config, self.session_builder); + + self.store + .put( + Keyspace::Sessions, + session.session_id(), + value.encode_to_vec(), + ) + .await?; + + Ok(session) + } + + async fn update_session( + &self, + session_id: &str, + config: &BallistaConfig, + ) -> Result> { + let mut settings: Vec = vec![]; + + for (key, value) in config.settings() { + settings.push(KeyValuePair { + key: key.clone(), + value: value.clone(), + }) + } + + let value = protobuf::SessionSettings { configs: settings }; + self.store + .put( + Keyspace::Sessions, + session_id.to_owned(), + value.encode_to_vec(), + ) + .await?; + + Ok(create_datafusion_context(config, self.session_builder)) + } +} + +async fn with_lock>(mut lock: Box, op: F) -> Out { + let result = op.await; + lock.unlock().await; + result +} + +#[cfg(test)] +mod test { + use crate::cluster::kv::KeyValueState; + use crate::cluster::storage::sled::SledClient; + use crate::cluster::test::{ + test_executor_registration, test_fuzz_reservations, test_job_lifecycle, + test_job_planning_failure, test_reservation, + }; + use crate::cluster::TaskDistribution; + use crate::test_utils::{ + test_aggregation_plan, test_join_plan, test_two_aggregations_plan, + }; + use ballista_core::error::Result; + use ballista_core::serde::BallistaCodec; + use ballista_core::utils::default_session_builder; + + #[cfg(feature = "sled")] + fn make_sled_state() -> Result> { + Ok(KeyValueState::new( + "", + SledClient::try_new_temporary()?, + BallistaCodec::default(), + default_session_builder, + )) + } + + #[cfg(feature = "sled")] + #[tokio::test] + async fn test_sled_executor_reservation() -> Result<()> { + test_executor_registration(make_sled_state()?).await + } + + #[cfg(feature = "sled")] + #[tokio::test] + async fn test_sled_reserve() -> Result<()> { + test_reservation(make_sled_state()?, TaskDistribution::Bias).await?; + test_reservation(make_sled_state()?, TaskDistribution::RoundRobin).await?; + + Ok(()) + } + + #[cfg(feature = "sled")] + #[tokio::test] + async fn test_sled_fuzz_reserve() -> Result<()> { + test_fuzz_reservations(make_sled_state()?, 10, TaskDistribution::Bias, 10, 10) + .await?; + test_fuzz_reservations( + make_sled_state()?, + 10, + TaskDistribution::RoundRobin, + 10, + 10, + ) + .await?; + + Ok(()) + } + + #[cfg(feature = "sled")] + #[tokio::test] + async fn test_sled_job_lifecycle() -> Result<()> { + test_job_lifecycle(make_sled_state()?, test_aggregation_plan(4).await).await?; + test_job_lifecycle(make_sled_state()?, test_two_aggregations_plan(4).await) + .await?; + test_job_lifecycle(make_sled_state()?, test_join_plan(4).await).await?; + Ok(()) + } + + #[cfg(feature = "sled")] + #[tokio::test] + async fn test_in_memory_job_planning_failure() -> Result<()> { + test_job_planning_failure(make_sled_state()?, test_aggregation_plan(4).await) + .await?; + test_job_planning_failure( + make_sled_state()?, + test_two_aggregations_plan(4).await, + ) + .await?; + test_job_planning_failure(make_sled_state()?, test_join_plan(4).await).await?; + + Ok(()) + } +} diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs new file mode 100644 index 000000000..1a852ea17 --- /dev/null +++ b/ballista/scheduler/src/cluster/memory.rs @@ -0,0 +1,588 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cluster::{ + reserve_slots_bias, reserve_slots_round_robin, ClusterState, ExecutorHeartbeatStream, + JobState, JobStateEvent, JobStateEventStream, JobStatus, TaskDistribution, +}; +use crate::state::execution_graph::ExecutionGraph; +use crate::state::executor_manager::ExecutorReservation; +use async_trait::async_trait; +use ballista_core::config::BallistaConfig; +use ballista_core::error::{BallistaError, Result}; +use ballista_core::serde::protobuf::{ + executor_status, AvailableTaskSlots, ExecutorHeartbeat, ExecutorStatus, + ExecutorTaskSlots, FailedJob, QueuedJob, +}; +use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; +use dashmap::DashMap; +use datafusion::prelude::SessionContext; + +use crate::cluster::event::ClusterEventSender; +use crate::scheduler_server::{timestamp_millis, timestamp_secs, SessionBuilder}; +use crate::state::session_manager::create_datafusion_context; +use ballista_core::serde::protobuf::job_status::Status; +use itertools::Itertools; +use log::warn; +use parking_lot::Mutex; +use std::collections::{HashMap, HashSet}; +use std::ops::DerefMut; + +use std::sync::Arc; +use tracing::debug; + +#[derive(Default)] +pub struct InMemoryClusterState { + /// Current available task slots for each executor + task_slots: Mutex, + /// Current executors + executors: DashMap, + /// Last heartbeat received for each executor + heartbeats: DashMap, + /// Broadcast channel sender for heartbeats, If `None` there are not + /// subscribers + heartbeat_sender: ClusterEventSender, +} + +#[async_trait] +impl ClusterState for InMemoryClusterState { + async fn reserve_slots( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option>, + ) -> Result> { + let mut guard = self.task_slots.lock(); + + let mut available_slots: Vec<&mut AvailableTaskSlots> = guard + .task_slots + .iter_mut() + .filter_map(|data| { + (data.slots > 0 + && executors + .as_ref() + .map(|executors| executors.contains(&data.executor_id)) + .unwrap_or(true)) + .then_some(data) + }) + .collect(); + + available_slots.sort_by(|a, b| Ord::cmp(&b.slots, &a.slots)); + + let reservations = match distribution { + TaskDistribution::Bias => reserve_slots_bias(available_slots, num_slots), + TaskDistribution::RoundRobin => { + reserve_slots_round_robin(available_slots, num_slots) + } + }; + + Ok(reservations) + } + + async fn reserve_slots_exact( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option>, + ) -> Result> { + let mut guard = self.task_slots.lock(); + + let rollback = guard.clone(); + + let mut available_slots: Vec<&mut AvailableTaskSlots> = guard + .task_slots + .iter_mut() + .filter_map(|data| { + (data.slots > 0 + && executors + .as_ref() + .map(|executors| executors.contains(&data.executor_id)) + .unwrap_or(true)) + .then_some(data) + }) + .collect(); + + available_slots.sort_by(|a, b| Ord::cmp(&b.slots, &a.slots)); + + let reservations = match distribution { + TaskDistribution::Bias => reserve_slots_bias(available_slots, num_slots), + TaskDistribution::RoundRobin => { + reserve_slots_round_robin(available_slots, num_slots) + } + }; + + if reservations.len() as u32 != num_slots { + *guard = rollback; + Ok(vec![]) + } else { + Ok(reservations) + } + } + + async fn cancel_reservations( + &self, + reservations: Vec, + ) -> Result<()> { + let mut increments = HashMap::new(); + for ExecutorReservation { executor_id, .. } in reservations { + if let Some(inc) = increments.get_mut(&executor_id) { + *inc += 1; + } else { + increments.insert(executor_id, 1usize); + } + } + + let mut guard = self.task_slots.lock(); + + for executor_slots in guard.task_slots.iter_mut() { + if let Some(slots) = increments.get(&executor_slots.executor_id) { + executor_slots.slots += *slots as u32; + } + } + + Ok(()) + } + + async fn register_executor( + &self, + metadata: ExecutorMetadata, + mut spec: ExecutorData, + reserve: bool, + ) -> Result> { + let heartbeat = ExecutorHeartbeat { + executor_id: metadata.id.clone(), + timestamp: timestamp_secs(), + metrics: vec![], + status: Some(ExecutorStatus { + status: Some(executor_status::Status::Active(String::default())), + }), + }; + + let mut guard = self.task_slots.lock(); + + // Check to see if we already have task slots for executor. If so, remove them. + if let Some((idx, _)) = guard + .task_slots + .iter() + .find_position(|slots| slots.executor_id == metadata.id) + { + guard.task_slots.swap_remove(idx); + } + + if reserve { + let slots = std::mem::take(&mut spec.available_task_slots) as usize; + + let reservations = (0..slots) + .into_iter() + .map(|_| ExecutorReservation::new_free(metadata.id.clone())) + .collect(); + + self.executors.insert(metadata.id.clone(), metadata.clone()); + + guard.task_slots.push(AvailableTaskSlots { + executor_id: metadata.id, + slots: 0, + }); + + self.heartbeat_sender.send(&heartbeat); + + Ok(reservations) + } else { + self.executors.insert(metadata.id.clone(), metadata.clone()); + + guard.task_slots.push(AvailableTaskSlots { + executor_id: metadata.id, + slots: spec.available_task_slots, + }); + + self.heartbeat_sender.send(&heartbeat); + + Ok(vec![]) + } + } + + async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()> { + self.executors.insert(metadata.id.clone(), metadata); + Ok(()) + } + + async fn get_executor_metadata(&self, executor_id: &str) -> Result { + self.executors + .get(executor_id) + .map(|pair| pair.value().clone()) + .ok_or_else(|| { + BallistaError::Internal(format!( + "Not executor with ID {executor_id} found" + )) + }) + } + + async fn save_executor_heartbeat(&self, heartbeat: ExecutorHeartbeat) -> Result<()> { + if let Some(mut last) = self.heartbeats.get_mut(&heartbeat.executor_id) { + let _ = std::mem::replace(last.deref_mut(), heartbeat.clone()); + } else { + self.heartbeats + .insert(heartbeat.executor_id.clone(), heartbeat.clone()); + } + + self.heartbeat_sender.send(&heartbeat); + + Ok(()) + } + + async fn remove_executor(&self, executor_id: &str) -> Result<()> { + { + let mut guard = self.task_slots.lock(); + + if let Some((idx, _)) = guard + .task_slots + .iter() + .find_position(|slots| slots.executor_id == executor_id) + { + guard.task_slots.swap_remove(idx); + } + } + + if let Some(heartbeat) = self.heartbeats.get_mut(executor_id).as_deref_mut() { + let new_heartbeat = ExecutorHeartbeat { + executor_id: executor_id.to_string(), + timestamp: timestamp_secs(), + metrics: vec![], + status: Some(ExecutorStatus { + status: Some(executor_status::Status::Dead(String::default())), + }), + }; + + *heartbeat = new_heartbeat; + + self.heartbeat_sender.send(heartbeat); + } + + Ok(()) + } + + async fn executor_heartbeat_stream(&self) -> Result { + Ok(Box::pin(self.heartbeat_sender.subscribe())) + } + + async fn executor_heartbeats(&self) -> Result> { + Ok(self + .heartbeats + .iter() + .map(|r| (r.key().clone(), r.value().clone())) + .collect()) + } +} + +/// Implementation of `JobState` which keeps all state in memory. If using `InMemoryJobState` +/// no job state will be shared between schedulers +pub struct InMemoryJobState { + scheduler: String, + /// Jobs which have either completed successfully or failed + completed_jobs: DashMap)>, + /// In-memory store of queued jobs. Map from Job ID -> (Job Name, queued_at timestamp) + queued_jobs: DashMap, + /// In-memory store of running job statuses. Map from Job ID -> JobStatus + running_jobs: DashMap, + /// Active ballista sessions + sessions: DashMap>, + /// `SessionBuilder` for building DataFusion `SessionContext` from `BallistaConfig` + session_builder: SessionBuilder, + /// Sender of job events + job_event_sender: ClusterEventSender, +} + +impl InMemoryJobState { + pub fn new(scheduler: impl Into, session_builder: SessionBuilder) -> Self { + Self { + scheduler: scheduler.into(), + completed_jobs: Default::default(), + queued_jobs: Default::default(), + running_jobs: Default::default(), + sessions: Default::default(), + session_builder, + job_event_sender: ClusterEventSender::new(100), + } + } +} + +#[async_trait] +impl JobState for InMemoryJobState { + async fn submit_job(&self, job_id: String, graph: &ExecutionGraph) -> Result<()> { + if self.queued_jobs.get(&job_id).is_some() { + self.running_jobs.insert(job_id.clone(), graph.status()); + self.queued_jobs.remove(&job_id); + + self.job_event_sender.send(&JobStateEvent::JobAcquired { + job_id, + owner: self.scheduler.clone(), + }); + + Ok(()) + } else { + Err(BallistaError::Internal(format!( + "Failed to submit job {job_id}, not found in queued jobs" + ))) + } + } + + async fn get_job_status(&self, job_id: &str) -> Result> { + if let Some((job_name, queued_at)) = self.queued_jobs.get(job_id).as_deref() { + return Ok(Some(JobStatus { + job_id: job_id.to_string(), + job_name: job_name.clone(), + status: Some(Status::Queued(QueuedJob { + queued_at: *queued_at, + })), + })); + } + + if let Some(status) = self.running_jobs.get(job_id).as_deref().cloned() { + return Ok(Some(status)); + } + + if let Some((status, _)) = self.completed_jobs.get(job_id).as_deref() { + return Ok(Some(status.clone())); + } + + Ok(None) + } + + async fn get_execution_graph(&self, job_id: &str) -> Result> { + Ok(self + .completed_jobs + .get(job_id) + .as_deref() + .and_then(|(_, graph)| graph.clone())) + } + + async fn try_acquire_job(&self, _job_id: &str) -> Result> { + // Always return None. The only state stored here are for completed jobs + // which cannot be acquired + Ok(None) + } + + async fn save_job(&self, job_id: &str, graph: &ExecutionGraph) -> Result<()> { + let status = graph.status(); + + debug!("saving state for job {job_id} with status {:?}", status); + + // If job is either successful or failed, save to completed jobs + if matches!( + status.status, + Some(Status::Successful(_)) | Some(Status::Failed(_)) + ) { + self.completed_jobs + .insert(job_id.to_string(), (status, Some(graph.clone()))); + self.running_jobs.remove(job_id); + } else if let Some(old_status) = + self.running_jobs.insert(job_id.to_string(), graph.status()) + { + self.job_event_sender.send(&JobStateEvent::JobUpdated { + job_id: job_id.to_string(), + status: old_status, + }) + } + + Ok(()) + } + + async fn get_session(&self, session_id: &str) -> Result> { + self.sessions + .get(session_id) + .map(|sess| sess.clone()) + .ok_or_else(|| { + BallistaError::General(format!("No session for {session_id} found")) + }) + } + + async fn create_session( + &self, + config: &BallistaConfig, + ) -> Result> { + let session = create_datafusion_context(config, self.session_builder); + self.sessions.insert(session.session_id(), session.clone()); + + Ok(session) + } + + async fn update_session( + &self, + session_id: &str, + config: &BallistaConfig, + ) -> Result> { + let session = create_datafusion_context(config, self.session_builder); + self.sessions + .insert(session_id.to_string(), session.clone()); + + Ok(session) + } + + async fn job_state_events(&self) -> Result { + Ok(Box::pin(self.job_event_sender.subscribe())) + } + + async fn remove_job(&self, job_id: &str) -> Result<()> { + if self.completed_jobs.remove(job_id).is_none() { + warn!("Tried to delete non-existent job {job_id} from state"); + } + Ok(()) + } + + async fn get_jobs(&self) -> Result> { + Ok(self + .completed_jobs + .iter() + .map(|pair| pair.key().clone()) + .collect()) + } + + async fn accept_job( + &self, + job_id: &str, + job_name: &str, + queued_at: u64, + ) -> Result<()> { + self.queued_jobs + .insert(job_id.to_string(), (job_name.to_string(), queued_at)); + + Ok(()) + } + + async fn fail_unscheduled_job(&self, job_id: &str, reason: String) -> Result<()> { + if let Some((job_id, (job_name, queued_at))) = self.queued_jobs.remove(job_id) { + self.completed_jobs.insert( + job_id.clone(), + ( + JobStatus { + job_id, + job_name, + status: Some(Status::Failed(FailedJob { + error: reason, + queued_at, + started_at: 0, + ended_at: timestamp_millis(), + })), + }, + None, + ), + ); + + Ok(()) + } else { + Err(BallistaError::Internal(format!( + "Could not fail unscheduler job {job_id}, job not found in queued jobs" + ))) + } + } +} + +#[cfg(test)] +mod test { + use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState}; + use crate::cluster::test::{ + test_executor_registration, test_fuzz_reservations, test_job_lifecycle, + test_job_planning_failure, test_reservation, + }; + use crate::cluster::TaskDistribution; + use crate::test_utils::{ + test_aggregation_plan, test_join_plan, test_two_aggregations_plan, + }; + use ballista_core::error::Result; + use ballista_core::utils::default_session_builder; + + #[tokio::test] + async fn test_in_memory_registration() -> Result<()> { + test_executor_registration(InMemoryClusterState::default()).await + } + + #[tokio::test] + async fn test_in_memory_reserve() -> Result<()> { + test_reservation(InMemoryClusterState::default(), TaskDistribution::Bias).await?; + test_reservation( + InMemoryClusterState::default(), + TaskDistribution::RoundRobin, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_in_memory_fuzz_reserve() -> Result<()> { + test_fuzz_reservations( + InMemoryClusterState::default(), + 10, + TaskDistribution::Bias, + 10, + 10, + ) + .await?; + test_fuzz_reservations( + InMemoryClusterState::default(), + 10, + TaskDistribution::RoundRobin, + 10, + 10, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_in_memory_job_lifecycle() -> Result<()> { + test_job_lifecycle( + InMemoryJobState::new("", default_session_builder), + test_aggregation_plan(4).await, + ) + .await?; + test_job_lifecycle( + InMemoryJobState::new("", default_session_builder), + test_two_aggregations_plan(4).await, + ) + .await?; + test_job_lifecycle( + InMemoryJobState::new("", default_session_builder), + test_join_plan(4).await, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_in_memory_job_planning_failure() -> Result<()> { + test_job_planning_failure( + InMemoryJobState::new("", default_session_builder), + test_aggregation_plan(4).await, + ) + .await?; + test_job_planning_failure( + InMemoryJobState::new("", default_session_builder), + test_two_aggregations_plan(4).await, + ) + .await?; + test_job_planning_failure( + InMemoryJobState::new("", default_session_builder), + test_join_plan(4).await, + ) + .await?; + + Ok(()) + } +} diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs new file mode 100644 index 000000000..35a5052a8 --- /dev/null +++ b/ballista/scheduler/src/cluster/mod.rs @@ -0,0 +1,443 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod event; +pub mod kv; +pub mod memory; +pub mod storage; + +#[cfg(test)] +#[allow(clippy::uninlined_format_args)] +pub mod test; + +use crate::cluster::kv::KeyValueState; +use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState}; +use crate::cluster::storage::etcd::EtcdClient; +use crate::cluster::storage::sled::SledClient; +use crate::cluster::storage::KeyValueStore; +use crate::config::{ClusterStorageConfig, SchedulerConfig}; +use crate::scheduler_server::SessionBuilder; +use crate::state::execution_graph::ExecutionGraph; +use crate::state::executor_manager::ExecutorReservation; +use ballista_core::config::BallistaConfig; +use ballista_core::error::{BallistaError, Result}; +use ballista_core::serde::protobuf::{AvailableTaskSlots, ExecutorHeartbeat, JobStatus}; +use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; +use ballista_core::serde::BallistaCodec; +use ballista_core::utils::default_session_builder; +use clap::ArgEnum; +use datafusion::prelude::SessionContext; +use datafusion_proto::logical_plan::AsLogicalPlan; +use datafusion_proto::physical_plan::AsExecutionPlan; +use futures::Stream; +use log::info; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; + +// an enum used to configure the backend +// needs to be visible to code generated by configure_me +#[derive(Debug, Clone, ArgEnum, serde::Deserialize, PartialEq, Eq)] +pub enum ClusterStorage { + Etcd, + Memory, + Sled, +} + +impl std::str::FromStr for ClusterStorage { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + ArgEnum::from_str(s, true) + } +} + +impl parse_arg::ParseArgFromStr for ClusterStorage { + fn describe_type(mut writer: W) -> fmt::Result { + write!(writer, "The cluster storage backend for the scheduler") + } +} + +#[derive(Clone)] +pub struct BallistaCluster { + cluster_state: Arc, + job_state: Arc, +} + +impl BallistaCluster { + pub fn new( + cluster_state: Arc, + job_state: Arc, + ) -> Self { + Self { + cluster_state, + job_state, + } + } + + pub fn new_memory( + scheduler: impl Into, + session_builder: SessionBuilder, + ) -> Self { + Self { + cluster_state: Arc::new(InMemoryClusterState::default()), + job_state: Arc::new(InMemoryJobState::new(scheduler, session_builder)), + } + } + + pub fn new_kv< + S: KeyValueStore, + T: 'static + AsLogicalPlan, + U: 'static + AsExecutionPlan, + >( + store: S, + scheduler: impl Into, + session_builder: SessionBuilder, + codec: BallistaCodec, + ) -> Self { + let kv_state = + Arc::new(KeyValueState::new(scheduler, store, codec, session_builder)); + Self { + cluster_state: kv_state.clone(), + job_state: kv_state, + } + } + + pub async fn new_from_config(config: &SchedulerConfig) -> Result { + let scheduler = config.scheduler_name(); + + match &config.cluster_storage { + #[cfg(feature = "etcd")] + ClusterStorageConfig::Etcd(urls) => { + let etcd = etcd_client::Client::connect(urls.as_slice(), None) + .await + .map_err(|err| { + BallistaError::Internal(format!( + "Could not connect to etcd: {err:?}" + )) + })?; + + Ok(Self::new_kv( + EtcdClient::new(config.namespace.clone(), etcd), + scheduler, + default_session_builder, + BallistaCodec::default(), + )) + } + #[cfg(not(feature = "etcd"))] + StateBackend::Etcd => { + unimplemented!( + "build the scheduler with the `etcd` feature to use the etcd config backend" + ) + } + #[cfg(feature = "sled")] + ClusterStorageConfig::Sled(dir) => { + if let Some(dir) = dir.as_ref() { + info!("Initializing Sled database in directory {}", dir); + let sled = SledClient::try_new(dir)?; + + Ok(Self::new_kv( + sled, + scheduler, + default_session_builder, + BallistaCodec::default(), + )) + } else { + info!("Initializing Sled database in temp directory"); + let sled = SledClient::try_new_temporary()?; + + Ok(Self::new_kv( + sled, + scheduler, + default_session_builder, + BallistaCodec::default(), + )) + } + } + #[cfg(not(feature = "sled"))] + StateBackend::Sled => { + unimplemented!( + "build the scheduler with the `sled` feature to use the sled config backend" + ) + } + ClusterStorageConfig::Memory => Ok(BallistaCluster::new_memory( + scheduler, + default_session_builder, + )), + } + } + + pub fn cluster_state(&self) -> Arc { + self.cluster_state.clone() + } + + pub fn job_state(&self) -> Arc { + self.job_state.clone() + } +} + +/// Stream of `ExecutorHeartbeat`. This stream should contain all `ExecutorHeartbeats` received +/// by any schedulers with a shared `ClusterState` +pub type ExecutorHeartbeatStream = Pin + Send>>; + +/// Method of distributing tasks to available executor slots +#[derive(Debug, Clone, Copy)] +pub enum TaskDistribution { + /// Eagerly assign tasks to executor slots. This will assign as many task slots per executor + /// as are currently available + Bias, + /// Distributed tasks evenely across executors. This will try and iterate through available executors + /// and assign one task to each executor until all tasks are assigned. + RoundRobin, +} + +/// A trait that contains the necessary method to maintain a globally consistent view of cluster resources +#[tonic::async_trait] +pub trait ClusterState: Send + Sync + 'static { + /// Reserve up to `num_slots` executor task slots. If not enough task slots are available, reserve + /// as many as possible. + /// + /// If `executors` is provided, only reserve slots of the specified executor IDs + async fn reserve_slots( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option>, + ) -> Result>; + + /// Reserve exactly `num_slots` executor task slots. If not enough task slots are available, + /// returns an empty vec + /// + /// If `executors` is provided, only reserve slots of the specified executor IDs + async fn reserve_slots_exact( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option>, + ) -> Result>; + + /// Cancel the specified reservations. This will make reserved executor slots available to other + /// tasks. + /// This operations should be atomic. Either all reservations are cancelled or none are + async fn cancel_reservations( + &self, + reservations: Vec, + ) -> Result<()>; + + /// Register a new executor in the cluster. If `reserve` is true, then the executors task slots + /// will be reserved and returned in the response and none of the new executors task slots will be + /// available to other tasks. + async fn register_executor( + &self, + metadata: ExecutorMetadata, + spec: ExecutorData, + reserve: bool, + ) -> Result>; + + /// Save the executor metadata. This will overwrite existing metadata for the executor ID + async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()>; + + /// Get executor metadata for the provided executor ID. Returns an error if the executor does not exist + async fn get_executor_metadata(&self, executor_id: &str) -> Result; + + /// Save the executor heartbeat + async fn save_executor_heartbeat(&self, heartbeat: ExecutorHeartbeat) -> Result<()>; + + /// Remove the executor from the cluster + async fn remove_executor(&self, executor_id: &str) -> Result<()>; + + /// Return the stream of executor heartbeats observed by all schedulers in the cluster. + /// This can be aggregated to provide an eventually consistent view of all executors within the cluster + async fn executor_heartbeat_stream(&self) -> Result; + + /// Return a map of the last seen heartbeat for all active executors + async fn executor_heartbeats(&self) -> Result>; +} + +/// Events related to the state of jobs. Implementations may or may not support all event types. +#[derive(Debug, Clone, PartialEq)] +pub enum JobStateEvent { + /// Event when a job status has been updated + JobUpdated { + /// Job ID of updated job + job_id: String, + /// New job status + status: JobStatus, + }, + /// Event when a scheduler acquires ownership of the job. This happens + /// either when a scheduler submits a job (in which case ownership is implied) + /// or when a scheduler acquires ownership of a running job release by a + /// different scheduler + JobAcquired { + /// Job ID of the acquired job + job_id: String, + /// The scheduler which acquired ownership of the job + owner: String, + }, + /// Event when a scheduler releases ownership of a still active job + JobReleased { + /// Job ID of the released job + job_id: String, + }, + /// Event when a new session has been created + SessionCreated { + session_id: String, + config: BallistaConfig, + }, + /// Event when a session configuration has been updated + SessionUpdated { + session_id: String, + config: BallistaConfig, + }, +} + +/// Stream of `JobStateEvent`. This stream should contain all `JobStateEvent`s received +/// by any schedulers with a shared `ClusterState` +pub type JobStateEventStream = Pin + Send>>; + +/// A trait that contains the necessary methods for persisting state related to executing jobs +#[tonic::async_trait] +pub trait JobState: Send + Sync { + /// Accept job into a scheduler's job queue. This should be called when a job is + /// received by the scheduler but before it is planned and may or may not be saved + /// in global state + async fn accept_job( + &self, + job_id: &str, + job_name: &str, + queued_at: u64, + ) -> Result<()>; + + /// Submit a new job to the `JobState`. It is assumed that the submitter owns the job. + /// In local state the job should be save as `JobStatus::Active` and in shared state + /// it should be saved as `JobStatus::Running` with `scheduler` set to the current scheduler + async fn submit_job(&self, job_id: String, graph: &ExecutionGraph) -> Result<()>; + + /// Return a `Vec` of all active job IDs in the `JobState` + async fn get_jobs(&self) -> Result>; + + /// Fetch the job status + async fn get_job_status(&self, job_id: &str) -> Result>; + + /// Get the `ExecutionGraph` for job. The job may or may not belong to the caller + /// and should return the `ExecutionGraph` for the given job (if it exists) at the + /// time this method is called with no guarantees that the graph has not been + /// subsequently updated by another scheduler. + async fn get_execution_graph(&self, job_id: &str) -> Result>; + + /// Persist the current state of an owned job to global state. This should fail + /// if the job is not owned by the caller. + async fn save_job(&self, job_id: &str, graph: &ExecutionGraph) -> Result<()>; + + /// Mark a job which has not been submitted as failed. This should be called if a job fails + /// during planning (and does not yet have an `ExecutionGraph`) + async fn fail_unscheduled_job(&self, job_id: &str, reason: String) -> Result<()>; + + /// Delete a job from the global state + async fn remove_job(&self, job_id: &str) -> Result<()>; + + /// Attempt to acquire ownership of the given job. If the job is still in a running state + /// and is successfully acquired by the caller, return the current `ExecutionGraph`, + /// otherwise return `None` + async fn try_acquire_job(&self, job_id: &str) -> Result>; + + /// Get a stream of all `JobState` events. An event should be published any time that status + /// of a job changes in state + async fn job_state_events(&self) -> Result; + + /// Get the `SessionContext` associated with `session_id`. Returns an error if the + /// session does not exist + async fn get_session(&self, session_id: &str) -> Result>; + + /// Create a new saved session + async fn create_session( + &self, + config: &BallistaConfig, + ) -> Result>; + + // Update a new saved session. If the session does not exist, a new one will be created + async fn update_session( + &self, + session_id: &str, + config: &BallistaConfig, + ) -> Result>; +} + +pub(crate) fn reserve_slots_bias( + mut slots: Vec<&mut AvailableTaskSlots>, + mut n: u32, +) -> Vec { + let mut reservations = Vec::with_capacity(n as usize); + + let mut iter = slots.iter_mut(); + + while n > 0 { + if let Some(executor) = iter.next() { + let take = executor.slots.min(n); + for _ in 0..take { + reservations + .push(ExecutorReservation::new_free(executor.executor_id.clone())); + } + + executor.slots -= take; + n -= take; + } else { + break; + } + } + + reservations +} + +pub(crate) fn reserve_slots_round_robin( + mut slots: Vec<&mut AvailableTaskSlots>, + mut n: u32, +) -> Vec { + let mut reservations = Vec::with_capacity(n as usize); + + let mut last_updated_idx = 0usize; + + loop { + let n_before = n; + for (idx, data) in slots.iter_mut().enumerate() { + if n == 0 { + break; + } + + // Since the vector is sorted in descending order, + // if finding one executor has not enough slots, the following will have not enough, either + if data.slots == 0 { + break; + } + + reservations.push(ExecutorReservation::new_free(data.executor_id.clone())); + data.slots -= 1; + n -= 1; + + if idx >= last_updated_idx { + last_updated_idx = idx + 1; + } + } + + if n_before == n { + break; + } + } + + reservations +} diff --git a/ballista/scheduler/src/state/backend/etcd.rs b/ballista/scheduler/src/cluster/storage/etcd.rs similarity index 98% rename from ballista/scheduler/src/state/backend/etcd.rs rename to ballista/scheduler/src/cluster/storage/etcd.rs index 631acb2c2..528a9116b 100644 --- a/ballista/scheduler/src/state/backend/etcd.rs +++ b/ballista/scheduler/src/cluster/storage/etcd.rs @@ -15,24 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! Etcd config backend. - use std::collections::HashSet; use std::task::Poll; +use async_trait::async_trait; use ballista_core::error::{ballista_error, Result}; use std::time::Instant; +use crate::cluster::storage::KeyValueStore; use etcd_client::{ GetOptions, LockOptions, LockResponse, Txn, TxnOp, WatchOptions, WatchStream, Watcher, }; use futures::{Stream, StreamExt}; use log::{debug, error, warn}; -use crate::state::backend::{ - Keyspace, Lock, Operation, StateBackendClient, Watch, WatchEvent, -}; +use crate::cluster::storage::{Keyspace, Lock, Operation, Watch, WatchEvent}; /// A [`StateBackendClient`] implementation that uses etcd to save cluster state. #[derive(Clone)] @@ -47,8 +45,8 @@ impl EtcdClient { } } -#[tonic::async_trait] -impl StateBackendClient for EtcdClient { +#[async_trait] +impl KeyValueStore for EtcdClient { async fn get(&self, keyspace: Keyspace, key: &str) -> Result> { let key = format!("/{}/{:?}/{}", self.namespace, keyspace, key); diff --git a/ballista/scheduler/src/state/backend/mod.rs b/ballista/scheduler/src/cluster/storage/mod.rs similarity index 81% rename from ballista/scheduler/src/state/backend/mod.rs rename to ballista/scheduler/src/cluster/storage/mod.rs index 5c859937e..f40804689 100644 --- a/ballista/scheduler/src/state/backend/mod.rs +++ b/ballista/scheduler/src/cluster/storage/mod.rs @@ -15,64 +15,42 @@ // specific language governing permissions and limitations // under the License. -use ballista_core::error::Result; -use clap::ArgEnum; -use futures::{future, Stream}; -use std::collections::HashSet; -use std::fmt; -use tokio::sync::OwnedMutexGuard; - -pub mod cluster; #[cfg(feature = "etcd")] pub mod etcd; -pub mod memory; #[cfg(feature = "sled")] pub mod sled; -mod utils; - -// an enum used to configure the backend -// needs to be visible to code generated by configure_me -#[derive(Debug, Clone, ArgEnum, serde::Deserialize, PartialEq, Eq)] -pub enum StateBackend { - Etcd, - Memory, - Sled, -} -impl std::str::FromStr for StateBackend { - type Err = String; - - fn from_str(s: &str) -> std::result::Result { - ArgEnum::from_str(s, true) - } -} - -impl parse_arg::ParseArgFromStr for StateBackend { - fn describe_type(mut writer: W) -> fmt::Result { - write!(writer, "The configuration backend for the scheduler") - } -} +use async_trait::async_trait; +use ballista_core::error::Result; +use futures::{future, Stream}; +use std::collections::HashSet; +use tokio::sync::OwnedMutexGuard; #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub enum Keyspace { Executors, - ActiveJobs, - CompletedJobs, - FailedJobs, + JobStatus, + ExecutionGraph, Slots, Sessions, Heartbeats, } +impl Keyspace { + pub fn strip_prefix<'a>(&'a self, key: &'a str) -> Option<&'a str> { + key.strip_prefix(&format!("{self:?}/")) + } +} + #[derive(Debug, Eq, PartialEq, Hash)] pub enum Operation { Put(Vec), Delete, } -/// A trait that contains the necessary methods to save and retrieve the state and configuration of a cluster. -#[tonic::async_trait] -pub trait StateBackendClient: Send + Sync { +/// A trait that defines a KeyValue interface with basic locking primitives for persisting Ballista cluster state +#[async_trait] +pub trait KeyValueStore: Send + Sync + Clone + 'static { /// Retrieve the data associated with a specific key in a given keyspace. /// /// An empty vec is returned if the key does not exist. @@ -126,7 +104,11 @@ pub trait StateBackendClient: Send + Sync { async fn lock(&self, keyspace: Keyspace, key: &str) -> Result>; /// Watch all events that happen on a specific prefix. - async fn watch(&self, keyspace: Keyspace, prefix: String) -> Result>; + async fn watch( + &self, + keyspace: Keyspace, + prefix: String, + ) -> Result>>; /// Permanently delete a key from state async fn delete(&self, keyspace: Keyspace, key: &str) -> Result<()>; @@ -144,7 +126,7 @@ pub enum TaskDistribution { } /// A Watch is a cancelable stream of put or delete events in the [StateBackendClient] -#[tonic::async_trait] +#[async_trait] pub trait Watch: Stream + Send + Unpin { async fn cancel(&mut self) -> Result<()>; } @@ -158,12 +140,12 @@ pub enum WatchEvent { Delete(String), } -#[tonic::async_trait] +#[async_trait] pub trait Lock: Send + Sync { async fn unlock(&mut self); } -#[tonic::async_trait] +#[async_trait] impl Lock for OwnedMutexGuard { async fn unlock(&mut self) {} } diff --git a/ballista/scheduler/src/state/backend/sled.rs b/ballista/scheduler/src/cluster/storage/sled.rs similarity index 93% rename from ballista/scheduler/src/state/backend/sled.rs rename to ballista/scheduler/src/cluster/storage/sled.rs index 66e896bcd..6be4c504e 100644 --- a/ballista/scheduler/src/state/backend/sled.rs +++ b/ballista/scheduler/src/cluster/storage/sled.rs @@ -20,14 +20,14 @@ use std::{sync::Arc, task::Poll}; use ballista_core::error::{ballista_error, BallistaError, Result}; +use crate::cluster::storage::KeyValueStore; +use async_trait::async_trait; use futures::{FutureExt, Stream}; use log::warn; use sled_package as sled; use tokio::sync::Mutex; -use crate::state::backend::{ - Keyspace, Lock, Operation, StateBackendClient, Watch, WatchEvent, -}; +use crate::cluster::storage::{Keyspace, Lock, Operation, Watch, WatchEvent}; /// A [`StateBackendClient`] implementation that uses file-based storage to save cluster state. #[derive(Clone)] @@ -64,8 +64,8 @@ fn sled_to_ballista_error(e: sled::Error) -> BallistaError { } } -#[tonic::async_trait] -impl StateBackendClient for SledClient { +#[async_trait] +impl KeyValueStore for SledClient { async fn get(&self, keyspace: Keyspace, key: &str) -> Result> { let key = format!("/{keyspace:?}/{key}"); Ok(self @@ -260,7 +260,7 @@ impl Stream for SledWatch { fn poll_next( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { match self.get_mut().subscriber.poll_unpin(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), @@ -282,10 +282,10 @@ impl Stream for SledWatch { #[cfg(test)] mod tests { - use super::{SledClient, StateBackendClient, Watch, WatchEvent}; + use super::{KeyValueStore, SledClient, Watch, WatchEvent}; + + use crate::cluster::storage::{Keyspace, Operation}; - use crate::state::backend::{Keyspace, Operation}; - use crate::state::with_locks; use futures::StreamExt; use std::result::Result; @@ -310,26 +310,24 @@ mod tests { let client = create_instance()?; let key = "key".to_string(); let value = "value".as_bytes().to_vec(); - let locks = client - .acquire_locks(vec![(Keyspace::ActiveJobs, ""), (Keyspace::Slots, "")]) - .await?; + { + let _locks = client + .acquire_locks(vec![(Keyspace::JobStatus, ""), (Keyspace::Slots, "")]) + .await?; - let _r: ballista_core::error::Result<()> = with_locks(locks, async { let txn_ops = vec![ (Operation::Put(value.clone()), Keyspace::Slots, key.clone()), ( Operation::Put(value.clone()), - Keyspace::ActiveJobs, + Keyspace::JobStatus, key.clone(), ), ]; client.apply_txn(txn_ops).await?; - Ok(()) - }) - .await; + } assert_eq!(client.get(Keyspace::Slots, key.as_str()).await?, value); - assert_eq!(client.get(Keyspace::ActiveJobs, key.as_str()).await?, value); + assert_eq!(client.get(Keyspace::JobStatus, key.as_str()).await?, value); Ok(()) } @@ -368,7 +366,7 @@ mod tests { let client = create_instance()?; let key = "key"; let value = "value".as_bytes(); - let mut watch: Box = + let mut watch: Box> = client.watch(Keyspace::Slots, key.to_owned()).await?; client .put(Keyspace::Slots, key.to_owned(), value.to_vec()) diff --git a/ballista/scheduler/src/cluster/test/mod.rs b/ballista/scheduler/src/cluster/test/mod.rs new file mode 100644 index 000000000..4d4001c63 --- /dev/null +++ b/ballista/scheduler/src/cluster/test/mod.rs @@ -0,0 +1,597 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cluster::{ClusterState, JobState, JobStateEvent, TaskDistribution}; +use crate::scheduler_server::timestamp_millis; +use crate::state::execution_graph::ExecutionGraph; +use crate::state::executor_manager::ExecutorReservation; +use crate::test_utils::{await_condition, mock_completed_task, mock_executor}; +use ballista_core::error::{BallistaError, Result}; +use ballista_core::serde::protobuf::job_status::Status; +use ballista_core::serde::protobuf::{executor_status, ExecutorHeartbeat, JobStatus}; +use ballista_core::serde::scheduler::{ + ExecutorData, ExecutorMetadata, ExecutorSpecification, +}; +use dashmap::DashMap; +use futures::StreamExt; +use itertools::Itertools; +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::RwLock; + +pub struct ClusterStateTest { + state: Arc, + received_heartbeats: Arc>, + reservations: Vec, + total_task_slots: u32, +} + +impl ClusterStateTest { + pub async fn new(state: S) -> Result { + let received_heartbeats = Arc::new(DashMap::new()); + + let mut heartbeat_stream = state.executor_heartbeat_stream().await?; + let received_heartbeat_clone = received_heartbeats.clone(); + + tokio::spawn(async move { + while let Some(heartbeat) = heartbeat_stream.next().await { + received_heartbeat_clone.insert(heartbeat.executor_id.clone(), heartbeat); + } + }); + + Ok(Self { + state: Arc::new(state), + received_heartbeats, + reservations: vec![], + total_task_slots: 0, + }) + } + + pub async fn register_executor( + mut self, + executor_id: &str, + task_slots: u32, + ) -> Result { + self.state + .register_executor( + ExecutorMetadata { + id: executor_id.to_string(), + host: executor_id.to_string(), + port: 0, + grpc_port: 0, + specification: ExecutorSpecification { task_slots }, + }, + ExecutorData { + executor_id: executor_id.to_string(), + total_task_slots: task_slots, + available_task_slots: task_slots, + }, + false, + ) + .await?; + + self.total_task_slots += task_slots; + + Ok(self) + } + + pub async fn remove_executor(self, executor_id: &str) -> Result { + self.state.remove_executor(executor_id).await?; + + Ok(self) + } + + pub async fn assert_live_executor( + self, + executor_id: &str, + task_slots: u32, + ) -> Result { + let executor = self.state.get_executor_metadata(executor_id).await; + assert!( + executor.is_ok(), + "Metadata for executor {} not found in state", + executor_id + ); + assert_eq!( + executor.unwrap().specification.task_slots, + task_slots, + "Unexpected number of task slots for executor" + ); + + // Heratbeat stream is async so wait up to 500ms for it to show up + await_condition(Duration::from_millis(50), 10, || { + let found_heartbeat = + self.received_heartbeats.get(executor_id).map(|heartbeat| { + matches!( + heartbeat.status, + Some(ballista_core::serde::generated::ballista::ExecutorStatus { + status: Some(executor_status::Status::Active(_)) + }) + ) + }); + + futures::future::ready(Ok(found_heartbeat.unwrap_or_default())) + }) + .await?; + + Ok(self) + } + + pub async fn assert_dead_executor(self, executor_id: &str) -> Result { + // Heratbeat stream is async so wait up to 500ms for it to show up + await_condition(Duration::from_millis(50), 10, || { + let found_heartbeat = + self.received_heartbeats.get(executor_id).map(|heartbeat| { + matches!( + heartbeat.status, + Some(ballista_core::serde::generated::ballista::ExecutorStatus { + status: Some(executor_status::Status::Dead(_)) + }) + ) + }); + + futures::future::ready(Ok(found_heartbeat.unwrap_or_default())) + }) + .await?; + + Ok(self) + } + + pub async fn try_reserve_slots( + mut self, + num_slots: u32, + distribution: TaskDistribution, + filter: Option>, + exact: bool, + ) -> Result { + let filter = filter.map(|f| f.into_iter().collect::>()); + let reservations = if exact { + self.state + .reserve_slots_exact(num_slots, distribution, filter) + .await? + } else { + self.state + .reserve_slots(num_slots, distribution, filter) + .await? + }; + + self.reservations.extend(reservations); + + Ok(self) + } + + pub async fn cancel_reservations(mut self, num_slots: usize) -> Result { + if self.reservations.len() < num_slots { + return Err(BallistaError::General(format!( + "Not enough reservations to cancel, expected {} but found {}", + num_slots, + self.reservations.len() + ))); + } + + let to_keep = self.reservations.split_off(num_slots); + + self.state + .cancel_reservations(std::mem::take(&mut self.reservations)) + .await?; + + self.reservations = to_keep; + + Ok(self) + } + + pub fn assert_open_reservations(self, n: usize) -> Self { + assert_eq!( + self.reservations.len(), + n, + "Expectedt {} open reservations but found {}", + n, + self.reservations.len() + ); + self + } + + pub fn assert_open_reservations_with bool>( + self, + n: usize, + predicate: F, + ) -> Self { + assert_eq!( + self.reservations.len(), + n, + "Expected {} open reservations but found {}", + n, + self.reservations.len() + ); + + for res in &self.reservations { + assert!(predicate(res), "Predicate failed on reservation {:?}", res); + } + self + } + + pub async fn fuzz_reservation( + mut self, + concurrency: usize, + distribution: TaskDistribution, + ) -> Result<()> { + let (sender, mut receiver) = tokio::sync::mpsc::channel(1_000); + + let total_slots = self.total_task_slots; + for _ in 0..concurrency { + let state = self.state.clone(); + let sender_clone = sender.clone(); + tokio::spawn(async move { + let mut open_reservations = vec![]; + for i in 0..10 { + if i % 2 == 0 { + let to_reserve = rand::random::() % total_slots; + + let reservations = state + .reserve_slots(to_reserve, distribution, None) + .await + .unwrap(); + + sender_clone + .send(FuzzEvent::Reserved(reservations.clone())) + .await + .unwrap(); + + open_reservations = reservations; + } else { + state + .cancel_reservations(open_reservations.clone()) + .await + .unwrap(); + sender_clone + .send(FuzzEvent::Cancelled(std::mem::take( + &mut open_reservations, + ))) + .await + .unwrap(); + } + } + }); + } + + drop(sender); + + while let Some(event) = receiver.recv().await { + match event { + FuzzEvent::Reserved(reservations) => { + self.reservations.extend(reservations); + assert!( + self.reservations.len() <= total_slots as usize, + "More than total number of slots was reserved" + ); + } + FuzzEvent::Cancelled(reservations) => { + for res in reservations { + let idx = self + .reservations + .iter() + .find_position(|r| r.executor_id == res.executor_id); + assert!(idx.is_some(), "Received invalid cancellation, not existing reservation for executor ID {}", res.executor_id); + + self.reservations.swap_remove(idx.unwrap().0); + } + } + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone)] +enum FuzzEvent { + Reserved(Vec), + Cancelled(Vec), +} + +pub async fn test_fuzz_reservations( + state: S, + concurrency: usize, + distribution: TaskDistribution, + num_executors: usize, + task_slots_per_executor: usize, +) -> Result<()> { + let mut test = ClusterStateTest::new(state).await?; + + for idx in 0..num_executors { + test = test + .register_executor(idx.to_string().as_str(), task_slots_per_executor as u32) + .await?; + } + + test.fuzz_reservation(concurrency, distribution).await +} + +pub async fn test_executor_registration(state: S) -> Result<()> { + let test = ClusterStateTest::new(state).await?; + + test.register_executor("1", 10) + .await? + .register_executor("2", 10) + .await? + .register_executor("3", 10) + .await? + .assert_live_executor("1", 10) + .await? + .assert_live_executor("2", 10) + .await? + .assert_live_executor("3", 10) + .await? + .remove_executor("1") + .await? + .assert_dead_executor("1") + .await? + .remove_executor("2") + .await? + .assert_dead_executor("2") + .await? + .remove_executor("3") + .await? + .assert_dead_executor("3") + .await?; + + Ok(()) +} + +pub async fn test_reservation( + state: S, + distribution: TaskDistribution, +) -> Result<()> { + let test = ClusterStateTest::new(state).await?; + + test.register_executor("1", 10) + .await? + .register_executor("2", 10) + .await? + .register_executor("3", 10) + .await? + .try_reserve_slots(10, distribution, None, false) + .await? + .assert_open_reservations(10) + .cancel_reservations(10) + .await? + .try_reserve_slots(30, distribution, None, false) + .await? + .assert_open_reservations(30) + .cancel_reservations(15) + .await? + .assert_open_reservations(15) + .try_reserve_slots(30, distribution, None, false) + .await? + .assert_open_reservations(30) + .cancel_reservations(30) + .await? + .assert_open_reservations(0) + .try_reserve_slots(50, distribution, None, false) + .await? + .assert_open_reservations(30) + .cancel_reservations(30) + .await? + .try_reserve_slots(20, distribution, Some(vec!["1".to_string()]), false) + .await? + .assert_open_reservations_with(10, |res| res.executor_id == "1") + .cancel_reservations(10) + .await? + .try_reserve_slots( + 20, + distribution, + Some(vec!["2".to_string(), "3".to_string()]), + false, + ) + .await? + .assert_open_reservations_with(20, |res| { + res.executor_id == "2" || res.executor_id == "3" + }); + + Ok(()) +} + +pub struct JobStateTest { + state: Arc, + events: Arc>>, +} + +impl JobStateTest { + pub async fn new(state: S) -> Result { + let events = Arc::new(RwLock::new(vec![])); + + let mut event_stream = state.job_state_events().await?; + let events_clone = events.clone(); + tokio::spawn(async move { + while let Some(event) = event_stream.next().await { + let mut guard = events_clone.write().await; + + guard.push(event); + } + }); + + Ok(Self { + state: Arc::new(state), + events, + }) + } + + pub async fn queue_job(self, job_id: &str) -> Result { + self.state + .accept_job(job_id, "", timestamp_millis()) + .await?; + Ok(self) + } + + pub async fn fail_planning(self, job_id: &str) -> Result { + self.state + .fail_unscheduled_job(job_id, "failed planning".to_string()) + .await?; + Ok(self) + } + + pub async fn assert_queued(self, job_id: &str) -> Result { + let status = self.state.get_job_status(job_id).await?; + + assert!(status.is_some(), "Queued job {} not found", job_id); + + let status = status.unwrap(); + assert!( + matches!(&status, JobStatus { + job_id: status_job_id, status: Some(Status::Queued(_)), .. + } if status_job_id.as_str() == job_id), + "Expected queued status but found {:?}", + status + ); + + Ok(self) + } + + pub async fn submit_job(self, graph: &ExecutionGraph) -> Result { + self.state + .submit_job(graph.job_id().to_string(), graph) + .await?; + Ok(self) + } + + pub async fn assert_job_running(self, job_id: &str) -> Result { + let status = self.state.get_job_status(job_id).await?; + + assert!(status.is_some(), "Job status not found for {}", job_id); + + let status = status.unwrap(); + assert!( + matches!(&status, JobStatus { + job_id: status_job_id, status: Some(Status::Running(_)), .. + } if status_job_id.as_str() == job_id), + "Expected running status but found {:?}", + status + ); + + Ok(self) + } + + pub async fn update_job(self, graph: &ExecutionGraph) -> Result { + self.state.save_job(graph.job_id(), graph).await?; + Ok(self) + } + + pub async fn assert_job_failed(self, job_id: &str) -> Result { + let status = self.state.get_job_status(job_id).await?; + + assert!(status.is_some(), "Job status not found for {}", job_id); + + let status = status.unwrap(); + assert!( + matches!(&status, JobStatus { + job_id: status_job_id, status: Some(Status::Failed(_)), .. + } if status_job_id.as_str() == job_id), + "Expected failed status but found {:?}", + status + ); + + Ok(self) + } + + pub async fn assert_job_successful(self, job_id: &str) -> Result { + let status = self.state.get_job_status(job_id).await?; + + assert!(status.is_some(), "Job status not found for {}", job_id); + let status = status.unwrap(); + assert!( + matches!(&status, JobStatus { + job_id: status_job_id, status: Some(Status::Successful(_)), .. + } if status_job_id.as_str() == job_id), + "Expected success status but found {:?}", + status + ); + + Ok(self) + } + + pub async fn assert_event(self, event: JobStateEvent) -> Result { + let events = self.events.clone(); + let found = await_condition(Duration::from_millis(50), 10, || async { + let guard = events.read().await; + + Ok(guard.iter().any(|ev| ev == &event)) + }) + .await?; + + assert!(found, "Expected event {:?}", event); + + Ok(self) + } +} + +pub async fn test_job_lifecycle( + state: S, + mut graph: ExecutionGraph, +) -> Result<()> { + let test = JobStateTest::new(state).await?; + + let job_id = graph.job_id().to_string(); + + let test = test + .queue_job(&job_id) + .await? + .assert_queued(&job_id) + .await? + .submit_job(&graph) + .await? + .assert_job_running(&job_id) + .await?; + + drain_tasks(&mut graph)?; + graph.succeed_job()?; + + test.update_job(&graph) + .await? + .assert_job_successful(&job_id) + .await?; + + Ok(()) +} + +pub async fn test_job_planning_failure( + state: S, + graph: ExecutionGraph, +) -> Result<()> { + let test = JobStateTest::new(state).await?; + + let job_id = graph.job_id().to_string(); + + test.queue_job(&job_id) + .await? + .fail_planning(&job_id) + .await? + .assert_job_failed(&job_id) + .await?; + + Ok(()) +} + +fn drain_tasks(graph: &mut ExecutionGraph) -> Result<()> { + let executor = mock_executor("executor-id1".to_string()); + while let Some(task) = graph.pop_next_task(&executor.id)? { + let task_status = mock_completed_task(task, &executor.id); + graph.update_task_status(&executor, vec![task_status], 1, 1)?; + } + + Ok(()) +} diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index 020d1c83b..ddd05b786 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -25,6 +25,13 @@ use std::fmt; /// Configurations for the ballista scheduler of scheduling jobs and tasks #[derive(Debug, Clone)] pub struct SchedulerConfig { + /// Namespace of this scheduler. Schedulers using the same cluster storage and namespace + /// will share gloabl cluster state. + pub namespace: String, + /// The external hostname of the scheduler + pub external_host: String, + /// The bind port for the scheduler's gRPC service + pub bind_port: u16, /// The task scheduling policy for the scheduler pub scheduling_policy: TaskSchedulingPolicy, /// The event loop buffer size. for a system of high throughput, a larger value like 1000000 is recommended @@ -40,27 +47,52 @@ pub struct SchedulerConfig { /// If provided, submitted jobs which do not have tasks scheduled will be resubmitted after `job_resubmit_interval_ms` /// milliseconds pub job_resubmit_interval_ms: Option, + /// Configuration for ballista cluster storage + pub cluster_storage: ClusterStorageConfig, } impl Default for SchedulerConfig { fn default() -> Self { Self { + namespace: String::default(), + external_host: "localhost".to_string(), + bind_port: 50050, scheduling_policy: TaskSchedulingPolicy::PullStaged, event_loop_buffer_size: 10000, executor_slots_policy: SlotsPolicy::Bias, finished_job_data_clean_up_interval_seconds: 300, finished_job_state_clean_up_interval_seconds: 3600, advertise_flight_sql_endpoint: None, + cluster_storage: ClusterStorageConfig::Memory, job_resubmit_interval_ms: None, } } } impl SchedulerConfig { + pub fn scheduler_name(&self) -> String { + format!("{}:{}", self.external_host, self.bind_port) + } + pub fn is_push_staged_scheduling(&self) -> bool { matches!(self.scheduling_policy, TaskSchedulingPolicy::PushStaged) } + pub fn with_namespace(mut self, namespace: impl Into) -> Self { + self.namespace = namespace.into(); + self + } + + pub fn with_hostname(mut self, hostname: impl Into) -> Self { + self.external_host = hostname.into(); + self + } + + pub fn with_port(mut self, port: u16) -> Self { + self.bind_port = port; + self + } + pub fn with_scheduler_policy(mut self, policy: TaskSchedulingPolicy) -> Self { self.scheduling_policy = policy; self @@ -100,12 +132,26 @@ impl SchedulerConfig { self } + pub fn with_cluster_storage(mut self, config: ClusterStorageConfig) -> Self { + self.cluster_storage = config; + self + } + pub fn with_job_resubmit_interval_ms(mut self, interval_ms: u64) -> Self { self.job_resubmit_interval_ms = Some(interval_ms); self } } +#[derive(Clone, Debug)] +pub enum ClusterStorageConfig { + Memory, + #[cfg(feature = "etcd")] + Etcd(Vec), + #[cfg(feature = "sled")] + Sled(Option), +} + // an enum used to configure the executor slots policy // needs to be visible to code generated by configure_me #[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)] diff --git a/ballista/scheduler/src/lib.rs b/ballista/scheduler/src/lib.rs index da21296f5..cd6f047b2 100644 --- a/ballista/scheduler/src/lib.rs +++ b/ballista/scheduler/src/lib.rs @@ -18,6 +18,7 @@ #![doc = include_str ! ("../README.md")] pub mod api; +pub mod cluster; pub mod config; pub mod display; pub mod metrics; diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index a44568671..e20eed8a9 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -22,7 +22,7 @@ use futures::future::{self, Either, TryFutureExt}; use hyper::{server::conn::AddrStream, service::make_service_fn, Server}; use log::info; use std::convert::Infallible; -use std::{net::SocketAddr, sync::Arc}; +use std::net::SocketAddr; use tonic::transport::server::Connected; use tower::Service; @@ -34,18 +34,15 @@ use ballista_core::utils::create_grpc_server; use ballista_core::BALLISTA_VERSION; use crate::api::{get_routes, EitherBody, Error}; +use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::flight_sql::FlightSqlServiceImpl; use crate::metrics::default_metrics_collector; use crate::scheduler_server::externalscaler::external_scaler_server::ExternalScalerServer; use crate::scheduler_server::SchedulerServer; -use crate::state::backend::cluster::ClusterState; -use crate::state::backend::StateBackendClient; pub async fn start_server( - scheduler_name: String, - config_backend: Arc, - cluster_state: Arc, + cluster: BallistaCluster, addr: SocketAddr, config: SchedulerConfig, ) -> Result<()> { @@ -63,9 +60,8 @@ pub async fn start_server( let mut scheduler_server: SchedulerServer = SchedulerServer::new( - scheduler_name, - config_backend.clone(), - cluster_state, + config.scheduler_name(), + cluster, BallistaCodec::default(), config, metrics_collector, diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index cdfa9cc6c..bff078d01 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -583,7 +583,7 @@ impl SchedulerGrpc #[cfg(all(test, feature = "sled"))] mod test { - use std::sync::Arc; + use std::time::Duration; use datafusion_proto::protobuf::LogicalPlanNode; @@ -600,23 +600,21 @@ mod test { }; use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::serde::BallistaCodec; - use ballista_core::utils::default_session_builder; - use crate::state::backend::cluster::DefaultClusterState; use crate::state::executor_manager::DEFAULT_EXECUTOR_TIMEOUT_SECONDS; - use crate::state::{backend::sled::SledClient, SchedulerState}; + use crate::state::SchedulerState; + use crate::test_utils::test_cluster_context; use super::{SchedulerGrpc, SchedulerServer}; #[tokio::test] async fn test_poll_work() -> Result<(), BallistaError> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage.clone(), - cluster_state.clone(), + cluster.clone(), BallistaCodec::default(), SchedulerConfig::default(), default_metrics_collector().unwrap(), @@ -643,9 +641,7 @@ mod test { assert!(response.tasks.is_empty()); let state: SchedulerState = SchedulerState::new_with_default_scheduler_name( - state_storage.clone(), - cluster_state.clone(), - default_session_builder, + cluster.clone(), BallistaCodec::default(), ); state.init().await?; @@ -677,9 +673,7 @@ mod test { assert!(response.tasks.is_empty()); let state: SchedulerState = SchedulerState::new_with_default_scheduler_name( - state_storage.clone(), - cluster_state, - default_session_builder, + cluster.clone(), BallistaCodec::default(), ); state.init().await?; @@ -701,13 +695,12 @@ mod test { #[tokio::test] async fn test_stop_executor() -> Result<(), BallistaError> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster.clone(), BallistaCodec::default(), SchedulerConfig::default(), default_metrics_collector().unwrap(), @@ -783,13 +776,12 @@ mod test { #[tokio::test] async fn test_register_executor_in_heartbeat_service() -> Result<(), BallistaError> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster, BallistaCodec::default(), SchedulerConfig::default(), default_metrics_collector().unwrap(), @@ -836,13 +828,12 @@ mod test { #[tokio::test] #[ignore] async fn test_expired_executor() -> Result<(), BallistaError> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster.clone(), BallistaCodec::default(), SchedulerConfig::default(), default_metrics_collector().unwrap(), diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index ee4f86e78..de37365b9 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -22,7 +22,6 @@ use ballista_core::error::Result; use ballista_core::event_loop::{EventLoop, EventSender}; use ballista_core::serde::protobuf::{StopExecutorParams, TaskStatus}; use ballista_core::serde::BallistaCodec; -use ballista_core::utils::default_session_builder; use datafusion::execution::context::SessionState; use datafusion::logical_expr::LogicalPlan; @@ -30,6 +29,7 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; +use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::metrics::SchedulerMetricsCollector; use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; @@ -37,8 +37,7 @@ use log::{error, warn}; use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::scheduler_server::query_stage_scheduler::QueryStageScheduler; -use crate::state::backend::cluster::ClusterState; -use crate::state::backend::StateBackendClient; + use crate::state::executor_manager::{ ExecutorManager, ExecutorReservation, DEFAULT_EXECUTOR_TIMEOUT_SECONDS, }; @@ -71,53 +70,13 @@ pub struct SchedulerServer SchedulerServer { pub fn new( scheduler_name: String, - config_backend: Arc, - cluster_state: Arc, + cluster: BallistaCluster, codec: BallistaCodec, config: SchedulerConfig, metrics_collector: Arc, ) -> Self { let state = Arc::new(SchedulerState::new( - config_backend, - cluster_state, - default_session_builder, - codec, - scheduler_name.clone(), - config.clone(), - )); - let query_stage_scheduler = Arc::new(QueryStageScheduler::new( - state.clone(), - metrics_collector, - config.job_resubmit_interval_ms, - )); - let query_stage_event_loop = EventLoop::new( - "query_stage".to_owned(), - config.event_loop_buffer_size as usize, - query_stage_scheduler.clone(), - ); - - Self { - scheduler_name, - start_time: timestamp_millis() as u128, - state, - query_stage_event_loop, - query_stage_scheduler, - } - } - - pub fn with_session_builder( - scheduler_name: String, - config_backend: Arc, - cluster_backend: Arc, - codec: BallistaCodec, - config: SchedulerConfig, - session_builder: SessionBuilder, - metrics_collector: Arc, - ) -> Self { - let state = Arc::new(SchedulerState::new( - config_backend, - cluster_backend, - session_builder, + cluster, codec, scheduler_name.clone(), config.clone(), @@ -143,19 +102,16 @@ impl SchedulerServer, - cluster_backend: Arc, + cluster: BallistaCluster, codec: BallistaCodec, config: SchedulerConfig, metrics_collector: Arc, task_launcher: Arc, ) -> Self { - let state = Arc::new(SchedulerState::with_task_launcher( - config_backend, - cluster_backend, - default_session_builder, + let state = Arc::new(SchedulerState::new_with_task_launcher( + cluster, codec, scheduler_name.clone(), config.clone(), @@ -406,13 +362,11 @@ mod test { use ballista_core::serde::BallistaCodec; use crate::scheduler_server::{timestamp_millis, SchedulerServer}; - use crate::state::backend::cluster::DefaultClusterState; - use crate::state::backend::sled::SledClient; use crate::test_utils::{ assert_completed_event, assert_failed_event, assert_no_submitted_event, - assert_submitted_event, ExplodingTableProvider, SchedulerTest, TaskRunnerFn, - TestMetricsCollector, + assert_submitted_event, test_cluster_context, ExplodingTableProvider, + SchedulerTest, TaskRunnerFn, TestMetricsCollector, }; #[tokio::test] @@ -441,6 +395,13 @@ mod test { let job_id = "job"; + // Enqueue job + scheduler + .state + .task_manager + .queue_job(job_id, "", timestamp_millis()) + .await?; + // Submit job scheduler .state @@ -453,7 +414,6 @@ mod test { .state .task_manager .get_active_execution_graph(job_id) - .await { let task = { let mut graph = graph.write().await; @@ -507,7 +467,6 @@ mod test { .state .task_manager .get_active_execution_graph(job_id) - .await .expect("Fail to find graph in the cache"); let final_graph = final_graph.read().await; @@ -543,6 +502,7 @@ mod test { match status.status { Some(job_status::Status::Successful(SuccessfulJob { partition_location, + .. })) => { assert_eq!(partition_location.len(), 4); } @@ -618,7 +578,8 @@ mod test { matches!( status, JobStatus { - status: Some(job_status::Status::Failed(_)) + status: Some(job_status::Status::Failed(_)), + .. } ), "{}", @@ -662,7 +623,8 @@ mod test { matches!( status, JobStatus { - status: Some(job_status::Status::Failed(_)) + status: Some(job_status::Status::Failed(_)), + .. } ), "{}", @@ -678,13 +640,12 @@ mod test { async fn test_scheduler( scheduling_policy: TaskSchedulingPolicy, ) -> Result> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster, BallistaCodec::default(), SchedulerConfig::default().with_scheduler_policy(scheduling_policy), Arc::new(TestMetricsCollector::default()), diff --git a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs index 25454188b..3b0354702 100644 --- a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs +++ b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs @@ -104,6 +104,11 @@ impl } => { info!("Job {} queued with name {:?}", job_id, job_name); + self.state + .task_manager + .queue_job(&job_id, &job_name, queued_at) + .await?; + let state = self.state.clone(); tokio::spawn(async move { let event = if let Err(e) = state @@ -272,6 +277,11 @@ impl .await?; } QueryStageSchedulerEvent::TaskUpdating(executor_id, tasks_status) => { + debug!( + "processing task status updates from {executor_id}: {:?}", + tasks_status + ); + let num_status = tasks_status.len(); match self .state @@ -363,6 +373,7 @@ mod tests { use datafusion::test_util::scan_empty_with_partitions; use std::sync::Arc; use std::time::Duration; + use tracing_subscriber::EnvFilter; #[tokio::test] async fn test_job_resubmit() -> Result<()> { @@ -409,6 +420,10 @@ mod tests { #[tokio::test] async fn test_pending_task_metric() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + let plan = test_plan(10); let metrics_collector = Arc::new(TestMetricsCollector::default()); @@ -450,7 +465,7 @@ mod tests { test.tick().await?; // Job should be finished now - let _ = test.await_completion("job-1").await?; + let _ = test.await_completion_timeout("job-1", 5_000).await?; Ok(()) } diff --git a/ballista/scheduler/src/standalone.rs b/ballista/scheduler/src/standalone.rs index 4438c44a9..5334ae1d9 100644 --- a/ballista/scheduler/src/standalone.rs +++ b/ballista/scheduler/src/standalone.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::metrics::default_metrics_collector; -use crate::state::backend::cluster::DefaultClusterState; -use crate::{scheduler_server::SchedulerServer, state::backend::sled::SledClient}; +use crate::{cluster::storage::sled::SledClient, scheduler_server::SchedulerServer}; use ballista_core::serde::BallistaCodec; -use ballista_core::utils::create_grpc_server; +use ballista_core::utils::{create_grpc_server, default_session_builder}; use ballista_core::{ error::Result, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer, BALLISTA_VERSION, @@ -28,23 +28,28 @@ use ballista_core::{ use datafusion_proto::protobuf::LogicalPlanNode; use datafusion_proto::protobuf::PhysicalPlanNode; use log::info; -use std::{net::SocketAddr, sync::Arc}; +use std::net::SocketAddr; use tokio::net::TcpListener; pub async fn new_standalone_scheduler() -> Result { - let backend = Arc::new(SledClient::try_new_temporary()?); - let metrics_collector = default_metrics_collector()?; + let cluster = BallistaCluster::new_kv( + SledClient::try_new_temporary()?, + "localhost:50050", + default_session_builder, + BallistaCodec::default(), + ); + let mut scheduler_server: SchedulerServer = SchedulerServer::new( "localhost:50050".to_owned(), - backend.clone(), - Arc::new(DefaultClusterState::new(backend)), + cluster, BallistaCodec::default(), SchedulerConfig::default(), metrics_collector, ); + scheduler_server.init().await?; let server = SchedulerGrpcServer::new(scheduler_server.clone()); // Let the OS assign a random, free port diff --git a/ballista/scheduler/src/state/backend/cluster.rs b/ballista/scheduler/src/state/backend/cluster.rs deleted file mode 100644 index 7e02697d7..000000000 --- a/ballista/scheduler/src/state/backend/cluster.rs +++ /dev/null @@ -1,700 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::state::backend::{ - Keyspace, Operation, StateBackendClient, TaskDistribution, WatchEvent, -}; -use crate::state::executor_manager::ExecutorReservation; -use crate::state::{decode_into, decode_protobuf, encode_protobuf, with_lock}; -use ballista_core::error; -use ballista_core::error::BallistaError; -use ballista_core::serde::protobuf; -use ballista_core::serde::protobuf::{ - executor_status, ExecutorHeartbeat, ExecutorStatus, -}; -use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; -use futures::{Stream, StreamExt}; -use log::{debug, info}; -use std::collections::{HashMap, HashSet}; -use std::pin::Pin; -use std::sync::Arc; -use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; - -pub type ExecutorHeartbeatStream = Pin + Send>>; - -/// A trait that contains the necessary method to maintain a globally consistent view of cluster resources -#[tonic::async_trait] -pub trait ClusterState: Send + Sync { - /// Reserve up to `num_slots` executor task slots. If not enough task slots are available, reserve - /// as many as possible. - /// - /// If `executors` is provided, only reserve slots of the specified executor IDs - async fn reserve_slots( - &self, - num_slots: u32, - distribution: TaskDistribution, - executors: Option>, - ) -> error::Result>; - - /// Reserve exactly `num_slots` executor task slots. If not enough task slots are available, - /// returns an empty vec - /// - /// If `executors` is provided, only reserve slots of the specified executor IDs - async fn reserve_slots_exact( - &self, - num_slots: u32, - distribution: TaskDistribution, - executors: Option>, - ) -> error::Result>; - - /// Cancel the specified reservations. This will make reserved executor slots available to other - /// tasks. - /// This operations should be atomic. Either all reservations are cancelled or none are - async fn cancel_reservations( - &self, - reservations: Vec, - ) -> error::Result<()>; - - /// Register a new executor in the cluster. If `reserve` is true, then the executors task slots - /// will be reserved and returned in the response and none of the new executors task slots will be - /// available to other tasks. - async fn register_executor( - &self, - metadata: ExecutorMetadata, - spec: ExecutorData, - reserve: bool, - ) -> error::Result>; - - /// Save the executor metadata. This will overwrite existing metadata for the executor ID - async fn save_executor_metadata( - &self, - metadata: ExecutorMetadata, - ) -> error::Result<()>; - - /// Get executor metadata for the provided executor ID. Returns an error if the executor does not exist - async fn get_executor_metadata( - &self, - executor_id: &str, - ) -> error::Result; - - /// Save the executor heartbeat - async fn save_executor_heartbeat( - &self, - heartbeat: ExecutorHeartbeat, - ) -> error::Result<()>; - - /// Remove the executor from the cluster - async fn remove_executor(&self, executor_id: &str) -> error::Result<()>; - - /// Return the stream of executor heartbeats observed by all schedulers in the cluster. - /// This can be aggregated to provide an eventually consistent view of all executors within the cluster - async fn executor_heartbeat_stream(&self) -> error::Result; - - /// Return a map of the last seen heartbeat for all active executors - async fn executor_heartbeats( - &self, - ) -> error::Result>; -} - -/// Default implementation of `ClusterState` that can use the key-value interface defined in -/// `StateBackendClient -pub struct DefaultClusterState { - kv_store: Arc, -} - -impl DefaultClusterState { - pub fn new(kv_store: Arc) -> Self { - Self { kv_store } - } -} - -#[tonic::async_trait] -impl ClusterState for DefaultClusterState { - async fn reserve_slots( - &self, - num_slots: u32, - distribution: TaskDistribution, - executors: Option>, - ) -> error::Result> { - let lock = self.kv_store.lock(Keyspace::Slots, "global").await?; - - with_lock(lock, async { - debug!("Attempting to reserve {} executor slots", num_slots); - let start = Instant::now(); - - let executors = match executors { - Some(executors) => executors, - None => { - let heartbeats = self.executor_heartbeats().await?; - - get_alive_executors(60, heartbeats)? - } - }; - - let (reservations, txn_ops) = match distribution { - TaskDistribution::Bias => { - reserve_slots_bias(self.kv_store.as_ref(), num_slots, executors) - .await? - } - TaskDistribution::RoundRobin => { - reserve_slots_round_robin( - self.kv_store.as_ref(), - num_slots, - executors, - ) - .await? - } - }; - - self.kv_store.apply_txn(txn_ops).await?; - - let elapsed = start.elapsed(); - info!( - "Reserved {} executor slots in {:?}", - reservations.len(), - elapsed - ); - - Ok(reservations) - }) - .await - } - - async fn reserve_slots_exact( - &self, - num_slots: u32, - distribution: TaskDistribution, - executors: Option>, - ) -> error::Result> { - let lock = self.kv_store.lock(Keyspace::Slots, "global").await?; - - with_lock(lock, async { - debug!("Attempting to reserve {} executor slots", num_slots); - let start = Instant::now(); - - let executors = match executors { - Some(executors) => executors, - None => { - let heartbeats = self.executor_heartbeats().await?; - - get_alive_executors(60, heartbeats)? - } - }; - - let (reservations, txn_ops) = match distribution { - TaskDistribution::Bias => { - reserve_slots_bias(self.kv_store.as_ref(), num_slots, executors) - .await? - } - TaskDistribution::RoundRobin => { - reserve_slots_round_robin( - self.kv_store.as_ref(), - num_slots, - executors, - ) - .await? - } - }; - - let elapsed = start.elapsed(); - if reservations.len() as u32 == num_slots { - self.kv_store.apply_txn(txn_ops).await?; - - info!( - "Reserved {} executor slots in {:?}", - reservations.len(), - elapsed - ); - - Ok(reservations) - } else { - info!( - "Failed to reserve exactly {} executor slots in {:?}", - reservations.len(), - elapsed - ); - - Ok(vec![]) - } - }) - .await - } - - async fn cancel_reservations( - &self, - reservations: Vec, - ) -> error::Result<()> { - let lock = self.kv_store.lock(Keyspace::Slots, "global").await?; - - with_lock(lock, async { - let num_reservations = reservations.len(); - debug!("Cancelling {} reservations", num_reservations); - let start = Instant::now(); - - let mut executor_slots: HashMap = HashMap::new(); - - for reservation in reservations { - let executor_id = &reservation.executor_id; - if let Some(data) = executor_slots.get_mut(executor_id) { - data.available_task_slots += 1; - } else { - let value = self.kv_store.get(Keyspace::Slots, executor_id).await?; - let mut data = - decode_into::(&value)?; - data.available_task_slots += 1; - executor_slots.insert(executor_id.clone(), data); - } - } - - let txn_ops: Vec<(Operation, Keyspace, String)> = executor_slots - .into_iter() - .map(|(executor_id, data)| { - let proto: protobuf::ExecutorData = data.into(); - let new_data = encode_protobuf(&proto)?; - Ok((Operation::Put(new_data), Keyspace::Slots, executor_id)) - }) - .collect::>>()?; - - self.kv_store.apply_txn(txn_ops).await?; - - let elapsed = start.elapsed(); - info!( - "Cancelled {} reservations in {:?}", - num_reservations, elapsed - ); - - Ok(()) - }) - .await - } - - async fn register_executor( - &self, - metadata: ExecutorMetadata, - spec: ExecutorData, - reserve: bool, - ) -> error::Result> { - let executor_id = metadata.id.clone(); - - let current_ts = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|e| { - BallistaError::Internal(format!("Error getting current timestamp: {e:?}")) - })? - .as_secs(); - - //TODO this should be in a transaction - // Now that we know we can connect, save the metadata and slots - self.save_executor_metadata(metadata).await?; - self.save_executor_heartbeat(ExecutorHeartbeat { - executor_id: executor_id.clone(), - timestamp: current_ts, - metrics: vec![], - status: Some(ExecutorStatus { - status: Some(executor_status::Status::Active("".to_string())), - }), - }) - .await?; - - if !reserve { - let proto: protobuf::ExecutorData = spec.into(); - let value = encode_protobuf(&proto)?; - self.kv_store - .put(Keyspace::Slots, executor_id, value) - .await?; - Ok(vec![]) - } else { - let mut specification = spec; - let num_slots = specification.available_task_slots as usize; - let mut reservations: Vec = vec![]; - for _ in 0..num_slots { - reservations.push(ExecutorReservation::new_free(executor_id.clone())); - } - - specification.available_task_slots = 0; - - let proto: protobuf::ExecutorData = specification.into(); - let value = encode_protobuf(&proto)?; - self.kv_store - .put(Keyspace::Slots, executor_id, value) - .await?; - Ok(reservations) - } - } - - async fn save_executor_metadata( - &self, - metadata: ExecutorMetadata, - ) -> error::Result<()> { - let executor_id = metadata.id.clone(); - let proto: protobuf::ExecutorMetadata = metadata.into(); - let value = encode_protobuf(&proto)?; - - self.kv_store - .put(Keyspace::Executors, executor_id, value) - .await - } - - async fn get_executor_metadata( - &self, - executor_id: &str, - ) -> error::Result { - let value = self.kv_store.get(Keyspace::Executors, executor_id).await?; - - // Throw error rather than panic if the executor metadata does not exist - if value.is_empty() { - Err(BallistaError::General(format!( - "The metadata of executor {executor_id} does not exist" - ))) - } else { - let decoded = - decode_into::(&value)?; - Ok(decoded) - } - } - - async fn save_executor_heartbeat( - &self, - heartbeat: ExecutorHeartbeat, - ) -> error::Result<()> { - let executor_id = heartbeat.executor_id.clone(); - let value = encode_protobuf(&heartbeat)?; - self.kv_store - .put(Keyspace::Heartbeats, executor_id, value) - .await - } - - async fn remove_executor(&self, executor_id: &str) -> error::Result<()> { - let current_ts = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|e| { - BallistaError::Internal(format!("Error getting current timestamp: {e:?}")) - })? - .as_secs(); - - let value = encode_protobuf(&ExecutorHeartbeat { - executor_id: executor_id.to_owned(), - timestamp: current_ts, - metrics: vec![], - status: Some(ExecutorStatus { - status: Some(executor_status::Status::Dead("".to_string())), - }), - })?; - self.kv_store - .put(Keyspace::Heartbeats, executor_id.to_owned(), value) - .await?; - - // TODO Check the Executor reservation logic for push-based scheduling - - Ok(()) - } - - async fn executor_heartbeat_stream(&self) -> error::Result { - let events = self - .kv_store - .watch(Keyspace::Heartbeats, String::default()) - .await?; - - Ok(events - .filter_map(|event| { - futures::future::ready(match event { - WatchEvent::Put(_, value) => { - if let Ok(heartbeat) = - decode_protobuf::(&value) - { - Some(heartbeat) - } else { - None - } - } - WatchEvent::Delete(_) => None, - }) - }) - .boxed()) - } - - async fn executor_heartbeats( - &self, - ) -> error::Result> { - let heartbeats = self.kv_store.scan(Keyspace::Heartbeats, None).await?; - - let mut heartbeat_map = HashMap::with_capacity(heartbeats.len()); - - for (_, value) in heartbeats { - let data: ExecutorHeartbeat = decode_protobuf(&value)?; - if let Some(ExecutorStatus { - status: Some(executor_status::Status::Active(_)), - }) = &data.status - { - heartbeat_map.insert(data.executor_id.clone(), data); - } - } - - Ok(heartbeat_map) - } -} - -/// Return the set of executor IDs which have heartbeated within `last_seen_threshold` seconds -fn get_alive_executors( - last_seen_threshold: u64, - heartbeats: HashMap, -) -> error::Result> { - let now_epoch_ts = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards"); - - let last_seen_threshold = now_epoch_ts - .checked_sub(Duration::from_secs(last_seen_threshold)) - .ok_or_else(|| { - BallistaError::Internal(format!( - "Error getting alive executors, invalid last_seen_threshold of {last_seen_threshold}" - )) - })? - .as_secs(); - - Ok(heartbeats - .iter() - .filter_map(|(exec, heartbeat)| { - (heartbeat.timestamp > last_seen_threshold).then(|| exec.clone()) - }) - .collect()) -} - -/// It will get ExecutorReservation from one executor as many as possible. -/// By this way, it can reduce the chance of decoding and encoding ExecutorData. -/// However, it may make the whole cluster unbalanced, -/// which means some executors may be very busy while other executors may be idle. -async fn reserve_slots_bias( - state: &dyn StateBackendClient, - mut n: u32, - executors: HashSet, -) -> error::Result<(Vec, Vec<(Operation, Keyspace, String)>)> { - let mut reservations: Vec = vec![]; - let mut txn_ops: Vec<(Operation, Keyspace, String)> = vec![]; - - for executor_id in executors { - if n == 0 { - break; - } - - let value = state.get(Keyspace::Slots, &executor_id).await?; - let mut data = decode_into::(&value)?; - let take = std::cmp::min(data.available_task_slots, n); - - for _ in 0..take { - reservations.push(ExecutorReservation::new_free(executor_id.clone())); - data.available_task_slots -= 1; - n -= 1; - } - - let proto: protobuf::ExecutorData = data.into(); - let new_data = encode_protobuf(&proto)?; - txn_ops.push((Operation::Put(new_data), Keyspace::Slots, executor_id)); - } - - Ok((reservations, txn_ops)) -} - -/// Create ExecutorReservation in a round robin way to evenly assign tasks to executors -async fn reserve_slots_round_robin( - state: &dyn StateBackendClient, - mut n: u32, - executors: HashSet, -) -> error::Result<(Vec, Vec<(Operation, Keyspace, String)>)> { - let mut reservations: Vec = vec![]; - let mut txn_ops: Vec<(Operation, Keyspace, String)> = vec![]; - - let all_executor_data = state - .scan(Keyspace::Slots, None) - .await? - .into_iter() - .map(|(_, data)| decode_into::(&data)) - .collect::>>()?; - - let mut available_executor_data: Vec = all_executor_data - .into_iter() - .filter_map(|data| { - (data.available_task_slots > 0 && executors.contains(&data.executor_id)) - .then_some(data) - }) - .collect(); - available_executor_data - .sort_by(|a, b| Ord::cmp(&b.available_task_slots, &a.available_task_slots)); - - // Exclusive - let mut last_updated_idx = 0usize; - loop { - let n_before = n; - for (idx, data) in available_executor_data.iter_mut().enumerate() { - if n == 0 { - break; - } - - // Since the vector is sorted in descending order, - // if finding one executor has not enough slots, the following will have not enough, either - if data.available_task_slots == 0 { - break; - } - - reservations.push(ExecutorReservation::new_free(data.executor_id.clone())); - data.available_task_slots -= 1; - n -= 1; - - if idx >= last_updated_idx { - last_updated_idx = idx + 1; - } - } - - if n_before == n { - break; - } - } - - for (idx, data) in available_executor_data.into_iter().enumerate() { - if idx >= last_updated_idx { - break; - } - let executor_id = data.executor_id.clone(); - let proto: protobuf::ExecutorData = data.into(); - let new_data = encode_protobuf(&proto)?; - txn_ops.push((Operation::Put(new_data), Keyspace::Slots, executor_id)); - } - - Ok((reservations, txn_ops)) -} - -#[cfg(test)] -mod tests { - use crate::state::backend::cluster::{ClusterState, DefaultClusterState}; - use crate::state::backend::sled::SledClient; - - use ballista_core::error::Result; - use ballista_core::serde::protobuf::{ - executor_status, ExecutorHeartbeat, ExecutorStatus, - }; - - use futures::StreamExt; - - use std::sync::Arc; - - #[tokio::test] - async fn test_heartbeat_stream() -> Result<()> { - let sled = Arc::new(SledClient::try_new_temporary()?); - - let cluster_state: Arc = - Arc::new(DefaultClusterState::new(sled)); - - for i in 0..10 { - let mut heartbeat_stream = cluster_state.executor_heartbeat_stream().await?; - - cluster_state - .save_executor_heartbeat(ExecutorHeartbeat { - executor_id: i.to_string(), - timestamp: 0, - metrics: vec![], - status: Some(ExecutorStatus { - status: Some(executor_status::Status::Active(String::default())), - }), - }) - .await?; - - let received = if let Some(event) = heartbeat_stream.next().await { - event.executor_id == i.to_string() - } else { - false - }; - - assert!(received, "{}", "Did not receive heartbeat for executor {i}"); - } - - Ok(()) - } - - #[tokio::test] - async fn test_heartbeats() -> Result<()> { - let sled = Arc::new(SledClient::try_new_temporary()?); - - let cluster_state: Arc = - Arc::new(DefaultClusterState::new(sled)); - - // Add 10 executor heartbeats - for i in 0..10 { - cluster_state - .save_executor_heartbeat(ExecutorHeartbeat { - executor_id: i.to_string(), - timestamp: i as u64, - metrics: vec![], - status: Some(ExecutorStatus { - status: Some(executor_status::Status::Active(String::default())), - }), - }) - .await?; - } - - let heartbeats = cluster_state.executor_heartbeats().await?; - - // Check that all 10 are present in the global view - for i in 0..10 { - let id = i.to_string(); - if let Some(hb) = heartbeats.get(&id) { - assert_eq!( - hb.executor_id, - i.to_string(), - "Expected heartbeat in map for {i}" - ); - assert_eq!(hb.timestamp, i, "Expected timestamp to be correct for {i}"); - } else { - panic!("Expected heartbeat for executor {}", i); - } - } - - // Send new heartbeat with updated timestamp - cluster_state - .save_executor_heartbeat(ExecutorHeartbeat { - executor_id: "0".to_string(), - timestamp: 100, - metrics: vec![], - status: Some(ExecutorStatus { - status: Some(executor_status::Status::Active(String::default())), - }), - }) - .await?; - - let heartbeats = cluster_state.executor_heartbeats().await?; - - if let Some(hb) = heartbeats.get("0") { - assert_eq!(hb.executor_id, "0", "Expected heartbeat in map for 0"); - assert_eq!(hb.timestamp, 100, "Expected timestamp to be updated for 0"); - } - - for i in 1..10 { - let id = i.to_string(); - if let Some(hb) = heartbeats.get(&id) { - assert_eq!( - hb.executor_id, - i.to_string(), - "Expected heartbeat in map for {i}" - ); - assert_eq!(hb.timestamp, i, "Expected timestamp to be correct for {i}"); - } else { - panic!("Expected heartbeat for executor {}", i); - } - } - - Ok(()) - } -} diff --git a/ballista/scheduler/src/state/backend/memory.rs b/ballista/scheduler/src/state/backend/memory.rs deleted file mode 100644 index ddb2a450d..000000000 --- a/ballista/scheduler/src/state/backend/memory.rs +++ /dev/null @@ -1,411 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::state::backend::utils::subscriber::{Subscriber, Subscribers}; -use crate::state::backend::{ - Keyspace, Lock, Operation, StateBackendClient, Watch, WatchEvent, -}; -use ballista_core::error::Result; -use dashmap::DashMap; -use futures::{FutureExt, Stream}; -use log::warn; -use std::collections::{BTreeMap, HashSet}; -use std::sync::Arc; -use tokio::sync::Mutex; - -type KeySpaceState = BTreeMap>; -type KeyLock = Arc>; - -/// A [`StateBackendClient`] implementation that uses in memory map to save cluster state. -#[derive(Clone, Default)] -pub struct MemoryBackendClient { - /// The key is the KeySpace. For every KeySpace, there will be a tree map which is better for prefix filtering - states: DashMap, - /// The key is the full key formatted like "/KeySpace/key". It's a flatted map - locks: DashMap, - subscribers: Arc, -} - -impl MemoryBackendClient { - pub fn new() -> Self { - Self::default() - } - - fn get_space_key(keyspace: &Keyspace) -> String { - format!("/{keyspace:?}") - } - - fn get_flat_key(keyspace: &Keyspace, key: &str) -> String { - format!("/{keyspace:?}/{key}") - } -} - -#[tonic::async_trait] -impl StateBackendClient for MemoryBackendClient { - async fn get(&self, keyspace: Keyspace, key: &str) -> Result> { - let space_key = Self::get_space_key(&keyspace); - Ok(self - .states - .get(&space_key) - .map(|space_state| space_state.value().get(key).cloned().unwrap_or_default()) - .unwrap_or_default()) - } - - async fn get_from_prefix( - &self, - keyspace: Keyspace, - prefix: &str, - ) -> Result)>> { - let space_key = Self::get_space_key(&keyspace); - Ok(self - .states - .get(&space_key) - .map(|space_state| { - space_state - .value() - .range(prefix.to_owned()..) - .take_while(|(k, _)| k.starts_with(prefix)) - .map(|e| (format!("{}/{}", space_key, e.0), e.1.clone())) - .collect() - }) - .unwrap_or_default()) - } - - async fn scan( - &self, - keyspace: Keyspace, - limit: Option, - ) -> Result)>> { - let space_key = Self::get_space_key(&keyspace); - Ok(self - .states - .get(&space_key) - .map(|space_state| { - if let Some(limit) = limit { - space_state - .value() - .iter() - .take(limit) - .map(|e| (format!("{}/{}", space_key, e.0), e.1.clone())) - .collect::)>>() - } else { - space_state - .value() - .iter() - .map(|e| (format!("{}/{}", space_key, e.0), e.1.clone())) - .collect::)>>() - } - }) - .unwrap_or_default()) - } - - async fn scan_keys(&self, keyspace: Keyspace) -> Result> { - let space_key = Self::get_space_key(&keyspace); - Ok(self - .states - .get(&space_key) - .map(|space_state| { - space_state - .value() - .iter() - .map(|e| format!("{}/{}", space_key, e.0)) - .collect::>() - }) - .unwrap_or_default()) - } - - async fn put(&self, keyspace: Keyspace, key: String, value: Vec) -> Result<()> { - let space_key = Self::get_space_key(&keyspace); - if !self.states.contains_key(&space_key) { - self.states.insert(space_key.clone(), BTreeMap::default()); - } - self.states - .get_mut(&space_key) - .unwrap() - .value_mut() - .insert(key.clone(), value.clone()); - - // Notify subscribers - let full_key = format!("{space_key}/{key}"); - if let Some(res) = self.subscribers.reserve(&full_key) { - let event = WatchEvent::Put(full_key, value); - res.complete(&event); - } - - Ok(()) - } - - /// Currently the locks should be acquired before invoking this method. - /// Later need to be refined by acquiring all of the related locks inside this method - async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> Result<()> { - for (op, keyspace, key) in ops.into_iter() { - match op { - Operation::Delete => { - self.delete(keyspace, &key).await?; - } - Operation::Put(value) => { - self.put(keyspace, key, value).await?; - } - }; - } - - Ok(()) - } - - /// Currently it's not used. Later will refine the caller side by leveraging this method - async fn mv( - &self, - from_keyspace: Keyspace, - to_keyspace: Keyspace, - key: &str, - ) -> Result<()> { - let from_space_key = Self::get_space_key(&from_keyspace); - - let ops = if let Some(from_space_state) = self.states.get(&from_space_key) { - if let Some(state) = from_space_state.value().get(key) { - Some(vec![ - (Operation::Delete, from_keyspace, key.to_owned()), - (Operation::Put(state.clone()), to_keyspace, key.to_owned()), - ]) - } else { - // TODO should this return an error? - warn!( - "Cannot move value at {}/{}, does not exist", - from_space_key, key - ); - None - } - } else { - // TODO should this return an error? - warn!( - "Cannot move value at {}/{}, does not exist", - from_space_key, key - ); - None - }; - - if let Some(ops) = ops { - self.apply_txn(ops).await?; - } - - Ok(()) - } - - async fn lock(&self, keyspace: Keyspace, key: &str) -> Result> { - let flat_key = Self::get_flat_key(&keyspace, key); - let lock = self - .locks - .entry(flat_key) - .or_insert_with(|| Arc::new(Mutex::new(()))); - Ok(Box::new(lock.value().clone().lock_owned().await)) - } - - async fn watch(&self, keyspace: Keyspace, prefix: String) -> Result> { - let prefix = format!("/{keyspace:?}/{prefix}"); - - Ok(Box::new(MemoryWatch { - subscriber: self.subscribers.register(prefix.as_bytes()), - })) - } - - async fn delete(&self, keyspace: Keyspace, key: &str) -> Result<()> { - let space_key = Self::get_space_key(&keyspace); - if let Some(mut space_state) = self.states.get_mut(&space_key) { - if space_state.value_mut().remove(key).is_some() { - // Notify subscribers - let full_key = format!("{space_key}/{key}"); - if let Some(res) = self.subscribers.reserve(&full_key) { - let event = WatchEvent::Delete(full_key); - res.complete(&event); - } - } - } - - Ok(()) - } -} - -struct MemoryWatch { - subscriber: Subscriber, -} - -#[tonic::async_trait] -impl Watch for MemoryWatch { - async fn cancel(&mut self) -> Result<()> { - Ok(()) - } -} - -impl Stream for MemoryWatch { - type Item = WatchEvent; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.get_mut().subscriber.poll_unpin(cx) - } - - fn size_hint(&self) -> (usize, Option) { - self.subscriber.size_hint() - } -} - -#[cfg(test)] -mod tests { - use super::{StateBackendClient, Watch, WatchEvent}; - - use crate::state::backend::memory::MemoryBackendClient; - use crate::state::backend::{Keyspace, Operation}; - use crate::state::with_locks; - use futures::StreamExt; - use std::result::Result; - - #[tokio::test] - async fn put_read() -> Result<(), Box> { - let client = MemoryBackendClient::new(); - let key = "key"; - let value = "value".as_bytes(); - client - .put(Keyspace::Slots, key.to_owned(), value.to_vec()) - .await?; - assert_eq!(client.get(Keyspace::Slots, key).await?, value); - Ok(()) - } - - #[tokio::test] - async fn put_move() -> Result<(), Box> { - let client = MemoryBackendClient::new(); - let key = "key"; - let value = "value".as_bytes(); - client - .put(Keyspace::ActiveJobs, key.to_owned(), value.to_vec()) - .await?; - client - .mv(Keyspace::ActiveJobs, Keyspace::FailedJobs, key) - .await?; - assert_eq!(client.get(Keyspace::FailedJobs, key).await?, value); - Ok(()) - } - - #[tokio::test] - async fn multiple_operation() -> Result<(), Box> { - let client = MemoryBackendClient::new(); - let key = "key".to_string(); - let value = "value".as_bytes().to_vec(); - let locks = client - .acquire_locks(vec![(Keyspace::ActiveJobs, ""), (Keyspace::Slots, "")]) - .await?; - - let _r: ballista_core::error::Result<()> = with_locks(locks, async { - let txn_ops = vec![ - (Operation::Put(value.clone()), Keyspace::Slots, key.clone()), - ( - Operation::Put(value.clone()), - Keyspace::ActiveJobs, - key.clone(), - ), - ]; - client.apply_txn(txn_ops).await?; - Ok(()) - }) - .await; - - assert_eq!(client.get(Keyspace::Slots, key.as_str()).await?, value); - assert_eq!(client.get(Keyspace::ActiveJobs, key.as_str()).await?, value); - Ok(()) - } - - #[tokio::test] - async fn read_empty() -> Result<(), Box> { - let client = MemoryBackendClient::new(); - let key = "key"; - let empty: &[u8] = &[]; - assert_eq!(client.get(Keyspace::Slots, key).await?, empty); - Ok(()) - } - - #[tokio::test] - async fn read_prefix() -> Result<(), Box> { - let client = MemoryBackendClient::new(); - let key = "key"; - let value = "value".as_bytes(); - client - .put(Keyspace::Slots, format!("{key}/1"), value.to_vec()) - .await?; - client - .put(Keyspace::Slots, format!("{key}/2"), value.to_vec()) - .await?; - assert_eq!( - client.get_from_prefix(Keyspace::Slots, key).await?, - vec![ - ("/Slots/key/1".to_owned(), value.to_vec()), - ("/Slots/key/2".to_owned(), value.to_vec()) - ] - ); - Ok(()) - } - - #[tokio::test] - async fn read_watch() -> Result<(), Box> { - let client = MemoryBackendClient::new(); - let key = "key"; - let value = "value".as_bytes(); - let mut watch_keyspace: Box = - client.watch(Keyspace::Slots, "".to_owned()).await?; - let mut watch_key: Box = - client.watch(Keyspace::Slots, key.to_owned()).await?; - client - .put(Keyspace::Slots, key.to_owned(), value.to_vec()) - .await?; - assert_eq!( - watch_keyspace.next().await, - Some(WatchEvent::Put( - format!("/{:?}/{}", Keyspace::Slots, key.to_owned()), - value.to_owned() - )) - ); - assert_eq!( - watch_key.next().await, - Some(WatchEvent::Put( - format!("/{:?}/{}", Keyspace::Slots, key.to_owned()), - value.to_owned() - )) - ); - let value2 = "value2".as_bytes(); - client - .put(Keyspace::Slots, key.to_owned(), value2.to_vec()) - .await?; - assert_eq!( - watch_keyspace.next().await, - Some(WatchEvent::Put( - format!("/{:?}/{}", Keyspace::Slots, key.to_owned()), - value2.to_owned() - )) - ); - assert_eq!( - watch_key.next().await, - Some(WatchEvent::Put( - format!("/{:?}/{}", Keyspace::Slots, key.to_owned()), - value2.to_owned() - )) - ); - watch_keyspace.cancel().await?; - watch_key.cancel().await?; - Ok(()) - } -} diff --git a/ballista/scheduler/src/state/backend/utils/mod.rs b/ballista/scheduler/src/state/backend/utils/mod.rs deleted file mode 100644 index de95dd6e0..000000000 --- a/ballista/scheduler/src/state/backend/utils/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#[allow(dead_code)] -mod oneshot; -#[allow(dead_code)] -pub(crate) mod subscriber; diff --git a/ballista/scheduler/src/state/backend/utils/oneshot.rs b/ballista/scheduler/src/state/backend/utils/oneshot.rs deleted file mode 100644 index a0d146996..000000000 --- a/ballista/scheduler/src/state/backend/utils/oneshot.rs +++ /dev/null @@ -1,179 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! It's mainly a modified version of sled::oneshot - -use std::{ - future::Future, - pin::Pin, - sync::Arc, - task::{Context, Poll, Waker}, - time::{Duration, Instant}, -}; - -use parking_lot::{Condvar, Mutex}; - -#[derive(Debug)] -struct OneShotState { - filled: bool, - fused: bool, - item: Option, - waker: Option, -} - -impl Default for OneShotState { - fn default() -> OneShotState { - OneShotState { - filled: false, - fused: false, - item: None, - waker: None, - } - } -} - -/// A Future value which may or may not be filled -#[derive(Debug)] -pub struct OneShot { - mu: Arc>>, - cv: Arc, -} - -/// The completer side of the Future -pub struct OneShotFiller { - mu: Arc>>, - cv: Arc, -} - -impl OneShot { - /// Create a new `OneShotFiller` and the `OneShot` - /// that will be filled by its completion. - pub fn pair() -> (OneShotFiller, Self) { - let mu = Arc::new(Mutex::new(OneShotState::default())); - let cv = Arc::new(Condvar::new()); - let future = Self { - mu: mu.clone(), - cv: cv.clone(), - }; - let filler = OneShotFiller { mu, cv }; - - (filler, future) - } - - /// Block on the `OneShot`'s completion - /// or dropping of the `OneShotFiller` - pub fn wait(self) -> Option { - let mut inner = self.mu.lock(); - while !inner.filled { - self.cv.wait(&mut inner); - } - inner.item.take() - } - - /// Block on the `OneShot`'s completion - /// or dropping of the `OneShotFiller`, - /// returning an error if not filled - /// before a given timeout or if the - /// system shuts down before then. - /// - /// Upon a successful receive, the - /// oneshot should be dropped, as it - /// will never yield that value again. - pub fn wait_timeout( - &mut self, - mut timeout: Duration, - ) -> Result { - let mut inner = self.mu.lock(); - while !inner.filled { - let start = Instant::now(); - let res = self.cv.wait_for(&mut inner, timeout); - if res.timed_out() { - return Err(std::sync::mpsc::RecvTimeoutError::Disconnected); - } - timeout = if let Some(timeout) = timeout.checked_sub(start.elapsed()) { - timeout - } else { - Duration::from_nanos(0) - }; - } - if let Some(item) = inner.item.take() { - Ok(item) - } else { - Err(std::sync::mpsc::RecvTimeoutError::Disconnected) - } - } -} - -impl Future for OneShot { - type Output = Option; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut state = self.mu.lock(); - if state.fused { - return Poll::Pending; - } - if state.filled { - state.fused = true; - Poll::Ready(state.item.take()) - } else { - state.waker = Some(cx.waker().clone()); - Poll::Pending - } - } -} - -impl OneShotFiller { - /// Complete the `OneShot` - pub fn fill(self, inner: T) { - let mut state = self.mu.lock(); - - if let Some(waker) = state.waker.take() { - waker.wake(); - } - - state.filled = true; - state.item = Some(inner); - - // having held the mutex makes this linearized - // with the notify below. - drop(state); - - let _notified = self.cv.notify_all(); - } -} - -impl Drop for OneShotFiller { - fn drop(&mut self) { - let mut state = self.mu.lock(); - - if state.filled { - return; - } - - if let Some(waker) = state.waker.take() { - waker.wake(); - } - - state.filled = true; - - // having held the mutex makes this linearized - // with the notify below. - drop(state); - - let _notified = self.cv.notify_all(); - } -} diff --git a/ballista/scheduler/src/state/backend/utils/subscriber.rs b/ballista/scheduler/src/state/backend/utils/subscriber.rs deleted file mode 100644 index dd74b6642..000000000 --- a/ballista/scheduler/src/state/backend/utils/subscriber.rs +++ /dev/null @@ -1,248 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! It's mainly a modified version of sled::subscriber - -use crate::state::backend::utils::oneshot::{OneShot, OneShotFiller}; -use crate::state::backend::WatchEvent; - -use parking_lot::RwLock; -use std::collections::{BTreeMap, HashMap}; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering::Relaxed; -use std::sync::atomic::{AtomicBool, AtomicUsize}; -use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TryRecvError}; -use std::sync::Arc; -use std::task::{Context, Poll, Waker}; -use std::time::{Duration, Instant}; - -static ID_GEN: AtomicUsize = AtomicUsize::new(0); - -type Senders = HashMap, SyncSender>>)>; - -/// Aynchronous, non-blocking subscriber: -/// -/// `Subscription` implements `Future>`. -/// -/// `while let Some(event) = (&mut subscriber).await { /* use it */ }` -pub struct Subscriber { - id: usize, - rx: Receiver>>, - existing: Option>>, - home: Arc>, -} - -impl Drop for Subscriber { - fn drop(&mut self) { - let mut w_senders = self.home.write(); - w_senders.remove(&self.id); - } -} - -impl Subscriber { - /// Attempts to wait for a value on this `Subscriber`, returning - /// an error if no event arrives within the provided `Duration` - /// or if the backing `Db` shuts down. - pub fn next_timeout( - &mut self, - mut timeout: Duration, - ) -> std::result::Result { - loop { - let start = Instant::now(); - let mut future_rx = if let Some(future_rx) = self.existing.take() { - future_rx - } else { - self.rx.recv_timeout(timeout)? - }; - timeout = if let Some(timeout) = timeout.checked_sub(start.elapsed()) { - timeout - } else { - Duration::from_nanos(0) - }; - - let start = Instant::now(); - match future_rx.wait_timeout(timeout) { - Ok(Some(event)) => return Ok(event), - Ok(None) => (), - Err(timeout_error) => { - self.existing = Some(future_rx); - return Err(timeout_error); - } - } - timeout = if let Some(timeout) = timeout.checked_sub(start.elapsed()) { - timeout - } else { - Duration::from_nanos(0) - }; - } - } -} - -impl Future for Subscriber { - type Output = Option; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let mut future_rx = if let Some(future_rx) = self.existing.take() { - future_rx - } else { - match self.rx.try_recv() { - Ok(future_rx) => future_rx, - Err(TryRecvError::Empty) => break, - Err(TryRecvError::Disconnected) => return Poll::Ready(None), - } - }; - - match Future::poll(Pin::new(&mut future_rx), cx) { - Poll::Ready(Some(event)) => return Poll::Ready(event), - Poll::Ready(None) => continue, - Poll::Pending => { - self.existing = Some(future_rx); - return Poll::Pending; - } - } - } - let mut home = self.home.write(); - let entry = home.get_mut(&self.id).unwrap(); - entry.0 = Some(cx.waker().clone()); - Poll::Pending - } -} - -impl Iterator for Subscriber { - type Item = WatchEvent; - - fn next(&mut self) -> Option { - loop { - let future_rx = self.rx.recv().ok()?; - match future_rx.wait() { - Some(Some(event)) => return Some(event), - Some(None) => return None, - None => continue, - } - } - } -} - -#[derive(Debug, Default)] -pub(crate) struct Subscribers { - watched: RwLock, Arc>>>, - ever_used: AtomicBool, -} - -impl Drop for Subscribers { - fn drop(&mut self) { - let watched = self.watched.read(); - - for senders in watched.values() { - let senders = std::mem::take(&mut *senders.write()); - for (_, (waker, sender)) in senders { - drop(sender); - if let Some(waker) = waker { - waker.wake(); - } - } - } - } -} - -impl Subscribers { - pub(crate) fn register(&self, prefix: &[u8]) -> Subscriber { - self.ever_used.store(true, Relaxed); - let r_mu = { - let r_mu = self.watched.read(); - if r_mu.contains_key(prefix) { - r_mu - } else { - drop(r_mu); - let mut w_mu = self.watched.write(); - if !w_mu.contains_key(prefix) { - let old = w_mu.insert( - prefix.to_vec(), - Arc::new(RwLock::new(HashMap::default())), - ); - assert!(old.is_none()); - } - drop(w_mu); - self.watched.read() - } - }; - - let (tx, rx) = sync_channel(1024); - - let arc_senders = &r_mu[prefix]; - let mut w_senders = arc_senders.write(); - - let id = ID_GEN.fetch_add(1, Relaxed); - - w_senders.insert(id, (None, tx)); - - Subscriber { - id, - rx, - existing: None, - home: arc_senders.clone(), - } - } - - pub(crate) fn reserve>(&self, key: R) -> Option { - if !self.ever_used.load(Relaxed) { - return None; - } - - let r_mu = self.watched.read(); - let prefixes = r_mu.iter().filter(|(k, _)| key.as_ref().starts_with(k)); - - let mut subscribers = vec![]; - - for (_, subs_rwl) in prefixes { - let subs = subs_rwl.read(); - - for (_id, (waker, sender)) in subs.iter() { - let (tx, rx) = OneShot::pair(); - if sender.send(rx).is_err() { - continue; - } - subscribers.push((waker.clone(), tx)); - } - } - - if subscribers.is_empty() { - None - } else { - Some(ReservedBroadcast { subscribers }) - } - } -} - -pub(crate) struct ReservedBroadcast { - subscribers: Vec<(Option, OneShotFiller>)>, -} - -impl ReservedBroadcast { - pub fn complete(self, event: &WatchEvent) { - let iter = self.subscribers.into_iter(); - - for (waker, tx) in iter { - tx.fill(Some(event.clone())); - if let Some(waker) = waker { - waker.wake(); - } - } - } -} diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index f4d8477d8..4f98e6b94 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -33,9 +33,10 @@ use log::{error, info, warn}; use ballista_core::error::{BallistaError, Result}; use ballista_core::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec}; use ballista_core::serde::protobuf::failed_task::FailedReason; +use ballista_core::serde::protobuf::job_status::Status; use ballista_core::serde::protobuf::{ - self, execution_graph_stage::StageType, FailedTask, JobStatus, QueuedJob, ResultLost, - SuccessfulJob, TaskStatus, + self, execution_graph_stage::StageType, FailedTask, JobStatus, ResultLost, + RunningJob, SuccessfulJob, TaskStatus, }; use ballista_core::serde::protobuf::{job_status, FailedJob, ShuffleWritePartition}; use ballista_core::serde::protobuf::{task_status, RunningTask}; @@ -102,8 +103,8 @@ mod execution_stage; /// publish its outputs to the `ExecutionGraph`s `output_locations` representing the final query results. #[derive(Clone)] pub struct ExecutionGraph { - /// Curator scheduler name - scheduler_id: String, + /// Curator scheduler name. Can be `None` is `ExecutionGraph` is not currently curated by any scheduler + scheduler_id: Option, /// ID for this job job_id: String, /// Job name, can be empty string @@ -158,16 +159,24 @@ impl ExecutionGraph { let builder = ExecutionStageBuilder::new(); let stages = builder.build(shuffle_stages)?; + let started_at = timestamp_millis(); + Ok(Self { - scheduler_id: scheduler_id.to_string(), + scheduler_id: Some(scheduler_id.to_string()), job_id: job_id.to_string(), job_name: job_name.to_string(), session_id: session_id.to_string(), status: JobStatus { - status: Some(job_status::Status::Queued(QueuedJob {})), + job_id: job_id.to_string(), + job_name: job_name.to_string(), + status: Some(Status::Running(RunningJob { + queued_at, + started_at, + scheduler: scheduler_id.to_string(), + })), }, queued_at, - start_time: timestamp_millis(), + start_time: started_at, end_time: 0, stages, output_partitions, @@ -222,6 +231,12 @@ impl ExecutionGraph { .all(|s| matches!(s, ExecutionStage::Successful(_))) } + pub fn is_complete(&self) -> bool { + self.stages + .values() + .all(|s| matches!(s, ExecutionStage::Successful(_))) + } + /// Revive the execution graph by converting the resolved stages to running stages /// If any stages are converted, return true; else false. pub fn revive(&mut self) -> bool { @@ -833,6 +848,7 @@ impl ExecutionGraph { self.status, JobStatus { status: Some(job_status::Status::Failed(_)), + .. } ) { warn!("Call pop_next_task on failed Job"); @@ -1206,7 +1222,14 @@ impl ExecutionGraph { /// fail job with error message pub fn fail_job(&mut self, error: String) { self.status = JobStatus { - status: Some(job_status::Status::Failed(FailedJob { error })), + job_id: self.job_id.clone(), + job_name: self.job_name.clone(), + status: Some(Status::Failed(FailedJob { + error, + queued_at: self.queued_at, + started_at: self.start_time, + ended_at: self.end_time, + })), }; } @@ -1226,8 +1249,14 @@ impl ExecutionGraph { .collect::>>()?; self.status = JobStatus { + job_id: self.job_id.clone(), + job_name: self.job_name.clone(), status: Some(job_status::Status::Successful(SuccessfulJob { partition_location, + + queued_at: self.queued_at, + started_at: self.start_time, + ended_at: self.end_time, })), }; self.end_time = SystemTime::now() @@ -1304,7 +1333,7 @@ impl ExecutionGraph { .collect(); Ok(ExecutionGraph { - scheduler_id: proto.scheduler_id, + scheduler_id: (!proto.scheduler_id.is_empty()).then_some(proto.scheduler_id), job_id: proto.job_id, job_name: proto.job_name, session_id: proto.session_id, @@ -1394,7 +1423,7 @@ impl ExecutionGraph { stages, output_partitions: graph.output_partitions as u64, output_locations, - scheduler_id: graph.scheduler_id, + scheduler_id: graph.scheduler_id.unwrap_or_default(), task_id_gen: graph.task_id_gen as u32, failed_attempts, }) @@ -1581,24 +1610,20 @@ fn partition_to_location( #[cfg(test)] mod test { use std::collections::HashSet; - use std::sync::Arc; - - use datafusion::arrow::datatypes::{DataType, Field, Schema}; - use datafusion::logical_expr::expr::Sort; - use datafusion::logical_expr::{col, count, sum, Expr, JoinType}; - use datafusion::physical_plan::display::DisplayableExecutionPlan; - use datafusion::prelude::{SessionConfig, SessionContext}; - use datafusion::test_util::scan_empty; use crate::scheduler_server::event::QueryStageSchedulerEvent; use ballista_core::error::Result; use ballista_core::serde::protobuf::{ - self, failed_task, job_status, task_status, ExecutionError, FailedTask, - FetchPartitionError, IoError, JobStatus, TaskKilled, TaskStatus, + self, failed_task, job_status, ExecutionError, FailedTask, FetchPartitionError, + IoError, JobStatus, TaskKilled, }; - use ballista_core::serde::scheduler::{ExecutorMetadata, ExecutorSpecification}; - use crate::state::execution_graph::{ExecutionGraph, TaskDescription}; + use crate::state::execution_graph::ExecutionGraph; + use crate::test_utils::{ + mock_completed_task, mock_executor, mock_failed_task, test_aggregation_plan, + test_coalesce_plan, test_join_plan, test_two_aggregations_plan, + test_union_all_plan, test_union_plan, + }; #[tokio::test] async fn test_drain_tasks() -> Result<()> { @@ -1663,7 +1688,8 @@ mod test { assert!(matches!( status, protobuf::JobStatus { - status: Some(job_status::Status::Successful(_)) + status: Some(job_status::Status::Successful(_)), + .. } )); @@ -1870,17 +1896,6 @@ mod test { 4, )?; - // TODO the JobStatus is not 'Running' here, no place to mark it to 'Running' in current code base. - assert!( - matches!( - agg_graph.status, - JobStatus { - status: Some(job_status::Status::Queued(_)) - } - ), - "Expected job status to be running" - ); - assert_eq!(agg_graph.available_tasks(), 2); drain_tasks(&mut agg_graph)?; assert_eq!(agg_graph.available_tasks(), 0); @@ -1960,7 +1975,8 @@ mod test { matches!( agg_graph.status, JobStatus { - status: Some(job_status::Status::Failed(_)) + status: Some(job_status::Status::Failed(_)), + .. } ), "Expected job status to be Failed" @@ -2751,289 +2767,4 @@ mod test { Ok(()) } - - async fn test_aggregation_plan(partition: usize) -> ExecutionGraph { - let config = SessionConfig::new().with_target_partitions(partition); - let ctx = Arc::new(SessionContext::with_config(config)); - let session_state = ctx.state(); - - let schema = Schema::new(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("gmv", DataType::UInt64, false), - ]); - - let logical_plan = scan_empty(None, &schema, Some(vec![0, 1])) - .unwrap() - .aggregate(vec![col("id")], vec![sum(col("gmv"))]) - .unwrap() - .build() - .unwrap(); - - let optimized_plan = session_state.optimize(&logical_plan).unwrap(); - - let plan = session_state - .create_physical_plan(&optimized_plan) - .await - .unwrap(); - - println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); - - ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap() - } - - async fn test_two_aggregations_plan(partition: usize) -> ExecutionGraph { - let config = SessionConfig::new().with_target_partitions(partition); - let ctx = Arc::new(SessionContext::with_config(config)); - let session_state = ctx.state(); - - let schema = Schema::new(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("name", DataType::Utf8, false), - Field::new("gmv", DataType::UInt64, false), - ]); - - let logical_plan = scan_empty(None, &schema, Some(vec![0, 1, 2])) - .unwrap() - .aggregate(vec![col("id"), col("name")], vec![sum(col("gmv"))]) - .unwrap() - .aggregate(vec![col("id")], vec![count(col("id"))]) - .unwrap() - .build() - .unwrap(); - - let optimized_plan = session_state.optimize(&logical_plan).unwrap(); - - let plan = session_state - .create_physical_plan(&optimized_plan) - .await - .unwrap(); - - println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); - - ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap() - } - - async fn test_coalesce_plan(partition: usize) -> ExecutionGraph { - let config = SessionConfig::new().with_target_partitions(partition); - let ctx = Arc::new(SessionContext::with_config(config)); - let session_state = ctx.state(); - - let schema = Schema::new(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("gmv", DataType::UInt64, false), - ]); - - let logical_plan = scan_empty(None, &schema, Some(vec![0, 1])) - .unwrap() - .limit(0, Some(1)) - .unwrap() - .build() - .unwrap(); - - let optimized_plan = session_state.optimize(&logical_plan).unwrap(); - - let plan = session_state - .create_physical_plan(&optimized_plan) - .await - .unwrap(); - - ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap() - } - - async fn test_join_plan(partition: usize) -> ExecutionGraph { - let mut config = SessionConfig::new().with_target_partitions(partition); - config - .config_options_mut() - .optimizer - .enable_round_robin_repartition = false; - let ctx = Arc::new(SessionContext::with_config(config)); - let session_state = ctx.state(); - - let schema = Schema::new(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("gmv", DataType::UInt64, false), - ]); - - let left_plan = scan_empty(Some("left"), &schema, None).unwrap(); - - let right_plan = scan_empty(Some("right"), &schema, None) - .unwrap() - .build() - .unwrap(); - - let sort_expr = Expr::Sort(Sort::new(Box::new(col("id")), false, false)); - - let logical_plan = left_plan - .join(right_plan, JoinType::Inner, (vec!["id"], vec!["id"]), None) - .unwrap() - .aggregate(vec![col("id")], vec![sum(col("gmv"))]) - .unwrap() - .sort(vec![sort_expr]) - .unwrap() - .build() - .unwrap(); - - let optimized_plan = session_state.optimize(&logical_plan).unwrap(); - - let plan = session_state - .create_physical_plan(&optimized_plan) - .await - .unwrap(); - - println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); - - let graph = ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0) - .unwrap(); - - println!("{graph:?}"); - - graph - } - - async fn test_union_all_plan(partition: usize) -> ExecutionGraph { - let config = SessionConfig::new().with_target_partitions(partition); - let ctx = Arc::new(SessionContext::with_config(config)); - let session_state = ctx.state(); - - let logical_plan = ctx - .sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;") - .await - .unwrap() - .into_optimized_plan() - .unwrap(); - - let optimized_plan = session_state.optimize(&logical_plan).unwrap(); - - let plan = session_state - .create_physical_plan(&optimized_plan) - .await - .unwrap(); - - println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); - - let graph = ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0) - .unwrap(); - - println!("{graph:?}"); - - graph - } - - async fn test_union_plan(partition: usize) -> ExecutionGraph { - let config = SessionConfig::new().with_target_partitions(partition); - let ctx = Arc::new(SessionContext::with_config(config)); - let session_state = ctx.state(); - - let logical_plan = ctx - .sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;") - .await - .unwrap() - .into_optimized_plan() - .unwrap(); - - let optimized_plan = session_state.optimize(&logical_plan).unwrap(); - - let plan = session_state - .create_physical_plan(&optimized_plan) - .await - .unwrap(); - - println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); - - let graph = ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0) - .unwrap(); - - println!("{graph:?}"); - - graph - } - - fn mock_executor(executor_id: String) -> ExecutorMetadata { - ExecutorMetadata { - id: executor_id, - host: "localhost2".to_string(), - port: 8080, - grpc_port: 9090, - specification: ExecutorSpecification { task_slots: 1 }, - } - } - - fn mock_completed_task(task: TaskDescription, executor_id: &str) -> TaskStatus { - let mut partitions: Vec = vec![]; - - let num_partitions = task - .output_partitioning - .map(|p| p.partition_count()) - .unwrap_or(1); - - for partition_id in 0..num_partitions { - partitions.push(protobuf::ShuffleWritePartition { - partition_id: partition_id as u64, - path: format!( - "/{}/{}/{}", - task.partition.job_id, - task.partition.stage_id, - task.partition.partition_id - ), - num_batches: 1, - num_rows: 1, - num_bytes: 1, - }) - } - - // Complete the task - protobuf::TaskStatus { - task_id: task.task_id as u32, - job_id: task.partition.job_id.clone(), - stage_id: task.partition.stage_id as u32, - stage_attempt_num: task.stage_attempt_num as u32, - partition_id: task.partition.partition_id as u32, - launch_time: 0, - start_exec_time: 0, - end_exec_time: 0, - metrics: vec![], - status: Some(task_status::Status::Successful(protobuf::SuccessfulTask { - executor_id: executor_id.to_owned(), - partitions, - })), - } - } - - fn mock_failed_task(task: TaskDescription, failed_task: FailedTask) -> TaskStatus { - let mut partitions: Vec = vec![]; - - let num_partitions = task - .output_partitioning - .map(|p| p.partition_count()) - .unwrap_or(1); - - for partition_id in 0..num_partitions { - partitions.push(protobuf::ShuffleWritePartition { - partition_id: partition_id as u64, - path: format!( - "/{}/{}/{}", - task.partition.job_id, - task.partition.stage_id, - task.partition.partition_id - ), - num_batches: 1, - num_rows: 1, - num_bytes: 1, - }) - } - - // Fail the task - protobuf::TaskStatus { - task_id: task.task_id as u32, - job_id: task.partition.job_id.clone(), - stage_id: task.partition.stage_id as u32, - stage_attempt_num: task.stage_attempt_num as u32, - partition_id: task.partition.partition_id as u32, - launch_time: 0, - start_exec_time: 0, - end_exec_time: 0, - metrics: vec![], - status: Some(task_status::Status::Failed(failed_task)), - } - } } diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs b/ballista/scheduler/src/state/execution_graph_dot.rs index 4edb717a6..853fa21dd 100644 --- a/ballista/scheduler/src/state/execution_graph_dot.rs +++ b/ballista/scheduler/src/state/execution_graph_dot.rs @@ -45,20 +45,20 @@ use std::fmt::{self, Write}; use std::sync::Arc; /// Utility for producing dot diagrams from execution graphs -pub struct ExecutionGraphDot { - graph: Arc, +pub struct ExecutionGraphDot<'a> { + graph: &'a ExecutionGraph, } -impl ExecutionGraphDot { +impl<'a> ExecutionGraphDot<'a> { /// Create a DOT graph from the provided ExecutionGraph - pub fn generate(graph: Arc) -> Result { + pub fn generate(graph: &'a ExecutionGraph) -> Result { let mut dot = Self { graph }; dot._generate() } /// Create a DOT graph for one query stage from the provided ExecutionGraph pub fn generate_for_query_stage( - graph: Arc, + graph: &ExecutionGraph, stage_id: usize, ) -> Result { if let Some(stage) = graph.stages().get(&stage_id) { @@ -426,7 +426,7 @@ mod tests { #[tokio::test] async fn dot() -> Result<()> { let graph = test_graph().await?; - let dot = ExecutionGraphDot::generate(Arc::new(graph)) + let dot = ExecutionGraphDot::generate(&graph) .map_err(|e| BallistaError::Internal(format!("{e:?}")))?; let expected = r#"digraph G { @@ -499,7 +499,7 @@ filter_expr="] #[tokio::test] async fn query_stage() -> Result<()> { let graph = test_graph().await?; - let dot = ExecutionGraphDot::generate_for_query_stage(Arc::new(graph), 3) + let dot = ExecutionGraphDot::generate_for_query_stage(&graph, 3) .map_err(|e| BallistaError::Internal(format!("{e:?}")))?; let expected = r#"digraph G { @@ -527,7 +527,7 @@ filter_expr="] #[tokio::test] async fn dot_optimized() -> Result<()> { let graph = test_graph_optimized().await?; - let dot = ExecutionGraphDot::generate(Arc::new(graph)) + let dot = ExecutionGraphDot::generate(&graph) .map_err(|e| BallistaError::Internal(format!("{e:?}")))?; let expected = r#"digraph G { @@ -591,7 +591,7 @@ filter_expr="] #[tokio::test] async fn query_stage_optimized() -> Result<()> { let graph = test_graph_optimized().await?; - let dot = ExecutionGraphDot::generate_for_query_stage(Arc::new(graph), 4) + let dot = ExecutionGraphDot::generate_for_query_stage(&graph, 4) .map_err(|e| BallistaError::Internal(format!("{e:?}")))?; let expected = r#"digraph G { diff --git a/ballista/scheduler/src/state/executor_manager.rs b/ballista/scheduler/src/state/executor_manager.rs index 38c8da689..9a51aa77e 100644 --- a/ballista/scheduler/src/state/executor_manager.rs +++ b/ballista/scheduler/src/state/executor_manager.rs @@ -17,13 +17,14 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use crate::state::backend::TaskDistribution; +use crate::cluster::TaskDistribution; use ballista_core::error::{BallistaError, Result}; use ballista_core::serde::protobuf; +use crate::cluster::ClusterState; use crate::config::SlotsPolicy; -use crate::state::backend::cluster::ClusterState; + use crate::state::execution_graph::RunningTaskInfo; use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::protobuf::{ @@ -454,6 +455,11 @@ impl ExecutorManager { specification: ExecutorData, reserve: bool, ) -> Result> { + debug!( + "registering executor {} with {} task slots", + metadata.id, specification.total_task_slots + ); + self.test_scheduler_connectivity(&metadata).await?; let current_ts = SystemTime::now() @@ -647,15 +653,15 @@ impl ExecutorManager { #[cfg(test)] mod test { + use crate::config::SlotsPolicy; - use crate::state::backend::cluster::DefaultClusterState; - use crate::state::backend::sled::SledClient; + use crate::state::executor_manager::{ExecutorManager, ExecutorReservation}; + use crate::test_utils::test_cluster_context; use ballista_core::error::Result; use ballista_core::serde::scheduler::{ ExecutorData, ExecutorMetadata, ExecutorSpecification, }; - use std::sync::Arc; #[tokio::test] async fn test_reserve_and_cancel() -> Result<()> { @@ -667,11 +673,10 @@ mod test { } async fn test_reserve_and_cancel_inner(slots_policy: SlotsPolicy) -> Result<()> { - let cluster_state = Arc::new(DefaultClusterState::new(Arc::new( - SledClient::try_new_temporary()?, - ))); + let cluster = test_cluster_context(); - let executor_manager = ExecutorManager::new(cluster_state, slots_policy); + let executor_manager = + ExecutorManager::new(cluster.cluster_state(), slots_policy); let executors = test_executors(10, 4); @@ -715,11 +720,10 @@ mod test { } async fn test_reserve_partial_inner(slots_policy: SlotsPolicy) -> Result<()> { - let cluster_state = Arc::new(DefaultClusterState::new(Arc::new( - SledClient::try_new_temporary()?, - ))); + let cluster = test_cluster_context(); - let executor_manager = ExecutorManager::new(cluster_state, slots_policy); + let executor_manager = + ExecutorManager::new(cluster.cluster_state(), slots_policy); let executors = test_executors(10, 4); @@ -772,11 +776,9 @@ mod test { let executors = test_executors(10, 4); - let cluster_state = Arc::new(DefaultClusterState::new(Arc::new( - SledClient::try_new_temporary()?, - ))); - - let executor_manager = ExecutorManager::new(cluster_state, slots_policy); + let cluster = test_cluster_context(); + let executor_manager = + ExecutorManager::new(cluster.cluster_state(), slots_policy); for (executor_metadata, executor_data) in executors { executor_manager @@ -819,11 +821,10 @@ mod test { } async fn test_register_reserve_inner(slots_policy: SlotsPolicy) -> Result<()> { - let cluster_state = Arc::new(DefaultClusterState::new(Arc::new( - SledClient::try_new_temporary()?, - ))); + let cluster = test_cluster_context(); - let executor_manager = ExecutorManager::new(cluster_state, slots_policy); + let executor_manager = + ExecutorManager::new(cluster.cluster_state(), slots_policy); let executors = test_executors(10, 4); diff --git a/ballista/scheduler/src/state/mod.rs b/ballista/scheduler/src/state/mod.rs index 0ce591453..17b1a4fb3 100644 --- a/ballista/scheduler/src/state/mod.rs +++ b/ballista/scheduler/src/state/mod.rs @@ -20,20 +20,18 @@ use datafusion::datasource::source_as_provider; use datafusion::logical_expr::PlanVisitor; use std::any::type_name; use std::collections::HashMap; -use std::future::Future; use std::sync::Arc; use std::time::Instant; use crate::scheduler_server::event::QueryStageSchedulerEvent; -use crate::scheduler_server::SessionBuilder; -use crate::state::backend::{Lock, StateBackendClient}; + use crate::state::executor_manager::{ExecutorManager, ExecutorReservation}; use crate::state::session_manager::SessionManager; use crate::state::task_manager::{TaskLauncher, TaskManager}; +use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::state::execution_graph::TaskDescription; -use backend::cluster::ClusterState; use ballista_core::error::{BallistaError, Result}; use ballista_core::serde::protobuf::TaskStatus; use ballista_core::serde::BallistaCodec; @@ -45,7 +43,6 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, info}; use prost::Message; -pub mod backend; pub mod execution_graph; pub mod execution_graph_dot; pub mod executor_manager; @@ -100,15 +97,11 @@ pub(super) struct SchedulerState SchedulerState { #[cfg(test)] pub fn new_with_default_scheduler_name( - config_client: Arc, - cluster_state: Arc, - session_builder: SessionBuilder, + cluster: BallistaCluster, codec: BallistaCodec, ) -> Self { SchedulerState::new( - config_client, - cluster_state, - session_builder, + cluster, codec, "localhost:50050".to_owned(), SchedulerConfig::default(), @@ -116,35 +109,30 @@ impl SchedulerState, - cluster_state: Arc, - session_builder: SessionBuilder, + cluster: BallistaCluster, codec: BallistaCodec, scheduler_name: String, config: SchedulerConfig, ) -> Self { Self { executor_manager: ExecutorManager::new( - cluster_state, + cluster.cluster_state(), config.executor_slots_policy, ), task_manager: TaskManager::new( - config_client.clone(), - session_builder, + cluster.job_state(), codec.clone(), scheduler_name, ), - session_manager: SessionManager::new(config_client, session_builder), + session_manager: SessionManager::new(cluster.job_state()), codec, config, } } #[allow(dead_code)] - pub(crate) fn with_task_launcher( - config_client: Arc, - cluster_state: Arc, - session_builder: SessionBuilder, + pub(crate) fn new_with_task_launcher( + cluster: BallistaCluster, codec: BallistaCodec, scheduler_name: String, config: SchedulerConfig, @@ -152,17 +140,16 @@ impl SchedulerState Self { Self { executor_manager: ExecutorManager::new( - cluster_state, + cluster.cluster_state(), config.executor_slots_policy, ), task_manager: TaskManager::with_launcher( - config_client.clone(), - session_builder, + cluster.job_state(), codec.clone(), scheduler_name, dispatcher, ), - session_manager: SessionManager::new(config_client, session_builder), + session_manager: SessionManager::new(cluster.job_state()), codec, config, } @@ -414,7 +401,7 @@ impl SchedulerState SchedulerState>( - mut lock: Box, - op: F, -) -> Out { - let result = op.await; - lock.unlock().await; - result -} -/// It takes multiple locks and reverse the order for releasing them to prevent a race condition. -pub async fn with_locks>( - locks: Vec>, - op: F, -) -> Out { - let result = op.await; - for mut lock in locks.into_iter().rev() { - lock.unlock().await; - } - result -} - #[cfg(test)] mod test { - use crate::state::backend::sled::SledClient; + use crate::state::SchedulerState; use ballista_core::config::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS}; use ballista_core::error::Result; @@ -463,11 +430,11 @@ mod test { ExecutorData, ExecutorMetadata, ExecutorSpecification, }; use ballista_core::serde::BallistaCodec; - use ballista_core::utils::default_session_builder; use crate::config::SchedulerConfig; - use crate::state::backend::cluster::DefaultClusterState; - use crate::test_utils::BlackholeTaskLauncher; + + use crate::scheduler_server::timestamp_millis; + use crate::test_utils::{test_cluster_context, BlackholeTaskLauncher}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::logical_expr::{col, sum}; use datafusion::physical_plan::ExecutionPlan; @@ -477,16 +444,14 @@ mod test { use datafusion_proto::protobuf::PhysicalPlanNode; use std::sync::Arc; + const TEST_SCHEDULER_NAME: &str = "localhost:50050"; + // We should free any reservations which are not assigned #[tokio::test] async fn test_offer_free_reservations() -> Result<()> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); let state: Arc> = Arc::new(SchedulerState::new_with_default_scheduler_name( - state_storage, - cluster_state, - default_session_builder, + test_cluster_context(), BallistaCodec::default(), )); @@ -518,15 +483,12 @@ mod test { let config = BallistaConfig::builder() .set(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, "4") .build()?; - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let state: Arc> = - Arc::new(SchedulerState::with_task_launcher( - state_storage, - cluster_state, - default_session_builder, + Arc::new(SchedulerState::new_with_task_launcher( + test_cluster_context(), BallistaCodec::default(), - String::default(), + TEST_SCHEDULER_NAME.into(), SchedulerConfig::default(), Arc::new(BlackholeTaskLauncher::default()), )); @@ -536,6 +498,10 @@ mod test { let plan = test_graph(session_ctx.clone()).await; // Create 4 jobs so we have four pending tasks + state + .task_manager + .queue_job("job-1", "", timestamp_millis()) + .await?; state .task_manager .submit_job( @@ -546,6 +512,10 @@ mod test { 0, ) .await?; + state + .task_manager + .queue_job("job-2", "", timestamp_millis()) + .await?; state .task_manager .submit_job( @@ -556,6 +526,10 @@ mod test { 0, ) .await?; + state + .task_manager + .queue_job("job-3", "", timestamp_millis()) + .await?; state .task_manager .submit_job( @@ -566,6 +540,10 @@ mod test { 0, ) .await?; + state + .task_manager + .queue_job("job-4", "", timestamp_millis()) + .await?; state .task_manager .submit_job( @@ -605,15 +583,12 @@ mod test { let config = BallistaConfig::builder() .set(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, "4") .build()?; - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let state: Arc> = - Arc::new(SchedulerState::with_task_launcher( - state_storage, - cluster_state, - default_session_builder, + Arc::new(SchedulerState::new_with_task_launcher( + test_cluster_context(), BallistaCodec::default(), - String::default(), + TEST_SCHEDULER_NAME.into(), SchedulerConfig::default(), Arc::new(BlackholeTaskLauncher::default()), )); @@ -623,6 +598,10 @@ mod test { let plan = test_graph(session_ctx.clone()).await; // Create a job + state + .task_manager + .queue_job("job-1", "", timestamp_millis()) + .await?; state .task_manager .submit_job( @@ -643,7 +622,6 @@ mod test { let plan_graph = state .task_manager .get_active_execution_graph("job-1") - .await .unwrap(); let task_def = plan_graph .write() diff --git a/ballista/scheduler/src/state/session_manager.rs b/ballista/scheduler/src/state/session_manager.rs index 540a5aad2..e07dbe900 100644 --- a/ballista/scheduler/src/state/session_manager.rs +++ b/ballista/scheduler/src/state/session_manager.rs @@ -16,32 +16,23 @@ // under the License. use crate::scheduler_server::SessionBuilder; -use crate::state::backend::{Keyspace, StateBackendClient}; -use crate::state::{decode_protobuf, encode_protobuf}; use ballista_core::config::BallistaConfig; use ballista_core::error::Result; -use ballista_core::serde::protobuf::{self, KeyValuePair}; use datafusion::prelude::{SessionConfig, SessionContext}; +use crate::cluster::JobState; use datafusion::common::ScalarValue; use log::warn; use std::sync::Arc; #[derive(Clone)] pub struct SessionManager { - state: Arc, - session_builder: SessionBuilder, + state: Arc, } impl SessionManager { - pub fn new( - state: Arc, - session_builder: SessionBuilder, - ) -> Self { - Self { - state, - session_builder, - } + pub fn new(state: Arc) -> Self { + Self { state } } pub async fn update_session( @@ -49,64 +40,18 @@ impl SessionManager { session_id: &str, config: &BallistaConfig, ) -> Result> { - let mut settings: Vec = vec![]; - - for (key, value) in config.settings() { - settings.push(KeyValuePair { - key: key.clone(), - value: value.clone(), - }) - } - - let value = encode_protobuf(&protobuf::SessionSettings { configs: settings })?; - self.state - .put(Keyspace::Sessions, session_id.to_owned(), value) - .await?; - - Ok(create_datafusion_context(config, self.session_builder)) + self.state.update_session(session_id, config).await } pub async fn create_session( &self, config: &BallistaConfig, ) -> Result> { - let mut settings: Vec = vec![]; - - for (key, value) in config.settings() { - settings.push(KeyValuePair { - key: key.clone(), - value: value.clone(), - }) - } - - let mut config_builder = BallistaConfig::builder(); - for kv_pair in &settings { - config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); - } - let config = config_builder.build()?; - - let ctx = create_datafusion_context(&config, self.session_builder); - - let value = encode_protobuf(&protobuf::SessionSettings { configs: settings })?; - self.state - .put(Keyspace::Sessions, ctx.session_id(), value) - .await?; - - Ok(ctx) + self.state.create_session(config).await } pub async fn get_session(&self, session_id: &str) -> Result> { - let value = self.state.get(Keyspace::Sessions, session_id).await?; - - let settings: protobuf::SessionSettings = decode_protobuf(&value)?; - - let mut config_builder = BallistaConfig::builder(); - for kv_pair in &settings.configs { - config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); - } - let config = config_builder.build()?; - - Ok(create_datafusion_context(&config, self.session_builder)) + self.state.get_session(session_id).await } } diff --git a/ballista/scheduler/src/state/task_manager.rs b/ballista/scheduler/src/state/task_manager.rs index 2db2bd83a..e1ed21e45 100644 --- a/ballista/scheduler/src/state/task_manager.rs +++ b/ballista/scheduler/src/state/task_manager.rs @@ -16,40 +16,37 @@ // under the License. use crate::scheduler_server::event::QueryStageSchedulerEvent; -use crate::scheduler_server::SessionBuilder; -use crate::state::backend::{Keyspace, Operation, StateBackendClient}; + use crate::state::execution_graph::{ ExecutionGraph, ExecutionStage, RunningTaskInfo, TaskDescription, }; use crate::state::executor_manager::{ExecutorManager, ExecutorReservation}; -use crate::state::{decode_protobuf, encode_protobuf, with_lock, with_locks}; -use ballista_core::config::BallistaConfig; + use ballista_core::error::BallistaError; use ballista_core::error::Result; -use crate::state::backend::Keyspace::{CompletedJobs, FailedJobs}; -use crate::state::session_manager::create_datafusion_context; - +use crate::cluster::JobState; use ballista_core::serde::protobuf::{ - self, job_status, FailedJob, JobStatus, MultiTaskDefinition, TaskDefinition, TaskId, - TaskStatus, + self, JobStatus, MultiTaskDefinition, TaskDefinition, TaskId, TaskStatus, }; use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto; use ballista_core::serde::scheduler::ExecutorMetadata; use ballista_core::serde::BallistaCodec; use dashmap::DashMap; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionContext; + use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, info, warn}; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::sync::Arc; use std::time::Duration; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; + use tracing::trace; type ActiveJobCache = Arc>; @@ -108,8 +105,7 @@ impl TaskLauncher for DefaultTaskLauncher { #[derive(Clone)] pub struct TaskManager { - state: Arc, - session_builder: SessionBuilder, + state: Arc, codec: BallistaCodec, scheduler_id: String, // Cache for active jobs curated by this scheduler @@ -145,14 +141,12 @@ pub struct UpdatedStages { impl TaskManager { pub fn new( - state: Arc, - session_builder: SessionBuilder, + state: Arc, codec: BallistaCodec, scheduler_id: String, ) -> Self { Self { state, - session_builder, codec, scheduler_id: scheduler_id.clone(), active_job_cache: Arc::new(DashMap::new()), @@ -162,15 +156,13 @@ impl TaskManager #[allow(dead_code)] pub(crate) fn with_launcher( - state: Arc, - session_builder: SessionBuilder, + state: Arc, codec: BallistaCodec, scheduler_id: String, launcher: Arc, ) -> Self { Self { state, - session_builder, codec, scheduler_id, active_job_cache: Arc::new(DashMap::new()), @@ -178,6 +170,16 @@ impl TaskManager } } + /// Enqueue a job for scheduling + pub async fn queue_job( + &self, + job_id: &str, + job_name: &str, + queued_at: u64, + ) -> Result<()> { + self.state.accept_job(job_id, job_name, queued_at).await + } + /// Generate an ExecutionGraph for the job and save it to the persistent state. /// By default, this job will be curated by the scheduler which receives it. /// Then we will also save it to the active execution graph @@ -198,13 +200,8 @@ impl TaskManager queued_at, )?; info!("Submitting execution graph: {:?}", graph); - self.state - .put( - Keyspace::ActiveJobs, - job_id.to_owned(), - self.encode_execution_graph(graph.clone())?, - ) - .await?; + + self.state.submit_job(job_id.to_string(), &graph).await?; graph.revive(); self.active_job_cache @@ -215,36 +212,20 @@ impl TaskManager /// Get a list of active job ids pub async fn get_jobs(&self) -> Result> { - let mut job_ids = vec![]; - for job_id in self.state.scan_keys(Keyspace::ActiveJobs).await? { - job_ids.push(job_id); - } - for job_id in self.state.scan_keys(Keyspace::CompletedJobs).await? { - job_ids.push(job_id); - } - for job_id in self.state.scan_keys(Keyspace::FailedJobs).await? { - job_ids.push(job_id); - } + let job_ids = self.state.get_jobs().await?; let mut jobs = vec![]; for job_id in &job_ids { - let graph = self.get_execution_graph(job_id).await?; - - let mut completed_stages = 0; - for stage in graph.stages().values() { - if let ExecutionStage::Successful(_) = stage { - completed_stages += 1; - } + if let Some(cached) = self.get_active_execution_graph(job_id) { + let graph = cached.read().await; + jobs.push(graph.deref().into()); + } else { + let graph = self.state + .get_execution_graph(job_id) + .await? + .ok_or_else(|| BallistaError::Internal(format!("Error getting job overview, no execution graph found for job {job_id}")))?; + jobs.push((&graph).into()); } - jobs.push(JobOverview { - job_id: job_id.clone(), - job_name: graph.job_name().to_string(), - status: graph.status(), - start_time: graph.start_time(), - end_time: graph.end_time(), - num_stages: graph.stage_count(), - completed_stages, - }); } Ok(jobs) } @@ -252,36 +233,29 @@ impl TaskManager /// Get the status of of a job. First look in the active cache. /// If no one found, then in the Active/Completed jobs, and then in Failed jobs pub async fn get_job_status(&self, job_id: &str) -> Result> { - if let Some(graph) = self.get_active_execution_graph(job_id).await { - let status = graph.read().await.status(); - Ok(Some(status)) - } else if let Ok(graph) = self.get_execution_graph(job_id).await { - Ok(Some(graph.status())) - } else { - let value = self.state.get(Keyspace::FailedJobs, job_id).await?; + if let Some(graph) = self.get_active_execution_graph(job_id) { + let guard = graph.read().await; - if !value.is_empty() { - let status = decode_protobuf(&value)?; - Ok(Some(status)) - } else { - Ok(None) - } + Ok(Some(guard.status())) + } else { + self.state.get_job_status(job_id).await } } /// Get the execution graph of of a job. First look in the active cache. /// If no one found, then in the Active/Completed jobs. - pub async fn get_job_execution_graph( + pub(crate) async fn get_job_execution_graph( &self, job_id: &str, ) -> Result>> { - if let Some(graph) = self.get_active_execution_graph(job_id).await { - Ok(Some(Arc::new(graph.read().await.clone()))) - } else if let Ok(graph) = self.get_execution_graph(job_id).await { - Ok(Some(Arc::new(graph))) + if let Some(cached) = self.get_active_execution_graph(job_id) { + let guard = cached.read().await; + + Ok(Some(Arc::new(guard.deref().clone()))) } else { - // if the job failed then we return no graph for now - Ok(None) + let graph = self.state.get_execution_graph(job_id).await?; + + Ok(graph.map(Arc::new)) } } @@ -306,9 +280,11 @@ impl TaskManager let num_tasks = statuses.len(); debug!("Updating {} tasks in job {}", num_tasks, job_id); - let graph = self.get_active_execution_graph(&job_id).await; - let job_events = if let Some(graph) = graph { - let mut graph = graph.write().await; + // let graph = self.get_active_execution_graph(&job_id).await; + let job_events = if let Some(cached) = + self.get_active_execution_graph(&job_id) + { + let mut graph = cached.write().await; graph.update_task_status( executor, statuses, @@ -388,16 +364,11 @@ impl TaskManager /// and remove the job from ActiveJobs pub(crate) async fn succeed_job(&self, job_id: &str) -> Result<()> { debug!("Moving job {} from Active to Success", job_id); - let lock = self.state.lock(Keyspace::ActiveJobs, "").await?; - with_lock(lock, self.state.delete(Keyspace::ActiveJobs, job_id)).await?; - if let Some(graph) = self.remove_active_execution_graph(job_id).await { + if let Some(graph) = self.remove_active_execution_graph(job_id) { let graph = graph.read().await.clone(); if graph.is_successful() { - let value = self.encode_execution_graph(graph)?; - self.state - .put(Keyspace::CompletedJobs, job_id.to_owned(), value) - .await?; + self.state.save_job(job_id, &graph).await?; } else { error!("Job {} has not finished and cannot be completed", job_id); return Ok(()); @@ -423,35 +394,28 @@ impl TaskManager job_id: &str, failure_reason: String, ) -> Result<(Vec, usize)> { - let locks = self - .state - .acquire_locks(vec![ - (Keyspace::ActiveJobs, job_id), - (Keyspace::FailedJobs, job_id), - ]) - .await?; let (tasks_to_cancel, pending_tasks) = if let Some(graph) = - self.get_active_execution_graph(job_id).await + self.get_active_execution_graph(job_id) { - let (pending_tasks, running_tasks) = { - let guard = graph.read().await; - (guard.available_tasks(), guard.running_tasks()) - }; + let mut guard = graph.write().await; + + let pending_tasks = guard.available_tasks(); + let running_tasks = guard.running_tasks(); info!( "Cancelling {} running tasks for job {}", running_tasks.len(), job_id ); - with_locks(locks, self.fail_job_state(job_id, failure_reason)) - .await - .unwrap(); + + guard.fail_job(failure_reason); + + self.state.save_job(job_id, &guard).await?; (running_tasks, pending_tasks) } else { // TODO listen the job state update event and fix task cancelling warn!("Fail to find job {} in the cache, unable to cancel tasks for job, fail the job state only.", job_id); - with_locks(locks, self.fail_job_state(job_id, failure_reason)).await?; (vec![], 0) }; @@ -465,75 +429,25 @@ impl TaskManager job_id: &str, failure_reason: String, ) -> Result<()> { - debug!("Moving job {} from Active or Queue to Failed", job_id); - let locks = self - .state - .acquire_locks(vec![ - (Keyspace::ActiveJobs, job_id), - (Keyspace::FailedJobs, job_id), - ]) - .await?; - with_locks(locks, self.fail_job_state(job_id, failure_reason)).await?; - - Ok(()) - } - - async fn fail_job_state(&self, job_id: &str, failure_reason: String) -> Result<()> { - let txn_operations = |value: Vec| -> Vec<(Operation, Keyspace, String)> { - vec![ - (Operation::Delete, Keyspace::ActiveJobs, job_id.to_string()), - ( - Operation::Put(value), - Keyspace::FailedJobs, - job_id.to_string(), - ), - ] - }; - - if let Some(graph) = self.remove_active_execution_graph(job_id).await { - let mut graph = graph.write().await; - let previous_status = graph.status(); - graph.fail_job(failure_reason); - - let value = encode_protobuf(&graph.status())?; - let txn_ops = txn_operations(value); - let result = self.state.apply_txn(txn_ops).await; - if result.is_err() { - // Rollback - graph.update_status(previous_status); - warn!("Rollback Execution Graph state change since it did not persisted due to a possible connection error.") - }; - } else { - info!("Fail to find job {} in the cache", job_id); - let status = JobStatus { - status: Some(job_status::Status::Failed(FailedJob { - error: failure_reason.clone(), - })), - }; - let value = encode_protobuf(&status)?; - let txn_ops = txn_operations(value); - self.state.apply_txn(txn_ops).await?; - }; - - Ok(()) + self.state + .fail_unscheduled_job(job_id, failure_reason) + .await } pub async fn update_job(&self, job_id: &str) -> Result { - debug!("Update job {} in Active", job_id); - if let Some(graph) = self.get_active_execution_graph(job_id).await { + debug!("Update active job {job_id}"); + if let Some(graph) = self.get_active_execution_graph(job_id) { let mut graph = graph.write().await; let curr_available_tasks = graph.available_tasks(); graph.revive(); - let graph = graph.clone(); - let new_tasks = graph.available_tasks() - curr_available_tasks; + println!("Saving job with status {:?}", graph.status()); - let value = self.encode_execution_graph(graph)?; - self.state - .put(Keyspace::ActiveJobs, job_id.to_owned(), value) - .await?; + self.state.save_job(job_id, &graph).await?; + + let new_tasks = graph.available_tasks() - curr_available_tasks; Ok(new_tasks) } else { @@ -561,26 +475,13 @@ impl TaskManager } } - let lock = self.state.lock(Keyspace::ActiveJobs, "").await?; - with_lock(lock, async { - // Transactional update graphs - let txn_ops: Vec<(Operation, Keyspace, String)> = updated_graphs - .into_iter() - .map(|(job_id, graph)| { - let value = self.encode_execution_graph(graph)?; - Ok((Operation::Put(value), Keyspace::ActiveJobs, job_id)) - }) - .collect::>>()?; - self.state.apply_txn(txn_ops).await?; - Ok(running_tasks_to_cancel) - }) - .await + Ok(running_tasks_to_cancel) } /// Retrieve the number of available tasks for the given job. The value returned /// is strictly a point-in-time snapshot pub async fn get_available_task_count(&self, job_id: &str) -> Result { - if let Some(graph) = self.get_active_execution_graph(job_id).await { + if let Some(graph) = self.get_active_execution_graph(job_id) { let available_tasks = graph.read().await.available_tasks(); Ok(available_tasks) } else { @@ -738,17 +639,18 @@ impl TaskManager } /// Get the `ExecutionGraph` for the given job ID from cache - pub(crate) async fn get_active_execution_graph( + pub(crate) fn get_active_execution_graph( &self, job_id: &str, ) -> Option>> { self.active_job_cache .get(job_id) - .map(|value| value.execution_graph.clone()) + .as_deref() + .map(|cached| cached.execution_graph.clone()) } /// Remove the `ExecutionGraph` for the given job ID from cache - pub(crate) async fn remove_active_execution_graph( + pub(crate) fn remove_active_execution_graph( &self, job_id: &str, ) -> Option>> { @@ -757,52 +659,6 @@ impl TaskManager .map(|value| value.1.execution_graph) } - /// Get the `ExecutionGraph` for the given job ID. This will search fist in the `ActiveJobs` - /// keyspace and then, if it doesn't find anything, search the `CompletedJobs` keyspace. - pub(crate) async fn get_execution_graph( - &self, - job_id: &str, - ) -> Result { - let value = self.state.get(Keyspace::ActiveJobs, job_id).await?; - - if value.is_empty() { - let value = self.state.get(Keyspace::CompletedJobs, job_id).await?; - self.decode_execution_graph(value).await - } else { - self.decode_execution_graph(value).await - } - } - - async fn get_session(&self, session_id: &str) -> Result> { - let value = self.state.get(Keyspace::Sessions, session_id).await?; - - let settings: protobuf::SessionSettings = decode_protobuf(&value)?; - - let mut config_builder = BallistaConfig::builder(); - for kv_pair in &settings.configs { - config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); - } - let config = config_builder.build()?; - - Ok(create_datafusion_context(&config, self.session_builder)) - } - - async fn decode_execution_graph(&self, value: Vec) -> Result { - let proto: protobuf::ExecutionGraph = decode_protobuf(&value)?; - - let session_id = &proto.session_id; - - let session_ctx = self.get_session(session_id).await?; - - ExecutionGraph::decode_execution_graph(proto, &self.codec, &session_ctx).await - } - - fn encode_execution_graph(&self, graph: ExecutionGraph) -> Result> { - let proto = ExecutionGraph::encode_execution_graph(graph, &self.codec)?; - - encode_protobuf(&proto) - } - /// Generate a new random Job ID pub fn generate_job_id(&self) -> String { let mut rng = thread_rng(); @@ -814,55 +670,21 @@ impl TaskManager } /// Clean up a failed job in FailedJobs Keyspace by delayed clean_up_interval seconds - pub(crate) fn clean_up_failed_job_delayed( - &self, - job_id: String, - clean_up_interval: u64, - ) { + pub(crate) fn clean_up_job_delayed(&self, job_id: String, clean_up_interval: u64) { if clean_up_interval == 0 { info!("The interval is 0 and the clean up for the failed job state {} will not triggered", job_id); return; } - self.delete_from_state_backend_delayed(FailedJobs, job_id, clean_up_interval) - } - - /// Clean up a successful job in CompletedJobs Keyspace by delayed clean_up_interval seconds - pub(crate) fn delete_successful_job_delayed( - &self, - job_id: String, - clean_up_interval: u64, - ) { - if clean_up_interval == 0 { - info!("The interval is 0 and the clean up for the successful job state {} will not triggered", job_id); - return; - } - self.delete_from_state_backend_delayed(CompletedJobs, job_id, clean_up_interval) - } - /// Clean up entries in some keyspace by delayed clean_up_interval seconds - fn delete_from_state_backend_delayed( - &self, - keyspace: Keyspace, - key: String, - clean_up_interval: u64, - ) { let state = self.state.clone(); tokio::spawn(async move { + let job_id = job_id; tokio::time::sleep(Duration::from_secs(clean_up_interval)).await; - Self::delete_from_state_backend(state, keyspace, &key).await + if let Err(err) = state.remove_job(&job_id).await { + error!("Failed to delete job {job_id}: {err:?}"); + } }); } - - async fn delete_from_state_backend( - state: Arc, - keyspace: Keyspace, - key: &str, - ) -> Result<()> { - let lock = state.lock(keyspace.clone(), "").await?; - with_lock(lock, state.delete(keyspace, key)).await?; - - Ok(()) - } } pub struct JobOverview { @@ -874,3 +696,24 @@ pub struct JobOverview { pub num_stages: usize, pub completed_stages: usize, } + +impl From<&ExecutionGraph> for JobOverview { + fn from(value: &ExecutionGraph) -> Self { + let mut completed_stages = 0; + for stage in value.stages().values() { + if let ExecutionStage::Successful(_) = stage { + completed_stages += 1; + } + } + + Self { + job_id: value.job_id().to_string(), + job_name: value.job_name().to_string(), + status: value.status(), + start_time: value.start_time(), + end_time: value.end_time(), + num_stages: value.stage_count(), + completed_stages, + } + } +} diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs index d5465f006..e0971dc61 100644 --- a/ballista/scheduler/src/test_utils.rs +++ b/ballista/scheduler/src/test_utils.rs @@ -27,7 +27,6 @@ use async_trait::async_trait; use crate::config::SchedulerConfig; use crate::metrics::SchedulerMetricsCollector; use crate::scheduler_server::{timestamp_millis, SchedulerServer}; -use crate::state::backend::sled::SledClient; use crate::state::executor_manager::ExecutorManager; use crate::state::task_manager::TaskLauncher; @@ -35,24 +34,30 @@ use crate::state::task_manager::TaskLauncher; use ballista_core::config::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS}; use ballista_core::serde::protobuf::job_status::Status; use ballista_core::serde::protobuf::{ - task_status, JobStatus, MultiTaskDefinition, ShuffleWritePartition, SuccessfulTask, - TaskId, TaskStatus, + task_status, FailedTask, JobStatus, MultiTaskDefinition, ShuffleWritePartition, + SuccessfulTask, TaskId, TaskStatus, }; use ballista_core::serde::scheduler::{ ExecutorData, ExecutorMetadata, ExecutorSpecification, }; -use ballista_core::serde::BallistaCodec; +use ballista_core::serde::{protobuf, BallistaCodec}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::common::DataFusionError; use datafusion::datasource::{TableProvider, TableType}; use datafusion::execution::context::{SessionConfig, SessionContext, SessionState}; +use datafusion::logical_expr::expr::Sort; use datafusion::logical_expr::{Expr, LogicalPlan}; +use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::CsvReadOptions; +use datafusion::prelude::{col, count, sum, CsvReadOptions, JoinType}; +use datafusion::test_util::scan_empty; +use crate::cluster::BallistaCluster; use crate::scheduler_server::event::QueryStageSchedulerEvent; + use crate::scheduler_server::query_stage_scheduler::QueryStageScheduler; -use crate::state::backend::cluster::DefaultClusterState; +use crate::state::execution_graph::{ExecutionGraph, TaskDescription}; +use ballista_core::utils::default_session_builder; use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use parking_lot::Mutex; use tokio::sync::mpsc::{channel, Receiver, Sender}; @@ -61,6 +66,8 @@ pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; +const TEST_SCHEDULER_NAME: &str = "localhost:50050"; + /// Sometimes we need to construct logical plans that will produce errors /// when we try and create physical plan. A scan using `ExplodingTableProvider` /// will do the trick @@ -116,6 +123,10 @@ pub async fn await_condition>, F: Fn() -> Fut> Ok(false) } +pub fn test_cluster_context() -> BallistaCluster { + BallistaCluster::new_memory(TEST_SCHEDULER_NAME, default_session_builder) +} + pub async fn datafusion_test_context(path: &str) -> Result { let default_shuffle_partitions = 2; let config = SessionConfig::new().with_target_partitions(default_shuffle_partitions); @@ -382,8 +393,7 @@ impl SchedulerTest { task_slots_per_executor: usize, runner: Option>, ) -> Result { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = BallistaCluster::new_from_config(&config).await?; let ballista_config = if num_executors > 0 && task_slots_per_executor > 0 { BallistaConfig::builder() @@ -419,10 +429,9 @@ impl SchedulerTest { }; let mut scheduler: SchedulerServer = - SchedulerServer::with_task_launcher( + SchedulerServer::new_with_task_launcher( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster, BallistaCodec::default(), config, metrics_collector, @@ -531,6 +540,46 @@ impl SchedulerTest { .await } + pub async fn await_completion_timeout( + &self, + job_id: &str, + timeout_ms: u64, + ) -> Result { + let mut time = 0; + let final_status: Result = loop { + let status = self + .scheduler + .state + .task_manager + .get_job_status(job_id) + .await?; + + if let Some(JobStatus { + status: Some(inner), + .. + }) = status.as_ref() + { + match inner { + Status::Failed(_) | Status::Successful(_) => { + break Ok(status.unwrap()) + } + _ => { + if time >= timeout_ms { + break Ok(status.unwrap()); + } else { + continue; + } + } + } + } + + tokio::time::sleep(Duration::from_millis(100)).await; + time += 100; + }; + + final_status + } + pub async fn await_completion(&self, job_id: &str) -> Result { let final_status: Result = loop { let status = self @@ -542,6 +591,7 @@ impl SchedulerTest { if let Some(JobStatus { status: Some(inner), + .. }) = status.as_ref() { match inner { @@ -597,6 +647,7 @@ impl SchedulerTest { if let Some(JobStatus { status: Some(inner), + .. }) = status.as_ref() { match inner { @@ -741,3 +792,288 @@ pub fn assert_failed_event(job_id: &str, collector: &TestMetricsCollector) { assert!(found, "{}", "Expected failed event for job {job_id}"); } + +pub async fn test_aggregation_plan(partition: usize) -> ExecutionGraph { + let config = SessionConfig::new().with_target_partitions(partition); + let ctx = Arc::new(SessionContext::with_config(config)); + let session_state = ctx.state(); + + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("gmv", DataType::UInt64, false), + ]); + + let logical_plan = scan_empty(None, &schema, Some(vec![0, 1])) + .unwrap() + .aggregate(vec![col("id")], vec![sum(col("gmv"))]) + .unwrap() + .build() + .unwrap(); + + let optimized_plan = session_state.optimize(&logical_plan).unwrap(); + + let plan = session_state + .create_physical_plan(&optimized_plan) + .await + .unwrap(); + + println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); + + ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap() +} + +pub async fn test_two_aggregations_plan(partition: usize) -> ExecutionGraph { + let config = SessionConfig::new().with_target_partitions(partition); + let ctx = Arc::new(SessionContext::with_config(config)); + let session_state = ctx.state(); + + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + Field::new("gmv", DataType::UInt64, false), + ]); + + let logical_plan = scan_empty(None, &schema, Some(vec![0, 1, 2])) + .unwrap() + .aggregate(vec![col("id"), col("name")], vec![sum(col("gmv"))]) + .unwrap() + .aggregate(vec![col("id")], vec![count(col("id"))]) + .unwrap() + .build() + .unwrap(); + + let optimized_plan = session_state.optimize(&logical_plan).unwrap(); + + let plan = session_state + .create_physical_plan(&optimized_plan) + .await + .unwrap(); + + println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); + + ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap() +} + +pub async fn test_coalesce_plan(partition: usize) -> ExecutionGraph { + let config = SessionConfig::new().with_target_partitions(partition); + let ctx = Arc::new(SessionContext::with_config(config)); + let session_state = ctx.state(); + + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("gmv", DataType::UInt64, false), + ]); + + let logical_plan = scan_empty(None, &schema, Some(vec![0, 1])) + .unwrap() + .limit(0, Some(1)) + .unwrap() + .build() + .unwrap(); + + let optimized_plan = session_state.optimize(&logical_plan).unwrap(); + + let plan = session_state + .create_physical_plan(&optimized_plan) + .await + .unwrap(); + + ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap() +} + +pub async fn test_join_plan(partition: usize) -> ExecutionGraph { + let mut config = SessionConfig::new().with_target_partitions(partition); + config + .config_options_mut() + .optimizer + .enable_round_robin_repartition = false; + let ctx = Arc::new(SessionContext::with_config(config)); + let session_state = ctx.state(); + + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("gmv", DataType::UInt64, false), + ]); + + let left_plan = scan_empty(Some("left"), &schema, None).unwrap(); + + let right_plan = scan_empty(Some("right"), &schema, None) + .unwrap() + .build() + .unwrap(); + + let sort_expr = Expr::Sort(Sort::new(Box::new(col("id")), false, false)); + + let logical_plan = left_plan + .join(right_plan, JoinType::Inner, (vec!["id"], vec!["id"]), None) + .unwrap() + .aggregate(vec![col("id")], vec![sum(col("gmv"))]) + .unwrap() + .sort(vec![sort_expr]) + .unwrap() + .build() + .unwrap(); + + let optimized_plan = session_state.optimize(&logical_plan).unwrap(); + + let plan = session_state + .create_physical_plan(&optimized_plan) + .await + .unwrap(); + + println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); + + let graph = + ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap(); + + println!("{graph:?}"); + + graph +} + +pub async fn test_union_all_plan(partition: usize) -> ExecutionGraph { + let config = SessionConfig::new().with_target_partitions(partition); + let ctx = Arc::new(SessionContext::with_config(config)); + let session_state = ctx.state(); + + let logical_plan = ctx + .sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;") + .await + .unwrap() + .into_optimized_plan() + .unwrap(); + + let optimized_plan = session_state.optimize(&logical_plan).unwrap(); + + let plan = session_state + .create_physical_plan(&optimized_plan) + .await + .unwrap(); + + println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); + + let graph = + ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap(); + + println!("{graph:?}"); + + graph +} + +pub async fn test_union_plan(partition: usize) -> ExecutionGraph { + let config = SessionConfig::new().with_target_partitions(partition); + let ctx = Arc::new(SessionContext::with_config(config)); + let session_state = ctx.state(); + + let logical_plan = ctx + .sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;") + .await + .unwrap() + .into_optimized_plan() + .unwrap(); + + let optimized_plan = session_state.optimize(&logical_plan).unwrap(); + + let plan = session_state + .create_physical_plan(&optimized_plan) + .await + .unwrap(); + + println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent()); + + let graph = + ExecutionGraph::new("localhost:50050", "job", "", "session", plan, 0).unwrap(); + + println!("{graph:?}"); + + graph +} + +pub fn mock_executor(executor_id: String) -> ExecutorMetadata { + ExecutorMetadata { + id: executor_id, + host: "localhost2".to_string(), + port: 8080, + grpc_port: 9090, + specification: ExecutorSpecification { task_slots: 1 }, + } +} + +pub fn mock_completed_task(task: TaskDescription, executor_id: &str) -> TaskStatus { + let mut partitions: Vec = vec![]; + + let num_partitions = task + .output_partitioning + .map(|p| p.partition_count()) + .unwrap_or(1); + + for partition_id in 0..num_partitions { + partitions.push(protobuf::ShuffleWritePartition { + partition_id: partition_id as u64, + path: format!( + "/{}/{}/{}", + task.partition.job_id, + task.partition.stage_id, + task.partition.partition_id + ), + num_batches: 1, + num_rows: 1, + num_bytes: 1, + }) + } + + // Complete the task + protobuf::TaskStatus { + task_id: task.task_id as u32, + job_id: task.partition.job_id.clone(), + stage_id: task.partition.stage_id as u32, + stage_attempt_num: task.stage_attempt_num as u32, + partition_id: task.partition.partition_id as u32, + launch_time: 0, + start_exec_time: 0, + end_exec_time: 0, + metrics: vec![], + status: Some(task_status::Status::Successful(protobuf::SuccessfulTask { + executor_id: executor_id.to_owned(), + partitions, + })), + } +} + +pub fn mock_failed_task(task: TaskDescription, failed_task: FailedTask) -> TaskStatus { + let mut partitions: Vec = vec![]; + + let num_partitions = task + .output_partitioning + .map(|p| p.partition_count()) + .unwrap_or(1); + + for partition_id in 0..num_partitions { + partitions.push(protobuf::ShuffleWritePartition { + partition_id: partition_id as u64, + path: format!( + "/{}/{}/{}", + task.partition.job_id, + task.partition.stage_id, + task.partition.partition_id + ), + num_batches: 1, + num_rows: 1, + num_bytes: 1, + }) + } + + // Fail the task + protobuf::TaskStatus { + task_id: task.task_id as u32, + job_id: task.partition.job_id.clone(), + stage_id: task.partition.stage_id as u32, + stage_attempt_num: task.stage_attempt_num as u32, + partition_id: task.partition.partition_id as u32, + launch_time: 0, + start_exec_time: 0, + end_exec_time: 0, + metrics: vec![], + status: Some(task_status::Status::Failed(failed_task)), + } +} diff --git a/ci/scripts/rust_toml_fmt.sh b/ci/scripts/rust_toml_fmt.sh index e297ef001..3ce50ace0 100755 --- a/ci/scripts/rust_toml_fmt.sh +++ b/ci/scripts/rust_toml_fmt.sh @@ -18,4 +18,4 @@ # under the License. set -ex -find . -mindepth 2 -name 'Cargo.toml' -exec cargo tomlfmt -p {} \; +find . -mindepth 2 -name 'Cargo.toml' -exec cargo tomlfmt -k -p {} \;