Skip to content

Commit

Permalink
feat: merge runtime updates in main ecs branch (#1709)
Browse files Browse the repository at this point in the history
* feat: runtime healthcheck, start runtime on 0.0.0.0

running on unspecified ip was necessary for the runner to be able to reach the runtime when they are running in separate containers

* feat(proto): update runtime::get_client to work with

* misc(proto): get client takes u16 port

* feat: add health toggle to runtime

* feat: set runtime to unhealthy if it doesn't start within 60s

* feat: change runtime::get_client to take address

* feat: kill runtime if it doesn't become healthy in time

* feat: increase provisioning timeout duration
  • Loading branch information
oddgrd authored and iulianbarbu committed Apr 11, 2024
1 parent 595c7a5 commit 197a344
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 22 deletions.
5 changes: 5 additions & 0 deletions proto/runtime.proto
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ service Runtime {

// Channel to notify a service has been stopped
rpc SubscribeStop(SubscribeStopRequest) returns (stream SubscribeStopResponse);

rpc HealthCheck(Ping) returns (Pong);
}

message LoadRequest {
Expand Down Expand Up @@ -78,3 +80,6 @@ enum StopReason {
// Service crashed
Crash = 2;
}

message Ping {}
message Pong {}
63 changes: 63 additions & 0 deletions proto/src/generated/runtime.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 3 additions & 6 deletions proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,14 @@ mod _runtime_client {
use tracing::{info, trace};

pub type Client = runtime_client::RuntimeClient<
shuttle_common::claims::ClaimService<
shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
>,
shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
>;

/// Get a runtime client that is correctly configured
#[cfg(feature = "client")]
pub async fn get_client(port: &str) -> anyhow::Result<Client> {
pub async fn get_client(address: String) -> anyhow::Result<Client> {
info!("connecting runtime client");
let conn = Endpoint::new(format!("http://127.0.0.1:{port}"))
let conn = Endpoint::new(address)
.context("creating runtime client endpoint")?
.connect_timeout(Duration::from_secs(5));

Expand All @@ -177,7 +175,6 @@ mod _runtime_client {
.context("runtime control port did not open in time")?;

let runtime_service = tower::ServiceBuilder::new()
.layer(shuttle_common::claims::ClaimLayer)
.layer(shuttle_common::claims::InjectPropagationLayer)
.service(channel);

Expand Down
92 changes: 79 additions & 13 deletions runtime/src/alpha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@ use std::{
collections::BTreeMap,
iter::FromIterator,
net::{Ipv4Addr, SocketAddr},
ops::DerefMut,
ops::{Deref, DerefMut},
str::FromStr,
sync::Mutex,
sync::{Arc, Mutex},
time::Duration,
};

use anyhow::Context;
use async_trait::async_trait;
use core::future::Future;
use shuttle_common::{extract_propagation::ExtractPropagationLayer, secrets::Secret};
use shuttle_proto::runtime::{
runtime_server::{Runtime, RuntimeServer},
LoadRequest, LoadResponse, StartRequest, StartResponse, StopReason, StopRequest, StopResponse,
SubscribeStopRequest, SubscribeStopResponse,
use shuttle_proto::{
runtime::{
runtime_server::{Runtime, RuntimeServer},
LoadRequest, LoadResponse, StartRequest, StartResponse, StopReason, StopRequest,
StopResponse, SubscribeStopRequest, SubscribeStopResponse,
},
runtime::{Ping, Pong},
};
use shuttle_service::{ResourceFactory, Service};
use tokio::sync::{
Expand Down Expand Up @@ -84,23 +87,42 @@ pub async fn start(loader: impl Loader + Send + 'static, runner: impl Runner + S
}

// where to serve the gRPC control layer
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), args.port);
let addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), args.port);

let mut server_builder = Server::builder()
.http2_keepalive_interval(Some(Duration::from_secs(60)))
.layer(ExtractPropagationLayer);

// A channel we can use to kill the runtime if it does not become healthy in time.
let (tx, rx) = tokio::sync::oneshot::channel::<()>();

let router = {
let alpha = Alpha::new(loader, runner);
let alpha = Alpha::new(loader, runner, tx);

let svc = RuntimeServer::new(alpha);
server_builder.add_service(svc)
};

match router.serve(addr).await {
Ok(_) => {}
Err(e) => panic!("Error while serving address {addr}: {e}"),
};
tokio::select! {
res = router.serve(addr) => {
match res{
Ok(_) => {}
Err(e) => panic!("Error while serving address {addr}: {e}")
}
}
res = rx => {
match res{
Ok(_) => panic!("Received runtime kill signal"),
Err(e) => panic!("Receiver error: {e}")
}
}
}
}

pub enum State {
Unhealthy,
Loading,
Running,
}

pub struct Alpha<L, R> {
Expand All @@ -109,17 +131,23 @@ pub struct Alpha<L, R> {
kill_tx: Mutex<Option<oneshot::Sender<String>>>,
loader: Mutex<Option<L>>,
runner: Mutex<Option<R>>,
/// The current state of the runtime, which is used by the ECS task to determine if the runtime
/// is healthy.
state: Arc<Mutex<State>>,
runtime_kill_tx: Mutex<Option<tokio::sync::oneshot::Sender<()>>>,
}

impl<L, R> Alpha<L, R> {
pub fn new(loader: L, runner: R) -> Self {
pub fn new(loader: L, runner: R, runtime_kill_tx: tokio::sync::oneshot::Sender<()>) -> Self {
let (stopped_tx, _stopped_rx) = broadcast::channel(10);

Self {
stopped_tx,
kill_tx: Mutex::new(None),
loader: Mutex::new(Some(loader)),
runner: Mutex::new(Some(runner)),
state: Arc::new(Mutex::new(State::Unhealthy)),
runtime_kill_tx: Mutex::new(Some(runtime_kill_tx)),
}
}
}
Expand Down Expand Up @@ -223,6 +251,31 @@ where
}
};

println!("setting current state to healthy");
*self.state.lock().unwrap() = State::Loading;

let state = self.state.clone();
let runtime_kill_tx = self
.runtime_kill_tx
.lock()
.unwrap()
.deref_mut()
.take()
.unwrap();

// Ensure that the runtime is set to unhealthy if it doesn't reach the running state after
// it has sent a load response, so that the ECS task will fail.
tokio::spawn(async move {
// Note: The timeout is quite low as we are not actually provisioning resources after
// sending the load response.
tokio::time::sleep(Duration::from_secs(180)).await;
if !matches!(state.lock().unwrap().deref(), State::Running) {
println!("the runtime failed to enter the running state before timing out");

runtime_kill_tx.send(()).unwrap();
}
});

Ok(Response::new(LoadResponse {
success: true,
message: String::new(),
Expand Down Expand Up @@ -355,6 +408,8 @@ where
..Default::default()
};

*self.state.lock().unwrap() = State::Running;

Ok(Response::new(message))
}

Expand Down Expand Up @@ -398,4 +453,15 @@ where

Ok(Response::new(ReceiverStream::new(rx)))
}

async fn health_check(&self, _request: Request<Ping>) -> Result<Response<Pong>, Status> {
if matches!(self.state.lock().unwrap().deref(), State::Unhealthy) {
println!("runtime health check failed");
return Err(Status::unavailable(
"runtime has not reached a healthy state",
));
}

Ok(Response::new(Pong {}))
}
}
6 changes: 3 additions & 3 deletions service/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ pub async fn start(
runtime_executable: PathBuf,
project_path: &Path,
) -> anyhow::Result<(process::Child, runtime::Client)> {
let port = &port.to_string();
let args = vec!["--port", port];
let port_str = port.to_string();
let args = vec!["--port", &port_str];

info!(
args = %format!("{} {}", runtime_executable.display(), args.join(" ")),
Expand All @@ -30,7 +30,7 @@ pub async fn start(
.spawn()
.context("spawning runtime process")?;

let runtime_client = runtime::get_client(port).await?;
let runtime_client = runtime::get_client(format!("http://0.0.0.0:{port}")).await?;

Ok((runtime, runtime_client))
}

0 comments on commit 197a344

Please sign in to comment.