diff --git a/ballista/scheduler/scheduler_config_spec.toml b/ballista/scheduler/scheduler_config_spec.toml index daf7bf604..52d2ee5cc 100644 --- a/ballista/scheduler/scheduler_config_spec.toml +++ b/ballista/scheduler/scheduler_config_spec.toml @@ -24,6 +24,11 @@ conf_file_param = "config_file" name = "version" doc = "Print version of this executable" +[[param]] +name = "advertise_endpoint" +type = "String" +doc = "Route for proxying flight results via scheduler. Should be of the form 'IP:PORT'" + [[param]] abbr = "b" name = "config_backend" diff --git a/ballista/scheduler/src/flight_sql.rs b/ballista/scheduler/src/flight_sql.rs index 7caca105b..c9c8c6bbb 100644 --- a/ballista/scheduler/src/flight_sql.rs +++ b/ballista/scheduler/src/flight_sql.rs @@ -196,26 +196,41 @@ impl FlightSqlServiceImpl { ) -> Result, Status> { let mut fieps: Vec<_> = vec![]; for loc in completed.partition_location.iter() { - let (host, port) = if let Some(ref md) = loc.executor_meta { + let (exec_host, exec_port) = if let Some(ref md) = loc.executor_meta { (md.host.clone(), md.port) } else { Err(Status::internal( - "Invalid partition location, missing executor metadata".to_string(), + "Invalid partition location, missing executor metadata and advertise_endpoint flag is undefined.".to_string(), ))? }; + + let (host, port) = match &self.server.advertise_endpoint { + Some(endpoint) => { + let advertise_endpoint_vec: Vec<&str> = endpoint.split(":").collect(); + match advertise_endpoint_vec.as_slice() { + [host_ip, port] => { + (String::from(*host_ip), FromStr::from_str(*port).expect("Failed to parse port from advertise-endpoint.")) + } + _ => { + Err(Status::internal("advertise-endpoint flag has incorrect format. Expected IP:Port".to_string()))? + } + } + } + None => (exec_host.clone(), exec_port.clone()), + }; + let fetch = if let Some(ref id) = loc.partition_id { let fetch = protobuf::FetchPartition { job_id: id.job_id.clone(), stage_id: id.stage_id, partition_id: id.partition_id, path: loc.path.clone(), - host: host.clone(), - port, + // Use executor ip:port for routing to flight result + host: exec_host.clone(), + port: exec_port, }; protobuf::Action { - action_type: Some(protobuf::action::ActionType::FetchPartition( - fetch, - )), + action_type: Some(FetchPartition(fetch)), settings: vec![], } } else { @@ -227,7 +242,7 @@ impl FlightSqlServiceImpl { } else { Err(Status::internal("Error getting stats".to_string()))? } - let authority = format!("{}:{}", &host, &port); // TODO: my host & port + let authority = format!("{}:{}", &host, &port); let loc = Location { uri: format!("grpc+tcp://{}", authority), }; @@ -458,7 +473,12 @@ impl FlightSqlService for FlightSqlServiceImpl { let stream = flight_client .do_get(request) .await - .map_err(|e| Status::internal(format!("{:?}", e)))? + .map_err(|e| { + Status::internal(format!( + "Error from within flight_client.do_get(): {:?}\n", + e + )) + })? .into_inner(); return Ok(Response::new(Box::pin(stream))); } diff --git a/ballista/scheduler/src/main.rs b/ballista/scheduler/src/main.rs index 0a0c4fafa..b8eac568d 100644 --- a/ballista/scheduler/src/main.rs +++ b/ballista/scheduler/src/main.rs @@ -45,7 +45,7 @@ use ballista_scheduler::state::backend::{StateBackend, StateBackendClient}; use ballista_core::config::TaskSchedulingPolicy; use ballista_core::serde::BallistaCodec; -use ballista_core::utils::default_session_builder; + use log::info; #[macro_use] @@ -75,6 +75,7 @@ async fn start_server( scheduling_policy: TaskSchedulingPolicy, slots_policy: SlotsPolicy, event_loop_buffer_size: usize, + advertise_endpoint: Option, ) -> Result<()> { info!( "Ballista v{} Scheduler listening on {:?}", @@ -85,6 +86,7 @@ async fn start_server( "Starting Scheduler grpc server with task scheduling policy of {:?}", scheduling_policy ); + let mut scheduler_server: SchedulerServer = match scheduling_policy { TaskSchedulingPolicy::PushStaged => SchedulerServer::new_with_policy( @@ -93,14 +95,15 @@ async fn start_server( scheduling_policy, slots_policy, BallistaCodec::default(), - default_session_builder, event_loop_buffer_size, + advertise_endpoint, ), _ => SchedulerServer::new( scheduler_name, config_backend.clone(), BallistaCodec::default(), event_loop_buffer_size, + advertise_endpoint, ), }; @@ -255,6 +258,7 @@ async fn main() -> Result<()> { scheduling_policy, slots_policy, event_loop_buffer_size, + opt.advertise_endpoint, ) .await?; Ok(()) diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 7b5c96bdd..d92f8c34e 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -577,6 +577,7 @@ mod test { state_storage.clone(), BallistaCodec::default(), 10000, + None, ); scheduler.init().await?; let exec_meta = ExecutorRegistration { @@ -663,6 +664,7 @@ mod test { state_storage.clone(), BallistaCodec::default(), 10000, + None, ); scheduler.init().await?; @@ -743,6 +745,7 @@ mod test { state_storage.clone(), BallistaCodec::default(), 10000, + None, ); scheduler.init().await?; diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index 44fed7db7..a74e6fe6c 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -57,6 +57,7 @@ pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState; #[derive(Clone)] pub struct SchedulerServer { pub scheduler_name: String, + pub advertise_endpoint: Option, pub(crate) state: Arc>, pub start_time: u128, policy: TaskSchedulingPolicy, @@ -69,15 +70,22 @@ impl SchedulerServer, codec: BallistaCodec, event_loop_buffer_size: usize, + advertise_endpoint: Option, ) -> Self { - SchedulerServer::new_with_policy( - scheduler_name, + let state = Arc::new(SchedulerState::new( config, - TaskSchedulingPolicy::PullStaged, - SlotsPolicy::Bias, - codec, default_session_builder, + codec, + scheduler_name.clone(), + SlotsPolicy::Bias, + )); + + SchedulerServer::new_with_state( + scheduler_name, + TaskSchedulingPolicy::PullStaged, + state, event_loop_buffer_size, + advertise_endpoint, ) } @@ -87,15 +95,22 @@ impl SchedulerServer, session_builder: SessionBuilder, event_loop_buffer_size: usize, + advertise_endpoint: Option, ) -> Self { - SchedulerServer::new_with_policy( - scheduler_name, + let state = Arc::new(SchedulerState::new( config, - TaskSchedulingPolicy::PullStaged, - SlotsPolicy::Bias, - codec, session_builder, + codec, + scheduler_name.clone(), + SlotsPolicy::Bias, + )); + + SchedulerServer::new_with_state( + scheduler_name, + TaskSchedulingPolicy::PullStaged, + state, event_loop_buffer_size, + advertise_endpoint, ) } @@ -105,12 +120,12 @@ impl SchedulerServer, - session_builder: SessionBuilder, event_loop_buffer_size: usize, + advertise_endpoint: Option, ) -> Self { let state = Arc::new(SchedulerState::new( config, - session_builder, + default_session_builder, codec, scheduler_name.clone(), slots_policy, @@ -121,6 +136,7 @@ impl SchedulerServer SchedulerServer>, event_loop_buffer_size: usize, + advertise_endpoint: Option, ) -> Self { let query_stage_scheduler = Arc::new(QueryStageScheduler::new(state.clone(), policy)); @@ -146,6 +163,7 @@ impl SchedulerServer Result { Arc::new(client), BallistaCodec::default(), 10000, + None, ); scheduler_server.init().await?; let server = SchedulerGrpcServer::new(scheduler_server.clone());