diff --git a/ballista/scheduler/scheduler_config_spec.toml b/ballista/scheduler/scheduler_config_spec.toml index daf7bf604..148e977c4 100644 --- a/ballista/scheduler/scheduler_config_spec.toml +++ b/ballista/scheduler/scheduler_config_spec.toml @@ -24,6 +24,11 @@ conf_file_param = "config_file" name = "version" doc = "Print version of this executable" +[[param]] +name = "advertise_host" +type = "String" +doc = "Route for proxying flight results via scheduler. Should be of the form 'IP:PORT'" + [[param]] abbr = "b" name = "config_backend" diff --git a/ballista/scheduler/src/flight_sql.rs b/ballista/scheduler/src/flight_sql.rs index 7caca105b..cf6e2d7d4 100644 --- a/ballista/scheduler/src/flight_sql.rs +++ b/ballista/scheduler/src/flight_sql.rs @@ -196,26 +196,41 @@ impl FlightSqlServiceImpl { ) -> Result, Status> { let mut fieps: Vec<_> = vec![]; for loc in completed.partition_location.iter() { - let (host, port) = if let Some(ref md) = loc.executor_meta { + + let (exec_host, exec_port) = if let Some(ref md) = loc.executor_meta { (md.host.clone(), md.port) } else { Err(Status::internal( - "Invalid partition location, missing executor metadata".to_string(), + "Invalid partition location, missing executor metadata and advertise_host flag is undefined.".to_string(), ))? }; + + let (host, port) = match self.server.advertise_host { + Some(_) => { + let advertise_host_flag: Vec<&str> = self + .server + .advertise_host + .as_ref() + .unwrap() + .split(":") + .collect(); + (advertise_host_flag[0].to_string(), advertise_host_flag[1].parse().unwrap()) + } + None => (exec_host.clone(), exec_port.clone()) + }; + let fetch = if let Some(ref id) = loc.partition_id { let fetch = protobuf::FetchPartition { job_id: id.job_id.clone(), stage_id: id.stage_id, partition_id: id.partition_id, path: loc.path.clone(), - host: host.clone(), - port, + // Use executor ip:port for routing to flight result + host: exec_host.clone(), + port: exec_port, }; protobuf::Action { - action_type: Some(protobuf::action::ActionType::FetchPartition( - fetch, - )), + action_type: Some(FetchPartition(fetch)), settings: vec![], } } else { @@ -227,7 +242,7 @@ impl FlightSqlServiceImpl { } else { Err(Status::internal("Error getting stats".to_string()))? } - let authority = format!("{}:{}", &host, &port); // TODO: my host & port + let authority = format!("{}:{}", &host, &port); let loc = Location { uri: format!("grpc+tcp://{}", authority), }; @@ -458,7 +473,7 @@ impl FlightSqlService for FlightSqlServiceImpl { let stream = flight_client .do_get(request) .await - .map_err(|e| Status::internal(format!("{:?}", e)))? + .map_err(|e| Status::internal(format!("Error from within flight_client.do_get(): {:?}\n", e)))? .into_inner(); return Ok(Response::new(Box::pin(stream))); } diff --git a/ballista/scheduler/src/main.rs b/ballista/scheduler/src/main.rs index 0a0c4fafa..9b991bcdd 100644 --- a/ballista/scheduler/src/main.rs +++ b/ballista/scheduler/src/main.rs @@ -75,6 +75,7 @@ async fn start_server( scheduling_policy: TaskSchedulingPolicy, slots_policy: SlotsPolicy, event_loop_buffer_size: usize, + advertise_host: Option, ) -> Result<()> { info!( "Ballista v{} Scheduler listening on {:?}", @@ -95,12 +96,14 @@ async fn start_server( BallistaCodec::default(), default_session_builder, event_loop_buffer_size, + advertise_host, ), _ => SchedulerServer::new( scheduler_name, config_backend.clone(), BallistaCodec::default(), event_loop_buffer_size, + advertise_host, ), }; @@ -255,6 +258,7 @@ async fn main() -> Result<()> { scheduling_policy, slots_policy, event_loop_buffer_size, + opt.advertise_host, ) .await?; Ok(()) diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 7b5c96bdd..d92f8c34e 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -577,6 +577,7 @@ mod test { state_storage.clone(), BallistaCodec::default(), 10000, + None, ); scheduler.init().await?; let exec_meta = ExecutorRegistration { @@ -663,6 +664,7 @@ mod test { state_storage.clone(), BallistaCodec::default(), 10000, + None, ); scheduler.init().await?; @@ -743,6 +745,7 @@ mod test { state_storage.clone(), BallistaCodec::default(), 10000, + None, ); scheduler.init().await?; diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index 44fed7db7..e5ecb58ca 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -57,6 +57,7 @@ pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState; #[derive(Clone)] pub struct SchedulerServer { pub scheduler_name: String, + pub advertise_host: Option, pub(crate) state: Arc>, pub start_time: u128, policy: TaskSchedulingPolicy, @@ -69,6 +70,7 @@ impl SchedulerServer, codec: BallistaCodec, event_loop_buffer_size: usize, + advertise_host: Option, ) -> Self { SchedulerServer::new_with_policy( scheduler_name, @@ -78,6 +80,7 @@ impl SchedulerServer SchedulerServer, session_builder: SessionBuilder, event_loop_buffer_size: usize, + advertise_host: Option, ) -> Self { SchedulerServer::new_with_policy( scheduler_name, @@ -96,6 +100,7 @@ impl SchedulerServer SchedulerServer, session_builder: SessionBuilder, event_loop_buffer_size: usize, + advertise_host: Option, ) -> Self { let state = Arc::new(SchedulerState::new( config, @@ -121,6 +127,7 @@ impl SchedulerServer SchedulerServer>, event_loop_buffer_size: usize, + advertise_host: Option, ) -> Self { let query_stage_scheduler = Arc::new(QueryStageScheduler::new(state.clone(), policy)); @@ -146,6 +154,7 @@ impl SchedulerServer Result { Arc::new(client), BallistaCodec::default(), 10000, + None, ); scheduler_server.init().await?; let server = SchedulerGrpcServer::new(scheduler_server.clone());