diff --git a/src/routers/middleware_router.rs b/src/routers/middleware_router.rs index 1f609f818..f7e8462fe 100644 --- a/src/routers/middleware_router.rs +++ b/src/routers/middleware_router.rs @@ -9,6 +9,8 @@ use matchit::Node; use anyhow::{bail, Error, Result}; +use super::router::RouteType; + /// Contains the thread safe hashmaps of different routes pub struct MiddlewareRouter { @@ -25,11 +27,10 @@ impl MiddlewareRouter { } #[inline] - fn get_relevant_map(&self, route: &str) -> Option<&RwLock>> { + fn get_relevant_map(&self, route: RouteType) -> Option<&RwLock>> { match route { - "BEFORE_REQUEST" => Some(&self.before_request), - "AFTER_REQUEST" => Some(&self.after_request), - _ => None, + RouteType::BeforeRequest => Some(&self.before_request), + RouteType::AfterRequest => Some(&self.after_request), } } @@ -37,7 +38,7 @@ impl MiddlewareRouter { // Inserts them in the router according to their nature(CoRoutine/SyncFunction) pub fn add_route( &self, - route_type: &str, // we can just have route type as WS + route_type: RouteType, // we can just have route type as WS route: &str, handler: Py, is_async: bool, @@ -64,7 +65,7 @@ impl MiddlewareRouter { pub fn get_route( &self, - route_method: &str, + route_method: RouteType, route: &str, // check for the route method here ) -> Option<((PyFunction, u8), HashMap)> { // need to split this function in multiple smaller functions diff --git a/src/routers/router.rs b/src/routers/router.rs index b75aef09d..97a521d38 100644 --- a/src/routers/router.rs +++ b/src/routers/router.rs @@ -10,6 +10,22 @@ use matchit::Node; use anyhow::{bail, Error, Result}; +#[derive(Debug)] +pub enum RouteType { + BeforeRequest, + AfterRequest, +} + +impl RouteType { + pub fn from_str(input: &str) -> RouteType { + match input { + "BEFORE_REQUEST" => RouteType::BeforeRequest, + "AFTER_REQUEST" => RouteType::AfterRequest, + _ => panic!("Invalid route type enum."), + } + } +} + /// Contains the thread safe hashmaps of different routes pub struct Router { diff --git a/src/server.rs b/src/server.rs index 9eecbab16..76ad567d8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,8 @@ use crate::io_helpers::apply_headers; use crate::request_handler::{handle_http_middleware_request, handle_http_request}; use crate::routers::const_router::ConstRouter; -use crate::routers::router::Router; + +use crate::routers::router::{RouteType, Router}; use crate::routers::{middleware_router::MiddlewareRouter, web_socket_router::WebSocketRouter}; use crate::shared_socket::SocketHeld; use crate::types::{Headers, PyFunction}; @@ -300,6 +301,9 @@ impl Server { number_of_params: u8, ) { debug!("MiddleWare Route added for {} {} ", route_type, route); + + let route_type = RouteType::from_str(route_type); + self.middleware_router .add_route(route_type, route, handler, is_async, number_of_params) .unwrap(); @@ -393,7 +397,8 @@ async fn index( let headers = merge_headers(&global_headers, req.headers()).await; // need a better name for this - let tuple_params = match middleware_router.get_route("BEFORE_REQUEST", req.uri().path()) { + let tuple_params = match middleware_router.get_route(RouteType::BeforeRequest, req.uri().path()) + { Some(((handler_function, number_of_params), route_params)) => { let x = handle_http_middleware_request( handler_function, @@ -454,18 +459,20 @@ async fn index( } }; - if let Some(((handler_function, number_of_params), route_params)) = middleware_router.get_route("AFTER_REQUEST", req.uri().path()) { + if let Some(((handler_function, number_of_params), route_params)) = + middleware_router.get_route(RouteType::AfterRequest, req.uri().path()) + { let x = handle_http_middleware_request( - handler_function, - number_of_params, - &headers_dup, - &mut payload, - &req, - route_params, - queries.clone(), - ) - .await; - debug!("{:?}", x); + handler_function, + number_of_params, + &headers_dup, + &mut payload, + &req, + route_params, + queries.clone(), + ) + .await; + debug!("{:?}", x); }; response