diff --git a/Cargo.lock b/Cargo.lock index 602b0b04..04e74061 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -204,6 +204,21 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "auth" +version = "0.0.0" +dependencies = [ + "anyhow", + "api", + "axum", + "headers", + "http 1.1.0", + "http-body", + "jsonwebtoken", + "serde", + "tower-http", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -276,6 +291,7 @@ dependencies = [ "axum-core", "bytes", "futures-util", + "headers", "http 1.1.0", "http-body", "http-body-util", @@ -874,8 +890,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -911,6 +929,30 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64 0.21.7", + "bytes", + "headers-core", + "http 1.1.0", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http 1.1.0", +] + [[package]] name = "heck" version = "0.5.0" @@ -1176,6 +1218,21 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "9.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ae10193d25051e74945f1ea2d0b42e03cc3b890f7e4cc5faa44997d808193f" +dependencies = [ + "base64 0.21.7", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1221,9 +1278,9 @@ dependencies = [ "anyhow", "api", "async-trait", + "auth", "axum", "axum-extra", - "base64 0.22.1", "chrono", "clap", "http 1.1.0", @@ -1254,8 +1311,9 @@ version = "0.5.1" dependencies = [ "anyhow", "api", + "auth", "axum", - "base64 0.22.1", + "axum-extra", "chrono", "clap", "http 1.1.0", @@ -2348,6 +2406,18 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time", +] + [[package]] name = "slab" version = "0.4.9" diff --git a/Cargo.toml b/Cargo.toml index e16e4c0b..c0b600d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,8 @@ webrtc = { git = "https://github.com/webrtc-rs/webrtc", rev = "ae93e81" } anyhow = "1.0" clap = "4.5" +http = "1.1" +http-body = "1.0" serde = "1" tokio = "1.36" tracing = "0.1" diff --git a/conf/live777.toml b/conf/live777.toml index 6e927d33..a0f08133 100644 --- a/conf/live777.toml +++ b/conf/live777.toml @@ -21,15 +21,11 @@ urls = [ # WHIP/WHEP auth token # Headers["Authorization"] = "Bearer {token}" # [auth] +# JSON WEB TOKEN secret +# secret = "" +# static JWT token, superadmin, debuggger can use this token # tokens = ["live777"] -# Not WHIP/WHEP standard -# https://developer.mozilla.org/en-US/docs/Web/HTTP/Authentication#basic -# Headers["Authorization"] = "Basic {Base64.encode({username}:{password})}" -# [[auth.accounts]] -# username = "live777" -# password = "live777" - [log] # Env: `LOG_LEVEL` # Default: info diff --git a/conf/liveman.toml b/conf/liveman.toml index ffe19ee1..356ca794 100644 --- a/conf/liveman.toml +++ b/conf/liveman.toml @@ -8,11 +8,12 @@ # WHIP/WHEP auth token # Headers["Authorization"] = "Bearer {token}" # [auth] +# JSON WEB TOKEN secret +# secret = "" +# static JWT token, superadmin, debuggger can use this token # tokens = ["live777"] -# Not WHIP/WHEP standard -# https://developer.mozilla.org/en-US/docs/Web/HTTP/Authentication#basic -# Headers["Authorization"] = "Basic {Base64.encode({username}:{password})}" +# Admin Dashboard Accounts # [[auth.accounts]] # username = "live777" # password = "live777" diff --git a/libs/auth/Cargo.toml b/libs/auth/Cargo.toml new file mode 100644 index 00000000..54108798 --- /dev/null +++ b/libs/auth/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "auth" +edition.workspace = true + +[lib] +crate-type = ["lib"] + +[dependencies] +api = { path = "../api" } +anyhow = { workspace = true, features = ["backtrace"] } +http = { workspace = true } +http-body = { workspace = true } +axum = { version = "0.7" } +jsonwebtoken = "9.3" +serde = { workspace = true, features = ["serde_derive"] } + +headers = "0.4.0" +tower-http = { version = "0.5.2", features = ["validate-request"] } diff --git a/libs/auth/src/access.rs b/libs/auth/src/access.rs new file mode 100644 index 00000000..9a2c748d --- /dev/null +++ b/libs/auth/src/access.rs @@ -0,0 +1,44 @@ +use axum::{extract::Request, http, middleware::Next, response::Response}; +use http::method::Method; + +use crate::{ + claims::{Access, Claims}, + ANY_ID, +}; + +pub async fn access_middleware(request: Request, next: Next) -> Response { + let ok = match request.extensions().get::() { + Some(claims) => match (claims.id.clone(), request.method(), request.uri().path()) { + (id, &Method::GET, path) if path == api::path::streams(&id) => true, + (id, &Method::DELETE, path) if path == api::path::streams(&id) => { + Access::from(claims.mode).x + } + (id, &Method::POST, path) if path == api::path::whip(&id) => { + Access::from(claims.mode).w + } + (id, &Method::POST, path) if path == api::path::whep(&id) => { + Access::from(claims.mode).r + } + (id, &Method::POST, path) if path == api::path::cascade(&id) => { + Access::from(claims.mode).x + } + (id, _, _) if id == ANY_ID => true, + (id, &Method::POST, path) if path == "/token" && id == ANY_ID => { + Access::from(claims.mode).r + && Access::from(claims.mode).w + && Access::from(claims.mode).x + } + _ => false, + }, + None => false, + }; + + if !ok { + return Response::builder() + .status(http::StatusCode::FORBIDDEN) + .body("Don't permission".into()) + .unwrap(); + } + + next.run(request).await +} diff --git a/libs/auth/src/claims.rs b/libs/auth/src/claims.rs new file mode 100644 index 00000000..c87a0844 --- /dev/null +++ b/libs/auth/src/claims.rs @@ -0,0 +1,109 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Claims { + pub id: String, + pub exp: u64, + pub mode: Mode, +} + +impl Display for Claims { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "id: {}\nexpire: {}, mode: {}", + self.id, + self.exp, + Access::from(self.mode) + ) + } +} + +/// Look like Linux File-system permissions +/// 4: read, allow use whep subscribe +/// 2: write, allow use whip publish +/// 1: execute, allow use manager this, example: destroy +type Mode = u8; + +impl From for Access { + fn from(mask: Mode) -> Access { + Access { + r: mask & 4 != 0, + w: mask & 2 != 0, + x: mask & 1 != 0, + } + } +} + +pub struct Access { + pub r: bool, + pub w: bool, + pub x: bool, +} + +impl From for Mode { + fn from(v: Access) -> Mode { + let r = if v.r { 4 } else { 0 }; + let w = if v.w { 2 } else { 0 }; + let x = if v.x { 1 } else { 0 }; + r + w + x + } +} + +impl Display for Access { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}{}{}", + if self.r { "r" } else { "-" }, + if self.w { "w" } else { "-" }, + if self.x { "x" } else { "-" }, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mode() { + let mut access = Access::from(7); + assert!(access.r); + assert!(access.w); + assert!(access.x); + assert_eq!(format!("{}", access), "rwx"); + + access = Access::from(6); + assert!(access.r); + assert!(access.w); + assert!(!access.x); + assert_eq!(format!("{}", access), "rw-"); + + access = Access::from(5); + assert!(access.r); + assert!(!access.w); + assert!(access.x); + assert_eq!(format!("{}", access), "r-x"); + + access = Access::from(4); + assert!(access.r); + assert!(!access.w); + assert!(!access.x); + assert_eq!(format!("{}", access), "r--"); + + access = Access::from(1); + assert!(!access.r); + assert!(!access.w); + assert!(access.x); + assert_eq!(format!("{}", access), "--x"); + + access = Access::from(0); + assert!(!access.r); + assert!(!access.w); + assert!(!access.x); + assert_eq!(format!("{}", access), "---"); + } +} diff --git a/libs/auth/src/lib.rs b/libs/auth/src/lib.rs new file mode 100644 index 00000000..6eefc2bb --- /dev/null +++ b/libs/auth/src/lib.rs @@ -0,0 +1,108 @@ +use std::{collections::HashSet, marker::PhantomData}; + +use anyhow::{anyhow, Error}; +use headers::authorization::{Bearer, Credentials}; +use http::{header, Request, Response, StatusCode}; +use http_body::Body; +use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; +use tower_http::validate_request::ValidateRequest; + +use crate::claims::Claims; + +pub mod access; +pub mod claims; + +pub const ANY_ID: &str = "*"; + +pub struct Keys { + encoding: EncodingKey, +} + +impl Keys { + pub fn new(secret: &[u8]) -> Self { + Self { + encoding: EncodingKey::from_secret(secret), + } + } + + pub fn token(self, claims: Claims) -> Result { + encode(&Header::default(), &claims, &self.encoding).map_err(|e| anyhow!(e)) + } +} + +pub struct ManyValidate { + tokens: HashSet, + decoding: DecodingKey, + _ty: PhantomData ResBody>, +} + +impl ManyValidate { + pub fn new(secret: String, tokens: Vec) -> Self + where + ResBody: Body + Default, + { + Self { + tokens: tokens.into_iter().collect(), + decoding: DecodingKey::from_secret(secret.as_bytes()), + _ty: PhantomData, + } + } +} + +impl Clone for ManyValidate { + fn clone(&self) -> Self { + Self { + tokens: self.tokens.clone(), + decoding: self.decoding.clone(), + _ty: PhantomData, + } + } +} + +impl ValidateRequest for ManyValidate +where + ResBody: Body + Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, request: &mut Request) -> Result<(), Response> { + if self.tokens.is_empty() { + request.extensions_mut().insert(Claims { + id: ANY_ID.to_string(), + exp: 0, + mode: 7, + }); + return Ok(()); + } + (match request.headers().get(header::AUTHORIZATION) { + Some(auth_header) => match Bearer::decode(auth_header) { + Some(bearer) if self.tokens.contains(bearer.token()) => { + // Static token is max permissions + request.extensions_mut().insert(Claims { + id: ANY_ID.to_string(), + exp: 0, + mode: 7, + }); + Ok(()) + } + Some(bearer) => { + match decode::(bearer.token(), &self.decoding, &Validation::default()) { + Ok(token_data) => { + request.extensions_mut().insert(token_data.claims); + Ok(()) + } + _ => Err(()), + } + } + _ => Err(()), + }, + _ => Err(()), + }) + .map_err(|_| { + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(ResBody::default()) + .unwrap() + }) + } +} diff --git a/libs/libwish/Cargo.toml b/libs/libwish/Cargo.toml index f66e0430..1abbee4e 100644 --- a/libs/libwish/Cargo.toml +++ b/libs/libwish/Cargo.toml @@ -7,9 +7,9 @@ crate-type = ["lib"] [dependencies] anyhow = { workspace = true } +http = { workspace = true } webrtc = { workspace = true } -http = "1.0.0" reqwest = { version = "0.12", features = [ "rustls-tls", ], default-features = false } diff --git a/libs/rtsp/Cargo.toml b/libs/rtsp/Cargo.toml index 13b017c8..385119bf 100644 --- a/libs/rtsp/Cargo.toml +++ b/libs/rtsp/Cargo.toml @@ -3,14 +3,14 @@ name = "rtsp" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] crate-type = ["lib"] [dependencies] -anyhow = "1.0" +anyhow = { workspace = true } + rtsp-types = "0.1.1" sdp = "0.6" tokio = "1.37" sdp-types = "0.1.6" -portpicker = "0.1.1" \ No newline at end of file +portpicker = "0.1.1" diff --git a/liveion/Cargo.toml b/liveion/Cargo.toml index 07b085e9..2acd7afe 100644 --- a/liveion/Cargo.toml +++ b/liveion/Cargo.toml @@ -11,6 +11,7 @@ crate-type = ["lib"] [dependencies] api = { path = "../libs/api" } +auth = { path = "../libs/auth" } http-log = { path = "../libs/http-log" } libwish = { path = "../libs/libwish" } signal = { path = "../libs/signal" } @@ -18,6 +19,8 @@ utils = { path = "../libs/utils" } anyhow = { workspace = true, features = ["backtrace"] } clap = { workspace = true, features = ["derive"] } +http = { workspace = true } +http-body = { workspace = true } serde = { workspace = true, features = ["serde_derive"] } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } @@ -26,10 +29,7 @@ webrtc = { workspace = true } async-trait = "0.1" axum = { version = "0.7", features = ["multipart", "tracing"] } axum-extra = { version = "0.9.3", features = ["query"] } -base64 = "0.22.1" chrono = "0.4" -http = "1.0.0" -http-body = "1.0.0" hyper = "1.2.0" lazy_static = "1.4.0" md5 = "0.7.0" diff --git a/liveion/src/auth.rs b/liveion/src/auth.rs deleted file mode 100644 index 630c0f8b..00000000 --- a/liveion/src/auth.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::{collections::HashSet, marker::PhantomData}; - -use http::{header, Request, Response, StatusCode}; -use http_body::Body; -use tower_http::validate_request::ValidateRequest; - -use crate::config::Auth; - -#[derive(Debug)] -pub struct ManyValidate { - header_values: HashSet, - _ty: PhantomData ResBody>, -} - -impl ManyValidate { - pub fn new(auths: Vec) -> Self - where - ResBody: Body + Default, - { - let mut header_values = HashSet::new(); - for auth in auths { - for authorization in auth.to_authorizations().into_iter() { - header_values.insert(authorization.parse().unwrap()); - } - } - Self { - header_values, - _ty: PhantomData, - } - } -} - -impl Clone for ManyValidate { - fn clone(&self) -> Self { - Self { - header_values: self.header_values.clone(), - _ty: PhantomData, - } - } -} - -impl ValidateRequest for ManyValidate -where - ResBody: Body + Default, -{ - type ResponseBody = ResBody; - - fn validate(&mut self, request: &mut Request) -> Result<(), Response> { - if self.header_values.is_empty() { - return Ok(()); - } - match request.headers().get(header::AUTHORIZATION) { - Some(actual) if self.header_values.contains(actual.to_str().unwrap()) => Ok(()), - _ => { - let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::UNAUTHORIZED; - Err(res) - } - } - } -} diff --git a/liveion/src/config.rs b/liveion/src/config.rs index 7dbb44c3..69cdfa16 100644 --- a/liveion/src/config.rs +++ b/liveion/src/config.rs @@ -1,7 +1,6 @@ -use base64::engine::general_purpose::STANDARD; -use base64::Engine; -use serde::{Deserialize, Serialize}; use std::{env, fs, net::SocketAddr, str::FromStr}; + +use serde::{Deserialize, Serialize}; use webrtc::{ ice, ice_transport::{ice_credential_type::RTCIceCredentialType, ice_server::RTCIceServer}, @@ -43,39 +42,11 @@ pub struct Http { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct Auth { #[serde(default)] - pub accounts: Vec, + pub secret: String, #[serde(default)] pub tokens: Vec, } -impl Auth { - pub fn to_authorizations(&self) -> Vec { - let mut authorizations = vec![]; - for account in self.accounts.iter() { - authorizations.push(account.to_authorization()); - } - for token in self.tokens.iter() { - authorizations.push(format!("Bearer {}", token)); - } - authorizations - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Account { - #[serde(default)] - pub username: String, - #[serde(default)] - pub password: String, -} - -impl Account { - pub fn to_authorization(&self) -> String { - let encoded = STANDARD.encode(format!("{}:{}", self.username, self.password)); - format!("Basic {}", encoded) - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Log { #[serde(default = "default_log_level")] diff --git a/liveion/src/convert.rs b/liveion/src/convert.rs index a1691efc..fd9aba41 100644 --- a/liveion/src/convert.rs +++ b/liveion/src/convert.rs @@ -59,7 +59,7 @@ impl From for api::response::CascadeInfo { impl From for NodeMetaData { fn from(value: Config) -> Self { Self { - authorization: value.auth.to_authorizations().first().cloned(), + authorization: value.auth.tokens.first().cloned(), } } } diff --git a/liveion/src/lib.rs b/liveion/src/lib.rs index e7f99480..448b7cb2 100644 --- a/liveion/src/lib.rs +++ b/liveion/src/lib.rs @@ -1,4 +1,5 @@ use axum::extract::Request; +use axum::middleware; use axum::response::IntoResponse; use axum::routing::get; use axum::Router; @@ -14,7 +15,8 @@ use tower_http::trace::TraceLayer; use tower_http::validate_request::ValidateRequestHeaderLayer; use tracing::{error, info_span}; -use crate::auth::ManyValidate; +use auth::{access::access_middleware, ManyValidate}; + use crate::config::Config; use crate::route::{admin, session, whep, whip, AppState}; @@ -26,7 +28,6 @@ struct Assets; pub mod config; -mod auth; mod constant; mod convert; mod error; @@ -46,7 +47,8 @@ where stream_manager: Arc::new(Manager::new(cfg.clone()).await), config: cfg.clone(), }; - let auth_layer = ValidateRequestHeaderLayer::custom(ManyValidate::new(vec![cfg.auth])); + let auth_layer = + ValidateRequestHeaderLayer::custom(ManyValidate::new(cfg.auth.secret, cfg.auth.tokens)); let mut app = Router::new() .merge( whip::route() @@ -54,6 +56,7 @@ where .merge(session::route()) .merge(admin::route()) .merge(crate::route::stream::route()) + .layer(middleware::from_fn(access_middleware)) .layer(auth_layer), ) .route(api::path::METRICS, get(metrics)) diff --git a/liveman/Cargo.toml b/liveman/Cargo.toml index 25d58108..44b51c65 100644 --- a/liveman/Cargo.toml +++ b/liveman/Cargo.toml @@ -10,21 +10,22 @@ repository.workspace = true liveion = { path = "../liveion", optional = true } api = { path = "../libs/api" } +auth = { path = "../libs/auth" } http-log = { path = "../libs/http-log" } signal = { path = "../libs/signal" } utils = { path = "../libs/utils" } anyhow = { workspace = true, features = ["backtrace"] } clap = { workspace = true, features = ["derive"] } +http = { workspace = true } +http-body = { workspace = true } serde = { workspace = true, features = ["serde_derive"] } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } axum = { version = "0.7", features = ["multipart", "tracing"] } -base64 = "0.22.1" +axum-extra = { version = "0.9.3", features = ["typed-header"] } chrono = { version = "0.4", features = ["serde"] } -http = "1.0.0" -http-body = "1.0.0" hyper-util = { version = "0.1", features = ["client-legacy"] } mime_guess = "2.0.4" reqwest = { version = "0.12", features = [ diff --git a/liveman/src/admin.rs b/liveman/src/admin.rs new file mode 100644 index 00000000..6a21398e --- /dev/null +++ b/liveman/src/admin.rs @@ -0,0 +1,144 @@ +use std::time::{Duration, SystemTime}; + +use axum::{ + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tracing::{debug, error}; + +use auth::{ + claims::{Access, Claims}, + Keys, ANY_ID, +}; + +use crate::{config::Account, AppState}; + +const JWT_TOKEN_EXPIRES: Duration = Duration::from_secs(60 * 60 * 24); + +pub async fn authorize( + State(state): State, + Json(payload): Json, +) -> Result, AuthError> { + // Check if the user sent the credentials + if payload.username.is_empty() || payload.password.is_empty() { + return Err(AuthError::MissingCredentials); + } + // Here you can check the user credentials from a database + let mut user: Option<&Account> = None; + for account in state.config.auth.accounts.iter() { + if payload.username == account.username && payload.password == account.password { + user = Some(account); + } + } + + if user.is_none() { + return Err(AuthError::WrongCredentials); + } + + debug!("User UID: {:?}", user); + + let keys = Keys::new(state.config.auth.secret.as_bytes()); + let token = keys + .token(Claims { + id: ANY_ID.to_string(), + exp: (SystemTime::now() + JWT_TOKEN_EXPIRES) + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + mode: 7, + }) + .map_err(|err| { + error!("Error while encoding: {err}"); + AuthError::TokenCreation + })?; + + // Send the authorized token + Ok(Json(AuthBody::new(token))) +} + +pub async fn token( + State(state): State, + Json(payload): Json, +) -> Result, AuthError> { + let keys = Keys::new(state.config.auth.secret.as_bytes()); + let token = keys.token(payload.into()).map_err(|err| { + error!("Error while encoding: {err}"); + AuthError::TokenCreation + })?; + + // Send the authorized token + Ok(Json(AuthBody::new(token))) +} + +impl IntoResponse for AuthError { + fn into_response(self) -> Response { + let (status, error_message) = match self { + AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), + AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"), + AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), + }; + let body = Json(json!({ + "error": error_message, + })); + (status, body).into_response() + } +} + +#[derive(Debug, Deserialize)] +pub struct TokenPayload { + id: String, + duration: u64, + subscribe: bool, + publish: bool, + admin: bool, +} + +impl From for Claims { + fn from(v: TokenPayload) -> Self { + Self { + id: v.id, + exp: (SystemTime::now() + Duration::from_secs(v.duration)) + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + mode: (Access { + r: v.subscribe, + w: v.publish, + x: v.admin, + }) + .into(), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct AuthPayload { + username: String, + password: String, +} + +#[derive(Debug, Serialize)] +pub struct AuthBody { + access_token: String, + token_type: String, +} + +impl AuthBody { + fn new(access_token: String) -> Self { + Self { + access_token, + token_type: "Bearer".to_string(), + } + } +} + +#[derive(Debug)] +pub enum AuthError { + WrongCredentials, + MissingCredentials, + TokenCreation, +} diff --git a/liveman/src/auth.rs b/liveman/src/auth.rs deleted file mode 100644 index 630c0f8b..00000000 --- a/liveman/src/auth.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::{collections::HashSet, marker::PhantomData}; - -use http::{header, Request, Response, StatusCode}; -use http_body::Body; -use tower_http::validate_request::ValidateRequest; - -use crate::config::Auth; - -#[derive(Debug)] -pub struct ManyValidate { - header_values: HashSet, - _ty: PhantomData ResBody>, -} - -impl ManyValidate { - pub fn new(auths: Vec) -> Self - where - ResBody: Body + Default, - { - let mut header_values = HashSet::new(); - for auth in auths { - for authorization in auth.to_authorizations().into_iter() { - header_values.insert(authorization.parse().unwrap()); - } - } - Self { - header_values, - _ty: PhantomData, - } - } -} - -impl Clone for ManyValidate { - fn clone(&self) -> Self { - Self { - header_values: self.header_values.clone(), - _ty: PhantomData, - } - } -} - -impl ValidateRequest for ManyValidate -where - ResBody: Body + Default, -{ - type ResponseBody = ResBody; - - fn validate(&mut self, request: &mut Request) -> Result<(), Response> { - if self.header_values.is_empty() { - return Ok(()); - } - match request.headers().get(header::AUTHORIZATION) { - Some(actual) if self.header_values.contains(actual.to_str().unwrap()) => Ok(()), - _ => { - let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::UNAUTHORIZED; - Err(res) - } - } - } -} diff --git a/liveman/src/config.rs b/liveman/src/config.rs index 02772265..d2a72c62 100644 --- a/liveman/src/config.rs +++ b/liveman/src/config.rs @@ -1,5 +1,3 @@ -use base64::engine::general_purpose::STANDARD; -use base64::Engine; use serde::{Deserialize, Serialize}; use std::{env, fs, net::SocketAddr, str::FromStr}; @@ -30,22 +28,11 @@ pub struct Http { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct Auth { #[serde(default)] - pub accounts: Vec, + pub secret: String, #[serde(default)] pub tokens: Vec, -} - -impl Auth { - pub fn to_authorizations(&self) -> Vec { - let mut authorizations = vec![]; - for account in self.accounts.iter() { - authorizations.push(account.to_authorization()); - } - for token in self.tokens.iter() { - authorizations.push(format!("Bearer {}", token)); - } - authorizations - } + #[serde(default)] + pub accounts: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -56,13 +43,6 @@ pub struct Account { pub password: String, } -impl Account { - pub fn to_authorization(&self) -> String { - let encoded = STANDARD.encode(format!("{}:{}", self.username, self.password)); - format!("Basic {}", encoded) - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Log { #[serde(default = "default_log_level")] diff --git a/liveman/src/main.rs b/liveman/src/main.rs index 43337c73..5acb7985 100644 --- a/liveman/src/main.rs +++ b/liveman/src/main.rs @@ -1,7 +1,10 @@ use axum::body::Body; use axum::extract::Request; +use axum::middleware; use axum::response::IntoResponse; +use axum::routing::post; use axum::Router; + use clap::Parser; use http::{header, StatusCode, Uri}; use hyper_util::client::legacy::connect::HttpConnector; @@ -14,7 +17,9 @@ use tower_http::trace::TraceLayer; use tower_http::validate_request::ValidateRequestHeaderLayer; use tracing::{debug, error, info, info_span, warn}; -use crate::auth::ManyValidate; +use auth::{access::access_middleware, ManyValidate}; + +use crate::admin::{authorize, token}; use crate::config::Config; use crate::mem::{MemStorage, Server}; @@ -22,7 +27,7 @@ use crate::mem::{MemStorage, Server}; #[folder = "../assets/liveman/"] struct Assets; -mod auth; +mod admin; mod config; mod error; mod mem; @@ -134,15 +139,23 @@ where client, storage: MemStorage::new(cfg.nodes), }; - let auth_layer = ValidateRequestHeaderLayer::custom(ManyValidate::new(vec![cfg.auth])); + + let auth_layer = + ValidateRequestHeaderLayer::custom(ManyValidate::new(cfg.auth.secret, cfg.auth.tokens)); let mut app = Router::new() - .merge(route::proxy::route().layer(auth_layer)) - .with_state(app_state.clone()) + .merge( + route::proxy::route() + .route("/token", post(token)) + .layer(middleware::from_fn(access_middleware)) + .layer(auth_layer), + ) .layer(if cfg.http.cors { CorsLayer::permissive() } else { CorsLayer::new() }) + .route("/login", post(authorize)) + .with_state(app_state.clone()) .layer(axum::middleware::from_fn(http_log::print_request_response)) .layer( TraceLayer::new_for_http().make_span_with(|request: &Request<_>| {