Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify headers #170

Merged
merged 1 commit into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,28 @@ def shutdown_handler():
logger.log(logging.INFO, "Shutting down")


@app.get("/redirect")
async def redirect(request):
return {"status_code": "307", "body": "", "type": "text"}


@app.get("/redirect_route")
async def redirect_route(request):
return "This is the redirected route"


@app.before_request("/redirect")
async def redirect_before_request(request):
request["headers"]["Location"] = "redirect_route"
return ""


@app.after_request("/redirect")
async def redirect_after_request(request):
request["headers"]["Location"] = "redirect_route"
return ""


if __name__ == "__main__":
ROBYN_URL = os.getenv("ROBYN_URL", "0.0.0.0")
app.add_header("server", "robyn")
Expand Down
5 changes: 5 additions & 0 deletions integration_tests/test_status_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ def test_404_status_code(session):
def test_404_post_request_status_code(session):
r = requests.post(f"{BASE_URL}/404")
assert r.status_code == 404

def test_307_get_request(session):
r = requests.get(f"{BASE_URL}/redirect")
assert r.text == "This is the redirected route"

6 changes: 2 additions & 4 deletions robyn/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def _format_response(self, res):
"body": res["body"],
**res
}
print("Setting the response", response)
else:
response = {"status_code": "200", "body": res, "type": "text"}

Expand All @@ -37,7 +36,6 @@ def _format_response(self, res):
def add_route(self, route_type, endpoint, handler):
async def async_inner_handler(*args):
response = self._format_response(await handler(*args))
print(f"This is the response in python: {response}")
return response

def inner_handler(*args):
Expand Down Expand Up @@ -95,7 +93,7 @@ def add_route(self, route_type, endpoint, handler):
def add_after_request(self, endpoint):
def inner(handler):
async def async_inner_handler(*args):
await handler(args)
await handler(*args)
return args

def inner_handler(*args):
Expand All @@ -112,7 +110,7 @@ def inner_handler(*args):
def add_before_request(self, endpoint):
def inner(handler):
async def async_inner_handler(*args):
await handler(args)
await handler(*args)
return args

def inner_handler(*args):
Expand Down
132 changes: 70 additions & 62 deletions src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use anyhow::{bail, Result};
use crate::types::{Headers, PyFunction};
use futures_util::stream::StreamExt;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use pyo3::types::PyDict;

use std::fs::File;
use std::io::Read;
Expand All @@ -17,9 +17,9 @@ use std::io::Read;
const MAX_SIZE: usize = 10_000;

#[inline]
pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc<Headers>) {
for a in headers.iter() {
response.insert_header((a.key().clone(), a.value().clone()));
pub fn apply_headers(response: &mut HttpResponseBuilder, headers: HashMap<String, String>) {
for (key, val) in (headers).iter() {
response.insert_header((key.clone(), val.clone()));
}
}

Expand All @@ -37,7 +37,7 @@ pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc<Headers>)
pub async fn handle_request(
function: PyFunction,
number_of_params: u8,
headers: &Arc<Headers>,
headers: HashMap<String, String>,
payload: &mut web::Payload,
req: &HttpRequest,
route_params: HashMap<String, String>,
Expand All @@ -46,7 +46,7 @@ pub async fn handle_request(
let contents = match execute_http_function(
function,
payload,
headers,
headers.clone(),
req,
route_params,
queries,
Expand All @@ -58,17 +58,30 @@ pub async fn handle_request(
Err(err) => {
println!("Error: {:?}", err);
let mut response = HttpResponse::InternalServerError();
apply_headers(&mut response, headers);
apply_headers(&mut response, headers.clone());
return response.finish();
}
};

let mut response = HttpResponse::Ok();
let body = contents.get("body").unwrap().to_owned();
let status_code =
actix_http::StatusCode::from_str(contents.get("status_code").unwrap()).unwrap();
apply_headers(&mut response, headers);
response.status(status_code);
response.body(contents.get("body").unwrap().to_owned())

let mut response = HttpResponse::build(status_code);
apply_headers(&mut response, headers.clone());
let final_response = if body != "" {
response.body(body)
} else {
response.finish()
};

println!(
"The status code is {} and the headers are {:?}",
final_response.status(),
final_response.headers()
);
// response.body(contents.get("body").unwrap().to_owned())
final_response
}

pub async fn handle_middleware_request(
Expand All @@ -79,7 +92,7 @@ pub async fn handle_middleware_request(
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
) -> Py<PyTuple> {
) -> HashMap<String, HashMap<String, String>> {
let contents = match execute_middleware_function(
function,
payload,
Expand All @@ -92,12 +105,10 @@ pub async fn handle_middleware_request(
.await
{
Ok(res) => res,
Err(err) => Python::with_gil(|py| {
println!("{:?}", err);
PyTuple::empty(py).into_py(py)
}),
Err(_err) => HashMap::new(),
};

println!("These are the middleware response {:?}", contents);
contents
}

Expand All @@ -123,12 +134,12 @@ async fn execute_middleware_function<'a>(
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
number_of_params: u8,
) -> Result<Py<PyTuple>> {
) -> Result<HashMap<String, HashMap<String, String>>> {
// TODO:
// try executing the first version of middleware(s) here
// with just headers as params

let mut data: Option<Vec<u8>> = None;
let mut data: Vec<u8> = Vec::new();

if req.method() == Method::POST
|| req.method() == Method::PUT
Expand All @@ -145,13 +156,13 @@ async fn execute_middleware_function<'a>(
body.extend_from_slice(&chunk);
}

data = Some(body.to_vec())
data = body.to_vec()
}

// request object accessible while creating routes
let mut request = HashMap::new();
let mut headers_python = HashMap::new();
for elem in headers.into_iter() {
for elem in (*headers).iter() {
headers_python.insert(elem.key().clone(), elem.value().clone());
}

Expand All @@ -162,7 +173,7 @@ async fn execute_middleware_function<'a>(
request.insert("params", route_params.into_py(py));
request.insert("queries", queries.into_py(py));
request.insert("headers", headers_python.into_py(py));
request.insert("body", data.into_py(py));
// request.insert("body", data.into_py(py));

// this makes the request object to be accessible across every route
let coro: PyResult<&PyAny> = match number_of_params {
Expand All @@ -176,10 +187,13 @@ async fn execute_middleware_function<'a>(

let output = output.await?;

let res = Python::with_gil(|py| -> PyResult<Py<PyTuple>> {
let output: Py<PyTuple> = output.extract(py).unwrap();
Ok(output)
})?;
let res =
Python::with_gil(|py| -> PyResult<HashMap<String, HashMap<String, String>>> {
let output: Vec<HashMap<String, HashMap<String, String>>> =
output.extract(py).unwrap();
let responses = output[0].clone();
Ok(responses)
})?;

Ok(res)
}
Expand All @@ -200,9 +214,10 @@ async fn execute_middleware_function<'a>(
2_u8..=u8::MAX => handler.call1((request,)),
};

let output: Py<PyTuple> = output?.extract().unwrap();
let output: Vec<HashMap<String, HashMap<String, String>>> =
output?.extract().unwrap();

Ok(output)
Ok(output[0].clone())
})
})
.await?
Expand All @@ -215,15 +230,15 @@ async fn execute_middleware_function<'a>(
async fn execute_http_function(
function: PyFunction,
payload: &mut web::Payload,
headers: &Headers,
headers: HashMap<String, String>,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
number_of_params: u8,
// need to change this to return a response struct
// create a custom struct for this
) -> Result<HashMap<String, String>> {
let mut data: Option<Vec<u8>> = None;
let mut data: Vec<u8> = Vec::new();

if req.method() == Method::POST
|| req.method() == Method::PUT
Expand All @@ -240,28 +255,21 @@ async fn execute_http_function(
body.extend_from_slice(&chunk);
}

data = Some(body.to_vec())
data = body.to_vec()
}

// request object accessible while creating routes
let mut request = HashMap::new();
let mut headers_python = HashMap::new();
for elem in headers.into_iter() {
headers_python.insert(elem.key().clone(), elem.value().clone());
}

match function {
PyFunction::CoRoutine(handler) => {
let output = Python::with_gil(|py| {
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("queries", queries.into_py(py));
request.insert("headers", headers_python.into_py(py));

if let Some(res) = data {
let data = res.into_py(py);
request.insert("body", data);
};
request.insert("headers", headers.into_py(py));
let data = data.into_py(py);
request.insert("body", data);

// this makes the request object to be accessible across every route
let coro: PyResult<&PyAny> = match number_of_params {
Expand Down Expand Up @@ -298,11 +306,9 @@ async fn execute_http_function(
Python::with_gil(|py| {
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("headers", headers_python.into_py(py));
if let Some(res) = data {
let data = res.into_py(py);
request.insert("body", data);
};
request.insert("headers", headers.into_py(py));
let data = data.into_py(py);
request.insert("body", data);

let output: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
Expand All @@ -325,22 +331,24 @@ pub async fn execute_event_handler(
event_handler: Option<Arc<PyFunction>>,
event_loop: Arc<Py<PyAny>>,
) {
if let Some(handler) = event_handler { match &(*handler) {
PyFunction::SyncFunction(function) => {
println!("Startup event handler");
Python::with_gil(|py| {
function.call0(py).unwrap();
});
}
PyFunction::CoRoutine(function) => {
let future = Python::with_gil(|py| {
println!("Startup event handler async");

let coroutine = function.as_ref(py).call0().unwrap();
pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine)
.unwrap()
});
future.await.unwrap();
if let Some(handler) = event_handler {
match &(*handler) {
PyFunction::SyncFunction(function) => {
println!("Startup event handler");
Python::with_gil(|py| {
function.call0(py).unwrap();
});
}
PyFunction::CoRoutine(function) => {
let future = Python::with_gil(|py| {
println!("Startup event handler async");

let coroutine = function.as_ref(py).call0().unwrap();
pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine)
.unwrap()
});
future.await.unwrap();
}
}
} }
}
}
Loading