Skip to content

Commit

Permalink
Add ability to modify headers (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
sansyrox authored Mar 14, 2022
1 parent 0d38b16 commit 6bbe2b1
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 96 deletions.
22 changes: 22 additions & 0 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,28 @@ def shutdown_handler():
logger.log(logging.INFO, "Shutting down")


@app.get("/redirect")
async def redirect(request):
return {"status_code": "307", "body": "", "type": "text"}


@app.get("/redirect_route")
async def redirect_route(request):
return "This is the redirected route"


@app.before_request("/redirect")
async def redirect_before_request(request):
request["headers"]["Location"] = "redirect_route"
return ""


@app.after_request("/redirect")
async def redirect_after_request(request):
request["headers"]["Location"] = "redirect_route"
return ""


if __name__ == "__main__":
ROBYN_URL = os.getenv("ROBYN_URL", "0.0.0.0")
app.add_header("server", "robyn")
Expand Down
5 changes: 5 additions & 0 deletions integration_tests/test_status_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ def test_404_status_code(session):
def test_404_post_request_status_code(session):
r = requests.post(f"{BASE_URL}/404")
assert r.status_code == 404

def test_307_get_request(session):
r = requests.get(f"{BASE_URL}/redirect")
assert r.text == "This is the redirected route"

6 changes: 2 additions & 4 deletions robyn/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def _format_response(self, res):
"body": res["body"],
**res
}
print("Setting the response", response)
else:
response = {"status_code": "200", "body": res, "type": "text"}

Expand All @@ -37,7 +36,6 @@ def _format_response(self, res):
def add_route(self, route_type, endpoint, handler):
async def async_inner_handler(*args):
response = self._format_response(await handler(*args))
print(f"This is the response in python: {response}")
return response

def inner_handler(*args):
Expand Down Expand Up @@ -95,7 +93,7 @@ def add_route(self, route_type, endpoint, handler):
def add_after_request(self, endpoint):
def inner(handler):
async def async_inner_handler(*args):
await handler(args)
await handler(*args)
return args

def inner_handler(*args):
Expand All @@ -112,7 +110,7 @@ def inner_handler(*args):
def add_before_request(self, endpoint):
def inner(handler):
async def async_inner_handler(*args):
await handler(args)
await handler(*args)
return args

def inner_handler(*args):
Expand Down
132 changes: 70 additions & 62 deletions src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use anyhow::{bail, Result};
use crate::types::{Headers, PyFunction};
use futures_util::stream::StreamExt;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use pyo3::types::PyDict;

use std::fs::File;
use std::io::Read;
Expand All @@ -17,9 +17,9 @@ use std::io::Read;
const MAX_SIZE: usize = 10_000;

#[inline]
pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc<Headers>) {
for a in headers.iter() {
response.insert_header((a.key().clone(), a.value().clone()));
pub fn apply_headers(response: &mut HttpResponseBuilder, headers: HashMap<String, String>) {
for (key, val) in (headers).iter() {
response.insert_header((key.clone(), val.clone()));
}
}

Expand All @@ -37,7 +37,7 @@ pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc<Headers>)
pub async fn handle_request(
function: PyFunction,
number_of_params: u8,
headers: &Arc<Headers>,
headers: HashMap<String, String>,
payload: &mut web::Payload,
req: &HttpRequest,
route_params: HashMap<String, String>,
Expand All @@ -46,7 +46,7 @@ pub async fn handle_request(
let contents = match execute_http_function(
function,
payload,
headers,
headers.clone(),
req,
route_params,
queries,
Expand All @@ -58,17 +58,30 @@ pub async fn handle_request(
Err(err) => {
println!("Error: {:?}", err);
let mut response = HttpResponse::InternalServerError();
apply_headers(&mut response, headers);
apply_headers(&mut response, headers.clone());
return response.finish();
}
};

let mut response = HttpResponse::Ok();
let body = contents.get("body").unwrap().to_owned();
let status_code =
actix_http::StatusCode::from_str(contents.get("status_code").unwrap()).unwrap();
apply_headers(&mut response, headers);
response.status(status_code);
response.body(contents.get("body").unwrap().to_owned())

let mut response = HttpResponse::build(status_code);
apply_headers(&mut response, headers.clone());
let final_response = if body != "" {
response.body(body)
} else {
response.finish()
};

println!(
"The status code is {} and the headers are {:?}",
final_response.status(),
final_response.headers()
);
// response.body(contents.get("body").unwrap().to_owned())
final_response
}

pub async fn handle_middleware_request(
Expand All @@ -79,7 +92,7 @@ pub async fn handle_middleware_request(
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
) -> Py<PyTuple> {
) -> HashMap<String, HashMap<String, String>> {
let contents = match execute_middleware_function(
function,
payload,
Expand All @@ -92,12 +105,10 @@ pub async fn handle_middleware_request(
.await
{
Ok(res) => res,
Err(err) => Python::with_gil(|py| {
println!("{:?}", err);
PyTuple::empty(py).into_py(py)
}),
Err(_err) => HashMap::new(),
};

println!("These are the middleware response {:?}", contents);
contents
}

Expand All @@ -123,12 +134,12 @@ async fn execute_middleware_function<'a>(
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
number_of_params: u8,
) -> Result<Py<PyTuple>> {
) -> Result<HashMap<String, HashMap<String, String>>> {
// TODO:
// try executing the first version of middleware(s) here
// with just headers as params

let mut data: Option<Vec<u8>> = None;
let mut data: Vec<u8> = Vec::new();

if req.method() == Method::POST
|| req.method() == Method::PUT
Expand All @@ -145,13 +156,13 @@ async fn execute_middleware_function<'a>(
body.extend_from_slice(&chunk);
}

data = Some(body.to_vec())
data = body.to_vec()
}

// request object accessible while creating routes
let mut request = HashMap::new();
let mut headers_python = HashMap::new();
for elem in headers.into_iter() {
for elem in (*headers).iter() {
headers_python.insert(elem.key().clone(), elem.value().clone());
}

Expand All @@ -162,7 +173,7 @@ async fn execute_middleware_function<'a>(
request.insert("params", route_params.into_py(py));
request.insert("queries", queries.into_py(py));
request.insert("headers", headers_python.into_py(py));
request.insert("body", data.into_py(py));
// request.insert("body", data.into_py(py));

// this makes the request object to be accessible across every route
let coro: PyResult<&PyAny> = match number_of_params {
Expand All @@ -176,10 +187,13 @@ async fn execute_middleware_function<'a>(

let output = output.await?;

let res = Python::with_gil(|py| -> PyResult<Py<PyTuple>> {
let output: Py<PyTuple> = output.extract(py).unwrap();
Ok(output)
})?;
let res =
Python::with_gil(|py| -> PyResult<HashMap<String, HashMap<String, String>>> {
let output: Vec<HashMap<String, HashMap<String, String>>> =
output.extract(py).unwrap();
let responses = output[0].clone();
Ok(responses)
})?;

Ok(res)
}
Expand All @@ -200,9 +214,10 @@ async fn execute_middleware_function<'a>(
2_u8..=u8::MAX => handler.call1((request,)),
};

let output: Py<PyTuple> = output?.extract().unwrap();
let output: Vec<HashMap<String, HashMap<String, String>>> =
output?.extract().unwrap();

Ok(output)
Ok(output[0].clone())
})
})
.await?
Expand All @@ -215,15 +230,15 @@ async fn execute_middleware_function<'a>(
async fn execute_http_function(
function: PyFunction,
payload: &mut web::Payload,
headers: &Headers,
headers: HashMap<String, String>,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
number_of_params: u8,
// need to change this to return a response struct
// create a custom struct for this
) -> Result<HashMap<String, String>> {
let mut data: Option<Vec<u8>> = None;
let mut data: Vec<u8> = Vec::new();

if req.method() == Method::POST
|| req.method() == Method::PUT
Expand All @@ -240,28 +255,21 @@ async fn execute_http_function(
body.extend_from_slice(&chunk);
}

data = Some(body.to_vec())
data = body.to_vec()
}

// request object accessible while creating routes
let mut request = HashMap::new();
let mut headers_python = HashMap::new();
for elem in headers.into_iter() {
headers_python.insert(elem.key().clone(), elem.value().clone());
}

match function {
PyFunction::CoRoutine(handler) => {
let output = Python::with_gil(|py| {
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("queries", queries.into_py(py));
request.insert("headers", headers_python.into_py(py));

if let Some(res) = data {
let data = res.into_py(py);
request.insert("body", data);
};
request.insert("headers", headers.into_py(py));
let data = data.into_py(py);
request.insert("body", data);

// this makes the request object to be accessible across every route
let coro: PyResult<&PyAny> = match number_of_params {
Expand Down Expand Up @@ -298,11 +306,9 @@ async fn execute_http_function(
Python::with_gil(|py| {
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("headers", headers_python.into_py(py));
if let Some(res) = data {
let data = res.into_py(py);
request.insert("body", data);
};
request.insert("headers", headers.into_py(py));
let data = data.into_py(py);
request.insert("body", data);

let output: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
Expand All @@ -325,22 +331,24 @@ pub async fn execute_event_handler(
event_handler: Option<Arc<PyFunction>>,
event_loop: Arc<Py<PyAny>>,
) {
if let Some(handler) = event_handler { match &(*handler) {
PyFunction::SyncFunction(function) => {
println!("Startup event handler");
Python::with_gil(|py| {
function.call0(py).unwrap();
});
}
PyFunction::CoRoutine(function) => {
let future = Python::with_gil(|py| {
println!("Startup event handler async");

let coroutine = function.as_ref(py).call0().unwrap();
pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine)
.unwrap()
});
future.await.unwrap();
if let Some(handler) = event_handler {
match &(*handler) {
PyFunction::SyncFunction(function) => {
println!("Startup event handler");
Python::with_gil(|py| {
function.call0(py).unwrap();
});
}
PyFunction::CoRoutine(function) => {
let future = Python::with_gil(|py| {
println!("Startup event handler async");

let coroutine = function.as_ref(py).call0().unwrap();
pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine)
.unwrap()
});
future.await.unwrap();
}
}
} }
}
}
Loading

0 comments on commit 6bbe2b1

Please sign in to comment.