Skip to content

Commit

Permalink
feat: add a users layer (#633)
Browse files Browse the repository at this point in the history
* feat: initial commit of user manager

* refactor: remove debug handlers

* fix: clippy

* fix: users table

* refactor: comment

* feat: move permissions to auth, add todos

* feat: setup integration testing of the api

* refactor: workspace dependencies

* refactor: use api key, not secret

* feat: replace super_user with admin account tier

* refactor: make serve a method on apibuilder

* feat: init command, admin guard on handlers, update tests

* misc: remove todo, update cargo.lock

* feat: UserResponse model, remove Permissions struct

* feat: refactor tests to use router, add init-db fn

* refactor: remove redundant method on user

* feat: add account tier to create user

* misc: cleanup, remove todos

* misc: remove await, add to CI

* refactor: cargo sort

* misc: workspace deps

* tests: clippy

* refactor: remove old comment

* feat: use strum, add tier to span
  • Loading branch information
oddgrd authored Feb 17, 2023
1 parent a9ab3e6 commit 0865c3b
Show file tree
Hide file tree
Showing 25 changed files with 761 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ workflows:
- workspace-clippy
matrix:
parameters:
crate: ["shuttle-deployer", "cargo-shuttle", "shuttle-codegen", "shuttle-common", "shuttle-proto", "shuttle-provisioner"]
crate: ["shuttle-auth", "shuttle-deployer", "cargo-shuttle", "shuttle-codegen", "shuttle-common", "shuttle-proto", "shuttle-provisioner"]
- e2e-test:
requires:
- service-test
Expand Down
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ once_cell = "1.16.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-datadog = { version = "0.6.0", features = ["reqwest-client"] }
opentelemetry-http = "0.7.0"
rand = "0.8.5"
serde = "1.0.148"
serde_json = "1.0.89"
strum = { version = "0.24.1", features = ["derive"] }
portpicker = "0.1.1"
thiserror = "1.0.37"
tower = "0.4.13"
tower-http = { version = "0.3.4", features = ["trace"] }
Expand Down
17 changes: 14 additions & 3 deletions auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,30 @@
name = "shuttle-auth"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = { workspace = true }
anyhow = { workspace = true }
async-trait = { workspace = true }
axum = { workspace = true, features = ["headers"] }
clap = { workspace = true }
opentelemetry = { workspace = true }
opentelemetry-datadog = { workspace = true }
rand = { workspace = true }
serde = { workspace = true, features = [ "derive" ] }
sqlx = { version = "0.6.2", features = [ "sqlite", "json", "runtime-tokio-native-tls", "migrate" ] }
strum = { workspace = true }
thiserror = { workspace = true }
tokio = { version = "1.22.0", features = [ "full" ] }
tracing = { workspace = true }
tracing-opentelemetry = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }

[dependencies.shuttle-common]
workspace = true
features = ["backend"]
features = ["backend", "models"]

[dev-dependencies]
hyper = { workspace = true }
serde_json = { workspace = true }
tower = { workspace = true, features = ["util"] }
5 changes: 5 additions & 0 deletions auth/migrations/0000_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE IF NOT EXISTS users (
account_name TEXT PRIMARY KEY,
key TEXT UNIQUE,
account_tier TEXT DEFAULT "basic" NOT NULL
);
82 changes: 82 additions & 0 deletions auth/src/api/builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use std::net::SocketAddr;

use axum::{
middleware::from_extractor,
routing::{get, post},
Router, Server,
};
use shuttle_common::{
backends::metrics::{Metrics, TraceLayer},
request_span,
};
use sqlx::SqlitePool;
use tracing::field;

use crate::user::UserManager;

use super::handlers::{
convert_cookie, convert_key, get_public_key, get_user, login, logout, post_user, refresh_token,
};

#[derive(Clone)]
pub struct RouterState {
pub user_manager: UserManager,
}

pub struct ApiBuilder {
router: Router<RouterState>,
pool: Option<SqlitePool>,
}

impl Default for ApiBuilder {
fn default() -> Self {
Self::new()
}
}

impl ApiBuilder {
pub fn new() -> Self {
let router = Router::new()
.route("/login", post(login))
.route("/logout", post(logout))
.route("/auth/session", post(convert_cookie))
.route("/auth/key", post(convert_key))
.route("/auth/refresh", post(refresh_token))
.route("/public-key", get(get_public_key))
.route("/user/:account_name", get(get_user))
.route("/user/:account_name/:account_tier", post(post_user))
.route_layer(from_extractor::<Metrics>())
.layer(
TraceLayer::new(|request| {
request_span!(
request,
request.params.account_name = field::Empty,
request.params.account_tier = field::Empty
)
})
.with_propagation()
.build(),
);

Self { router, pool: None }
}

pub fn with_sqlite_pool(mut self, pool: SqlitePool) -> Self {
self.pool = Some(pool);
self
}

pub fn into_router(self) -> Router {
let pool = self.pool.expect("an sqlite pool is required");

let user_manager = UserManager { pool };
self.router.with_state(RouterState { user_manager })
}
}

pub async fn serve(router: Router, address: SocketAddr) {
Server::bind(&address)
.serve(router.into_make_service())
.await
.unwrap_or_else(|_| panic!("Failed to bind to address: {}", address));
}
45 changes: 45 additions & 0 deletions auth/src/api/handlers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::{
api::builder::RouterState,
error::Error,
user::{AccountName, AccountTier, Admin, UserManagement},
};
use axum::{
extract::{Path, State},
Json,
};
use shuttle_common::models::auth;
use tracing::instrument;

#[instrument(skip(user_manager))]
pub(crate) async fn get_user(
_: Admin,
State(RouterState { user_manager }): State<RouterState>,
Path(account_name): Path<AccountName>,
) -> Result<Json<auth::UserResponse>, Error> {
let user = user_manager.get_user(account_name).await?;

Ok(Json(user.into()))
}

#[instrument(skip(user_manager))]
pub(crate) async fn post_user(
_: Admin,
State(RouterState { user_manager }): State<RouterState>,
Path((account_name, account_tier)): Path<(AccountName, AccountTier)>,
) -> Result<Json<auth::UserResponse>, Error> {
let user = user_manager.create_user(account_name, account_tier).await?;

Ok(Json(user.into()))
}

pub(crate) async fn login() {}

pub(crate) async fn logout() {}

pub(crate) async fn convert_cookie() {}

pub(crate) async fn convert_key() {}

pub(crate) async fn refresh_token() {}

pub(crate) async fn get_public_key() {}
4 changes: 4 additions & 0 deletions auth/src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod builder;
mod handlers;

pub use builder::{serve, ApiBuilder, RouterState};
24 changes: 23 additions & 1 deletion auth/src/args.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
use std::{net::SocketAddr, path::PathBuf};

use clap::Parser;
use clap::{Parser, Subcommand};

#[derive(Parser, Debug)]
pub struct Args {
/// Where to store auth state (such as users)
#[arg(long, default_value = "./")]
pub state: PathBuf,

#[command(subcommand)]
pub command: Commands,
}

#[derive(Subcommand, Debug)]
pub enum Commands {
Start(StartArgs),
Init(InitArgs),
}

#[derive(clap::Args, Debug, Clone)]
pub struct StartArgs {
/// Address to bind to
#[arg(long, default_value = "127.0.0.1:8000")]
pub address: SocketAddr,
}

#[derive(clap::Args, Debug, Clone)]
pub struct InitArgs {
/// Name of initial account to create
#[arg(long)]
pub name: String,
/// Key to assign to initial account
#[arg(long)]
pub key: Option<String>,
}
62 changes: 62 additions & 0 deletions auth/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::error::Error as StdError;

use axum::http::{header, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::Json;

use serde::{ser::SerializeMap, Serialize};
use shuttle_common::models::error::ApiError;

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("User could not be found")]
UserNotFound,
#[error("API key is missing.")]
KeyMissing,
#[error("Unauthorized.")]
Unauthorized,
#[error("Forbidden.")]
Forbidden,
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error(transparent)]
UnexpectedError(#[from] anyhow::Error),
}

impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("type", &format!("{:?}", self))?;
// use the error source if available, if not use display implementation
map.serialize_entry("msg", &self.source().unwrap_or(self).to_string())?;
map.end()
}
}

impl IntoResponse for Error {
fn into_response(self) -> Response {
let code = match self {
Error::Forbidden => StatusCode::FORBIDDEN,
Error::Unauthorized => StatusCode::UNAUTHORIZED,
Error::KeyMissing => StatusCode::BAD_REQUEST,
Error::Database(_) | Error::UserNotFound => StatusCode::NOT_FOUND,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};

(
code,
[(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
)],
Json(ApiError {
message: self.to_string(),
status_code: code.as_u16(),
}),
)
.into_response()
}
}
69 changes: 62 additions & 7 deletions auth/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,71 @@
mod api;
mod args;
mod router;
mod error;
mod user;

pub use args::Args;
use std::{io, str::FromStr};

use args::StartArgs;
use sqlx::{
migrate::Migrator,
query,
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqliteSynchronous},
SqlitePool,
};
use tracing::info;

pub async fn start(args: Args) {
let router = router::new();
use crate::{
api::serve,
user::{AccountTier, Key},
};
pub use api::ApiBuilder;
pub use args::{Args, Commands, InitArgs};

pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations");

pub async fn start(pool: SqlitePool, args: StartArgs) -> io::Result<()> {
let router = api::ApiBuilder::new().with_sqlite_pool(pool).into_router();

info!(address=%args.address, "Binding to and listening at address");

axum::Server::bind(&args.address)
.serve(router.into_make_service())
serve(router, args.address).await;

Ok(())
}

pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
let key = match args.key {
Some(ref key) => Key::from_str(key).unwrap(),
None => Key::new_random(),
};

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
.bind(&args.name)
.bind(&key)
.bind(AccountTier::Admin)
.execute(&pool)
.await
.unwrap_or_else(|_| panic!("Failed to bind to address: {}", args.address));
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

println!("`{}` created as super user with key: {key}", args.name);
Ok(())
}

/// Initialize an SQLite database at the given URI, creating it if it does not
/// already exist. To create an in-memory database for tests, simply pass in
/// `sqlite::memory:` for the `db_uri`.
pub async fn sqlite_init(db_uri: &str) -> SqlitePool {
let sqlite_options = SqliteConnectOptions::from_str(db_uri)
.unwrap()
.create_if_missing(true)
// To see the sources for choosing these settings, see:
// https://github.com/shuttle-hq/shuttle/pull/623
.journal_mode(SqliteJournalMode::Wal)
.synchronous(SqliteSynchronous::Normal);

let pool = SqlitePool::connect_with(sqlite_options).await.unwrap();

MIGRATIONS.run(&pool).await.unwrap();

pool
}
Loading

0 comments on commit 0865c3b

Please sign in to comment.