From cb44c28c3d61d22d06fa166dc027980337c8f6f9 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Mon, 7 Mar 2022 22:19:00 +0000 Subject: [PATCH] Add ability to modify headers --- integration_tests/base_routes.py | 22 +++++ integration_tests/test_status_code.py | 5 + robyn/router.py | 6 +- src/processor.rs | 132 ++++++++++++++------------ src/server.rs | 72 ++++++++------ 5 files changed, 141 insertions(+), 96 deletions(-) diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index c20762038..c2ecda18d 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -163,6 +163,28 @@ def shutdown_handler(): logger.log(logging.INFO, "Shutting down") +@app.get("/redirect") +async def redirect(request): + return {"status_code": "307", "body": "", "type": "text"} + + +@app.get("/redirect_route") +async def redirect_route(request): + return "This is the redirected route" + + +@app.before_request("/redirect") +async def redirect_before_request(request): + request["headers"]["Location"] = "redirect_route" + return "" + + +@app.after_request("/redirect") +async def redirect_after_request(request): + request["headers"]["Location"] = "redirect_route" + return "" + + if __name__ == "__main__": ROBYN_URL = os.getenv("ROBYN_URL", "0.0.0.0") app.add_header("server", "robyn") diff --git a/integration_tests/test_status_code.py b/integration_tests/test_status_code.py index 5f94eeb03..f355333da 100644 --- a/integration_tests/test_status_code.py +++ b/integration_tests/test_status_code.py @@ -11,3 +11,8 @@ def test_404_status_code(session): def test_404_post_request_status_code(session): r = requests.post(f"{BASE_URL}/404") assert r.status_code == 404 + +def test_307_get_request(session): + r = requests.get(f"{BASE_URL}/redirect") + assert r.text == "This is the redirected route" + diff --git a/robyn/router.py b/robyn/router.py index 5420a95bc..d80aae11d 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -28,7 +28,6 @@ def _format_response(self, res): "body": res["body"], **res } - print("Setting the response", response) else: response = {"status_code": "200", "body": res, "type": "text"} @@ -37,7 +36,6 @@ def _format_response(self, res): def add_route(self, route_type, endpoint, handler): async def async_inner_handler(*args): response = self._format_response(await handler(*args)) - print(f"This is the response in python: {response}") return response def inner_handler(*args): @@ -95,7 +93,7 @@ def add_route(self, route_type, endpoint, handler): def add_after_request(self, endpoint): def inner(handler): async def async_inner_handler(*args): - await handler(args) + await handler(*args) return args def inner_handler(*args): @@ -112,7 +110,7 @@ def inner_handler(*args): def add_before_request(self, endpoint): def inner(handler): async def async_inner_handler(*args): - await handler(args) + await handler(*args) return args def inner_handler(*args): diff --git a/src/processor.rs b/src/processor.rs index 5c2fde649..755361d44 100644 --- a/src/processor.rs +++ b/src/processor.rs @@ -8,7 +8,7 @@ use anyhow::{bail, Result}; use crate::types::{Headers, PyFunction}; use futures_util::stream::StreamExt; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyTuple}; +use pyo3::types::PyDict; use std::fs::File; use std::io::Read; @@ -17,9 +17,9 @@ use std::io::Read; const MAX_SIZE: usize = 10_000; #[inline] -pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc) { - for a in headers.iter() { - response.insert_header((a.key().clone(), a.value().clone())); +pub fn apply_headers(response: &mut HttpResponseBuilder, headers: HashMap) { + for (key, val) in (headers).iter() { + response.insert_header((key.clone(), val.clone())); } } @@ -37,7 +37,7 @@ pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc) pub async fn handle_request( function: PyFunction, number_of_params: u8, - headers: &Arc, + headers: HashMap, payload: &mut web::Payload, req: &HttpRequest, route_params: HashMap, @@ -46,7 +46,7 @@ pub async fn handle_request( let contents = match execute_http_function( function, payload, - headers, + headers.clone(), req, route_params, queries, @@ -58,17 +58,30 @@ pub async fn handle_request( Err(err) => { println!("Error: {:?}", err); let mut response = HttpResponse::InternalServerError(); - apply_headers(&mut response, headers); + apply_headers(&mut response, headers.clone()); return response.finish(); } }; - let mut response = HttpResponse::Ok(); + let body = contents.get("body").unwrap().to_owned(); let status_code = actix_http::StatusCode::from_str(contents.get("status_code").unwrap()).unwrap(); - apply_headers(&mut response, headers); - response.status(status_code); - response.body(contents.get("body").unwrap().to_owned()) + + let mut response = HttpResponse::build(status_code); + apply_headers(&mut response, headers.clone()); + let final_response = if body != "" { + response.body(body) + } else { + response.finish() + }; + + println!( + "The status code is {} and the headers are {:?}", + final_response.status(), + final_response.headers() + ); + // response.body(contents.get("body").unwrap().to_owned()) + final_response } pub async fn handle_middleware_request( @@ -79,7 +92,7 @@ pub async fn handle_middleware_request( req: &HttpRequest, route_params: HashMap, queries: HashMap, -) -> Py { +) -> HashMap> { let contents = match execute_middleware_function( function, payload, @@ -92,12 +105,10 @@ pub async fn handle_middleware_request( .await { Ok(res) => res, - Err(err) => Python::with_gil(|py| { - println!("{:?}", err); - PyTuple::empty(py).into_py(py) - }), + Err(_err) => HashMap::new(), }; + println!("These are the middleware response {:?}", contents); contents } @@ -123,12 +134,12 @@ async fn execute_middleware_function<'a>( route_params: HashMap, queries: HashMap, number_of_params: u8, -) -> Result> { +) -> Result>> { // TODO: // try executing the first version of middleware(s) here // with just headers as params - let mut data: Option> = None; + let mut data: Vec = Vec::new(); if req.method() == Method::POST || req.method() == Method::PUT @@ -145,13 +156,13 @@ async fn execute_middleware_function<'a>( body.extend_from_slice(&chunk); } - data = Some(body.to_vec()) + data = body.to_vec() } // request object accessible while creating routes let mut request = HashMap::new(); let mut headers_python = HashMap::new(); - for elem in headers.into_iter() { + for elem in (*headers).iter() { headers_python.insert(elem.key().clone(), elem.value().clone()); } @@ -162,7 +173,7 @@ async fn execute_middleware_function<'a>( request.insert("params", route_params.into_py(py)); request.insert("queries", queries.into_py(py)); request.insert("headers", headers_python.into_py(py)); - request.insert("body", data.into_py(py)); + // request.insert("body", data.into_py(py)); // this makes the request object to be accessible across every route let coro: PyResult<&PyAny> = match number_of_params { @@ -176,10 +187,13 @@ async fn execute_middleware_function<'a>( let output = output.await?; - let res = Python::with_gil(|py| -> PyResult> { - let output: Py = output.extract(py).unwrap(); - Ok(output) - })?; + let res = + Python::with_gil(|py| -> PyResult>> { + let output: Vec>> = + output.extract(py).unwrap(); + let responses = output[0].clone(); + Ok(responses) + })?; Ok(res) } @@ -200,9 +214,10 @@ async fn execute_middleware_function<'a>( 2_u8..=u8::MAX => handler.call1((request,)), }; - let output: Py = output?.extract().unwrap(); + let output: Vec>> = + output?.extract().unwrap(); - Ok(output) + Ok(output[0].clone()) }) }) .await? @@ -215,7 +230,7 @@ async fn execute_middleware_function<'a>( async fn execute_http_function( function: PyFunction, payload: &mut web::Payload, - headers: &Headers, + headers: HashMap, req: &HttpRequest, route_params: HashMap, queries: HashMap, @@ -223,7 +238,7 @@ async fn execute_http_function( // need to change this to return a response struct // create a custom struct for this ) -> Result> { - let mut data: Option> = None; + let mut data: Vec = Vec::new(); if req.method() == Method::POST || req.method() == Method::PUT @@ -240,15 +255,11 @@ async fn execute_http_function( body.extend_from_slice(&chunk); } - data = Some(body.to_vec()) + data = body.to_vec() } // request object accessible while creating routes let mut request = HashMap::new(); - let mut headers_python = HashMap::new(); - for elem in headers.into_iter() { - headers_python.insert(elem.key().clone(), elem.value().clone()); - } match function { PyFunction::CoRoutine(handler) => { @@ -256,12 +267,9 @@ async fn execute_http_function( let handler = handler.as_ref(py); request.insert("params", route_params.into_py(py)); request.insert("queries", queries.into_py(py)); - request.insert("headers", headers_python.into_py(py)); - - if let Some(res) = data { - let data = res.into_py(py); - request.insert("body", data); - }; + request.insert("headers", headers.into_py(py)); + let data = data.into_py(py); + request.insert("body", data); // this makes the request object to be accessible across every route let coro: PyResult<&PyAny> = match number_of_params { @@ -298,11 +306,9 @@ async fn execute_http_function( Python::with_gil(|py| { let handler = handler.as_ref(py); request.insert("params", route_params.into_py(py)); - request.insert("headers", headers_python.into_py(py)); - if let Some(res) = data { - let data = res.into_py(py); - request.insert("body", data); - }; + request.insert("headers", headers.into_py(py)); + let data = data.into_py(py); + request.insert("body", data); let output: PyResult<&PyAny> = match number_of_params { 0 => handler.call0(), @@ -325,22 +331,24 @@ pub async fn execute_event_handler( event_handler: Option>, event_loop: Arc>, ) { - if let Some(handler) = event_handler { match &(*handler) { - PyFunction::SyncFunction(function) => { - println!("Startup event handler"); - Python::with_gil(|py| { - function.call0(py).unwrap(); - }); - } - PyFunction::CoRoutine(function) => { - let future = Python::with_gil(|py| { - println!("Startup event handler async"); - - let coroutine = function.as_ref(py).call0().unwrap(); - pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine) - .unwrap() - }); - future.await.unwrap(); + if let Some(handler) = event_handler { + match &(*handler) { + PyFunction::SyncFunction(function) => { + println!("Startup event handler"); + Python::with_gil(|py| { + function.call0(py).unwrap(); + }); + } + PyFunction::CoRoutine(function) => { + let future = Python::with_gil(|py| { + println!("Startup event handler async"); + + let coroutine = function.as_ref(py).call0().unwrap(); + pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine) + .unwrap() + }); + future.await.unwrap(); + } } - } } + } } diff --git a/src/server.rs b/src/server.rs index 7b1925a75..26338a243 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,6 +12,7 @@ use std::convert::TryInto; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; use std::sync::{Arc, RwLock}; + use std::thread; use actix_files::Files; @@ -319,28 +320,38 @@ async fn index( } } - let _ = if let Some(((handler_function, number_of_params), route_params)) = - middleware_router.get_route("BEFORE_REQUEST", req.uri().path()) - { - let x = handle_middleware_request( - handler_function, - number_of_params, - &headers, - &mut payload, - &req, - route_params, - queries.clone(), - ) - .await; - println!("{:?}", x.to_string()); + let tuple_params = match middleware_router.get_route("BEFORE_REQUEST", req.uri().path()) { + Some(((handler_function, number_of_params), route_params)) => { + let x = handle_middleware_request( + handler_function, + number_of_params, + &headers, + &mut payload, + &req, + route_params, + queries.clone(), + ) + .await; + println!("Middleware contents {:?}", x); + x + } + None => HashMap::new(), }; + println!("These are the tuple params {:?}", tuple_params); + + let mut headers_dup = HashMap::new(); + + if tuple_params.len() != 0 { + headers_dup = tuple_params.get("headers").unwrap().clone(); + } + let response = match router.get_route(req.method().clone(), req.uri().path()) { Some(((handler_function, number_of_params), route_params)) => { handle_request( handler_function, number_of_params, - &headers, + headers_dup.clone(), &mut payload, &req, route_params, @@ -350,25 +361,26 @@ async fn index( } None => { let mut response = HttpResponse::Ok(); - apply_headers(&mut response, &headers); + apply_headers(&mut response, headers_dup.clone()); response.finish() } }; - let _ = if let Some(((handler_function, number_of_params), route_params)) = - middleware_router.get_route("AFTER_REQUEST", req.uri().path()) - { - let x = handle_middleware_request( - handler_function, - number_of_params, - &headers, - &mut payload, - &req, - route_params, - queries.clone(), - ) - .await; - println!("{:?}", x.to_string()); + let _ = match middleware_router.get_route("AFTER_REQUEST", req.uri().path()) { + Some(((handler_function, number_of_params), route_params)) => { + let x = handle_middleware_request( + handler_function, + number_of_params, + &headers, + &mut payload, + &req, + route_params, + queries.clone(), + ) + .await; + println!("{:?}", x); + } + None => {} }; response