diff --git a/common/Cargo.toml b/common/Cargo.toml index 54dfbe262..89ea63a15 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -30,7 +30,7 @@ serde = { workspace = true, features = ["derive", "std"] } serde_json = { workspace = true } sqlx = { workspace = true, optional = true } strum = { workspace = true, features = ["derive"] } -thiserror = { workspace = true, optional = true } +thiserror = { workspace = true } tonic = { workspace = true, optional = true } tower = { workspace = true, optional = true } tracing = { workspace = true, features = ["std"], optional = true } @@ -65,7 +65,7 @@ extract_propagation = [ "tower", "tracing-opentelemetry", ] -models = ["async-trait", "reqwest", "service", "thiserror"] +models = ["async-trait", "reqwest", "service"] persist = ["sqlx", "rand"] sqlx = ["dep:sqlx", "sqlx/sqlite"] service = ["chrono/serde", "display", "tracing", "tracing-subscriber", "uuid"] diff --git a/common/src/models/error.rs b/common/src/models/error.rs index 3263c9c5b..f99313c4b 100644 --- a/common/src/models/error.rs +++ b/common/src/models/error.rs @@ -30,13 +30,22 @@ impl ApiError { } /// Creates an internal error without exposing sensitive information to the user. - pub fn internal_safe(error: impl std::error::Error) -> Self { - error!(error = error.to_string(), ""); - - Self { - message: "Internal server error occured. Please create a ticket to get this fixed." - .to_string(), - status_code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + #[inline(always)] + pub fn internal_safe(message: &str, error: E) -> Self + where + E: std::error::Error + 'static, + { + error!(error = &error as &dyn std::error::Error, "{message}"); + + // Return the raw error during debug builds + #[cfg(debug_assertions)] + { + ApiError::internal(&error.to_string()) + } + // Return the safe message during release builds + #[cfg(not(debug_assertions))] + { + ApiError::internal(message) } } @@ -54,13 +63,6 @@ impl ApiError { } } - pub fn not_found(message: impl ToString) -> Self { - Self { - message: message.to_string(), - status_code: StatusCode::NOT_FOUND.as_u16(), - } - } - pub fn unauthorized() -> Self { Self { message: "Unauthorized".to_string(), @@ -80,6 +82,131 @@ impl ApiError { } } +pub trait ErrorContext { + /// Make a new internal server error with the given message. + #[inline(always)] + fn context_internal_error(self, message: &str) -> Result + where + Self: Sized, + { + self.with_context_internal_error(move || message.to_string()) + } + + /// Make a new internal server error using the given function to create the message. + fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result; + + /// Make a new bad request error with the given message. + #[inline(always)] + fn context_bad_request(self, message: &str) -> Result + where + Self: Sized, + { + self.with_context_bad_request(move || message.to_string()) + } + + /// Make a new bad request error using the given function to create the message. + fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result; + + /// Make a new not found error with the given message. + #[inline(always)] + fn context_not_found(self, message: &str) -> Result + where + Self: Sized, + { + self.with_context_not_found(move || message.to_string()) + } + + /// Make a new not found error using the given function to create the message. + fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result; +} + +impl ErrorContext for Result +where + E: std::error::Error + 'static, +{ + #[inline(always)] + fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result { + match self { + Ok(value) => Ok(value), + Err(error) => Err(ApiError::internal_safe(message().as_ref(), error)), + } + } + + #[inline(always)] + fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result { + match self { + Ok(value) => Ok(value), + Err(error) => Err({ + let message = message(); + warn!( + error = &error as &dyn std::error::Error, + "bad request: {message}" + ); + + ApiError { + message, + status_code: StatusCode::BAD_REQUEST.as_u16(), + } + }), + } + } + + #[inline(always)] + fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result { + match self { + Ok(value) => Ok(value), + Err(error) => Err({ + let message = message(); + warn!( + error = &error as &dyn std::error::Error, + "not found: {message}" + ); + + ApiError { + message, + status_code: StatusCode::NOT_FOUND.as_u16(), + } + }), + } + } +} + +impl ErrorContext for Option { + #[inline] + fn with_context_internal_error(self, message: impl FnOnce() -> String) -> Result { + match self { + Some(value) => Ok(value), + None => Err(ApiError::internal(message().as_ref())), + } + } + + #[inline] + fn with_context_bad_request(self, message: impl FnOnce() -> String) -> Result { + match self { + Some(value) => Ok(value), + None => Err({ + ApiError { + message: message(), + status_code: StatusCode::BAD_REQUEST.as_u16(), + } + }), + } + } + + #[inline] + fn with_context_not_found(self, message: impl FnOnce() -> String) -> Result { + match self { + Some(value) => Ok(value), + None => Err({ + ApiError { + message: message(), + status_code: StatusCode::NOT_FOUND.as_u16(), + } + }), + } + } +} + impl Display for ApiError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( diff --git a/common/src/resource.rs b/common/src/resource.rs index 53931733e..0fdb65d0c 100644 --- a/common/src/resource.rs +++ b/common/src/resource.rs @@ -2,6 +2,7 @@ use std::{fmt::Display, str::FromStr}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use thiserror::Error; use crate::{constants::RESOURCE_SCHEMA_VERSION, database}; @@ -87,21 +88,32 @@ pub enum Type { Container, } +#[derive(Debug, Error)] +pub enum InvalidResourceType { + #[error("'{0}' is an unknown database type")] + Type(String), + + #[error("{0}")] + Database(String), +} + impl FromStr for Type { - type Err = String; + type Err = InvalidResourceType; fn from_str(s: &str) -> Result { if let Some((prefix, rest)) = s.split_once("::") { match prefix { - "database" => Ok(Self::Database(database::Type::from_str(rest)?)), - _ => Err(format!("'{prefix}' is an unknown resource type")), + "database" => Ok(Self::Database( + database::Type::from_str(rest).map_err(InvalidResourceType::Database)?, + )), + _ => Err(InvalidResourceType::Type(prefix.to_string())), } } else { match s { "secrets" => Ok(Self::Secrets), "persist" => Ok(Self::Persist), "container" => Ok(Self::Container), - _ => Err(format!("'{s}' is an unknown resource type")), + _ => Err(InvalidResourceType::Type(s.to_string())), } } } diff --git a/resource-recorder/src/dal.rs b/resource-recorder/src/dal.rs index 9ad8aeb56..78f679a86 100644 --- a/resource-recorder/src/dal.rs +++ b/resource-recorder/src/dal.rs @@ -4,7 +4,7 @@ use crate::Error; use async_trait::async_trait; use chrono::{DateTime, Utc}; use prost_types::Timestamp; -use shuttle_common::resource::Type; +use shuttle_common::resource::{InvalidResourceType, Type}; use shuttle_proto::resource_recorder::{self, record_request}; use sqlx::{ migrate::{MigrateDatabase, Migrator}, @@ -244,7 +244,7 @@ impl FromRow<'_, SqliteRow> for Resource { } impl TryFrom for Resource { - type Error = String; + type Error = InvalidResourceType; fn try_from(value: record_request::Resource) -> Result { let r#type = value.r#type.parse()?; diff --git a/resource-recorder/src/lib.rs b/resource-recorder/src/lib.rs index 768289e62..1b36894dc 100644 --- a/resource-recorder/src/lib.rs +++ b/resource-recorder/src/lib.rs @@ -2,7 +2,10 @@ use async_trait::async_trait; use dal::{Dal, DalError, Resource}; use prost_types::TimestampError; use shuttle_backends::{auth::VerifyClaim, client::ServicesApiClient, ClaimExt}; -use shuttle_common::claims::{Claim, Scope}; +use shuttle_common::{ + claims::{Claim, Scope}, + resource::InvalidResourceType, +}; use shuttle_proto::resource_recorder::{ self, resource_recorder_server::ResourceRecorder, ProjectResourcesRequest, RecordRequest, ResourceIds, ResourceResponse, ResourcesResponse, ResultResponse, ServiceResourcesRequest, @@ -28,17 +31,13 @@ pub enum Error { Dal(#[from] DalError), #[error("could not parse resource type: {0}")] - String(String), + ResourceType(#[from] InvalidResourceType), #[error("could not parse timestamp: {0}")] Timestamp(#[from] TimestampError), -} -// thiserror is not happy to handle a `#[from] String` -impl From for Error { - fn from(value: String) -> Self { - Self::String(value) - } + #[error("{0}")] + String(String), } pub struct Service {