Skip to content

Commit

Permalink
feat: Allow global level Response headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Parth Shandilya committed Feb 12, 2023
1 parent 54040f5 commit 84af7b4
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 4 deletions.
10 changes: 9 additions & 1 deletion docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,22 @@ async def hello(request):
return "Hello World"
```

## Global Headers
## Global Request Headers

You can also add global headers for every request.

```python
app.add_request_header("server", "robyn")
```

## Global Response Headers

You can also add global response headers for every request.

```python
app.add_response_header("content-type", "application/json")
```

## Per route headers

You can also add headers for every route.
Expand Down
6 changes: 6 additions & 0 deletions integration_tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def test_add_request_header():
assert app.request_headers == [Header(key="server", val="robyn")]


def test_add_response_header():
app = Robyn(__file__)
app.add_response_header("content-type", "application/json")
assert app.response_headers == [Header(key="content-type", val="application/json")]


def test_lifecycle_handlers():
def mock_startup_handler():
pass
Expand Down
4 changes: 4 additions & 0 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, file_object: str) -> None:
self.middleware_router = MiddlewareRouter()
self.web_socket_router = WebSocketRouter()
self.request_headers: List[Header] = [] # This needs a better type
self.response_headers: List[Header] = [] # This needs a better type
self.directories: List[Directory] = []
self.event_handlers = {}
load_vars(project_root=directory_path)
Expand Down Expand Up @@ -82,6 +83,9 @@ def add_directory(
def add_request_header(self, key: str, value: str) -> None:
self.request_headers.append(Header(key, value))

def add_response_header(self, key: str, value: str) -> None:
self.response_headers.append(Header(key, value))

def add_web_socket(self, endpoint: str, ws: WS) -> None:
self.web_socket_router.add_route(endpoint, ws)

Expand Down
4 changes: 4 additions & 0 deletions robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def initialize_event_loop():
def spawn_process(
directories: List[Directory],
request_headers: List[Header],
response_headers: List[Header],
routes: List[Route],
middlewares: List[MiddlewareRoute],
web_sockets: Dict[str, WS],
Expand Down Expand Up @@ -156,6 +157,9 @@ def spawn_process(
for header in request_headers:
server.add_request_header(*header.as_list())

for header in response_headers:
server.add_response_header(*header.as_list())

for route in routes:
route_type, endpoint, function, is_const = route
server.add_route(route_type, endpoint, function, is_const)
Expand Down
2 changes: 2 additions & 0 deletions robyn/robyn.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Server:
pass
def add_request_header(self, key: str, value: str) -> None:
pass
def add_response_header(self, key: str, value: str) -> None:
pass
def add_route(
self,
route_type: str,
Expand Down
27 changes: 24 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct Server {
websocket_router: Arc<WebSocketRouter>,
middleware_router: Arc<MiddlewareRouter>,
global_request_headers: Arc<DashMap<String, String>>,
global_response_headers: Arc<DashMap<String, String>>,
directories: Arc<RwLock<Vec<Directory>>>,
startup_handler: Option<Arc<FunctionInfo>>,
shutdown_handler: Option<Arc<FunctionInfo>>,
Expand All @@ -63,6 +64,7 @@ impl Server {
websocket_router: Arc::new(WebSocketRouter::new()),
middleware_router: Arc::new(MiddlewareRouter::new()),
global_request_headers: Arc::new(DashMap::new()),
global_response_headers: Arc::new(DashMap::new()),
directories: Arc::new(RwLock::new(Vec::new())),
startup_handler: None,
shutdown_handler: None,
Expand Down Expand Up @@ -92,6 +94,7 @@ impl Server {
let middleware_router = self.middleware_router.clone();
let web_socket_router = self.websocket_router.clone();
let global_request_headers = self.global_request_headers.clone();
let global_response_headers = self.global_response_headers.clone();
let directories = self.directories.clone();
let workers = Arc::new(workers);

Expand Down Expand Up @@ -145,7 +148,8 @@ impl Server {
.app_data(web::Data::new(router.clone()))
.app_data(web::Data::new(const_router.clone()))
.app_data(web::Data::new(middleware_router.clone()))
.app_data(web::Data::new(global_request_headers.clone()));
.app_data(web::Data::new(global_request_headers.clone()))
.app_data(web::Data::new(global_response_headers.clone()));

let web_socket_map = web_socket_router.get_web_socket_map();
for (elem, value) in (web_socket_map.read().unwrap()).iter() {
Expand All @@ -165,6 +169,7 @@ impl Server {
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
global_request_headers,
global_response_headers,
body,
req| {
pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move {
Expand All @@ -173,6 +178,7 @@ impl Server {
const_router,
middleware_router,
global_request_headers,
global_response_headers,
body,
req,
)
Expand Down Expand Up @@ -223,19 +229,32 @@ impl Server {
});
}

/// Adds a new header to our concurrent hashmap
/// Adds a new request header to our concurrent hashmap
/// this can be called after the server has started.
pub fn add_request_header(&self, key: &str, value: &str) {
self.global_request_headers
.insert(key.to_string(), value.to_string());
}

/// Removes a new header to our concurrent hashmap
/// Adds a new response header to our concurrent hashmap
/// this can be called after the server has started.
pub fn add_response_header(&self, key: &str, value: &str) {
self.global_response_headers
.insert(key.to_string(), value.to_string());
}

/// Removes a new request header to our concurrent hashmap
/// this can be called after the server has started.
pub fn remove_header(&self, key: &str) {
self.global_request_headers.remove(key);
}

/// Removes a new response header to our concurrent hashmap
/// this can be called after the server has started.
pub fn remove_response_header(&self, key: &str) {
self.global_response_headers.remove(key);
}

/// Add a new route to the routing tables
/// can be called after the server has been started
pub fn add_route(
Expand Down Expand Up @@ -345,6 +364,7 @@ async fn index(
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
global_request_headers: web::Data<Arc<Headers>>,
global_response_headers: web::Data<Arc<Headers>>,
body: Bytes,
req: HttpRequest,
) -> impl Responder {
Expand All @@ -360,6 +380,7 @@ async fn index(

let mut response_builder = HttpResponse::Ok();
apply_dashmap_headers(&mut response_builder, &global_request_headers);
apply_dashmap_headers(&mut response_builder, &global_response_headers);
apply_hashmap_headers(&mut response_builder, &request.headers);

let response = if let Some(r) = const_router.get_route(req.method(), req.uri().path()) {
Expand Down

0 comments on commit 84af7b4

Please sign in to comment.