diff --git a/docs/features.md b/docs/features.md index 49fd5b81a..eec46ec5f 100644 --- a/docs/features.md +++ b/docs/features.md @@ -171,7 +171,7 @@ async def hello(request): return "Hello World" ``` -## Global Headers +## Global Request Headers You can also add global headers for every request. @@ -179,6 +179,14 @@ You can also add global headers for every request. app.add_request_header("server", "robyn") ``` +## Global Response Headers + +You can also add global response headers for every request. + +```python +app.add_response_header("content-type", "application/json") +``` + ## Per route headers You can also add headers for every route. diff --git a/integration_tests/test_app.py b/integration_tests/test_app.py index 5081cc7fe..d8ffbb19f 100644 --- a/integration_tests/test_app.py +++ b/integration_tests/test_app.py @@ -9,6 +9,12 @@ def test_add_request_header(): assert app.request_headers == [Header(key="server", val="robyn")] +def test_add_response_header(): + app = Robyn(__file__) + app.add_response_header("content-type", "application/json") + assert app.response_headers == [Header(key="content-type", val="application/json")] + + def test_lifecycle_handlers(): def mock_startup_handler(): pass diff --git a/robyn/__init__.py b/robyn/__init__.py index bfc24987c..de324fffa 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -32,6 +32,7 @@ def __init__(self, file_object: str) -> None: self.middleware_router = MiddlewareRouter() self.web_socket_router = WebSocketRouter() self.request_headers: List[Header] = [] # This needs a better type + self.response_headers: List[Header] = [] # This needs a better type self.directories: List[Directory] = [] self.event_handlers = {} load_vars(project_root=directory_path) @@ -82,6 +83,9 @@ def add_directory( def add_request_header(self, key: str, value: str) -> None: self.request_headers.append(Header(key, value)) + def add_response_header(self, key: str, value: str) -> None: + self.response_headers.append(Header(key, value)) + def add_web_socket(self, endpoint: str, ws: WS) -> None: self.web_socket_router.add_route(endpoint, ws) diff --git a/robyn/processpool.py b/robyn/processpool.py index 9a9666d32..66c3b72ba 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -122,6 +122,7 @@ def initialize_event_loop(): def spawn_process( directories: List[Directory], request_headers: List[Header], + response_headers: List[Header], routes: List[Route], middlewares: List[MiddlewareRoute], web_sockets: Dict[str, WS], @@ -156,6 +157,9 @@ def spawn_process( for header in request_headers: server.add_request_header(*header.as_list()) + for header in response_headers: + server.add_response_header(*header.as_list()) + for route in routes: route_type, endpoint, function, is_const = route server.add_route(route_type, endpoint, function, is_const) diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index aee6248fb..aafff93ce 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -37,6 +37,8 @@ class Server: pass def add_request_header(self, key: str, value: str) -> None: pass + def add_response_header(self, key: str, value: str) -> None: + pass def add_route( self, route_type: str, diff --git a/src/server.rs b/src/server.rs index cf8137e30..adc085dcb 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,6 +48,7 @@ pub struct Server { websocket_router: Arc, middleware_router: Arc, global_request_headers: Arc>, + global_response_headers: Arc>, directories: Arc>>, startup_handler: Option>, shutdown_handler: Option>, @@ -63,6 +64,7 @@ impl Server { websocket_router: Arc::new(WebSocketRouter::new()), middleware_router: Arc::new(MiddlewareRouter::new()), global_request_headers: Arc::new(DashMap::new()), + global_response_headers: Arc::new(DashMap::new()), directories: Arc::new(RwLock::new(Vec::new())), startup_handler: None, shutdown_handler: None, @@ -92,6 +94,7 @@ impl Server { let middleware_router = self.middleware_router.clone(); let web_socket_router = self.websocket_router.clone(); let global_request_headers = self.global_request_headers.clone(); + let global_response_headers = self.global_response_headers.clone(); let directories = self.directories.clone(); let workers = Arc::new(workers); @@ -145,7 +148,8 @@ impl Server { .app_data(web::Data::new(router.clone())) .app_data(web::Data::new(const_router.clone())) .app_data(web::Data::new(middleware_router.clone())) - .app_data(web::Data::new(global_request_headers.clone())); + .app_data(web::Data::new(global_request_headers.clone())) + .app_data(web::Data::new(global_response_headers.clone())); let web_socket_map = web_socket_router.get_web_socket_map(); for (elem, value) in (web_socket_map.read().unwrap()).iter() { @@ -165,6 +169,7 @@ impl Server { const_router: web::Data>, middleware_router: web::Data>, global_request_headers, + global_response_headers, body, req| { pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move { @@ -173,6 +178,7 @@ impl Server { const_router, middleware_router, global_request_headers, + global_response_headers, body, req, ) @@ -223,19 +229,32 @@ impl Server { }); } - /// Adds a new header to our concurrent hashmap + /// Adds a new request header to our concurrent hashmap /// this can be called after the server has started. pub fn add_request_header(&self, key: &str, value: &str) { self.global_request_headers .insert(key.to_string(), value.to_string()); } - /// Removes a new header to our concurrent hashmap + /// Adds a new response header to our concurrent hashmap + /// this can be called after the server has started. + pub fn add_response_header(&self, key: &str, value: &str) { + self.global_response_headers + .insert(key.to_string(), value.to_string()); + } + + /// Removes a new request header to our concurrent hashmap /// this can be called after the server has started. pub fn remove_header(&self, key: &str) { self.global_request_headers.remove(key); } + /// Removes a new response header to our concurrent hashmap + /// this can be called after the server has started. + pub fn remove_response_header(&self, key: &str) { + self.global_response_headers.remove(key); + } + /// Add a new route to the routing tables /// can be called after the server has been started pub fn add_route( @@ -345,6 +364,7 @@ async fn index( const_router: web::Data>, middleware_router: web::Data>, global_request_headers: web::Data>, + global_response_headers: web::Data>, body: Bytes, req: HttpRequest, ) -> impl Responder { @@ -360,6 +380,7 @@ async fn index( let mut response_builder = HttpResponse::Ok(); apply_dashmap_headers(&mut response_builder, &global_request_headers); + apply_dashmap_headers(&mut response_builder, &global_response_headers); apply_hashmap_headers(&mut response_builder, &request.headers); let response = if let Some(r) = const_router.get_route(req.method(), req.uri().path()) {