From 5b5eb69c7754285b348b7be4163f38bee2464547 Mon Sep 17 00:00:00 2001 From: Pieter Date: Tue, 6 Feb 2024 10:33:52 +0000 Subject: [PATCH] refactor: uniform client wrappers (#1614) * refactor: better client wrapper for provisioner * refactor: better client wrapper for logger * refactor: better client wrapper for builder * refactor: better client wrapper for runtime * refactor: trimmed unused dependencies * refactor: fix optional deps * refactor: more optional dependency fixes --- Cargo.lock | 2 - cargo-shuttle/src/lib.rs | 36 ++---- common-tests/src/builder.rs | 27 +--- common-tests/src/logger.rs | 31 +---- deployer/src/args.rs | 5 +- deployer/src/deployment/mod.rs | 52 ++------ deployer/src/deployment/queue.rs | 11 +- deployer/src/deployment/run.rs | 11 +- deployer/src/deployment/state_change_layer.rs | 12 +- deployer/src/lib.rs | 16 +-- deployer/src/main.rs | 17 +-- deployer/src/persistence/mod.rs | 43 ++----- deployer/src/runtime_manager.rs | 40 ++---- proto/Cargo.toml | 7 +- proto/src/lib.rs | 119 ++++++++++++++++++ runtime/Cargo.toml | 1 - runtime/src/alpha/mod.rs | 20 +-- runtime/src/provisioner_factory.rs | 16 +-- runtime/tests/integration/helpers.rs | 10 +- service/Cargo.toml | 3 +- service/src/runner.rs | 45 +------ 21 files changed, 197 insertions(+), 327 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fc2111a13..89e15ab89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6102,7 +6102,6 @@ dependencies = [ "tokio", "tokio-stream", "tonic 0.10.2", - "tower", "tracing-subscriber", "uuid", "wasi-common", @@ -6125,7 +6124,6 @@ dependencies = [ "thiserror", "tokio", "toml 0.8.8", - "tower", "tracing", ] diff --git a/cargo-shuttle/src/lib.rs b/cargo-shuttle/src/lib.rs index 7a9ae039f..56db9c408 100644 --- a/cargo-shuttle/src/lib.rs +++ b/cargo-shuttle/src/lib.rs @@ -19,7 +19,6 @@ use args::{ConfirmationArgs, GenerateCommand}; use clap_mangen::Man; use shuttle_common::{ - claims::{ClaimService, InjectPropagation}, constants::{ API_URL_DEFAULT, DEFAULT_IDLE_MINUTES, EXECUTABLE_DIRNAME, SHUTTLE_CLI_DOCS_URL, SHUTTLE_GH_ISSUE_URL, SHUTTLE_IDLE_DOCS_URL, SHUTTLE_INSTALL_DOCS_URL, SHUTTLE_LOGIN_URL, @@ -37,9 +36,8 @@ use shuttle_common::{ }, resource, semvers_are_compatible, ApiKey, LogItem, VersionInfo, }; -use shuttle_proto::runtime::{ - runtime_client::RuntimeClient, LoadRequest, StartRequest, StopRequest, -}; +use shuttle_proto::runtime; +use shuttle_proto::runtime::{LoadRequest, StartRequest, StopRequest}; use shuttle_service::runner; use shuttle_service::{ builder::{build_workspace, BuiltService}, @@ -67,7 +65,6 @@ use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Child; use tokio::task::JoinHandle; use tokio::time::{sleep, Duration}; -use tonic::transport::Channel; use tonic::Status; use tracing::{debug, error, trace, warn}; use uuid::Uuid; @@ -938,12 +935,7 @@ impl Shuttle { provisioner_server: &JoinHandle>, idx: u16, provisioner_port: u16, - ) -> Result< - Option<( - Child, - RuntimeClient>>, - )>, - > { + ) -> Result> { let crate_directory = service.crate_directory(); let secrets_path = if crate_directory.join("Secrets.dev.toml").exists() { crate_directory.join("Secrets.dev.toml") @@ -1160,7 +1152,7 @@ impl Shuttle { async fn stop_runtime( runtime: &mut Child, - runtime_client: &mut RuntimeClient>>, + runtime_client: &mut runtime::Client, ) -> Result<(), Status> { let stop_request = StopRequest {}; trace!(?stop_request, "stopping service"); @@ -1178,14 +1170,8 @@ impl Shuttle { } async fn add_runtime_info( - runtime: Option<( - Child, - RuntimeClient>>, - )>, - existing_runtimes: &mut Vec<( - Child, - RuntimeClient>>, - )>, + runtime: Option<(Child, runtime::Client)>, + existing_runtimes: &mut Vec<(Child, runtime::Client)>, extra_servers: &[&JoinHandle>], ) -> Result<(), Status> { match runtime { @@ -1269,10 +1255,7 @@ impl Shuttle { .expect("Can not get the SIGINT signal receptor"); // Start all the services. - let mut runtimes: Vec<( - Child, - RuntimeClient>>, - )> = Vec::new(); + let mut runtimes: Vec<(Child, runtime::Client)> = Vec::new(); Shuttle::find_available_port(&mut run_args, services.len()); @@ -1423,10 +1406,7 @@ impl Shuttle { let (provisioner_server, provisioner_port) = Shuttle::setup_local_provisioner().await?; // Start all the services. - let mut runtimes: Vec<( - Child, - RuntimeClient>>, - )> = Vec::new(); + let mut runtimes: Vec<(Child, runtime::Client)> = Vec::new(); Shuttle::find_available_port(&mut run_args, services.len()); diff --git a/common-tests/src/builder.rs b/common-tests/src/builder.rs index 963eeafed..60725faa0 100644 --- a/common-tests/src/builder.rs +++ b/common-tests/src/builder.rs @@ -4,21 +4,13 @@ use std::{ }; use portpicker::pick_unused_port; -use shuttle_common::claims::{ClaimLayer, InjectPropagationLayer}; use shuttle_proto::builder::{ - builder_client::BuilderClient, + self, builder_server::{Builder, BuilderServer}, }; -use tonic::transport::{Endpoint, Server}; -use tower::ServiceBuilder; +use tonic::transport::Server; -pub async fn get_mocked_builder_client( - builder: impl Builder, -) -> BuilderClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, -> { +pub async fn get_mocked_builder_client(builder: impl Builder) -> builder::Client { let builder_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), pick_unused_port().unwrap()); let builder_uri = format!("http://{}", builder_addr); tokio::spawn(async move { @@ -31,16 +23,5 @@ pub async fn get_mocked_builder_client( // Wait for the builder server to start before creating a client. tokio::time::sleep(Duration::from_millis(200)).await; - let channel = Endpoint::try_from(builder_uri.to_string()) - .unwrap() - .connect() - .await - .expect("failed to connect to builder"); - - let channel = ServiceBuilder::new() - .layer(ClaimLayer) - .layer(InjectPropagationLayer) - .service(channel); - - BuilderClient::new(channel) + builder::get_client(builder_uri.parse().unwrap()).await } diff --git a/common-tests/src/logger.rs b/common-tests/src/logger.rs index 5e7c7453d..e8f8a65ba 100644 --- a/common-tests/src/logger.rs +++ b/common-tests/src/logger.rs @@ -4,20 +4,14 @@ use std::{ }; use portpicker::pick_unused_port; -use shuttle_common::claims::{ClaimLayer, InjectPropagationLayer}; use shuttle_proto::logger::{ - logger_client::LoggerClient, + self, logger_server::{Logger, LoggerServer}, LogLine, LogsRequest, LogsResponse, StoreLogsRequest, StoreLogsResponse, }; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; -use tonic::{ - async_trait, - transport::{Endpoint, Server}, - Request, Response, Status, -}; -use tower::ServiceBuilder; +use tonic::{async_trait, transport::Server, Request, Response, Status}; pub struct MockedLogger; @@ -47,13 +41,7 @@ impl Logger for MockedLogger { } } -pub async fn get_mocked_logger_client( - logger: impl Logger, -) -> LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, -> { +pub async fn get_mocked_logger_client(logger: impl Logger) -> logger::Client { let logger_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), pick_unused_port().unwrap()); let logger_uri = format!("http://{}", logger_addr); tokio::spawn(async move { @@ -66,16 +54,5 @@ pub async fn get_mocked_logger_client( // Wait for the logger server to start before creating a client. tokio::time::sleep(Duration::from_millis(200)).await; - let channel = Endpoint::try_from(logger_uri.to_string()) - .unwrap() - .connect() - .await - .expect("failed to connect to logger"); - - let channel = ServiceBuilder::new() - .layer(ClaimLayer) - .layer(InjectPropagationLayer) - .service(channel); - - LoggerClient::new(channel) + logger::get_client(logger_uri.parse().unwrap()).await } diff --git a/deployer/src/args.rs b/deployer/src/args.rs index 9a345495f..6dd82aa70 100644 --- a/deployer/src/args.rs +++ b/deployer/src/args.rs @@ -4,7 +4,6 @@ use clap::Parser; use fqdn::FQDN; use hyper::Uri; use shuttle_common::models::project::ProjectName; -use tonic::transport::Endpoint; /// Program to handle the deploys for a single project /// Handling includes, building, testing, and running each service @@ -21,7 +20,7 @@ pub struct Args { /// Address to connect to the logger service #[clap(long, default_value = "http://logger:8000")] - pub logger_uri: Endpoint, + pub logger_uri: Uri, /// FQDN where the proxy can be reached at #[clap(long)] @@ -61,7 +60,7 @@ pub struct Args { /// Address to reach the builder service at #[clap(long, default_value = "http://builder:8000")] - pub builder_uri: Endpoint, + pub builder_uri: Uri, /// Uri to folder to store all artifacts #[clap(long, default_value = "/tmp")] diff --git a/deployer/src/deployment/mod.rs b/deployer/src/deployment/mod.rs index 3ad83b2f6..f8ad2b388 100644 --- a/deployer/src/deployment/mod.rs +++ b/deployer/src/deployment/mod.rs @@ -4,7 +4,7 @@ use std::{ }; use shuttle_common::log::LogRecorder; -use shuttle_proto::{builder::builder_client::BuilderClient, logger::logger_client::LoggerClient}; +use shuttle_proto::{builder, logger}; use tokio::{ sync::{mpsc, Mutex}, task::JoinSet, @@ -31,26 +31,14 @@ const RUN_BUFFER_SIZE: usize = 100; pub struct DeploymentManagerBuilder { build_log_recorder: Option, - logs_fetcher: Option< - LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, + logs_fetcher: Option, active_deployment_getter: Option, artifacts_path: Option, runtime_manager: Option>>, deployment_updater: Option, resource_manager: Option, queue_client: Option, - builder_client: Option< - BuilderClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, + builder_client: Option, } impl DeploymentManagerBuilder @@ -67,29 +55,13 @@ where self } - pub fn log_fetcher( - mut self, - logs_fetcher: LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - ) -> Self { + pub fn log_fetcher(mut self, logs_fetcher: logger::Client) -> Self { self.logs_fetcher = Some(logs_fetcher); self } - pub fn builder_client( - mut self, - builder_client: Option< - BuilderClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, - ) -> Self { + pub fn builder_client(mut self, builder_client: Option) -> Self { self.builder_client = builder_client; self @@ -195,11 +167,7 @@ pub struct DeploymentManager { queue_send: QueueSender, run_send: RunSender, runtime_manager: Arc>, - logs_fetcher: LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, + logs_fetcher: logger::Client, _join_set: Arc>>, builds_path: PathBuf, } @@ -259,13 +227,7 @@ impl DeploymentManager { self.builds_path.as_path() } - pub fn logs_fetcher( - &self, - ) -> &LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - > { + pub fn logs_fetcher(&self) -> &logger::Client { &self.logs_fetcher } } diff --git a/deployer/src/deployment/queue.rs b/deployer/src/deployment/queue.rs index 117225fec..71588fc8b 100644 --- a/deployer/src/deployment/queue.rs +++ b/deployer/src/deployment/queue.rs @@ -14,8 +14,7 @@ use shuttle_common::{ log::LogRecorder, LogItem, }; -use shuttle_proto::builder::builder_client::BuilderClient; -use shuttle_proto::builder::BuildRequest; +use shuttle_proto::builder::{self, BuildRequest}; use shuttle_service::builder::{build_workspace, BuiltService}; use tar::Archive; use tokio::{ @@ -42,13 +41,7 @@ pub async fn task( deployment_updater: impl DeploymentUpdater, log_recorder: impl LogRecorder, queue_client: impl BuildQueueClient, - builder_client: Option< - BuilderClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, + builder_client: Option, builds_path: PathBuf, ) { info!("Queue task started"); diff --git a/deployer/src/deployment/run.rs b/deployer/src/deployment/run.rs index 3b6ac2902..0c9918fbd 100644 --- a/deployer/src/deployment/run.rs +++ b/deployer/src/deployment/run.rs @@ -9,7 +9,7 @@ use async_trait::async_trait; use opentelemetry::global; use portpicker::pick_unused_port; use shuttle_common::{ - claims::{Claim, ClaimService, InjectPropagation}, + claims::Claim, constants::EXECUTABLE_DIRNAME, deployment::{ DEPLOYER_END_MSG_COMPLETED, DEPLOYER_END_MSG_CRASHED, DEPLOYER_END_MSG_STARTUP_ERR, @@ -20,15 +20,14 @@ use shuttle_common::{ use shuttle_proto::{ resource_recorder::record_request, runtime::{ - runtime_client::RuntimeClient, LoadRequest, StartRequest, StopReason, SubscribeStopRequest, - SubscribeStopResponse, + self, LoadRequest, StartRequest, StopReason, SubscribeStopRequest, SubscribeStopResponse, }, }; use tokio::{ sync::Mutex, task::{JoinHandle, JoinSet}, }; -use tonic::{transport::Channel, Code}; +use tonic::Code; use tracing::{debug, debug_span, error, info, instrument, warn, Instrument}; use tracing_opentelemetry::OpenTelemetrySpanExt; use ulid::Ulid; @@ -313,7 +312,7 @@ async fn load( service_id: Ulid, executable_path: PathBuf, mut resource_manager: impl ResourceManager, - mut runtime_client: RuntimeClient>>, + mut runtime_client: runtime::Client, claim: Claim, mut secrets: HashMap, ) -> Result<()> { @@ -414,7 +413,7 @@ async fn load( async fn run( id: Uuid, service_name: String, - mut runtime_client: RuntimeClient>>, + mut runtime_client: runtime::Client, address: SocketAddr, deployment_updater: impl DeploymentUpdater, cleanup: impl FnOnce(Option) + Send + 'static, diff --git a/deployer/src/deployment/state_change_layer.rs b/deployer/src/deployment/state_change_layer.rs index e156a35d6..697b14a4e 100644 --- a/deployer/src/deployment/state_change_layer.rs +++ b/deployer/src/deployment/state_change_layer.rs @@ -153,8 +153,8 @@ mod tests { use shuttle_proto::{ builder::{builder_server::Builder, BuildRequest, BuildResponse}, logger::{ - logger_client::LoggerClient, logger_server::Logger, Batcher, LogLine, LogsRequest, - LogsResponse, StoreLogsRequest, StoreLogsResponse, + self, logger_server::Logger, Batcher, LogLine, LogsRequest, LogsResponse, + StoreLogsRequest, StoreLogsResponse, }, provisioner::{ provisioner_server::{Provisioner, ProvisionerServer}, @@ -360,13 +360,7 @@ mod tests { } async fn get_runtime_manager( - logger_client: Batcher< - LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, + logger_client: Batcher, ) -> Arc> { let provisioner_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), pick_unused_port().unwrap()); diff --git a/deployer/src/lib.rs b/deployer/src/lib.rs index 1b0d8070c..91cf10a55 100644 --- a/deployer/src/lib.rs +++ b/deployer/src/lib.rs @@ -9,7 +9,7 @@ pub use persistence::Persistence; use proxy::AddressGetter; pub use runtime_manager::RuntimeManager; use shuttle_common::log::LogRecorder; -use shuttle_proto::{builder::builder_client::BuilderClient, logger::logger_client::LoggerClient}; +use shuttle_proto::{builder, logger}; use tokio::sync::Mutex; use tracing::{error, info}; use ulid::Ulid; @@ -33,18 +33,8 @@ pub async fn start( persistence: Persistence, runtime_manager: Arc>, log_recorder: impl LogRecorder, - log_fetcher: LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - builder_client: Option< - BuilderClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, + log_fetcher: logger::Client, + builder_client: Option, args: Args, ) { // when _set is dropped once axum exits, the deployment tasks will be aborted. diff --git a/deployer/src/main.rs b/deployer/src/main.rs index bbfdb5e53..a282609d0 100644 --- a/deployer/src/main.rs +++ b/deployer/src/main.rs @@ -3,16 +3,14 @@ use std::process::exit; use clap::Parser; use shuttle_common::{ backends::trace::setup_tracing, - claims::{ClaimLayer, InjectPropagationLayer}, log::{Backend, DeploymentLogLayer}, }; use shuttle_deployer::{start, start_proxy, Args, Persistence, RuntimeManager, StateChangeLayer}; use shuttle_proto::{ // builder::builder_client::BuilderClient, - logger::{logger_client::LoggerClient, Batcher}, + logger::{self, Batcher}, }; use tokio::select; -use tower::ServiceBuilder; use tracing::{error, trace}; use tracing_subscriber::prelude::*; use ulid::Ulid; @@ -28,22 +26,13 @@ async fn main() { let (persistence, _) = Persistence::new( &args.state, args.resource_recorder.clone(), - &args.provisioner_address, + args.provisioner_address.clone(), Ulid::from_string(args.project_id.as_str()) .expect("to get a valid ULID for project_id arg"), ) .await; - let channel = ServiceBuilder::new() - .layer(ClaimLayer) - .layer(InjectPropagationLayer) - .service( - args.logger_uri - .connect() - .await - .expect("failed to connect to logger"), - ); - let logger_client = LoggerClient::new(channel); + let logger_client = logger::get_client(args.logger_uri.clone()).await; let logger_batcher = Batcher::wrap(logger_client.clone()); let builder_client = None; diff --git a/deployer/src/persistence/mod.rs b/deployer/src/persistence/mod.rs index a46edb784..65709e897 100644 --- a/deployer/src/persistence/mod.rs +++ b/deployer/src/persistence/mod.rs @@ -5,12 +5,9 @@ use std::str::FromStr; use chrono::Utc; use error::{Error, Result}; use hyper::Uri; -use shuttle_common::{ - claims::{Claim, ClaimLayer, InjectPropagationLayer}, - resource::Type, -}; +use shuttle_common::{claims::Claim, resource::Type}; use shuttle_proto::{ - provisioner::{provisioner_client::ProvisionerClient, DatabaseRequest}, + provisioner::{self, DatabaseRequest}, resource_recorder::{ self, record_request, RecordRequest, ResourceIds, ResourceResponse, ResourcesResponse, ResultResponse, ServiceResourcesRequest, @@ -22,8 +19,7 @@ use sqlx::{ QueryBuilder, }; use tokio::task::JoinHandle; -use tonic::{transport::Endpoint, Request}; -use tower::ServiceBuilder; +use tonic::Request; use tracing::{error, info, instrument, trace}; use ulid::Ulid; use uuid::Uuid; @@ -55,13 +51,7 @@ pub struct Persistence { pool: SqlitePool, state_send: tokio::sync::mpsc::UnboundedSender, resource_recorder_client: Option, - provisioner_client: Option< - ProvisionerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, + provisioner_client: Option, project_id: Ulid, } @@ -73,7 +63,7 @@ impl Persistence { pub async fn new( path: &str, resource_recorder_uri: Uri, - provisioner_address: &Uri, + provisioner_uri: Uri, project_id: Ulid, ) -> (Self, JoinHandle<()>) { if !Path::new(path).exists() { @@ -103,13 +93,7 @@ impl Persistence { let pool = SqlitePool::connect_with(sqlite_options).await.unwrap(); - Self::configure( - pool, - resource_recorder_uri, - provisioner_address.to_string(), - project_id, - ) - .await + Self::configure(pool, resource_recorder_uri, provisioner_uri, project_id).await } #[cfg(test)] @@ -139,22 +123,11 @@ impl Persistence { async fn configure( pool: SqlitePool, resource_recorder_uri: Uri, - provisioner_address: String, + provisioner_uri: Uri, project_id: Ulid, ) -> (Self, JoinHandle<()>) { - let channel = Endpoint::from_shared(provisioner_address) - .expect("to have a valid string endpoint for the provisioner") - .connect() - .await - .expect("failed to connect to provisioner"); - - let provisioner_service = ServiceBuilder::new() - .layer(ClaimLayer) - .layer(InjectPropagationLayer) - .service(channel); - let resource_recorder_client = resource_recorder::get_client(resource_recorder_uri).await; - let provisioner_client = ProvisionerClient::new(provisioner_service); + let provisioner_client = provisioner::get_client(provisioner_uri).await; let (state_send, handle) = Self::from_pool(pool.clone()).await; diff --git a/deployer/src/runtime_manager.rs b/deployer/src/runtime_manager.rs index 8dcd5bb63..e9e6ff108 100644 --- a/deployer/src/runtime_manager.rs +++ b/deployer/src/runtime_manager.rs @@ -7,33 +7,19 @@ use std::{ use anyhow::Context; use chrono::Utc; use prost_types::Timestamp; -use shuttle_common::{ - claims::{ClaimService, InjectPropagation}, - log::Backend, -}; +use shuttle_common::log::Backend; use shuttle_proto::{ - logger::{logger_client::LoggerClient, Batcher, LogItem, LogLine}, - runtime::{runtime_client::RuntimeClient, StopRequest}, + logger::{self, Batcher, LogItem, LogLine}, + runtime::{self, StopRequest}, }; use shuttle_service::{runner, Environment}; use tokio::{io::AsyncBufReadExt, io::BufReader, process, sync::Mutex}; -use tonic::transport::Channel; use tracing::{debug, error, info, trace, warn}; use uuid::Uuid; const MANIFEST_DIR: &str = env!("CARGO_MANIFEST_DIR"); -type Runtimes = Arc< - std::sync::Mutex< - HashMap< - Uuid, - ( - process::Child, - RuntimeClient>>, - ), - >, - >, ->; +type Runtimes = Arc>>; /// Manager that can start up multiple runtimes. This is needed so that two runtimes can be up when a new deployment is made: /// One runtime for the new deployment being loaded; another for the currently active deployment @@ -41,26 +27,14 @@ type Runtimes = Arc< pub struct RuntimeManager { runtimes: Runtimes, provisioner_address: String, - logger_client: Batcher< - LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, + logger_client: Batcher, auth_uri: Option, } impl RuntimeManager { pub fn new( provisioner_address: String, - logger_client: Batcher< - LoggerClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, - >, - >, + logger_client: Batcher, auth_uri: Option, ) -> Arc> { Arc::new(Mutex::new(Self { @@ -77,7 +51,7 @@ impl RuntimeManager { project_path: &Path, service_name: String, alpha_runtime_path: Option, - ) -> anyhow::Result>>> { + ) -> anyhow::Result { trace!("making new client"); let port = portpicker::pick_unused_port().context("failed to find available port")?; diff --git a/proto/Cargo.toml b/proto/Cargo.toml index 3a12addd7..a1f1d6ded 100644 --- a/proto/Cargo.toml +++ b/proto/Cargo.toml @@ -30,14 +30,15 @@ portpicker = { workspace = true } default = [] test-utils = ["portpicker"] -builder = [] +builder = ["http", "tower"] logger = [ "shuttle-common/service", "chrono", + "http", "tracing", "tokio/macros", "tokio/time", ] -provisioner = [] +provisioner = ["http", "tower"] resource-recorder = ["anyhow", "async-trait", "http", "serde_json", "shuttle-common/backend", "tower"] -runtime = [] +runtime = ["anyhow", "tokio", "tower", "tracing"] diff --git a/proto/src/lib.rs b/proto/src/lib.rs index ad6287204..879dc6f32 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -12,13 +12,37 @@ pub mod test_utils; pub mod provisioner { use std::fmt::Display; + use http::Uri; use shuttle_common::{ database::{self, AwsRdsEngine, SharedEngine}, DatabaseInfo, }; + use self::provisioner_client::ProvisionerClient; + pub use super::generated::provisioner::*; + pub type Client = ProvisionerClient< + shuttle_common::claims::ClaimService< + shuttle_common::claims::InjectPropagation, + >, + >; + + /// Get a provisioner client that is correctly configured for all services + pub async fn get_client(provisioner_uri: Uri) -> Client { + let channel = tonic::transport::Endpoint::from(provisioner_uri) + .connect() + .await + .expect("failed to connect to provisioner"); + + let provisioner_service = tower::ServiceBuilder::new() + .layer(shuttle_common::claims::ClaimLayer) + .layer(shuttle_common::claims::InjectPropagationLayer) + .service(channel); + + ProvisionerClient::new(provisioner_service) + } + impl From for DatabaseInfo { fn from(response: DatabaseResponse) -> Self { DatabaseInfo::new( @@ -103,7 +127,53 @@ pub mod provisioner { #[cfg(feature = "runtime")] pub mod runtime { + use std::time::Duration; + + use anyhow::Context; + use tonic::transport::Endpoint; + use tracing::{info, trace}; + + use self::runtime_client::RuntimeClient; + pub use super::generated::runtime::*; + + pub type Client = RuntimeClient< + shuttle_common::claims::ClaimService< + shuttle_common::claims::InjectPropagation, + >, + >; + + /// Get a runtime client that is correctly configured + pub async fn get_client(port: &str) -> anyhow::Result { + info!("connecting runtime client"); + let conn = Endpoint::new(format!("http://127.0.0.1:{port}")) + .context("creating runtime client endpoint")? + .connect_timeout(Duration::from_secs(5)); + + // Wait for the spawned process to open the control port. + // Connecting instantly does not give it enough time. + let channel = tokio::time::timeout(Duration::from_millis(7000), async move { + let mut ms = 5; + loop { + if let Ok(channel) = conn.connect().await { + break channel; + } + trace!("waiting for runtime control port to open"); + // exponential backoff + tokio::time::sleep(Duration::from_millis(ms)).await; + ms *= 2; + } + }) + .await + .context("runtime control port did not open in time")?; + + let runtime_service = tower::ServiceBuilder::new() + .layer(shuttle_common::claims::ClaimLayer) + .layer(shuttle_common::claims::InjectPropagationLayer) + .service(channel); + + Ok(RuntimeClient::new(runtime_service)) + } } #[cfg(feature = "resource-recorder")] @@ -308,7 +378,32 @@ pub mod resource_recorder { #[cfg(feature = "builder")] pub mod builder { + use http::Uri; + + use self::builder_client::BuilderClient; + pub use super::generated::builder::*; + + pub type Client = BuilderClient< + shuttle_common::claims::ClaimService< + shuttle_common::claims::InjectPropagation, + >, + >; + + /// Get a builder client that is correctly configured for all services + pub async fn get_client(builder_uri: Uri) -> Client { + let channel = tonic::transport::Endpoint::from(builder_uri) + .connect() + .await + .expect("failed to connect to builder"); + + let builder_service = tower::ServiceBuilder::new() + .layer(shuttle_common::claims::ClaimLayer) + .layer(shuttle_common::claims::InjectPropagationLayer) + .service(channel); + + BuilderClient::new(builder_service) + } } #[cfg(feature = "logger")] @@ -317,6 +412,7 @@ pub mod logger { use std::time::Duration; use chrono::{NaiveDateTime, TimeZone, Utc}; + use http::Uri; use prost::bytes::Bytes; use tokio::{select, sync::mpsc, time::interval}; use tonic::{ @@ -331,8 +427,31 @@ pub mod logger { DeploymentId, }; + use self::logger_client::LoggerClient; + pub use super::generated::logger::*; + pub type Client = LoggerClient< + shuttle_common::claims::ClaimService< + shuttle_common::claims::InjectPropagation, + >, + >; + + /// Get a logger client that is correctly configured for all services + pub async fn get_client(logger_uri: Uri) -> Client { + let channel = tonic::transport::Endpoint::from(logger_uri) + .connect() + .await + .expect("failed to connect to logger"); + + let logger_service = tower::ServiceBuilder::new() + .layer(shuttle_common::claims::ClaimLayer) + .layer(shuttle_common::claims::InjectPropagationLayer) + .service(channel); + + LoggerClient::new(logger_service) + } + impl From for LogItem { fn from(value: LogItemCommon) -> Self { Self { diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index a5a4bc58a..11d86bc50 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -34,7 +34,6 @@ thiserror = { workspace = true } tokio = { workspace = true, features = ["full"] } tokio-stream = { workspace = true } tonic = { workspace = true } -tower = { workspace = true } tracing-subscriber = { workspace = true, optional = true } wasi-common = { version = "13.0.0", optional = true } wasmtime = { version = "13.0.0", optional = true } diff --git a/runtime/src/alpha/mod.rs b/runtime/src/alpha/mod.rs index 70df821bd..616b0ea30 100644 --- a/runtime/src/alpha/mod.rs +++ b/runtime/src/alpha/mod.rs @@ -16,12 +16,12 @@ use shuttle_common::{ auth::{AuthPublicKey, JwtAuthenticationLayer}, trace::ExtractPropagationLayer, }, - claims::{Claim, ClaimLayer, InjectPropagationLayer}, + claims::Claim, resource, secrets::Secret, }; use shuttle_proto::{ - provisioner::provisioner_client::ProvisionerClient, + provisioner, runtime::{ runtime_server::{Runtime, RuntimeServer}, LoadRequest, LoadResponse, StartRequest, StartResponse, StopReason, StopRequest, @@ -38,7 +38,6 @@ use tonic::{ transport::{Endpoint, Server}, Request, Response, Status, }; -use tower::ServiceBuilder; use crate::__internals::{print_version, ProvisionerFactory, ResourceTracker}; @@ -204,19 +203,8 @@ where } = request.into_inner(); println!("loading alpha service at {path}"); - let channel = self - .provisioner_address - .clone() - .connect() - .await - .context("failed to connect to provisioner") - .map_err(|err| Status::internal(err.to_string()))?; - let channel = ServiceBuilder::new() - .layer(ClaimLayer) - .layer(InjectPropagationLayer) - .service(channel); - - let provisioner_client = ProvisionerClient::new(channel); + let provisioner_client = + provisioner::get_client(self.provisioner_address.uri().clone()).await; // TODO: merge new & old secrets diff --git a/runtime/src/provisioner_factory.rs b/runtime/src/provisioner_factory.rs index 89ec96779..460295c75 100644 --- a/runtime/src/provisioner_factory.rs +++ b/runtime/src/provisioner_factory.rs @@ -2,22 +2,16 @@ use std::{collections::BTreeMap, path::PathBuf}; use async_trait::async_trait; use shuttle_common::{ - claims::{Claim, ClaimService, InjectPropagation}, - constants::STORAGE_DIRNAME, - database, - secrets::Secret, - DatabaseInfo, -}; -use shuttle_proto::provisioner::{ - provisioner_client::ProvisionerClient, ContainerRequest, ContainerResponse, DatabaseRequest, + claims::Claim, constants::STORAGE_DIRNAME, database, secrets::Secret, DatabaseInfo, }; +use shuttle_proto::provisioner::{self, ContainerRequest, ContainerResponse, DatabaseRequest}; use shuttle_service::{DeploymentMetadata, Environment, Factory}; -use tonic::{transport::Channel, Request}; +use tonic::Request; /// A factory (service locator) which goes through the provisioner crate pub struct ProvisionerFactory { service_name: String, - provisioner_client: ProvisionerClient>>, + provisioner_client: provisioner::Client, secrets: BTreeMap>, env: Environment, claim: Option, @@ -25,7 +19,7 @@ pub struct ProvisionerFactory { impl ProvisionerFactory { pub(crate) fn new( - provisioner_client: ProvisionerClient>>, + provisioner_client: provisioner::Client, service_name: String, secrets: BTreeMap>, env: Environment, diff --git a/runtime/tests/integration/helpers.rs b/runtime/tests/integration/helpers.rs index e0d741910..a03191719 100644 --- a/runtime/tests/integration/helpers.rs +++ b/runtime/tests/integration/helpers.rs @@ -6,24 +6,20 @@ use std::{ use anyhow::Result; use async_trait::async_trait; -use shuttle_common::claims::{ClaimService, InjectPropagation}; use shuttle_proto::{ provisioner::{ provisioner_server::{Provisioner, ProvisionerServer}, ContainerRequest, ContainerResponse, DatabaseDeletionResponse, DatabaseRequest, DatabaseResponse, Ping, Pong, }, - runtime::runtime_client::RuntimeClient, + runtime, }; use shuttle_service::{builder::build_workspace, runner, Environment}; use tokio::process::Child; -use tonic::{ - transport::{Channel, Server}, - Request, Response, Status, -}; +use tonic::{transport::Server, Request, Response, Status}; pub struct TestRuntime { - pub runtime_client: RuntimeClient>>, + pub runtime_client: runtime::Client, pub bin_path: String, pub service_name: String, pub runtime_address: SocketAddr, diff --git a/service/Cargo.toml b/service/Cargo.toml index 387224141..9c415a6a0 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -29,7 +29,6 @@ serde = { workspace = true, features = ["derive"] } strfmt = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, optional = true } -tower = { workspace = true, optional = true } toml = { workspace = true, optional = true } tracing = { workspace = true, optional = true } @@ -48,4 +47,4 @@ builder = [ "toml", "tracing", ] -runner = ["shuttle-proto/runtime", "tokio/process", "dunce", "tower"] +runner = ["shuttle-proto/runtime", "tokio/process", "dunce"] diff --git a/service/src/runner.rs b/service/src/runner.rs index 431bdfdf9..e357c0730 100644 --- a/service/src/runner.rs +++ b/service/src/runner.rs @@ -1,19 +1,13 @@ use std::{ path::{Path, PathBuf}, process::Stdio, - time::Duration, }; use anyhow::Context; -use shuttle_common::{ - claims::{ClaimLayer, ClaimService, InjectPropagation, InjectPropagationLayer}, - deployment::Environment, -}; -use shuttle_proto::runtime::runtime_client; -use shuttle_proto::tonic::transport::{Channel, Endpoint}; +use shuttle_common::deployment::Environment; +use shuttle_proto::runtime; use tokio::process; -use tower::ServiceBuilder; -use tracing::{info, trace}; +use tracing::info; pub async fn start( wasm: bool, @@ -23,10 +17,7 @@ pub async fn start( port: u16, runtime_executable: PathBuf, project_path: &Path, -) -> anyhow::Result<( - process::Child, - runtime_client::RuntimeClient>>, -)> { +) -> anyhow::Result<(process::Child, runtime::Client)> { let port = &port.to_string(); let environment = &environment.to_string(); @@ -63,33 +54,7 @@ pub async fn start( .spawn() .context("spawning runtime process")?; - info!("connecting runtime client"); - let conn = Endpoint::new(format!("http://127.0.0.1:{port}")) - .context("creating runtime client endpoint")? - .connect_timeout(Duration::from_secs(5)); - - // Wait for the spawned process to open the control port. - // Connecting instantly does not give it enough time. - let channel = tokio::time::timeout(Duration::from_millis(7000), async move { - let mut ms = 5; - loop { - if let Ok(channel) = conn.connect().await { - break channel; - } - trace!("waiting for runtime control port to open"); - // exponential backoff - tokio::time::sleep(Duration::from_millis(ms)).await; - ms *= 2; - } - }) - .await - .context("runtime control port did not open in time")?; - - let channel = ServiceBuilder::new() - .layer(ClaimLayer) - .layer(InjectPropagationLayer) - .service(channel); - let runtime_client = runtime_client::RuntimeClient::new(channel); + let runtime_client = runtime::get_client(port).await?; Ok((runtime, runtime_client)) }