diff --git a/Cargo.lock b/Cargo.lock index 16dd5809..a63eb4d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,6 +612,22 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "commonlib" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "env_logger", + "failure", + "indexmap 1.9.3", + "job_scheduler_ng", + "log", + "md5", + "serde", + "serde_derive", +] + [[package]] name = "const-oid" version = "0.9.5" @@ -1399,6 +1415,7 @@ dependencies = [ "axum 0.7.4", "byteorder", "bytes", + "commonlib", "failure", "log", "rtmp 0.4.1", @@ -1522,6 +1539,7 @@ dependencies = [ "axum 0.7.4", "byteorder", "bytes", + "commonlib", "failure", "futures", "log", @@ -1833,6 +1851,12 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.5.0" @@ -2484,6 +2508,7 @@ dependencies = [ "bytes", "bytesio 0.3.1", "chrono", + "commonlib", "failure", "h264-decoder 0.2.0", "hex", @@ -3930,6 +3955,7 @@ dependencies = [ "anyhow", "axum 0.6.10", "clap", + "commonlib", "env_logger_extend", "failure", "hls", @@ -3968,6 +3994,7 @@ dependencies = [ "bytes", "bytesio 0.3.1", "chrono", + "commonlib", "failure", "hex", "http 0.2.9", @@ -3987,6 +4014,7 @@ dependencies = [ "byteorder", "bytes", "bytesio 0.3.1", + "commonlib", "failure", "fdk-aac", "http 0.2.9", diff --git a/Cargo.toml b/Cargo.toml index 622e62f8..04e3953b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,4 +15,5 @@ members = [ "library/codec/h264", "library/logger", "library/streamhub", + "library/common", ] diff --git a/application/xiu/Cargo.toml b/application/xiu/Cargo.toml index 52f4e828..4cf275f3 100644 --- a/application/xiu/Cargo.toml +++ b/application/xiu/Cargo.toml @@ -28,6 +28,7 @@ tokio-metrics = { version = "0.2.0", default-features = false } env_logger_extend = { path = "../../library/logger/" } streamhub = { path = "../../library/streamhub/" } +commonlib = { path = "../../library/common/" } rtmp = { path = "../../protocol/rtmp/" } xrtsp = { path = "../../protocol/rtsp/" } xwebrtc = { path = "../../protocol/webrtc/" } diff --git a/application/xiu/src/config/config.toml b/application/xiu/src/config/config.toml index 8b1b193b..89aa903d 100644 --- a/application/xiu/src/config/config.toml +++ b/application/xiu/src/config/config.toml @@ -6,6 +6,11 @@ enabled = true port = 1935 gop_num = 0 +[rtmp.auth] +pull_enabled = true +push_enabled = true +# simple or md5 +algorithm = "simple" # pull streams from other server node. [rtmp.pull] enabled = false @@ -28,6 +33,12 @@ on_unpublish = "http://localhost:3001/on_unpuslish" on_play = "http://localhost:3001/on_play" on_stop = "http://localhost:3001/on_stop" +[authsecret] +# used for md5 authentication +key = "" +# used for simple authentication +password = "" + ########################## # RTSP configurations # @@ -35,6 +46,11 @@ on_stop = "http://localhost:3001/on_stop" [rtsp] enabled = false port = 445 +[rtsp.auth] +pull_enabled = true +push_enabled = true +# simple or md5 +algorithm = "simple" ########################## # WebRTC configurations # @@ -42,6 +58,11 @@ port = 445 [webrtc] enabled = false port = 8083 +[webrtc.auth] +pull_enabled = true +push_enabled = true +# simple or md5 +algorithm = "simple" ########################## # HTTPFLV configurations # @@ -49,6 +70,11 @@ port = 8083 [httpflv] enabled = false port = 8081 +[httpflv.auth] +pull_enabled = true +# simple or md5 +algorithm = "simple" + ########################## # HLS configurations # @@ -56,7 +82,12 @@ port = 8081 [hls] enabled = false port = 8080 -need_record = true +need_record = false +[hls.auth] +pull_enabled = true +# simple or md5 +algorithm = "simple" + ########################## # LOG configurations # diff --git a/application/xiu/src/config/mod.rs b/application/xiu/src/config/mod.rs index a2757f54..c51160e7 100644 --- a/application/xiu/src/config/mod.rs +++ b/application/xiu/src/config/mod.rs @@ -1,5 +1,6 @@ pub mod errors; +use commonlib::auth::AuthAlgorithm; use errors::ConfigError; use serde_derive::Deserialize; use std::fs; @@ -14,6 +15,7 @@ pub struct Config { pub hls: Option, pub httpapi: Option, pub httpnotify: Option, + pub authsecret: AuthSecretConfig, pub log: Option, } @@ -34,6 +36,7 @@ impl Config { port: rtmp_port, pull: None, push: None, + auth: None, }); } @@ -42,6 +45,7 @@ impl Config { rtsp_config = Some(RtspConfig { enabled: true, port: rtsp_port, + auth: None, }); } @@ -50,6 +54,7 @@ impl Config { webrtc_config = Some(WebRTCConfig { enabled: true, port: webrtc_port, + auth: None, }); } @@ -58,6 +63,7 @@ impl Config { httpflv_config = Some(HttpFlvConfig { enabled: true, port: httpflv_port, + auth: None, }); } @@ -67,6 +73,7 @@ impl Config { enabled: true, port: hls_port, need_record: false, + auth: None, }); } @@ -83,6 +90,7 @@ impl Config { hls: hls_config, httpapi: None, httpnotify: None, + authsecret: AuthSecretConfig::default(), log: log_config, } } @@ -95,6 +103,7 @@ pub struct RtmpConfig { pub gop_num: Option, pub pull: Option, pub push: Option>, + pub auth: Option, } #[derive(Debug, Deserialize, Clone)] pub struct RtmpPullConfig { @@ -113,18 +122,21 @@ pub struct RtmpPushConfig { pub struct RtspConfig { pub enabled: bool, pub port: usize, + pub auth: Option, } #[derive(Debug, Deserialize, Clone)] pub struct WebRTCConfig { pub enabled: bool, pub port: usize, + pub auth: Option, } #[derive(Debug, Deserialize, Clone)] pub struct HttpFlvConfig { pub enabled: bool, pub port: usize, + pub auth: Option, } #[derive(Debug, Deserialize, Clone)] @@ -133,6 +145,7 @@ pub struct HlsConfig { pub port: usize, //record or not pub need_record: bool, + pub auth: Option, } pub enum LogLevel { @@ -170,6 +183,19 @@ pub struct HttpNotifierConfig { pub on_stop: Option, } +#[derive(Debug, Deserialize, Clone, Default)] +pub struct AuthSecretConfig { + pub key: String, + pub password: String, +} + +#[derive(Debug, Deserialize, Clone, Default)] +pub struct AuthConfig { + pub pull_enabled: bool, + pub push_enabled: Option, + pub algorithm: AuthAlgorithm, +} + pub fn load(cfg_path: &String) -> Result { let content = fs::read_to_string(cfg_path)?; let decoded_config = toml::from_str(&content[..]).unwrap(); @@ -184,15 +210,13 @@ fn test_toml_parse() { Err(err) => println!("{}", err), } - let str = fs::read_to_string( - "./src/config/config.toml", - ); + let str = fs::read_to_string("./src/config/config.toml"); match str { Ok(val) => { println!("++++++{val}\n"); let decoded: Config = toml::from_str(&val[..]).unwrap(); - + println!("whole config: {:?}", decoded); let rtmp = decoded.httpnotify; if let Some(val) = rtmp { diff --git a/application/xiu/src/service.rs b/application/xiu/src/service.rs index 774d1aa4..047e1e40 100644 --- a/application/xiu/src/service.rs +++ b/application/xiu/src/service.rs @@ -1,10 +1,14 @@ +use commonlib::auth::AuthType; use rtmp::remuxer::RtmpRemuxer; +use crate::config::{AuthConfig, AuthSecretConfig}; + use { super::api, super::config::Config, //https://rustcc.cn/article?id=6dcbf032-0483-4980-8bfe-c64a7dfb33c7 anyhow::Result, + commonlib::auth::Auth, hls::remuxer::HlsRemuxer, hls::server as hls_server, httpflv::server as httpflv_server, @@ -27,6 +31,35 @@ impl Service { Service { cfg } } + fn gen_auth(auth_config: &Option, authsecret: &AuthSecretConfig) -> Option { + if let Some(cfg) = auth_config { + let auth_type = if let Some(push_enabled) = cfg.push_enabled { + if push_enabled && cfg.pull_enabled { + AuthType::Both + } else if !push_enabled && !cfg.pull_enabled { + AuthType::None + } else if push_enabled && !cfg.pull_enabled { + AuthType::Push + } else { + AuthType::Pull + } + } else { + match cfg.pull_enabled { + true => AuthType::Pull, + false => AuthType::None, + } + }; + Some(Auth::new( + authsecret.key.clone(), + authsecret.password.clone(), + cfg.algorithm.clone(), + auth_type, + )) + } else { + None + } + } + pub async fn run(&mut self) -> Result<()> { let notifier = if let Some(httpnotifier) = &self.cfg.httpnotify { if !httpnotifier.enabled { @@ -146,7 +179,8 @@ impl Service { let listen_port = rtmp_cfg_value.port; let address = format!("0.0.0.0:{listen_port}"); - let mut rtmp_server = RtmpServer::new(address, producer, gop_num); + let auth = Self::gen_auth(&rtmp_cfg_value.auth, &self.cfg.authsecret); + let mut rtmp_server = RtmpServer::new(address, producer, gop_num, auth); tokio::spawn(async move { if let Err(err) = rtmp_server.run().await { log::error!("rtmp server error: {}", err); @@ -213,7 +247,8 @@ impl Service { let listen_port = rtsp_cfg_value.port; let address = format!("0.0.0.0:{listen_port}"); - let mut rtsp_server = RtspServer::new(address, producer); + let auth = Self::gen_auth(&rtsp_cfg_value.auth, &self.cfg.authsecret); + let mut rtsp_server = RtspServer::new(address, producer, auth); tokio::spawn(async move { if let Err(err) = rtsp_server.run().await { log::error!("rtsp server error: {}", err); @@ -237,7 +272,8 @@ impl Service { let listen_port = webrtc_cfg_value.port; let address = format!("0.0.0.0:{listen_port}"); - let mut webrtc_server = WebRTCServer::new(address, producer); + let auth = Self::gen_auth(&webrtc_cfg_value.auth, &self.cfg.authsecret); + let mut webrtc_server = WebRTCServer::new(address, producer, auth); tokio::spawn(async move { if let Err(err) = webrtc_server.run().await { log::error!("webrtc server error: {}", err); @@ -258,8 +294,9 @@ impl Service { let port = httpflv_cfg_value.port; let event_producer = stream_hub.get_hub_event_sender(); + let auth = Self::gen_auth(&httpflv_cfg_value.auth, &self.cfg.authsecret); tokio::spawn(async move { - if let Err(err) = httpflv_server::run(event_producer, port).await { + if let Err(err) = httpflv_server::run(event_producer, port, auth).await { log::error!("httpflv server error: {}", err); } }); @@ -291,9 +328,9 @@ impl Service { }); let port = hls_cfg_value.port; - + let auth = Self::gen_auth(&hls_cfg_value.auth, &self.cfg.authsecret); tokio::spawn(async move { - if let Err(err) = hls_server::run(port).await { + if let Err(err) = hls_server::run(port, auth).await { log::error!("hls server error: {}", err); } }); diff --git a/confs/local/hls.Cargo.toml b/confs/local/hls.Cargo.toml index 0649d4b4..de28e0bc 100644 --- a/confs/local/hls.Cargo.toml +++ b/confs/local/hls.Cargo.toml @@ -22,6 +22,7 @@ streamhub = { path = "../../library/streamhub/" } xmpegts = { path = "../../library/container/mpegts/" } xflv = { path = "../../library/container/flv/" } rtmp = { path = "../rtmp/" } +commonlib = { path = "../../library/common/" } [dependencies.tokio] version = "1.4.0" diff --git a/confs/local/httpflv.Cargo.toml b/confs/local/httpflv.Cargo.toml index ca4e1792..c90c2dbb 100644 --- a/confs/local/httpflv.Cargo.toml +++ b/confs/local/httpflv.Cargo.toml @@ -21,7 +21,7 @@ futures = "0.3" streamhub = { path = "../../library/streamhub/" } xflv = { path = "../../library/container/flv/" } rtmp = { path = "../rtmp/" } #"0.0.4" - +commonlib = { path = "../../library/common/" } [dependencies.tokio] version = "1.4.0" default-features = false diff --git a/confs/local/rtmp.Cargo.toml b/confs/local/rtmp.Cargo.toml index 0fcbab86..8992a5f7 100644 --- a/confs/local/rtmp.Cargo.toml +++ b/confs/local/rtmp.Cargo.toml @@ -35,6 +35,7 @@ bytesio = { path = "../../library/bytesio/" } streamhub = { path = "../../library/streamhub/" } h264-decoder = { path = "../../library/codec/h264/" } xflv = { path = "../../library/container/flv/" } +commonlib = { path = "../../library/common/" } [dependencies.tokio] version = "1.4.0" diff --git a/confs/local/rtsp.Cargo.toml b/confs/local/rtsp.Cargo.toml index 33cc60df..f52a73d6 100644 --- a/confs/local/rtsp.Cargo.toml +++ b/confs/local/rtsp.Cargo.toml @@ -22,3 +22,4 @@ hex = "0.4.3" bytesio = { path = "../../library/bytesio/" } streamhub = { path = "../../library/streamhub/" } +commonlib = { path = "../../library/common/" } \ No newline at end of file diff --git a/confs/local/webrtc.Cargo.toml b/confs/local/webrtc.Cargo.toml index 1b9d6eeb..7718dec8 100644 --- a/confs/local/webrtc.Cargo.toml +++ b/confs/local/webrtc.Cargo.toml @@ -25,3 +25,4 @@ opus = "0.3.0" bytesio = { path = "../../library/bytesio/" } streamhub = { path = "../../library/streamhub/" } xflv = { path = "../../library/container/flv/" } +commonlib = { path = "../../library/common/" } diff --git a/confs/local/xiu.Cargo.toml b/confs/local/xiu.Cargo.toml index 52f4e828..a80459a0 100644 --- a/confs/local/xiu.Cargo.toml +++ b/confs/local/xiu.Cargo.toml @@ -33,6 +33,7 @@ xrtsp = { path = "../../protocol/rtsp/" } xwebrtc = { path = "../../protocol/webrtc/" } httpflv = { path = "../../protocol/httpflv/" } hls = { path = "../../protocol/hls/" } +commonlib = { path = "../../library/common/" } [features] default = ["std"] diff --git a/confs/online/common.Cargo.toml b/confs/online/common.Cargo.toml new file mode 100644 index 00000000..c940b211 --- /dev/null +++ b/confs/online/common.Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "commonlib" +version = "0.1.0" +authors = ["HarlanC "] +edition = "2018" +description = "a common library for xiu project." +license = "MIT" +repository = "https://github.com/harlanc/xiu" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "^1.0" +env_logger = "0.10.0" +job_scheduler_ng = "2.0.4" +chrono = "0.4" +failure = "0.1.1" +log = "0.4.0" +indexmap = "1.9.3" +md5 = "0.7.0" +serde = { version = "1.0", features = ["derive", "rc"] } +serde_derive = "1.0" diff --git a/confs/online/hls.Cargo.toml b/confs/online/hls.Cargo.toml index 97e210ae..fb057a53 100644 --- a/confs/online/hls.Cargo.toml +++ b/confs/online/hls.Cargo.toml @@ -1,7 +1,7 @@ [package] name = "hls" description = "hls library." -version = "0.4.3" +version = "0.5.1" authors = ["HarlanC "] @@ -26,3 +26,4 @@ hex = "0.4.3" bytesio = "0.3.1" streamhub = "0.2.0" +commonlib = "0.1.0" diff --git a/confs/online/webrtc.Cargo.toml b/confs/online/webrtc.Cargo.toml index 9a77723e..d7d8ebf4 100644 --- a/confs/online/webrtc.Cargo.toml +++ b/confs/online/webrtc.Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xwebrtc" -version = "0.2.0" +version = "0.3.0" description = "A whip/whep library." edition = "2021" authors = ["HarlanC "] @@ -25,3 +25,4 @@ opus = "0.3.0" bytesio = "0.3.1" streamhub = "0.2.0" xflv = "0.4.0" +commonlib = "0.1.0" diff --git a/confs/online/xiu.Cargo.toml b/confs/online/xiu.Cargo.toml index 322c0c90..fffc9655 100644 --- a/confs/online/xiu.Cargo.toml +++ b/confs/online/xiu.Cargo.toml @@ -1,7 +1,7 @@ [package] name = "xiu" description = "A powerful live server by Rust ." -version = "0.10.0" +version = "0.11.0" authors = ["HarlanC "] +edition = "2018" +description = "a common library for xiu project." +license = "MIT" +repository = "https://github.com/harlanc/xiu" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "^1.0" +env_logger = "0.10.0" +job_scheduler_ng = "2.0.4" +chrono = "0.4" +failure = "0.1.1" +log = "0.4.0" +indexmap = "1.9.3" +md5 = "0.7.0" +serde = { version = "1.0", features = ["derive", "rc"] } +serde_derive = "1.0" diff --git a/library/common/README.md b/library/common/README.md new file mode 100644 index 00000000..bfd5bc80 --- /dev/null +++ b/library/common/README.md @@ -0,0 +1,4 @@ +A common library. + +## v0.1.0 +- The first version including HTTP/auth logics. \ No newline at end of file diff --git a/library/common/src/auth.rs b/library/common/src/auth.rs new file mode 100644 index 00000000..eb0cb738 --- /dev/null +++ b/library/common/src/auth.rs @@ -0,0 +1,105 @@ +use indexmap::IndexMap; +use md5; +use serde_derive::Deserialize; + +use crate::errors::{AuthError, AuthErrorValue}; +use crate::scanf; + +#[derive(Debug, Deserialize, Clone, Default)] +pub enum AuthAlgorithm { + #[default] + #[serde(rename = "simple")] + Simple, + #[serde(rename = "md5")] + Md5, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum AuthType { + Pull, + Push, + Both, + None, +} +#[derive(Debug, Clone)] +pub struct Auth { + algorithm: AuthAlgorithm, + key: String, + password: String, + pub auth_type: AuthType, +} + +impl Auth { + pub fn new( + key: String, + password: String, + algorithm: AuthAlgorithm, + auth_type: AuthType, + ) -> Self { + Self { + algorithm, + key, + password, + auth_type, + } + } + + pub fn authenticate( + &self, + stream_name: &String, + query: &Option, + is_pull: bool, + ) -> Result<(), AuthError> { + if self.auth_type == AuthType::Both + || is_pull && (self.auth_type == AuthType::Pull) + || !is_pull && (self.auth_type == AuthType::Push) + { + let mut auth_err_reason: String = String::from("there is no token str found."); + let mut err: AuthErrorValue = AuthErrorValue::NoTokenFound; + + /*Here we should do auth and it must be successful. */ + if let Some(query_val) = query { + let mut query_pairs = IndexMap::new(); + let pars_array: Vec<&str> = query_val.split('&').collect(); + for ele in pars_array { + let (k, v) = scanf!(ele, '=', String, String); + if k.is_none() || v.is_none() { + continue; + } + query_pairs.insert(k.unwrap(), v.unwrap()); + } + + if let Some(token) = query_pairs.get("token") { + if self.check(stream_name, token) { + return Ok(()); + } + auth_err_reason = format!("token is not correct: {}", token); + err = AuthErrorValue::TokenIsNotCorrect; + } + } + + log::error!( + "Auth error stream_name: {} auth type: {:?} pull: {} reason: {}", + stream_name, + self.auth_type, + is_pull, + auth_err_reason, + ); + return Err(AuthError { value: err }); + } + Ok(()) + } + + fn check(&self, stream_name: &String, auth_str: &str) -> bool { + match self.algorithm { + AuthAlgorithm::Simple => { + self.password == auth_str + } + AuthAlgorithm::Md5 => { + let raw_data = format!("{}{}", self.key, stream_name); + let digest_str = format!("{:x}", md5::compute(raw_data)); + auth_str == digest_str + } + } + } +} diff --git a/protocol/webrtc/src/http/define.rs b/library/common/src/define.rs similarity index 100% rename from protocol/webrtc/src/http/define.rs rename to library/common/src/define.rs diff --git a/library/common/src/errors.rs b/library/common/src/errors.rs new file mode 100644 index 00000000..56b21a37 --- /dev/null +++ b/library/common/src/errors.rs @@ -0,0 +1,31 @@ +use failure::{Backtrace, Fail}; +use std::fmt; + +#[derive(Debug)] +pub struct AuthError { + pub value: AuthErrorValue, +} + +#[derive(Debug, Fail)] +pub enum AuthErrorValue { + #[fail(display = "token is not correct.")] + TokenIsNotCorrect, + #[fail(display = "no token found.")] + NoTokenFound, +} + +impl fmt::Display for AuthError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.value, f) + } +} + +impl Fail for AuthError { + fn cause(&self) -> Option<&dyn Fail> { + self.value.cause() + } + + fn backtrace(&self) -> Option<&Backtrace> { + self.value.backtrace() + } +} diff --git a/protocol/webrtc/src/http/mod.rs b/library/common/src/http.rs similarity index 63% rename from protocol/webrtc/src/http/mod.rs rename to library/common/src/http.rs index 6c3521e8..33ac07f9 100644 --- a/protocol/webrtc/src/http/mod.rs +++ b/library/common/src/http.rs @@ -1,12 +1,6 @@ -pub mod define; +use crate::scanf; use indexmap::IndexMap; - -macro_rules! scanf { - ( $string:expr, $sep:expr, $( $x:ty ),+ ) => {{ - let mut iter = $string.split($sep); - ($(iter.next().and_then(|word| word.parse::<$x>().ok()),)*) - }} -} +use std::fmt; pub trait Unmarshal { fn unmarshal(request_data: &str) -> Option @@ -18,15 +12,142 @@ pub trait Marshal { fn marshal(&self) -> String; } +#[derive(Debug, Clone, Default)] +pub enum Schema { + //used for webrtc(WHIP/WHEP) + WEBRTC, + RTSP, + #[default] + UNKNOWN, +} + +impl fmt::Display for Schema { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Schema::RTSP => { + write!(f, "rtsp") + } + //Because webrtc request uri does not contain the schema name, so here write empty string. + Schema::WEBRTC => { + write!(f, "") + } + Schema::UNKNOWN => { + write!(f, "unknown") + } + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct Uri { + pub schema: Schema, + pub host: String, + pub port: Option, + pub path: String, + pub query: Option, +} + +impl Unmarshal for Uri { + /* + RTSP HTTP header : "ANNOUNCE rtsp://127.0.0.1:5544/stream RTSP/1.0\r\n\ + the uri rtsp://127.0.0.1:5544/stream is standard with schema://host:port/path?query. + + WEBRTC is special: + "POST /whep?app=live&stream=test HTTP/1.1\r\n\ + Host: localhost:3000\r\n\ + Accept: \r\n\" + It only contains path?query after HTTP method, host:port is saved in the Host parameter. + In this function, for Webrtc we only parse path?query, host:port will be parsed in the HTTPRequest + unmarshal method. + */ + fn unmarshal(url: &str) -> Option { + let mut uri = Uri::default(); + + /*first judge the correct schema */ + if url.starts_with("rtsp://") { + uri.schema = Schema::RTSP; + } else if url.starts_with("/whip") || url.starts_with("/whep") { + uri.schema = Schema::WEBRTC; + } else { + log::warn!("cannot judge the schema: {}", url); + uri.schema = Schema::UNKNOWN; + } + + let path_with_query = match uri.schema { + Schema::RTSP => { + let rtsp_path_with_query = if let Some(rtsp_url_without_prefix) = + url.strip_prefix("rtsp://") + { + /*split host:port and path?query*/ + + if let Some(index) = rtsp_url_without_prefix.find('/') { + let path_with_query = &rtsp_url_without_prefix[index + 1..]; + /*parse host and port*/ + let host_with_port = &rtsp_url_without_prefix[..index]; + let (host_val, port_val) = scanf!(host_with_port, ':', String, u16); + if let Some(host) = host_val { + uri.host = host; + } + if let Some(port) = port_val { + uri.port = Some(port); + } + + path_with_query + } else { + log::error!("cannot find split '/' for host:port and path?query."); + return None; + } + } else { + log::error!("cannot find RTSP prefix."); + return None; + }; + rtsp_path_with_query + } + Schema::WEBRTC => url, + Schema::UNKNOWN => url, + }; + + let path_data: Vec<&str> = path_with_query.splitn(2, '?').collect(); + uri.path = path_data[0].to_string(); + + if path_data.len() > 1 { + uri.query = Some(path_data[1].to_string()); + } + + Some(uri) + } +} + +impl Marshal for Uri { + fn marshal(&self) -> String { + /*first pice path and query together*/ + let path_with_query = if let Some(query) = &self.query { + format!("{}?{}", self.path, query) + } else { + self.path.clone() + }; + + match self.schema { + Schema::RTSP => { + let host_with_port = if let Some(port) = &self.port { + format!("{}:{}", self.host, port) + } else { + self.host.clone() + }; + format!("{}://{}/{}", self.schema, host_with_port, path_with_query) + } + Schema::WEBRTC => path_with_query, + Schema::UNKNOWN => path_with_query, + } + } +} + #[derive(Debug, Clone, Default)] pub struct HttpRequest { pub method: String, - pub address: String, - pub port: u16, - pub path: String, - pub path_parameters: Option, - //parse path_parameters and save the results - pub path_parameters_map: IndexMap, + pub uri: Uri, + /*parse the query and save the results*/ + pub query_pairs: IndexMap, pub version: String, pub headers: IndexMap, pub body: Option, @@ -51,55 +172,57 @@ pub fn parse_content_length(request_data: &str) -> Option { impl Unmarshal for HttpRequest { fn unmarshal(request_data: &str) -> Option { - log::trace!("len: {} content: {}", request_data.len(), request_data); - let mut http_request = HttpRequest::default(); let header_end_idx = if let Some(idx) = request_data.find("\r\n\r\n") { let data_except_body = &request_data[..idx]; let mut lines = data_except_body.lines(); - //parse the first line - //POST /whip?app=live&stream=test HTTP/1.1 + /*parse the first line + POST /whip?app=live&stream=test HTTP/1.1*/ if let Some(request_first_line) = lines.next() { let mut fields = request_first_line.split_ascii_whitespace(); + /* method */ if let Some(method) = fields.next() { http_request.method = method.to_string(); } - if let Some(path) = fields.next() { - let path_data: Vec<&str> = path.splitn(2, '?').collect(); - http_request.path = path_data[0].to_string(); - - if path_data.len() > 1 { - let pars = path_data[1].to_string(); - let pars_array: Vec<&str> = pars.split('&').collect(); - - for ele in pars_array { - let (k, v) = scanf!(ele, '=', String, String); - if k.is_none() || v.is_none() { - continue; + /* url */ + if let Some(url) = fields.next() { + if let Some(uri) = Uri::unmarshal(url) { + http_request.uri = uri; + + if let Some(query) = &http_request.uri.query { + let pars_array: Vec<&str> = query.split('&').collect(); + + for ele in pars_array { + let (k, v) = scanf!(ele, '=', String, String); + if k.is_none() || v.is_none() { + continue; + } + http_request.query_pairs.insert(k.unwrap(), v.unwrap()); } - http_request - .path_parameters_map - .insert(k.unwrap(), v.unwrap()); } - http_request.path_parameters = Some(pars); + } else { + log::error!("cannot get a Uri."); + return None; } } + /* version */ if let Some(version) = fields.next() { http_request.version = version.to_string(); } } - //parse headers + /*parse headers*/ for line in lines { if let Some(index) = line.find(": ") { let name = line[..index].to_string(); let value = line[index + 2..].to_string(); + /*for schema: webrtc*/ if name == "Host" { let (address_val, port_val) = scanf!(value, ':', String, u16); if let Some(address) = address_val { - http_request.address = address; + http_request.uri.host = address; } if let Some(port) = port_val { - http_request.port = port; + http_request.uri.port = Some(port); } } http_request.headers.insert(name, value); @@ -116,7 +239,7 @@ impl Unmarshal for HttpRequest { ); if request_data.len() > header_end_idx { - //parse body + /*parse body*/ http_request.body = Some(request_data[header_end_idx..].to_string()); } @@ -126,12 +249,12 @@ impl Unmarshal for HttpRequest { impl Marshal for HttpRequest { fn marshal(&self) -> String { - let full_path = if let Some(parameters) = &self.path_parameters { - format!("{}?{}", self.path, parameters) - } else { - self.path.clone() - }; - let mut request_str = format!("{} {} {}\r\n", self.method, full_path, self.version); + let mut request_str = format!( + "{} {} {}\r\n", + self.method, + self.uri.marshal(), + self.version + ); for (header_name, header_value) in &self.headers { if header_name == &"Content-Length".to_string() { if let Some(body) = &self.body { @@ -532,4 +655,99 @@ mod tests { assert_eq!(response, marshal_result); } } + + // #[test] + // fn test_parse_rtsp_request_chatgpt() { + // let data1 = "ANNOUNCE rtsp://127.0.0.1:5544/stream RTSP/1.0\r\n\ + // Content-Type: application/sdp\r\n\ + // CSeq: 2\r\n\ + // User-Agent: Lavf58.76.100\r\n\ + // Content-Length: 500\r\n\ + // \r\n\ + // v=0\r\n\ + // o=- 0 0 IN IP4 127.0.0.1\r\n\ + // s=No Name\r\n\ + // c=IN IP4 127.0.0.1\r\n\ + // t=0 0\r\n\ + // a=tool:libavformat 58.76.100\r\n\ + // m=video 0 RTP/AVP 96\r\n\ + // b=AS:284\r\n\ + // a=rtpmap:96 H264/90000 + // a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z2QAHqzZQKAv+XARAAADAAEAAAMAMg8WLZY=,aOvjyyLA; profile-level-id=64001E\r\n\ + // a=control:streamid=0\r\n\ + // m=audio 0 RTP/AVP 97\r\n\ + // b=AS:128\r\n\ + // a=rtpmap:97 MPEG4-GENERIC/48000/2\r\n\ + // a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=119056E500\r\n\ + // a=control:streamid=1\r\n"; + + // // v=0:SDP版本号,通常为0。 + // // o=- 0 0 IN IP4 127.0.0.1:会话的所有者和会话ID,以及会话开始时间和会话结束时间的信息。 + // // s=No Name:会话名称或标题。 + // // c=IN IP4 127.0.0.1:表示会话数据传输的地址类型(IPv4)和地址(127.0.0.1)。 + // // t=0 0:会话时间,包括会话开始时间和结束时间,这里的值都是0,表示会话没有预定义的结束时间。 + // // a=tool:libavformat 58.76.100:会话所使用的工具或软件名称和版本号。 + + // // m=video 0 RTP/AVP 96:媒体类型(video或audio)、媒体格式(RTP/AVP)、媒体格式编号(96)和媒体流的传输地址。 + // // b=AS:284:视频流所使用的带宽大小。 + // // a=rtpmap:96 H264/90000:视频流所使用的编码方式(H.264)和时钟频率(90000)。 + // // a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z2QAHqzZQKAv+XARAAADAAEAAAMAMg8WLZY=,aOvjyyLA; profile-level-id=64001E:视频流的格式参数,如分片方式、SPS和PPS等。 + // // a=control:streamid=0:指定视频流的流ID。 + + // // m=audio 0 RTP/AVP 97:媒体类型(audio)、媒体格式(RTP/AVP)、媒体格式编号(97)和媒体流的传输地址。 + // // b=AS:128:音频流所使用的带宽大小。 + // // a=rtpmap:97 MPEG4-GENERIC/48000/2:音频流所使用的编码方式(MPEG4-GENERIC)、采样率(48000Hz)、和通道数(2)。 + // // a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=119056E500:音频流的格式参数,如编码方式、采样长度、索引长度等。 + // // a=control:streamid=1:指定音频流的流ID。 + + // if let Some(request) = parse_request_bytes(&data1.as_bytes()) { + // println!(" parser: {:?}", request); + // } + // } + + #[test] + fn test_parse_rtsp_request() { + let data1 = "SETUP rtsp://127.0.0.1/stream/streamid=0 RTSP/1.0\r\n\ + Transport: RTP/AVP/TCP;unicast;interleaved=0-1;mode=record\r\n\ + CSeq: 3\r\n\ + User-Agent: Lavf58.76.100\r\n\ + \r\n"; + + if let Some(parser) = HttpRequest::unmarshal(data1) { + println!(" parser: {parser:?}"); + let marshal_result = parser.marshal(); + print!("marshal result: =={marshal_result}=="); + assert_eq!(data1, marshal_result); + } + + let data2 = "ANNOUNCE rtsp://127.0.0.1/stream RTSP/1.0\r\n\ + Content-Type: application/sdp\r\n\ + CSeq: 2\r\n\ + User-Agent: Lavf58.76.100\r\n\ + Content-Length: 500\r\n\ + \r\n\ + v=0\r\n\ + o=- 0 0 IN IP4 127.0.0.1\r\n\ + s=No Name\r\n\ + c=IN IP4 127.0.0.1\r\n\ + t=0 0\r\n\ + a=tool:libavformat 58.76.100\r\n\ + m=video 0 RTP/AVP 96\r\n\ + b=AS:284\r\n\ + a=rtpmap:96 H264/90000\r\n\ + a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z2QAHqzZQKAv+XARAAADAAEAAAMAMg8WLZY=,aOvjyyLA; profile-level-id=64001E\r\n\ + a=control:streamid=0\r\n\ + m=audio 0 RTP/AVP 97\r\n\ + b=AS:128\r\n\ + a=rtpmap:97 MPEG4-GENERIC/48000/2\r\n\ + a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=119056E500\r\n\ + a=control:streamid=1\r\n"; + + if let Some(parser) = HttpRequest::unmarshal(data2) { + println!(" parser: {parser:?}"); + let marshal_result = parser.marshal(); + print!("marshal result: =={marshal_result}=="); + assert_eq!(data2, marshal_result); + } + } } diff --git a/library/common/src/lib.rs b/library/common/src/lib.rs new file mode 100644 index 00000000..fa5d561b --- /dev/null +++ b/library/common/src/lib.rs @@ -0,0 +1,5 @@ +pub mod auth; +pub mod define; +pub mod errors; +pub mod http; +pub mod utils; diff --git a/library/common/src/utils.rs b/library/common/src/utils.rs new file mode 100644 index 00000000..2057a4ca --- /dev/null +++ b/library/common/src/utils.rs @@ -0,0 +1,7 @@ +#[macro_export] +macro_rules! scanf { + ( $string:expr, $sep:expr, $( $x:ty ),+ ) => {{ + let mut iter = $string.split($sep); + ($(iter.next().and_then(|word| word.parse::<$x>().ok()),)*) + }} +} diff --git a/protocol/hls/Cargo.toml b/protocol/hls/Cargo.toml index 0649d4b4..01449bbe 100644 --- a/protocol/hls/Cargo.toml +++ b/protocol/hls/Cargo.toml @@ -19,6 +19,7 @@ axum = { version = "0.7.4" } tokio-util = { version = "0.6.5", features = ["codec"] } streamhub = { path = "../../library/streamhub/" } +commonlib = { path = "../../library/common/" } xmpegts = { path = "../../library/container/mpegts/" } xflv = { path = "../../library/container/flv/" } rtmp = { path = "../rtmp/" } diff --git a/protocol/hls/README.md b/protocol/hls/README.md index 9aa4c949..3a1ee1a0 100644 --- a/protocol/hls/README.md +++ b/protocol/hls/README.md @@ -33,6 +33,10 @@ Change the listening port type. - Remove no used "\n" for error message. - Replace hyper HTTP library with axum. - Receive and process sub event result. +## 0.5.0 +- Support auth. +## 0.5.1 +- Fix RTMP build error. diff --git a/protocol/hls/src/server.rs b/protocol/hls/src/server.rs index 1aed35c2..41eaa20e 100644 --- a/protocol/hls/src/server.rs +++ b/protocol/hls/src/server.rs @@ -1,8 +1,12 @@ use { axum::{ - body::Body, extract::Request, handler::HandlerWithoutStateExt, http::StatusCode, + body::Body, + extract::{Request, State}, + handler::Handler, + http::StatusCode, response::Response, }, + commonlib::auth::Auth, std::net::SocketAddr, tokio::{fs::File, net::TcpListener}, tokio_util::codec::{BytesCodec, FramedRead}, @@ -11,10 +15,12 @@ use { type GenericError = Box; type Result = std::result::Result; static NOTFOUND: &[u8] = b"Not Found"; +static UNAUTHORIZED: &[u8] = b"Unauthorized"; -async fn handle_connection(req: Request) -> Response { +async fn handle_connection(State(auth): State>, req: Request) -> Response { let path = req.uri().path(); + let query_string: Option = req.uri().query().map(|s| s.to_string()); let mut file_path: String = String::from(""); if path.ends_with(".m3u8") { @@ -28,6 +34,18 @@ async fn handle_connection(req: Request) -> Response { let app_name = String::from(rv[1]); let stream_name = String::from(rv[2]); + if let Some(auth_val) = auth { + if auth_val + .authenticate(&stream_name, &query_string, true) + .is_err() + { + return Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(UNAUTHORIZED.into()) + .unwrap(); + } + } + file_path = format!("./{app_name}/{stream_name}/{stream_name}.m3u8"); } } else if path.ends_with(".ts") { @@ -69,7 +87,7 @@ async fn simple_file_send(filename: &str) -> Response { not_found() } -pub async fn run(port: usize) -> Result<()> { +pub async fn run(port: usize, auth: Option) -> Result<()> { let listen_address = format!("0.0.0.0:{port}"); let sock_addr: SocketAddr = listen_address.parse().unwrap(); @@ -77,7 +95,9 @@ pub async fn run(port: usize) -> Result<()> { log::info!("Hls server listening on http://{}", sock_addr); - axum::serve(listener, handle_connection.into_service()).await?; + let handle_connection = handle_connection.with_state(auth); + + axum::serve(listener, handle_connection.into_make_service()).await?; Ok(()) } diff --git a/protocol/httpflv/Cargo.toml b/protocol/httpflv/Cargo.toml index ca4e1792..1920e12f 100644 --- a/protocol/httpflv/Cargo.toml +++ b/protocol/httpflv/Cargo.toml @@ -19,6 +19,7 @@ axum = { version = "0.7.4" } futures = "0.3" streamhub = { path = "../../library/streamhub/" } +commonlib = { path = "../../library/common/" } xflv = { path = "../../library/container/flv/" } rtmp = { path = "../rtmp/" } #"0.0.4" diff --git a/protocol/httpflv/README.md b/protocol/httpflv/README.md index ced7cfd5..b88875e7 100644 --- a/protocol/httpflv/README.md +++ b/protocol/httpflv/README.md @@ -23,5 +23,8 @@ A httpflv library. - Replace hyper HTTP library with axum. - Fix bug: when http connection is disconnected , don't need to retry. - Receive and process sub event result. - +## 0.4.0 +- Support auth. +## 0.4.1 +- Fix RTMP build error. diff --git a/protocol/httpflv/src/server.rs b/protocol/httpflv/src/server.rs index 58a4aa95..943fec75 100644 --- a/protocol/httpflv/src/server.rs +++ b/protocol/httpflv/src/server.rs @@ -7,6 +7,7 @@ use { http::StatusCode, response::Response, }, + commonlib::auth::Auth, futures::channel::mpsc::unbounded, std::net::SocketAddr, streamhub::define::StreamHubEventSender, @@ -16,13 +17,15 @@ use { type GenericError = Box; type Result = std::result::Result; static NOTFOUND: &[u8] = b"Not Found"; +static UNAUTHORIZED: &[u8] = b"Unauthorized"; async fn handle_connection( - State(event_producer): State, // event_producer: ChannelEventProducer + State((event_producer, auth)): State<(StreamHubEventSender, Option)>, // event_producer: ChannelEventProducer ConnectInfo(remote_addr): ConnectInfo, req: Request, ) -> Response { let path = req.uri().path(); + let query_string: Option = req.uri().query().map(|s| s.to_string()); match path.find(".flv") { Some(index) if index > 0 => { @@ -32,6 +35,18 @@ async fn handle_connection( let app_name = String::from(rv[1]); let stream_name = String::from(rv[2]); + if let Some(auth_val) = auth { + if auth_val + .authenticate(&stream_name, &query_string, true) + .is_err() + { + return Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(UNAUTHORIZED.into()) + .unwrap(); + } + } + let (http_response_data_producer, http_response_data_consumer) = unbounded(); let mut flv_hanlder = HttpFlv::new( @@ -63,7 +78,11 @@ async fn handle_connection( } } -pub async fn run(event_producer: StreamHubEventSender, port: usize) -> Result<()> { +pub async fn run( + event_producer: StreamHubEventSender, + port: usize, + auth: Option, +) -> Result<()> { let listen_address = format!("0.0.0.0:{port}"); let sock_addr: SocketAddr = listen_address.parse().unwrap(); @@ -71,7 +90,7 @@ pub async fn run(event_producer: StreamHubEventSender, port: usize) -> Result<() log::info!("Httpflv server listening on http://{}", sock_addr); - let handle_connection = handle_connection.with_state(event_producer.clone()); + let handle_connection = handle_connection.with_state((event_producer.clone(), auth)); axum::serve( listener, diff --git a/protocol/rtmp/Cargo.toml b/protocol/rtmp/Cargo.toml index 0fcbab86..31a4428c 100644 --- a/protocol/rtmp/Cargo.toml +++ b/protocol/rtmp/Cargo.toml @@ -33,6 +33,7 @@ serde = { version = "1.0", features = ["derive", "rc"] } bytesio = { path = "../../library/bytesio/" } streamhub = { path = "../../library/streamhub/" } +commonlib = { path = "../../library/common/" } h264-decoder = { path = "../../library/codec/h264/" } xflv = { path = "../../library/container/flv/" } diff --git a/protocol/rtmp/README.md b/protocol/rtmp/README.md index 91314e24..80b03627 100644 --- a/protocol/rtmp/README.md +++ b/protocol/rtmp/README.md @@ -171,6 +171,10 @@ async fn main() -> Result<()> { - Fix RTMP chunks are uncompressed in packetizer mod. - Fix err: when encountering an unknown RTMP message type, it should be skipped rather than returning an error. - Support remuxing from WHIP to rtmp. +## 0.6.0 +- Support auth. +## 0.6.1 +- Fix RTMP build error. diff --git a/protocol/rtmp/src/chunk/unpacketizer.rs b/protocol/rtmp/src/chunk/unpacketizer.rs index 76eac89a..65695fd5 100644 --- a/protocol/rtmp/src/chunk/unpacketizer.rs +++ b/protocol/rtmp/src/chunk/unpacketizer.rs @@ -8,7 +8,6 @@ use { byteorder::{BigEndian, LittleEndian}, bytes::{BufMut, BytesMut}, bytesio::bytes_reader::BytesReader, - chrono::prelude::*, std::{cmp::min, collections::HashMap, fmt, vec::Vec}, }; @@ -127,12 +126,6 @@ impl ChunkUnpacketizer { } pub fn read_chunks(&mut self) -> Result { - log::trace!( - "read chunks begin, current time: {}, and read state: {}", - Local::now().timestamp_nanos(), - self.chunk_read_state - ); - // log::trace!( // "read chunks, reader remaining data: {}", // self.reader.get_remaining_bytes() @@ -165,12 +158,6 @@ impl ChunkUnpacketizer { } } - log::trace!( - "read chunks end, current time: {}, and read state: {}", - Local::now().timestamp_nanos(), - self.chunk_read_state - ); - if !chunks.is_empty() { Ok(UnpackResult::Chunks(chunks)) } else { @@ -191,13 +178,6 @@ impl ChunkUnpacketizer { pub fn read_chunk(&mut self) -> Result { let mut result: UnpackResult = UnpackResult::Empty; - log::trace!( - "read chunk begin, current time: {}, and read state: {}, and chunk index: {}", - Local::now().timestamp_nanos(), - self.chunk_read_state, - self.chunk_index, - ); - self.chunk_index += 1; loop { @@ -213,12 +193,6 @@ impl ChunkUnpacketizer { }; } - log::trace!( - "read chunk end, current time: {}, and read state: {}, and chunk index: {}", - Local::now().timestamp_nanos(), - self.chunk_read_state, - self.chunk_index, - ); Ok(result) // Ok(UnpackResult::Success) diff --git a/protocol/rtmp/src/rtmp.rs b/protocol/rtmp/src/rtmp.rs index 3246bb03..487b8e78 100644 --- a/protocol/rtmp/src/rtmp.rs +++ b/protocol/rtmp/src/rtmp.rs @@ -1,6 +1,7 @@ use streamhub::define::StreamHubEventSender; use super::session::server_session; +use commonlib::auth::Auth; use std::net::SocketAddr; use tokio::io::Error; use tokio::net::TcpListener; @@ -9,14 +10,21 @@ pub struct RtmpServer { address: String, event_producer: StreamHubEventSender, gop_num: usize, + auth: Option, } impl RtmpServer { - pub fn new(address: String, event_producer: StreamHubEventSender, gop_num: usize) -> Self { + pub fn new( + address: String, + event_producer: StreamHubEventSender, + gop_num: usize, + auth: Option, + ) -> Self { Self { address, event_producer, gop_num, + auth, } } @@ -33,6 +41,7 @@ impl RtmpServer { tcp_stream, self.event_producer.clone(), self.gop_num, + self.auth.clone(), ); tokio::spawn(async move { if let Err(err) = session.run().await { diff --git a/protocol/rtmp/src/session/client_session.rs b/protocol/rtmp/src/session/client_session.rs index 3ef1bcca..958ff678 100644 --- a/protocol/rtmp/src/session/client_session.rs +++ b/protocol/rtmp/src/session/client_session.rs @@ -119,10 +119,7 @@ impl ClientSession { }; let common = Common::new(packetizer, event_producer, SessionType::Client, remote_addr); - - let (stream_name, _) = RtmpUrlParser::default() - .set_raw_stream_name(raw_stream_name.clone()) - .parse_raw_stream_name(); + let (stream_name, _) = RtmpUrlParser::parse_stream_name_with_query(&raw_stream_name); Self { io: Arc::clone(&net_io), diff --git a/protocol/rtmp/src/session/errors.rs b/protocol/rtmp/src/session/errors.rs index c8129fa1..d4f2bfec 100644 --- a/protocol/rtmp/src/session/errors.rs +++ b/protocol/rtmp/src/session/errors.rs @@ -11,6 +11,7 @@ use { user_control_messages::errors::EventMessagesError, }, bytesio::{bytes_errors::BytesWriteError, bytesio_errors::BytesIOError}, + commonlib::errors::AuthError, failure::{Backtrace, Fail}, std::fmt, streamhub::errors::ChannelError, @@ -79,6 +80,8 @@ pub enum SessionErrorValue { #[fail(display = "session is finished.")] Finish, + #[fail(display = "Auth err: {}", _0)] + AuthError(#[cause] AuthError), } impl From for SessionError { @@ -200,6 +203,13 @@ impl From for SessionError { } } } +impl From for SessionError { + fn from(error: AuthError) -> Self { + SessionError { + value: SessionErrorValue::AuthError(error), + } + } +} impl fmt::Display for SessionError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/protocol/rtmp/src/session/server_session.rs b/protocol/rtmp/src/session/server_session.rs index ce3fcfbd..2ddb6ff4 100644 --- a/protocol/rtmp/src/session/server_session.rs +++ b/protocol/rtmp/src/session/server_session.rs @@ -27,6 +27,7 @@ use { bytes_writer::AsyncBytesWriter, bytesio::{TNetIO, TcpIO}, }, + commonlib::auth::Auth, indexmap::IndexMap, std::{sync::Arc, time::Duration}, streamhub::{ @@ -49,7 +50,7 @@ enum ServerSessionState { pub struct ServerSession { pub app_name: String, pub stream_name: String, - pub url_parameters: String, + pub query: Option, io: Arc>>, handshaker: HandshakeServer, unpacketizer: ChunkUnpacketizer, @@ -64,10 +65,16 @@ pub struct ServerSession { pub common: Common, /*configure how many gops will be cached.*/ gop_num: usize, + auth: Option, } impl ServerSession { - pub fn new(stream: TcpStream, event_producer: StreamHubEventSender, gop_num: usize) -> Self { + pub fn new( + stream: TcpStream, + event_producer: StreamHubEventSender, + gop_num: usize, + auth: Option, + ) -> Self { let remote_addr = if let Ok(addr) = stream.peer_addr() { log::info!("server session: {}", addr.to_string()); Some(addr) @@ -81,7 +88,7 @@ impl ServerSession { Self { app_name: String::from(""), stream_name: String::from(""), - url_parameters: String::from(""), + query: None, io: Arc::clone(&net_io), handshaker: HandshakeServer::new(Arc::clone(&net_io)), unpacketizer: ChunkUnpacketizer::new(), @@ -97,6 +104,7 @@ impl ServerSession { has_remaing_data: false, connect_properties: ConnectProperties::default(), gop_num, + auth, } } @@ -634,15 +642,23 @@ impl ServerSession { let raw_stream_name = stream_name.unwrap(); - (self.stream_name, self.url_parameters) = RtmpUrlParser::default() - .set_raw_stream_name(raw_stream_name.clone()) - .parse_raw_stream_name(); + (self.stream_name, self.query) = + RtmpUrlParser::parse_stream_name_with_query(&raw_stream_name); + if let Some(auth) = &self.auth { + auth.authenticate(&self.stream_name, &self.query, true)? + } + + let query = if let Some(query_val) = &self.query { + query_val.clone() + } else { + String::from("none") + }; log::info!( - "[ S->C ] [stream is record] app_name: {}, stream_name: {}, url parameters: {}", + "[ S->C ] [stream is record] app_name: {}, stream_name: {}, query: {}", self.app_name, self.stream_name, - self.url_parameters + query ); /*Now it can update the request url*/ @@ -674,7 +690,7 @@ impl ServerSession { }); } - let raw_stream_name = match other_values.remove(0) { + let stream_name_with_query = match other_values.remove(0) { Amf0ValueType::UTF8String(val) => val, _ => { return Err(SessionError { @@ -683,12 +699,15 @@ impl ServerSession { } }; - (self.stream_name, self.url_parameters) = RtmpUrlParser::default() - .set_raw_stream_name(raw_stream_name.clone()) - .parse_raw_stream_name(); + (self.stream_name, self.query) = + RtmpUrlParser::parse_stream_name_with_query(&stream_name_with_query); + + if let Some(auth) = &self.auth { + auth.authenticate(&self.stream_name, &self.query, false)? + } /*Now it can update the request url*/ - self.common.request_url = self.get_request_url(raw_stream_name); + self.common.request_url = self.get_request_url(stream_name_with_query); let _ = match other_values.remove(0) { Amf0ValueType::UTF8String(val) => val, @@ -699,18 +718,24 @@ impl ServerSession { } }; + let query = if let Some(query_val) = &self.query { + query_val.clone() + } else { + String::from("none") + }; + log::info!( - "[ S<-C ] [publish] app_name: {}, stream_name: {}, url parameters: {}", + "[ S<-C ] [publish] app_name: {}, stream_name: {}, query: {}", self.app_name, self.stream_name, - self.url_parameters + query ); log::info!( - "[ S->C ] [stream begin] app_name: {}, stream_name: {}, url parameters: {}", + "[ S->C ] [stream begin] app_name: {}, stream_name: {}, query: {}", self.app_name, self.stream_name, - self.url_parameters + query ); let mut event_messages = EventMessagesWriter::new(AsyncBytesWriter::new(self.io.clone())); diff --git a/protocol/rtmp/src/utils/mod.rs b/protocol/rtmp/src/utils/mod.rs index fc160074..006f9a40 100644 --- a/protocol/rtmp/src/utils/mod.rs +++ b/protocol/rtmp/src/utils/mod.rs @@ -1,48 +1,46 @@ pub mod errors; pub mod print; +use commonlib::scanf; use errors::RtmpUrlParseError; use errors::RtmpUrlParseErrorValue; +use indexmap::IndexMap; #[derive(Debug, Clone, Default)] pub struct RtmpUrlParser { - pub raw_url: String, - // raw_domain_name = format!("{}:{}",domain_name,port) - pub raw_domain_name: String, - pub domain_name: String, - pub port: String, + pub url: String, + // host_with_port = format!("{}:{}",host,port) + pub host_with_port: String, + pub host: String, + pub port: Option, pub app_name: String, - // raw_stream_name = format!("{}?{}",stream_name,url_parameters) - pub raw_stream_name: String, + // / = format!("{}?{}",stream_name,query) + pub stream_name_with_query: String, pub stream_name: String, - pub parameters: String, + pub query: Option, } impl RtmpUrlParser { pub fn new(url: String) -> Self { Self { - raw_url: url, + url, ..Default::default() } } - pub fn set_raw_stream_name(&mut self, raw_stream_name: String) -> &mut Self { - self.raw_stream_name = raw_stream_name; - self - } - /* example;rtmp://domain.name.cn:1935/app_name/stream_name?auth_key=test_Key - raw_domain_name: domain.name.cn:1935 - domain_name: domain.name.cn + host_with_port: domain.name.cn:1935 + host: domain.name.cn + port: 1935 app_name: app_name - raw_stream_name: stream_name?auth_key=test_Key + stream_name_with_query: stream_name?auth_key=test_Key stream_name: stream_name - url_parameters: auth_key=test_Key + query: auth_key=test_Key */ pub fn parse_url(&mut self) -> Result<(), RtmpUrlParseError> { - if let Some(idx) = self.raw_url.find("rtmp://") { - let remove_header_left = &self.raw_url[idx + 7..]; + if let Some(idx) = self.url.find("rtmp://") { + let remove_header_left = &self.url[idx + 7..]; let url_parts: Vec<&str> = remove_header_left.split('/').collect(); if url_parts.len() != 3 { return Err(RtmpUrlParseError { @@ -50,12 +48,13 @@ impl RtmpUrlParser { }); } - self.raw_domain_name = url_parts[0].to_string(); + self.host_with_port = url_parts[0].to_string(); self.app_name = url_parts[1].to_string(); - self.raw_stream_name = url_parts[2].to_string(); + self.stream_name_with_query = url_parts[2].to_string(); - self.parse_raw_domain_name()?; - self.parse_raw_stream_name(); + self.parse_host_with_port()?; + (self.stream_name, self.query) = + Self::parse_stream_name_with_query(&self.stream_name_with_query); } else { return Err(RtmpUrlParseError { value: RtmpUrlParseErrorValue::Notvalid, @@ -65,28 +64,41 @@ impl RtmpUrlParser { Ok(()) } - pub fn parse_raw_domain_name(&mut self) -> Result<(), RtmpUrlParseError> { - let data: Vec<&str> = self.raw_domain_name.split(':').collect(); - self.domain_name = data[0].to_string(); + pub fn parse_host_with_port(&mut self) -> Result<(), RtmpUrlParseError> { + let data: Vec<&str> = self.host_with_port.split(':').collect(); + self.host = data[0].to_string(); if data.len() > 1 { - self.port = data[1].to_string(); + self.port = Some(data[1].to_string()); } Ok(()) } - /*parse the raw stream name to get real stream name and the URL parameters*/ - pub fn parse_raw_stream_name(&mut self) -> (String, String) { - let data: Vec<&str> = self.raw_stream_name.split('?').collect(); - self.stream_name = data[0].to_string(); - if data.len() > 1 { - self.parameters = data[1].to_string(); - } - (self.stream_name.clone(), self.parameters.clone()) + /*parse the stream name and query to get real stream name and query*/ + pub fn parse_stream_name_with_query(stream_name_with_query: &str) -> (String, Option) { + let data: Vec<&str> = stream_name_with_query.split('?').collect(); + let stream_name = data[0].to_string(); + let query = if data.len() > 1 { + let query_val = data[1].to_string(); + + let mut query_pairs = IndexMap::new(); + let pars_array: Vec<&str> = query_val.split('&').collect(); + for ele in pars_array { + let (k, v) = scanf!(ele, '=', String, String); + if k.is_none() || v.is_none() { + continue; + } + query_pairs.insert(k.unwrap(), v.unwrap()); + } + Some(data[1].to_string()) + } else { + None + }; + (stream_name, query) } pub fn append_port(&mut self, port: String) { - if !self.raw_domain_name.contains(':') { - self.raw_domain_name = format!("{}:{}", self.raw_domain_name, port); - self.port = port; + if !self.host_with_port.contains(':') { + self.host_with_port = format!("{}:{}", self.host_with_port, port); + self.port = Some(port); } } } @@ -103,13 +115,17 @@ mod tests { parser.parse_url().unwrap(); - println!(" raw_domain_name: {}", parser.raw_domain_name); - println!(" port: {}", parser.port); - println!(" domain_name: {}", parser.domain_name); + println!(" raw_domain_name: {}", parser.host_with_port); + if parser.port.is_some() { + println!(" port: {}", parser.port.unwrap()); + } + println!(" domain_name: {}", parser.host); println!(" app_name: {}", parser.app_name); - println!(" raw_stream_name: {}", parser.raw_stream_name); + println!(" stream_name_with_query: {}", parser.stream_name_with_query); println!(" stream_name: {}", parser.stream_name); - println!(" url_parameters: {}", parser.parameters); + if parser.query.is_some() { + println!(" query: {}", parser.query.unwrap()); + } } #[test] fn test_rtmp_url_parser2() { @@ -118,12 +134,16 @@ mod tests { parser.parse_url().unwrap(); - println!(" raw_domain_name: {}", parser.raw_domain_name); - println!(" port: {}", parser.port); - println!(" domain_name: {}", parser.domain_name); + println!(" raw_domain_name: {}", parser.host_with_port); + if parser.port.is_some() { + println!(" port: {}", parser.port.unwrap()); + } + println!(" domain_name: {}", parser.host); println!(" app_name: {}", parser.app_name); - println!(" raw_stream_name: {}", parser.raw_stream_name); + println!(" stream_name_with_query: {}", parser.stream_name_with_query); println!(" stream_name: {}", parser.stream_name); - println!(" url_parameters: {}", parser.parameters); + if parser.query.is_some() { + println!(" query: {}", parser.query.unwrap()); + } } } diff --git a/protocol/rtsp/Cargo.toml b/protocol/rtsp/Cargo.toml index 33cc60df..3ada035c 100644 --- a/protocol/rtsp/Cargo.toml +++ b/protocol/rtsp/Cargo.toml @@ -22,3 +22,4 @@ hex = "0.4.3" bytesio = { path = "../../library/bytesio/" } streamhub = { path = "../../library/streamhub/" } +commonlib = { path = "../../library/common/" } diff --git a/protocol/rtsp/README.md b/protocol/rtsp/README.md index b48d7251..11f16e24 100644 --- a/protocol/rtsp/README.md +++ b/protocol/rtsp/README.md @@ -7,5 +7,6 @@ A rtsp library. ## v0.1.3 - Remove no used "\n" for error message. - Receive and process sub event result. - +## 0.2.0 +- Support auth. diff --git a/protocol/rtsp/src/http/mod.rs b/protocol/rtsp/src/http/mod.rs deleted file mode 100644 index aa956278..00000000 --- a/protocol/rtsp/src/http/mod.rs +++ /dev/null @@ -1,314 +0,0 @@ -use crate::global_trait::Marshal; -use crate::global_trait::Unmarshal; -use crate::rtsp_utils; -use indexmap::IndexMap; - -#[derive(Debug, Clone, Default)] -pub struct RtspRequest { - pub method: String, - pub url: String, - //url = "rtsp://{}:{}/{}", address, port, path - pub address: String, - pub port: u16, - pub path: String, - pub version: String, - pub headers: IndexMap, - pub body: Option, -} - -impl RtspRequest { - pub fn get_header(&self, header_name: &String) -> Option<&String> { - self.headers.get(header_name) - } -} - -impl Unmarshal for RtspRequest { - fn unmarshal(request_data: &str) -> Option { - let mut rtsp_request = RtspRequest::default(); - let header_end_idx = if let Some(idx) = request_data.find("\r\n\r\n") { - let data_except_body = &request_data[..idx]; - let mut lines = data_except_body.lines(); - //parse the first line - if let Some(request_first_line) = lines.next() { - let mut fields = request_first_line.split_ascii_whitespace(); - if let Some(method) = fields.next() { - rtsp_request.method = method.to_string(); - } - if let Some(url) = fields.next() { - rtsp_request.url = url.to_string(); - - if let Some(val) = url.strip_prefix("rtsp://") { - if let Some(index) = val.find('/') { - let path = &url[7 + index + 1..]; - rtsp_request.path = String::from(path); - let address_with_port = &url[7..7 + index]; - - let (address_val, port_val) = - rtsp_utils::scanf!(address_with_port, ':', String, u16); - - if let Some(address) = address_val { - rtsp_request.address = address; - } - if let Some(port) = port_val { - rtsp_request.port = port; - } - - print!("address_with_port: {address_with_port}"); - } - } - } - if let Some(version) = fields.next() { - rtsp_request.version = version.to_string(); - } - } - //parse headers - for line in lines { - if let Some(index) = line.find(": ") { - let name = line[..index].to_string(); - let value = line[index + 2..].to_string(); - rtsp_request.headers.insert(name, value); - } - } - idx + 4 - } else { - return None; - }; - - if request_data.len() > header_end_idx { - //parse body - rtsp_request.body = Some(request_data[header_end_idx..].to_string()); - } - - Some(rtsp_request) - } -} - -impl Marshal for RtspRequest { - fn marshal(&self) -> String { - let mut request_str = format!("{} {} {}\r\n", self.method, self.url, self.version); - for (header_name, header_value) in &self.headers { - if header_name != &"Content-Length".to_string() { - request_str += &format!("{header_name}: {header_value}\r\n"); - } - } - if let Some(body) = &self.body { - request_str += &format!("Content-Length: {}\r\n", body.len()); - } - request_str += "\r\n"; - if let Some(body) = &self.body { - request_str += body; - } - request_str - } -} - -#[derive(Debug, Clone, Default)] -pub struct RtspResponse { - pub version: String, - pub status_code: u16, - pub reason_phrase: String, - pub headers: IndexMap, - pub body: Option, -} - -impl Unmarshal for RtspResponse { - fn unmarshal(request_data: &str) -> Option { - let mut rtsp_response = RtspResponse::default(); - let header_end_idx = if let Some(idx) = request_data.find("\r\n\r\n") { - let data_except_body = &request_data[..idx]; - let mut lines = data_except_body.lines(); - //parse the first line - if let Some(request_first_line) = lines.next() { - let mut fields = request_first_line.split_ascii_whitespace(); - - if let Some(version) = fields.next() { - rtsp_response.version = version.to_string(); - } - if let Some(status) = fields.next() { - if let Ok(status) = status.parse::() { - rtsp_response.status_code = status; - } - } - if let Some(reason_phrase) = fields.next() { - rtsp_response.reason_phrase = reason_phrase.to_string(); - } - } - //parse headers - for line in lines { - if let Some(index) = line.find(": ") { - let name = line[..index].to_string(); - let value = line[index + 2..].to_string(); - rtsp_response.headers.insert(name, value); - } - } - idx + 4 - } else { - return None; - }; - - if request_data.len() > header_end_idx { - //parse body - rtsp_response.body = Some(request_data[header_end_idx..].to_string()); - } - - Some(rtsp_response) - } -} - -impl Marshal for RtspResponse { - fn marshal(&self) -> String { - let mut response_str = format!( - "{} {} {}\r\n", - self.version, self.status_code, self.reason_phrase - ); - for (header_name, header_value) in &self.headers { - if header_name != &"Content-Length".to_string() { - response_str += &format!("{header_name}: {header_value}\r\n"); - } - } - if let Some(body) = &self.body { - response_str += &format!("Content-Length: {}\r\n", body.len()); - } - response_str += "\r\n"; - if let Some(body) = &self.body { - response_str += body; - } - response_str - } -} - -#[cfg(test)] -mod tests { - - use crate::{global_trait::Marshal, http::Unmarshal}; - - use super::RtspRequest; - - use indexmap::IndexMap; - use std::io::BufRead; - #[allow(dead_code)] - fn read_headers(reader: &mut dyn BufRead) -> Option> { - let mut headers = IndexMap::new(); - loop { - let mut line = String::new(); - match reader.read_line(&mut line) { - Ok(0) => break, - Ok(_) => { - if let Some(index) = line.find(": ") { - let name = line[..index].to_string(); - let value = line[index + 2..].trim().to_string(); - headers.insert(name, value); - } - } - Err(_) => return None, - } - } - Some(headers) - } - - // #[test] - // fn test_parse_rtsp_request_chatgpt() { - // let data1 = "ANNOUNCE rtsp://127.0.0.1:5544/stream RTSP/1.0\r\n\ - // Content-Type: application/sdp\r\n\ - // CSeq: 2\r\n\ - // User-Agent: Lavf58.76.100\r\n\ - // Content-Length: 500\r\n\ - // \r\n\ - // v=0\r\n\ - // o=- 0 0 IN IP4 127.0.0.1\r\n\ - // s=No Name\r\n\ - // c=IN IP4 127.0.0.1\r\n\ - // t=0 0\r\n\ - // a=tool:libavformat 58.76.100\r\n\ - // m=video 0 RTP/AVP 96\r\n\ - // b=AS:284\r\n\ - // a=rtpmap:96 H264/90000 - // a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z2QAHqzZQKAv+XARAAADAAEAAAMAMg8WLZY=,aOvjyyLA; profile-level-id=64001E\r\n\ - // a=control:streamid=0\r\n\ - // m=audio 0 RTP/AVP 97\r\n\ - // b=AS:128\r\n\ - // a=rtpmap:97 MPEG4-GENERIC/48000/2\r\n\ - // a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=119056E500\r\n\ - // a=control:streamid=1\r\n"; - - // // v=0:SDP版本号,通常为0。 - // // o=- 0 0 IN IP4 127.0.0.1:会话的所有者和会话ID,以及会话开始时间和会话结束时间的信息。 - // // s=No Name:会话名称或标题。 - // // c=IN IP4 127.0.0.1:表示会话数据传输的地址类型(IPv4)和地址(127.0.0.1)。 - // // t=0 0:会话时间,包括会话开始时间和结束时间,这里的值都是0,表示会话没有预定义的结束时间。 - // // a=tool:libavformat 58.76.100:会话所使用的工具或软件名称和版本号。 - - // // m=video 0 RTP/AVP 96:媒体类型(video或audio)、媒体格式(RTP/AVP)、媒体格式编号(96)和媒体流的传输地址。 - // // b=AS:284:视频流所使用的带宽大小。 - // // a=rtpmap:96 H264/90000:视频流所使用的编码方式(H.264)和时钟频率(90000)。 - // // a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z2QAHqzZQKAv+XARAAADAAEAAAMAMg8WLZY=,aOvjyyLA; profile-level-id=64001E:视频流的格式参数,如分片方式、SPS和PPS等。 - // // a=control:streamid=0:指定视频流的流ID。 - - // // m=audio 0 RTP/AVP 97:媒体类型(audio)、媒体格式(RTP/AVP)、媒体格式编号(97)和媒体流的传输地址。 - // // b=AS:128:音频流所使用的带宽大小。 - // // a=rtpmap:97 MPEG4-GENERIC/48000/2:音频流所使用的编码方式(MPEG4-GENERIC)、采样率(48000Hz)、和通道数(2)。 - // // a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=119056E500:音频流的格式参数,如编码方式、采样长度、索引长度等。 - // // a=control:streamid=1:指定音频流的流ID。 - - // if let Some(request) = parse_request_bytes(&data1.as_bytes()) { - // println!(" parser: {:?}", request); - // } - // } - - #[test] - fn test_parse_rtsp_request() { - let data1 = "SETUP rtsp://127.0.0.1/stream/streamid=0 RTSP/1.0\r\n\ - Transport: RTP/AVP/TCP;unicast;interleaved=0-1;mode=record\r\n\ - CSeq: 3\r\n\ - User-Agent: Lavf58.76.100\r\n\ - \r\n"; - - if let Some(parser) = RtspRequest::unmarshal(data1) { - println!(" parser: {parser:?}"); - let marshal_result = parser.marshal(); - print!("marshal result: =={marshal_result}=="); - assert_eq!(data1, marshal_result); - } - - let data2 = "ANNOUNCE rtsp://127.0.0.1/stream RTSP/1.0\r\n\ - Content-Type: application/sdp\r\n\ - CSeq: 2\r\n\ - User-Agent: Lavf58.76.100\r\n\ - Content-Length: 500\r\n\ - \r\n\ - v=0\r\n\ - o=- 0 0 IN IP4 127.0.0.1\r\n\ - s=No Name\r\n\ - c=IN IP4 127.0.0.1\r\n\ - t=0 0\r\n\ - a=tool:libavformat 58.76.100\r\n\ - m=video 0 RTP/AVP 96\r\n\ - b=AS:284\r\n\ - a=rtpmap:96 H264/90000\r\n\ - a=fmtp:96 packetization-mode=1; sprop-parameter-sets=Z2QAHqzZQKAv+XARAAADAAEAAAMAMg8WLZY=,aOvjyyLA; profile-level-id=64001E\r\n\ - a=control:streamid=0\r\n\ - m=audio 0 RTP/AVP 97\r\n\ - b=AS:128\r\n\ - a=rtpmap:97 MPEG4-GENERIC/48000/2\r\n\ - a=fmtp:97 profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3; config=119056E500\r\n\ - a=control:streamid=1\r\n"; - - if let Some(parser) = RtspRequest::unmarshal(data2) { - println!(" parser: {parser:?}"); - let marshal_result = parser.marshal(); - print!("marshal result: =={marshal_result}=="); - assert_eq!(data2, marshal_result); - } - } - - #[test] - fn test_http_status_code() { - let stats_code = http::StatusCode::OK; - - println!( - "stats_code: {}, {}", - stats_code.canonical_reason().unwrap(), - stats_code.as_u16() - ) - } -} diff --git a/protocol/rtsp/src/lib.rs b/protocol/rtsp/src/lib.rs index fdedd898..8560801b 100644 --- a/protocol/rtsp/src/lib.rs +++ b/protocol/rtsp/src/lib.rs @@ -1,5 +1,5 @@ pub mod global_trait; -pub mod http; +// pub mod http; pub mod rtp; pub mod rtsp; pub mod rtsp_codec; diff --git a/protocol/rtsp/src/rtsp.rs b/protocol/rtsp/src/rtsp.rs index 39f70009..2c9020e9 100644 --- a/protocol/rtsp/src/rtsp.rs +++ b/protocol/rtsp/src/rtsp.rs @@ -1,6 +1,7 @@ use streamhub::define::StreamHubEventSender; use super::session::RtspServerSession; +use commonlib::auth::Auth; use std::net::SocketAddr; use tokio::io::Error; use tokio::net::TcpListener; @@ -8,13 +9,15 @@ use tokio::net::TcpListener; pub struct RtspServer { address: String, event_producer: StreamHubEventSender, + auth: Option, } impl RtspServer { - pub fn new(address: String, event_producer: StreamHubEventSender) -> Self { + pub fn new(address: String, event_producer: StreamHubEventSender, auth: Option) -> Self { Self { address, event_producer, + auth, } } @@ -25,7 +28,8 @@ impl RtspServer { log::info!("Rtsp server listening on tcp://{}", socket_addr); loop { let (tcp_stream, _) = listener.accept().await?; - let mut session = RtspServerSession::new(tcp_stream, self.event_producer.clone()); + let mut session = + RtspServerSession::new(tcp_stream, self.event_producer.clone(), self.auth.clone()); tokio::spawn(async move { if let Err(err) = session.run().await { log::error!("session run error, err: {}", err); diff --git a/protocol/rtsp/src/rtsp_track.rs b/protocol/rtsp/src/rtsp_track.rs index 1b68705c..49b6a478 100644 --- a/protocol/rtsp/src/rtsp_track.rs +++ b/protocol/rtsp/src/rtsp_track.rs @@ -10,10 +10,6 @@ use bytesio::bytesio::TNetIO; use std::sync::Arc; use tokio::sync::Mutex; -pub trait Track { - fn create_packer(&mut self, writer: Arc>>); - fn create_unpacker(&mut self); -} #[derive(Debug, Clone, Default, Hash, Eq, PartialEq)] pub enum TrackType { #[default] diff --git a/protocol/rtsp/src/session/errors.rs b/protocol/rtsp/src/session/errors.rs index ed163c69..9c6c25f9 100644 --- a/protocol/rtsp/src/session/errors.rs +++ b/protocol/rtsp/src/session/errors.rs @@ -2,6 +2,7 @@ use { crate::rtp::errors::{PackerError, UnPackerError}, bytesio::bytes_errors::BytesReadError, bytesio::{bytes_errors::BytesWriteError, bytesio_errors::BytesIOError}, + commonlib::errors::AuthError, failure::{Backtrace, Fail}, std::fmt, std::str::Utf8Error, @@ -36,6 +37,8 @@ pub enum SessionErrorValue { ChannelError(#[cause] ChannelError), #[fail(display = "tokio: oneshot receiver err: {}", _0)] RecvError(#[cause] RecvError), + #[fail(display = "auth err: {}", _0)] + AuthError(#[cause] AuthError), #[fail(display = "Channel receive error")] ChannelRecvError, } @@ -104,6 +107,14 @@ impl From for SessionError { } } +impl From for SessionError { + fn from(error: AuthError) -> Self { + SessionError { + value: SessionErrorValue::AuthError(error), + } + } +} + impl fmt::Display for SessionError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self.value, f) diff --git a/protocol/rtsp/src/session/mod.rs b/protocol/rtsp/src/session/mod.rs index 7c835ff4..e35c0d1b 100644 --- a/protocol/rtsp/src/session/mod.rs +++ b/protocol/rtsp/src/session/mod.rs @@ -3,9 +3,13 @@ pub mod errors; use super::rtsp_codec; use crate::global_trait::Marshal; use crate::global_trait::Unmarshal; -use crate::http::RtspResponse; + use crate::rtp::define::ANNEXB_NALU_START_CODE; use crate::rtp::utils::Marshal as RtpMarshal; +use commonlib::http::HttpRequest as RtspRequest; +use commonlib::http::HttpResponse as RtspResponse; +use commonlib::http::Marshal as RtspMarshal; +use commonlib::http::Unmarshal as RtspUnmarshal; use crate::rtp::RtpPacket; use crate::rtsp_range::RtspRange; @@ -33,7 +37,6 @@ use streamhub::define::MediaInfo; use streamhub::define::VideoCodecType; use tokio::sync::oneshot; -use super::http::RtspRequest; use super::rtp::errors::UnPackerError; use super::sdp::Sdp; @@ -46,6 +49,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::mpsc; +use commonlib::auth::Auth; use streamhub::{ define::{ FrameData, Information, InformationSender, NotifyInfo, PublishType, PublisherInfo, @@ -70,6 +74,8 @@ pub struct RtspServerSession { stream_handler: Arc, event_producer: StreamHubEventSender, + + auth: Option, } pub struct InterleavedBinaryData { @@ -102,7 +108,11 @@ impl InterleavedBinaryData { } impl RtspServerSession { - pub fn new(stream: TcpStream, event_producer: StreamHubEventSender) -> Self { + pub fn new( + stream: TcpStream, + event_producer: StreamHubEventSender, + auth: Option, + ) -> Self { // let remote_addr = if let Ok(addr) = stream.peer_addr() { // log::info!("server session: {}", addr.to_string()); // Some(addr) @@ -122,6 +132,7 @@ impl RtspServerSession { session_id: None, event_producer, stream_handler: Arc::new(RtspStreamHandler::new()), + auth, } } @@ -193,7 +204,7 @@ impl RtspServerSession { } rtsp_method_name::PLAY => { if self.handle_play(&rtsp_request).await.is_err() { - self.unsubscribe_from_stream_hub(rtsp_request.path)?; + self.unsubscribe_from_stream_hub(rtsp_request.uri.path)?; } } rtsp_method_name::RECORD => { @@ -233,7 +244,7 @@ impl RtspServerSession { let request_event = StreamHubEvent::Request { identifier: StreamIdentifier::Rtsp { - stream_path: rtsp_request.path.clone(), + stream_path: rtsp_request.uri.path.clone(), }, sender, }; @@ -265,6 +276,11 @@ impl RtspServerSession { } async fn handle_announce(&mut self, rtsp_request: &RtspRequest) -> Result<(), SessionError> { + if let Some(auth) = &self.auth { + let stream_name = rtsp_request.uri.path.clone(); + auth.authenticate(&stream_name, &rtsp_request.uri.query, false)?; + } + if let Some(request_body) = &rtsp_request.body { if let Some(sdp) = Sdp::unmarshal(request_body) { self.sdp = sdp.clone(); @@ -279,7 +295,7 @@ impl RtspServerSession { let publish_event = StreamHubEvent::Publish { identifier: StreamIdentifier::Rtsp { - stream_path: rtsp_request.path.clone(), + stream_path: rtsp_request.uri.path.clone(), }, result_sender: event_result_sender, info: self.get_publisher_info(), @@ -328,7 +344,7 @@ impl RtspServerSession { let mut response = Self::gen_response(status_code, rtsp_request); for track in self.tracks.values_mut() { - if !rtsp_request.url.contains(&track.media_control) { + if !rtsp_request.uri.marshal().contains(&track.media_control) { continue; } @@ -356,7 +372,7 @@ impl RtspServerSession { (0, 0) }; - let address = rtsp_request.address.clone(); + let address = rtsp_request.uri.host.clone(); if let Some(rtp_io) = UdpIO::new(address.clone(), rtp_port, 0).await { rtp_server_port = rtp_io.get_local_port(); @@ -411,6 +427,11 @@ impl RtspServerSession { } async fn handle_play(&mut self, rtsp_request: &RtspRequest) -> Result<(), SessionError> { + if let Some(auth) = &self.auth { + let stream_name = rtsp_request.uri.path.clone(); + auth.authenticate(&stream_name, &rtsp_request.uri.query, true)?; + } + for track in self.tracks.values_mut() { let protocol_type = track.transport.protocol_type.clone(); @@ -465,7 +486,7 @@ impl RtspServerSession { let publish_event = StreamHubEvent::Subscribe { identifier: StreamIdentifier::Rtsp { - stream_path: rtsp_request.path.clone(), + stream_path: rtsp_request.uri.path.clone(), }, info: self.get_subscriber_info(), result_sender: event_result_sender, @@ -561,7 +582,7 @@ impl RtspServerSession { } fn handle_teardown(&mut self, rtsp_request: &RtspRequest) -> Result<(), SessionError> { - let stream_path = &rtsp_request.path; + let stream_path = &rtsp_request.uri.path; let unpublish_event = StreamHubEvent::UnPublish { identifier: StreamIdentifier::Rtsp { stream_path: stream_path.clone(), diff --git a/protocol/webrtc/Cargo.toml b/protocol/webrtc/Cargo.toml index 1b9d6eeb..7718dec8 100644 --- a/protocol/webrtc/Cargo.toml +++ b/protocol/webrtc/Cargo.toml @@ -25,3 +25,4 @@ opus = "0.3.0" bytesio = { path = "../../library/bytesio/" } streamhub = { path = "../../library/streamhub/" } xflv = { path = "../../library/container/flv/" } +commonlib = { path = "../../library/common/" } diff --git a/protocol/webrtc/README.md b/protocol/webrtc/README.md index 41e86e1a..a9ee5d64 100644 --- a/protocol/webrtc/README.md +++ b/protocol/webrtc/README.md @@ -6,6 +6,7 @@ A webrtc library. - Remove no used "\n" for error message. - Receive and process sub event result. - Support remux from WHIP to RTMP. - +## 0.3.0 +- Support auth. diff --git a/protocol/webrtc/src/clients/index.html b/protocol/webrtc/src/clients/index.html index e33978b9..1313b235 100644 --- a/protocol/webrtc/src/clients/index.html +++ b/protocol/webrtc/src/clients/index.html @@ -89,6 +89,8 @@

WHEP Example

+ +

@@ -103,6 +105,7 @@

WHEP Example

startWhepBtn.addEventListener("click", () => { const appName = document.getElementById("app-name").value; const streamName = document.getElementById("stream-name").value; + const token = document.getElementById("token").value; //Create peerconnection const pc = window.pc = new RTCPeerConnection(); @@ -123,8 +126,8 @@

WHEP Example

//Create whep client const whep = new WHEPClient(); - const url = location.origin + "/whep?app=" + appName + "&stream=" + streamName; - const token = "" + const url = location.origin + "/whep?app=" + appName + "&stream=" + streamName + "&token=" + token; + //const token = "" //Start viewing whep.view(pc, url, token); diff --git a/protocol/webrtc/src/lib.rs b/protocol/webrtc/src/lib.rs index 82e566ea..6e3801d5 100644 --- a/protocol/webrtc/src/lib.rs +++ b/protocol/webrtc/src/lib.rs @@ -1,5 +1,5 @@ pub mod errors; -pub mod http; +// pub mod http; pub mod session; pub mod webrtc; pub mod whep; diff --git a/protocol/webrtc/src/session/errors.rs b/protocol/webrtc/src/session/errors.rs index 79ee29a7..ddf30dbc 100644 --- a/protocol/webrtc/src/session/errors.rs +++ b/protocol/webrtc/src/session/errors.rs @@ -2,6 +2,7 @@ use streamhub::errors::ChannelError; use { bytesio::bytes_errors::BytesReadError, bytesio::{bytes_errors::BytesWriteError, bytesio_errors::BytesIOError}, + commonlib::errors::AuthError, failure::{Backtrace, Fail}, std::fmt, std::str::Utf8Error, @@ -30,6 +31,8 @@ pub enum SessionErrorValue { RTCError(#[cause] RTCError), #[fail(display = "tokio: oneshot receiver err: {}", _0)] RecvError(#[cause] RecvError), + #[fail(display = "Auth err: {}", _0)] + AuthError(#[cause] AuthError), #[fail(display = "stream hub event send error")] StreamHubEventSendErr, #[fail(display = "cannot receive frame data from stream hub")] @@ -102,6 +105,14 @@ impl From for SessionError { } } +impl From for SessionError { + fn from(error: AuthError) -> Self { + SessionError { + value: SessionErrorValue::AuthError(error), + } + } +} + impl fmt::Display for SessionError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self.value, f) diff --git a/protocol/webrtc/src/session/mod.rs b/protocol/webrtc/src/session/mod.rs index 7579e97e..1c89c0ff 100644 --- a/protocol/webrtc/src/session/mod.rs +++ b/protocol/webrtc/src/session/mod.rs @@ -18,9 +18,13 @@ use std::io::Read; use std::{collections::HashMap, fs::File, sync::Arc}; use tokio::net::TcpStream; -use super::http::define::http_method_name; -use super::http::parse_content_length; -use super::http::{HttpRequest, HttpResponse, Marshal, Unmarshal}; +use commonlib::define::http_method_name; +use commonlib::http::{parse_content_length, HttpRequest, HttpResponse}; + +use commonlib::http::Marshal as HttpMarshal; +use commonlib::http::Unmarshal as HttpUnmarshal; + +use commonlib::auth::Auth; use super::whep::handle_whep; use super::whip::handle_whip; @@ -46,10 +50,16 @@ pub struct WebRTCServerSession { pub session_id: Option, pub http_request_data: Option, pub peer_connection: Option>, + + auth: Option, } impl WebRTCServerSession { - pub fn new(stream: TcpStream, event_producer: StreamHubEventSender) -> Self { + pub fn new( + stream: TcpStream, + event_producer: StreamHubEventSender, + auth: Option, + ) -> Self { let net_io: Box = Box::new(TcpIO::new(stream)); let io = Arc::new(Mutex::new(net_io)); @@ -62,6 +72,7 @@ impl WebRTCServerSession { session_id: None, http_request_data: None, peer_connection: None, + auth, } } @@ -100,17 +111,20 @@ impl WebRTCServerSession { if let Some(http_request) = HttpRequest::unmarshal(std::str::from_utf8(&request_data)?) { //POST /whip?app=live&stream=test HTTP/1.1 - let eles: Vec<&str> = http_request.path.splitn(2, '/').collect(); - let pars_map = &http_request.path_parameters_map; + let eles: Vec<&str> = http_request.uri.path.splitn(2, '/').collect(); + let pars_map = &http_request.query_pairs; let request_method = http_request.method.as_str(); if request_method == http_method_name::GET { - let response = match http_request.path.as_str() { + let response = match http_request.uri.path.as_str() { "/" => Self::gen_file_response("./index.html"), "/whip.js" => Self::gen_file_response("./whip.js"), "/whep.js" => Self::gen_file_response("./whep.js"), _ => { - log::warn!("the http get path: {} is not supported.", http_request.path); + log::warn!( + "the http get path: {} is not supported.", + http_request.uri.path + ); return Ok(()); } }; @@ -122,7 +136,7 @@ impl WebRTCServerSession { if eles.len() < 2 || pars_map.get("app").is_none() || pars_map.get("stream").is_none() { log::error!( "WebRTCServerSession::run the http path is not correct: {}", - http_request.path + http_request.uri.path ); return Err(SessionError { @@ -149,25 +163,31 @@ impl WebRTCServerSession { let path = format!( "{}?{}&session_id={}", - http_request.path, - http_request.path_parameters.as_ref().unwrap(), + http_request.uri.path, + http_request.uri.query.as_ref().unwrap(), self.session_id.unwrap() ); let offer = RTCSessionDescription::offer(sdp_data.clone())?; match t.to_lowercase().as_str() { "whip" => { + if let Some(auth) = &self.auth { + auth.authenticate(&stream_name, &http_request.uri.query, false)?; + } self.publish_whip(app_name, stream_name, path, offer) .await?; } "whep" => { + if let Some(auth) = &self.auth { + auth.authenticate(&stream_name, &http_request.uri.query, true)?; + } self.subscribe_whep(app_name, stream_name, path, offer) .await?; } _ => { log::error!( "current path: {}, method: {}", - http_request.path, + http_request.uri.path, t.to_lowercase() ); return Err(SessionError { @@ -201,8 +221,8 @@ impl WebRTCServerSession { } else { log::error!( "the delete path does not contain session id: {}?{}", - http_request.path, - http_request.path_parameters.as_ref().unwrap() + http_request.uri.path, + http_request.uri.query.as_ref().unwrap() ); } @@ -219,7 +239,7 @@ impl WebRTCServerSession { _ => { log::error!( "current path: {}, method: {}", - http_request.path, + http_request.uri.path, t.to_lowercase() ); return Err(SessionError { diff --git a/protocol/webrtc/src/webrtc.rs b/protocol/webrtc/src/webrtc.rs index b1d34086..6ff7f243 100644 --- a/protocol/webrtc/src/webrtc.rs +++ b/protocol/webrtc/src/webrtc.rs @@ -2,7 +2,8 @@ use streamhub::define::StreamHubEventSender; use super::session::WebRTCServerSession; -use super::http::define::http_method_name; +use commonlib::auth::Auth; +use commonlib::define::http_method_name; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; @@ -15,14 +16,16 @@ pub struct WebRTCServer { address: String, event_producer: StreamHubEventSender, uuid_2_sessions: Arc>>>>, + auth: Option, } impl WebRTCServer { - pub fn new(address: String, event_producer: StreamHubEventSender) -> Self { + pub fn new(address: String, event_producer: StreamHubEventSender, auth: Option) -> Self { Self { address, event_producer, uuid_2_sessions: Arc::new(Mutex::new(HashMap::new())), + auth, } } @@ -36,6 +39,7 @@ impl WebRTCServer { let session = Arc::new(Mutex::new(WebRTCServerSession::new( tcp_stream, self.event_producer.clone(), + self.auth.clone(), ))); let uuid_2_sessions = self.uuid_2_sessions.clone(); tokio::spawn(async move {