diff --git a/common/Cargo.toml b/common/Cargo.toml index f57c7ecab..4ef89bcc2 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -57,11 +57,12 @@ backend = [ "axum/matched-path", "axum/json", "claims", - "hyper/client", + "hyper", "opentelemetry_sdk", "opentelemetry-appender-tracing", "opentelemetry-otlp", "models", + "reqwest/json", "rustrict", # only ProjectName model uses it "thiserror", "tokio", diff --git a/common/src/backends/client/gateway.rs b/common/src/backends/client/gateway.rs index 0c76e3245..ec1e25bfa 100644 --- a/common/src/backends/client/gateway.rs +++ b/common/src/backends/client/gateway.rs @@ -1,37 +1,9 @@ -use headers::Authorization; -use http::{Method, Uri}; +use http::Method; use tracing::instrument; use crate::models; -use super::{Error, ServicesApiClient}; - -/// Wrapper struct to make API calls to gateway easier -#[derive(Clone)] -pub struct Client { - public_client: ServicesApiClient, - private_client: ServicesApiClient, -} - -impl Client { - /// Make a gateway client that is able to call the public and private APIs on gateway - pub fn new(public_uri: Uri, private_uri: Uri) -> Self { - Self { - public_client: ServicesApiClient::new(public_uri), - private_client: ServicesApiClient::new(private_uri), - } - } - - /// Get the client of public API calls - pub fn public_client(&self) -> &ServicesApiClient { - &self.public_client - } - - /// Get the client of private API calls - pub fn private_client(&self) -> &ServicesApiClient { - &self.private_client - } -} +use super::{header_map_with_bearer, Error, ServicesApiClient}; /// Interact with all the data relating to projects #[allow(async_fn_in_trait)] @@ -65,33 +37,29 @@ pub trait ProjectsDal { } } -impl ProjectsDal for Client { +impl ProjectsDal for ServicesApiClient { #[instrument(skip_all)] async fn get_user_project( &self, user_token: &str, project_name: &str, ) -> Result { - self.public_client - .request( - Method::GET, - format!("projects/{}", project_name).as_str(), - None::<()>, - Some(Authorization::bearer(user_token).expect("to build an authorization bearer")), - ) - .await + self.get( + format!("projects/{}", project_name).as_str(), + Some(header_map_with_bearer(user_token)), + ) + .await } #[instrument(skip_all)] async fn head_user_project(&self, user_token: &str, project_name: &str) -> Result { - self.public_client - .request_raw( - Method::HEAD, - format!("projects/{}", project_name).as_str(), - None::<()>, - Some(Authorization::bearer(user_token).expect("to build an authorization bearer")), - ) - .await?; + self.request_raw( + Method::HEAD, + format!("projects/{}", project_name).as_str(), + None::<()>, + Some(header_map_with_bearer(user_token)), + ) + .await?; Ok(true) } @@ -101,13 +69,7 @@ impl ProjectsDal for Client { &self, user_token: &str, ) -> Result, Error> { - self.public_client - .request( - Method::GET, - "projects", - None::<()>, - Some(Authorization::bearer(user_token).expect("to build an authorization bearer")), - ) + self.get("projects", Some(header_map_with_bearer(user_token))) .await } } @@ -116,24 +78,25 @@ impl ProjectsDal for Client { mod tests { use test_context::{test_context, AsyncTestContext}; + use crate::backends::client::ServicesApiClient; use crate::models::project::{Response, State}; use crate::test_utils::get_mocked_gateway_server; - use super::{Client, ProjectsDal}; + use super::ProjectsDal; - impl AsyncTestContext for Client { + impl AsyncTestContext for ServicesApiClient { async fn setup() -> Self { let server = get_mocked_gateway_server().await; - Client::new(server.uri().parse().unwrap(), server.uri().parse().unwrap()) + ServicesApiClient::new(server.uri().parse().unwrap()) } async fn teardown(self) {} } - #[test_context(Client)] + #[test_context(ServicesApiClient)] #[tokio::test] - async fn get_user_projects(client: &mut Client) { + async fn get_user_projects(client: &mut ServicesApiClient) { let res = client.get_user_projects("user-1").await.unwrap(); assert_eq!( @@ -155,9 +118,9 @@ mod tests { ) } - #[test_context(Client)] + #[test_context(ServicesApiClient)] #[tokio::test] - async fn get_user_project_ids(client: &mut Client) { + async fn get_user_project_ids(client: &mut ServicesApiClient) { let res = client.get_user_project_ids("user-2").await.unwrap(); assert_eq!(res, vec!["00000000000000000000000003"]) diff --git a/common/src/backends/client/mod.rs b/common/src/backends/client/mod.rs index 5b7ca299a..f65d08d1c 100644 --- a/common/src/backends/client/mod.rs +++ b/common/src/backends/client/mod.rs @@ -1,9 +1,11 @@ +use std::time::Duration; + use bytes::Bytes; -use headers::{ContentType, Header, HeaderMapExt}; -use http::{Method, Request, StatusCode, Uri}; -use hyper::{body, client::HttpConnector, Body, Client}; +use headers::{Authorization, HeaderMapExt}; +use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri}; use opentelemetry::global; use opentelemetry_http::HeaderInjector; +use reqwest::{Client, ClientBuilder, Response}; use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; use tracing::{trace, Span}; @@ -17,95 +19,143 @@ pub use resource_recorder::ResourceDal; #[derive(Error, Debug)] pub enum Error { - #[error("Hyper error: {0}")] - Hyper(#[from] hyper::Error), + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), #[error("Serde JSON error: {0}")] SerdeJson(#[from] serde_json::Error), - #[error("Hyper error: {0}")] - Http(#[from] hyper::http::Error), #[error("Request did not return correctly. Got status code: {0}")] RequestError(StatusCode), #[error("GRpc request did not return correctly. Got status code: {0}")] GrpcError(#[from] tonic::Status), } -/// `Hyper` wrapper to make request to RESTful Shuttle services easy +/// `reqwest` wrapper to make requests to other services easy #[derive(Clone)] pub struct ServicesApiClient { - client: Client, + client: Client, base: Uri, } impl ServicesApiClient { - fn new(base: Uri) -> Self { + pub fn builder() -> ClientBuilder { + Client::builder().timeout(Duration::from_secs(60)) + } + + pub fn new(base: Uri) -> Self { + Self { + client: Self::builder().build().unwrap(), + base, + } + } + + pub fn new_with_bearer(base: Uri, token: &str) -> Self { Self { - client: Client::new(), + client: Self::builder() + .default_headers(header_map_with_bearer(token)) + .build() + .unwrap(), base, } } - pub async fn request( + pub async fn get( + &self, + path: &str, + headers: Option>, + ) -> Result { + self.request(Method::GET, path, None::<()>, headers).await + } + + pub async fn post( + &self, + path: &str, + body: B, + headers: Option>, + ) -> Result { + self.request(Method::POST, path, Some(body), headers).await + } + + pub async fn delete( + &self, + path: &str, + body: B, + headers: Option>, + ) -> Result { + self.request(Method::DELETE, path, Some(body), headers) + .await + } + + pub async fn request( &self, method: Method, path: &str, body: Option, - extra_header: Option, + headers: Option>, ) -> Result { - let bytes = self.request_raw(method, path, body, extra_header).await?; - let json = serde_json::from_slice(&bytes)?; - - Ok(json) + Ok(self + .request_raw(method, path, body, headers) + .await? + .json() + .await?) } - pub async fn request_raw( + pub async fn request_bytes( &self, method: Method, path: &str, body: Option, - extra_header: Option, + headers: Option>, ) -> Result { + Ok(self + .request_raw(method, path, body, headers) + .await? + .bytes() + .await?) + } + + // can be used for explicit HEAD requests (ignores body) + pub async fn request_raw( + &self, + method: Method, + path: &str, + body: Option, + headers: Option>, + ) -> Result { let uri = format!("{}{path}", self.base); trace!(uri, "calling inner service"); - let mut req = Request::builder().method(method).uri(uri); - let headers = req - .headers_mut() - .expect("new request to have mutable headers"); - if let Some(extra_header) = extra_header { - headers.typed_insert(extra_header); - } - if body.is_some() { - headers.typed_insert(ContentType::json()); - } - + let mut h = headers.unwrap_or_default(); let cx = Span::current().context(); global::get_text_map_propagator(|propagator| { - propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut().unwrap())) + propagator.inject_context(&cx, &mut HeaderInjector(&mut h)) }); - + let req = self.client.request(method, uri).headers(h); let req = if let Some(body) = body { - req.body(Body::from(serde_json::to_vec(&body)?)) + req.json(&body) } else { - req.body(Body::empty()) + req }; - let resp = self.client.request(req?).await?; - trace!(response = ?resp, "Load response"); + let resp = req.send().await?; + trace!(response = ?resp, "service response"); - if resp.status() != StatusCode::OK { + if !resp.status().is_success() { return Err(Error::RequestError(resp.status())); } - let bytes = body::to_bytes(resp.into_body()).await?; - - Ok(bytes) + Ok(resp) } } +pub fn header_map_with_bearer(token: &str) -> HeaderMap { + let mut h = HeaderMap::new(); + h.typed_insert(Authorization::bearer(token).expect("valid token")); + h +} + #[cfg(test)] mod tests { - use headers::{authorization::Bearer, Authorization}; - use http::{Method, StatusCode}; + use http::StatusCode; use crate::models; use crate::test_utils::get_mocked_gateway_server; @@ -120,12 +170,7 @@ mod tests { let client = ServicesApiClient::new(server.uri().parse().unwrap()); let err = client - .request::<_, Vec, _>( - Method::GET, - "projects", - None::<()>, - None::>, - ) + .get::>("projects", None) .await .unwrap_err(); diff --git a/deployer/src/deployment/gateway_client.rs b/deployer/src/deployment/gateway_client.rs index fbc1258df..f988714ff 100644 --- a/deployer/src/deployment/gateway_client.rs +++ b/deployer/src/deployment/gateway_client.rs @@ -1,8 +1,6 @@ -use axum::headers::{authorization::Bearer, Authorization}; -use hyper::Method; use shuttle_common::{ - backends::client::{gateway, Error}, - models::{self}, + backends::client::{Error, ServicesApiClient}, + models, }; use uuid::Uuid; @@ -17,16 +15,13 @@ pub trait BuildQueueClient: Clone + Send + Sync + 'static { } #[async_trait::async_trait] -impl BuildQueueClient for gateway::Client { +impl BuildQueueClient for ServicesApiClient { async fn get_slot(&self, deployment_id: Uuid) -> Result { - let body = models::stats::LoadRequest { id: deployment_id }; let load: models::stats::LoadResponse = self - .public_client() - .request( - Method::POST, + .post( "stats/load", - Some(body), - None::>, + models::stats::LoadRequest { id: deployment_id }, + None, ) .await?; @@ -34,14 +29,11 @@ impl BuildQueueClient for gateway::Client { } async fn release_slot(&self, deployment_id: Uuid) -> Result<(), Error> { - let body = models::stats::LoadRequest { id: deployment_id }; let _load: models::stats::LoadResponse = self - .public_client() - .request( - Method::DELETE, + .delete( "stats/load", - Some(body), - None::>, + models::stats::LoadRequest { id: deployment_id }, + None, ) .await?; diff --git a/deployer/src/lib.rs b/deployer/src/lib.rs index 087ae1d10..da61c8829 100644 --- a/deployer/src/lib.rs +++ b/deployer/src/lib.rs @@ -2,7 +2,7 @@ use std::sync::Arc; pub use persistence::Persistence; pub use runtime_manager::RuntimeManager; -use shuttle_common::log::LogRecorder; +use shuttle_common::{backends::client::ServicesApiClient, log::LogRecorder}; use shuttle_proto::{logger, provisioner}; use tokio::sync::Mutex; use tracing::info; @@ -18,7 +18,6 @@ mod runtime_manager; pub use crate::args::Args; pub use crate::deployment::state_change_layer::StateChangeLayer; use crate::deployment::{Built, DeploymentManager}; -use shuttle_common::backends::client::gateway; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -40,10 +39,7 @@ pub async fn start( .runtime(runtime_manager) .resource_manager(persistence.clone()) .provisioner_client(provisioner::get_client(args.provisioner_address).await) - .queue_client(gateway::Client::new( - args.gateway_uri.clone(), - args.gateway_uri, - )) + .queue_client(ServicesApiClient::new(args.gateway_uri)) .log_fetcher(log_fetcher) .build(); diff --git a/examples b/examples index d0872d676..b1ae18580 160000 --- a/examples +++ b/examples @@ -1 +1 @@ -Subproject commit d0872d6761b0d50cfdbb6c3a5bad9ffb65a8699a +Subproject commit b1ae18580f6ba12af2e5d88d38d6bc74729f0e05 diff --git a/provisioner/src/lib.rs b/provisioner/src/lib.rs index 70f0483d9..6a6253f92 100644 --- a/provisioner/src/lib.rs +++ b/provisioner/src/lib.rs @@ -12,7 +12,7 @@ pub use error::Error; use mongodb::{bson::doc, options::ClientOptions}; use rand::Rng; use shuttle_common::backends::auth::VerifyClaim; -use shuttle_common::backends::client::gateway; +use shuttle_common::backends::client::ServicesApiClient; use shuttle_common::backends::ClaimExt; use shuttle_common::claims::{Claim, Scope}; use shuttle_common::models::project::ProjectName; @@ -44,7 +44,7 @@ pub struct ShuttleProvisioner { internal_pg_address: String, internal_mongodb_address: String, rr_client: Arc>, - gateway_client: gateway::Client, + gateway_client: ServicesApiClient, } impl ShuttleProvisioner { @@ -81,7 +81,7 @@ impl ShuttleProvisioner { let rr_client = resource_recorder::get_client(resource_recorder_uri).await; - let gateway_client = gateway::Client::new(gateway_uri.clone(), gateway_uri); + let gateway_client = ServicesApiClient::new(gateway_uri); Ok(Self { pool, diff --git a/resource-recorder/src/lib.rs b/resource-recorder/src/lib.rs index 072f897f8..35b91dba7 100644 --- a/resource-recorder/src/lib.rs +++ b/resource-recorder/src/lib.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use dal::{Dal, DalError, Resource}; use prost_types::TimestampError; use shuttle_common::{ - backends::{auth::VerifyClaim, client::gateway, ClaimExt}, + backends::{auth::VerifyClaim, client::ServicesApiClient, ClaimExt}, claims::{Claim, Scope}, }; use shuttle_proto::resource_recorder::{ @@ -45,14 +45,14 @@ impl From for Error { pub struct Service { dal: D, - gateway_client: gateway::Client, + gateway_client: ServicesApiClient, } impl Service where D: Dal + Send + Sync + 'static, { - pub fn new(dal: D, gateway_client: gateway::Client) -> Self { + pub fn new(dal: D, gateway_client: ServicesApiClient) -> Self { Self { dal, gateway_client, diff --git a/resource-recorder/src/main.rs b/resource-recorder/src/main.rs index 066c66060..f5a7e1e98 100644 --- a/resource-recorder/src/main.rs +++ b/resource-recorder/src/main.rs @@ -4,7 +4,7 @@ use clap::Parser; use shuttle_common::{ backends::{ auth::{AuthPublicKey, JwtAuthenticationLayer}, - client::gateway, + client::ServicesApiClient, trace::setup_tracing, }, extract_propagation::ExtractPropagationLayer, @@ -30,7 +30,7 @@ async fn main() { .layer(JwtAuthenticationLayer::new(AuthPublicKey::new(auth_uri))) .layer(ExtractPropagationLayer); - let gateway_client = gateway::Client::new(gateway_uri.clone(), gateway_uri); + let gateway_client = ServicesApiClient::new(gateway_uri); let db_path = state.join("resource-recorder.sqlite"); let svc = Service::new( diff --git a/resource-recorder/tests/integration.rs b/resource-recorder/tests/integration.rs index ad3700b38..7c57a5d5a 100644 --- a/resource-recorder/tests/integration.rs +++ b/resource-recorder/tests/integration.rs @@ -4,7 +4,7 @@ use portpicker::pick_unused_port; use pretty_assertions::{assert_eq, assert_ne}; use serde_json::json; use shuttle_common::{ - backends::client::gateway::Client, claims::Scope, test_utils::get_mocked_gateway_server, + backends::client::ServicesApiClient, claims::Scope, test_utils::get_mocked_gateway_server, }; use shuttle_common_tests::JwtScopesLayer; use shuttle_proto::resource_recorder::{ @@ -22,7 +22,7 @@ async fn manage_resources() { let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port); let server = get_mocked_gateway_server().await; - let client = Client::new(server.uri().parse().unwrap(), server.uri().parse().unwrap()); + let client = ServicesApiClient::new(server.uri().parse().unwrap()); let server_future = async { Server::builder()