Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update services api client #1695

Merged
merged 6 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ backend = [
"axum/matched-path",
"axum/json",
"claims",
"hyper/client",
"hyper",
jonaro00 marked this conversation as resolved.
Show resolved Hide resolved
"opentelemetry_sdk",
"opentelemetry-appender-tracing",
"opentelemetry-otlp",
"models",
"reqwest",
"rustrict", # only ProjectName model uses it
"thiserror",
"tokio",
Expand Down
94 changes: 32 additions & 62 deletions common/src/backends/client/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,9 @@
use headers::Authorization;
use http::{Method, Uri};
use http::Method;
use tracing::instrument;

use crate::models;

use super::{Error, ServicesApiClient};

/// Wrapper struct to make API calls to gateway easier
#[derive(Clone)]
pub struct Client {
public_client: ServicesApiClient,
private_client: ServicesApiClient,
}

impl Client {
/// Make a gateway client that is able to call the public and private APIs on gateway
pub fn new(public_uri: Uri, private_uri: Uri) -> Self {
Self {
public_client: ServicesApiClient::new(public_uri),
private_client: ServicesApiClient::new(private_uri),
}
}

/// Get the client of public API calls
pub fn public_client(&self) -> &ServicesApiClient {
&self.public_client
}

/// Get the client of private API calls
pub fn private_client(&self) -> &ServicesApiClient {
&self.private_client
}
}
use super::{header_map_with_bearer, Error, ServicesApiClient};

/// Interact with all the data relating to projects
#[allow(async_fn_in_trait)]
Expand Down Expand Up @@ -65,33 +37,31 @@ pub trait ProjectsDal {
}
}

impl ProjectsDal for Client {
impl ProjectsDal for ServicesApiClient {
#[instrument(skip_all)]
async fn get_user_project(
&self,
user_token: &str,
project_name: &str,
) -> Result<models::project::Response, Error> {
self.public_client
.request(
Method::GET,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
.await
self.request(
Method::GET,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(header_map_with_bearer(user_token)),
)
.await
}

#[instrument(skip_all)]
async fn head_user_project(&self, user_token: &str, project_name: &str) -> Result<bool, Error> {
self.public_client
.request_raw(
Method::HEAD,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
.await?;
self.request_raw(
Method::HEAD,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(header_map_with_bearer(user_token)),
)
.await?;

Ok(true)
}
Expand All @@ -101,39 +71,39 @@ impl ProjectsDal for Client {
&self,
user_token: &str,
) -> Result<Vec<models::project::Response>, Error> {
self.public_client
.request(
Method::GET,
"projects",
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
.await
self.request(
Method::GET,
"projects",
None::<()>,
Some(header_map_with_bearer(user_token)),
)
.await
}
}

#[cfg(test)]
mod tests {
use test_context::{test_context, AsyncTestContext};

use crate::backends::client::ServicesApiClient;
use crate::models::project::{Response, State};
use crate::test_utils::get_mocked_gateway_server;

use super::{Client, ProjectsDal};
use super::ProjectsDal;

impl AsyncTestContext for Client {
impl AsyncTestContext for ServicesApiClient {
async fn setup() -> Self {
let server = get_mocked_gateway_server().await;

Client::new(server.uri().parse().unwrap(), server.uri().parse().unwrap())
ServicesApiClient::new(server.uri().parse().unwrap())
}

async fn teardown(self) {}
}

#[test_context(Client)]
#[test_context(ServicesApiClient)]
#[tokio::test]
async fn get_user_projects(client: &mut Client) {
async fn get_user_projects(client: &mut ServicesApiClient) {
let res = client.get_user_projects("user-1").await.unwrap();

assert_eq!(
Expand All @@ -155,9 +125,9 @@ mod tests {
)
}

#[test_context(Client)]
#[test_context(ServicesApiClient)]
#[tokio::test]
async fn get_user_project_ids(client: &mut Client) {
async fn get_user_project_ids(client: &mut ServicesApiClient) {
let res = client.get_user_project_ids("user-2").await.unwrap();

assert_eq!(res, vec!["00000000000000000000000003"])
Expand Down
141 changes: 94 additions & 47 deletions common/src/backends/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::time::Duration;

use bytes::Bytes;
use headers::{ContentType, Header, HeaderMapExt};
use http::{Method, Request, StatusCode, Uri};
use hyper::{body, client::HttpConnector, Body, Client};
use http::{header::AUTHORIZATION, HeaderMap, HeaderValue, Method, StatusCode, Uri};
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use reqwest::{Client, ClientBuilder, Response};
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use tracing::{trace, Span};
Expand All @@ -17,94 +18,145 @@ pub use resource_recorder::ResourceDal;

#[derive(Error, Debug)]
pub enum Error {
#[error("Hyper error: {0}")]
Hyper(#[from] hyper::Error),
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Serde JSON error: {0}")]
SerdeJson(#[from] serde_json::Error),
#[error("Hyper error: {0}")]
Http(#[from] hyper::http::Error),
#[error("Request did not return correctly. Got status code: {0}")]
RequestError(StatusCode),
#[error("GRpc request did not return correctly. Got status code: {0}")]
GrpcError(#[from] tonic::Status),
}

/// `Hyper` wrapper to make request to RESTful Shuttle services easy
/// `reqwest` wrapper to make requests to other services easy
#[derive(Clone)]
pub struct ServicesApiClient {
client: Client<HttpConnector>,
client: Client,
base: Uri,
}

impl ServicesApiClient {
fn new(base: Uri) -> Self {
fn _builder() -> ClientBuilder {
jonaro00 marked this conversation as resolved.
Show resolved Hide resolved
Client::builder().timeout(Duration::from_secs(60))
}

pub fn new(base: Uri) -> Self {
Self {
client: Self::_builder().build().unwrap(),
base,
}
}

pub fn new_with_bearer(base: Uri, token: &str) -> Self {
Self {
client: Client::new(),
client: Self::_builder()
.default_headers(header_map_with_bearer(token))
.build()
.unwrap(),
base,
}
}

pub async fn request<B: Serialize, T: DeserializeOwned, H: Header>(
pub async fn get<T: DeserializeOwned>(
&self,
path: &str,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<T, Error> {
self.request(Method::GET, path, None::<()>, headers).await
}

pub async fn post<B: Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: B,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<T, Error> {
self.request(Method::POST, path, Some(body), headers).await
}

pub async fn delete<B: Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: B,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<T, Error> {
self.request(Method::DELETE, path, Some(body), headers)
.await
}

pub async fn request<B: Serialize, T: DeserializeOwned>(
&self,
method: Method,
path: &str,
body: Option<B>,
extra_header: Option<H>,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<T, Error> {
let bytes = self.request_raw(method, path, body, extra_header).await?;
let json = serde_json::from_slice(&bytes)?;

Ok(json)
Ok(self
.request_raw(method, path, body, headers)
.await?
.json()
.await?)
}

pub async fn request_raw<B: Serialize, H: Header>(
pub async fn request_bytes<B: Serialize>(
&self,
method: Method,
path: &str,
body: Option<B>,
extra_header: Option<H>,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<Bytes, Error> {
Ok(self
.request_raw(method, path, body, headers)
.await?
.bytes()
.await?)
}

// can be used for explicit HEAD requests (ignores body)
pub async fn request_raw<B: Serialize>(
&self,
method: Method,
path: &str,
body: Option<B>,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<Response, Error> {
let uri = format!("{}{path}", self.base);
trace!(uri, "calling inner service");

let mut req = Request::builder().method(method).uri(uri);
let headers = req
.headers_mut()
.expect("new request to have mutable headers");
if let Some(extra_header) = extra_header {
headers.typed_insert(extra_header);
}
if body.is_some() {
headers.typed_insert(ContentType::json());
}

let mut h = headers.unwrap_or_default();
let cx = Span::current().context();
global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut().unwrap()))
propagator.inject_context(&cx, &mut HeaderInjector(&mut h))
});

let req = self.client.request(method, uri).headers(h);
let req = if let Some(body) = body {
req.body(Body::from(serde_json::to_vec(&body)?))
req.json(&body)
} else {
req.body(Body::empty())
req
};

let resp = self.client.request(req?).await?;
trace!(response = ?resp, "Load response");
let resp = req.send().await?;
trace!(response = ?resp, "service response");

if resp.status() != StatusCode::OK {
if !resp.status().is_success() {
return Err(Error::RequestError(resp.status()));
}

let bytes = body::to_bytes(resp.into_body()).await?;

Ok(bytes)
Ok(resp)
}
}

pub fn header_map_with_bearer(token: &str) -> HeaderMap {
let mut h = HeaderMap::new();
h.append(
jonaro00 marked this conversation as resolved.
Show resolved Hide resolved
AUTHORIZATION,
format!("Bearer {token}").parse().expect("valid token"),
);
h
}

#[cfg(test)]
mod tests {
use headers::{authorization::Bearer, Authorization};
use http::{Method, StatusCode};

use crate::models;
Expand All @@ -120,12 +172,7 @@ mod tests {
let client = ServicesApiClient::new(server.uri().parse().unwrap());

let err = client
.request::<_, Vec<models::project::Response>, _>(
Method::GET,
"projects",
None::<()>,
None::<Authorization<Bearer>>,
)
.request::<_, Vec<models::project::Response>>(Method::GET, "projects", None::<()>, None)
.await
.unwrap_err();

Expand Down
Loading