Skip to content

Commit

Permalink
fix: fix request headers not being propagated (#232)
Browse files Browse the repository at this point in the history
* fix: fix request headers not being propagated

* Update the headers

* fix circle ci

* chore: fix clippy
  • Loading branch information
sansyrox authored Jul 19, 2022
1 parent 500e261 commit 870f3f6
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 38 deletions.
6 changes: 3 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ orbs:
# Orb commands and jobs help you with common scripting around a language/tool
# so you dont have to copy and paste it everywhere.
# See the orb documentation here: https://circleci.com/developer/orbs/orb/circleci/python
python: circleci/python@2.0.3
rust: circleci/rust@1.6.0
python: circleci/python@1.2
rust: circleci/rust@1.5.0

commands:
test:
Expand Down Expand Up @@ -152,6 +152,6 @@ jobs:
command: |
curl https://sh.rustup.rs -sSf | sh -s -- -y
source $HOME/.cargo/env
maturin build -i python --release --universal2 --out dist --no-sdist
maturin build -i python --release --universal2 --out dist
pip install --force-reinstall dist/robyn*.whl
pytest ~/project/integration_tests
12 changes: 11 additions & 1 deletion integration_tests/base_routes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from robyn import Robyn, static_file, jsonify, WS

from robyn.log_colors import Colors
import asyncio
import os
import pathlib
import logging

app = Robyn(__file__)
websocket = WS(app, "/web_socket")
i = -1

logger = logging.getLogger(__name__)


@websocket.on("message")
async def connect(websocket_id):
Expand Down Expand Up @@ -102,7 +107,6 @@ async def query_get(request):

@app.post("/jsonify/:id")
async def json(request):
print(request["params"]["id"])
return jsonify({"hello": "world"})


Expand All @@ -127,6 +131,12 @@ async def putreq_with_body(request):
return bytearray(request["body"]).decode("utf-8")


@app.post("/headers")
async def postreq_with_headers(request):
logger.info(f"{Colors.OKGREEN} {request['headers']} \n{Colors.ENDC}")
return jsonify(request["headers"])


@app.delete("/delete")
async def delete():
return "DELETE Request"
Expand Down
8 changes: 7 additions & 1 deletion integration_tests/test_post_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ def test_post_with_param(session):
def test_jsonify_request(session):
res = requests.post(f"{BASE_URL}/jsonify/123")
assert(res.status_code == 200)
assert res.json()=={"hello":"world"}
assert res.json() == {"hello": "world"}


def test_post_request_headers(session):
res = requests.post(f"{BASE_URL}/headers", headers={"hello": "world"})
assert(res.status_code == 200)
assert res.json()["hello"] == "world"

18 changes: 8 additions & 10 deletions src/executors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use actix_web::{http::Method, web, HttpRequest};
use anyhow::{bail, Result};
use log::debug;
// pyO3 module
use crate::types::{Headers, PyFunction};
use crate::types::PyFunction;
use futures_util::stream::StreamExt;
use pyo3::prelude::*;
use pyo3::types::PyDict;
Expand All @@ -22,15 +22,14 @@ const MAX_SIZE: usize = 10_000;
pub async fn execute_middleware_function<'a>(
function: PyFunction,
payload: &mut web::Payload,
headers: &Headers,
headers: &HashMap<String, String>,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: Rc<RefCell<HashMap<String, String>>>,
number_of_params: u8,
) -> Result<HashMap<String, HashMap<String, String>>> {
// TODO:
// try executing the first version of middleware(s) here
// with just headers as params
// add body in middlewares too

let mut data: Vec<u8> = Vec::new();

Expand All @@ -54,10 +53,7 @@ pub async fn execute_middleware_function<'a>(

// request object accessible while creating routes
let mut request = HashMap::new();
let mut headers_python = HashMap::new();
for elem in (*headers).iter() {
headers_python.insert(elem.key().clone(), elem.value().clone());
}

let mut queries_clone: HashMap<String, String> = HashMap::new();

for (key, value) in (*queries).borrow().clone() {
Expand All @@ -70,7 +66,8 @@ pub async fn execute_middleware_function<'a>(
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("queries", queries_clone.into_py(py));
request.insert("headers", headers_python.into_py(py));
// is this a bottleneck again?
request.insert("headers", headers.clone().into_py(py));
// request.insert("body", data.into_py(py));

// this makes the request object to be accessible across every route
Expand Down Expand Up @@ -104,7 +101,8 @@ pub async fn execute_middleware_function<'a>(
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("queries", queries_clone.into_py(py));
request.insert("headers", headers_python.into_py(py));
// is this a bottleneck again?
request.insert("headers", headers.clone().into_py(py));
request.insert("body", data.into_py(py));

let output: PyResult<&PyAny> = match number_of_params {
Expand Down
17 changes: 9 additions & 8 deletions src/request_handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use crate::executors::{execute_http_function, execute_middleware_function};
use log::debug;
use std::rc::Rc;
use std::str::FromStr;
use std::sync::Arc;
use std::{cell::RefCell, collections::HashMap};

use actix_web::{web, HttpRequest, HttpResponse, HttpResponseBuilder};
// pyO3 module
use crate::types::{Headers, PyFunction};
use crate::types::PyFunction;

#[inline]
pub fn apply_headers(response: &mut HttpResponseBuilder, headers: HashMap<String, String>) {
Expand Down Expand Up @@ -61,37 +60,39 @@ pub async fn handle_http_request(
let status_code =
actix_http::StatusCode::from_str(contents.get("status_code").unwrap()).unwrap();

let headers: HashMap<String, String> = match contents.get("headers") {
let response_headers: HashMap<String, String> = match contents.get("headers") {
Some(headers) => {
let h: HashMap<String, String> = serde_json::from_str(headers).unwrap();
h
}
None => HashMap::new(),
};

debug!("These are the headers from serde {:?}", headers);
debug!(
"These are the request headers from serde {:?}",
response_headers
);

let mut response = HttpResponse::build(status_code);
apply_headers(&mut response, headers);
apply_headers(&mut response, response_headers);
let final_response = if !body.is_empty() {
response.body(body)
} else {
response.finish()
};

debug!(
"The status code is {} and the headers are {:?}",
"The response status code is {} and the headers are {:?}",
final_response.status(),
final_response.headers()
);
// response.body(contents.get("body").unwrap().to_owned())
final_response
}

pub async fn handle_http_middleware_request(
function: PyFunction,
number_of_params: u8,
headers: &Arc<Headers>,
headers: &HashMap<String, String>,
payload: &mut web::Payload,
req: &HttpRequest,
route_params: HashMap<String, String>,
Expand Down
58 changes: 43 additions & 15 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::{Arc, RwLock};
use std::thread;

use actix_files::Files;
use actix_http::header::HeaderMap;
use actix_http::KeepAlive;
use actix_web::*;
use dashmap::DashMap;
Expand All @@ -44,7 +45,7 @@ pub struct Server {
const_router: Arc<ConstRouter>,
websocket_router: Arc<WebSocketRouter>,
middleware_router: Arc<MiddlewareRouter>,
headers: Arc<DashMap<String, String>>,
global_headers: Arc<DashMap<String, String>>,
directories: Arc<RwLock<Vec<Directory>>>,
startup_handler: Option<Arc<PyFunction>>,
shutdown_handler: Option<Arc<PyFunction>>,
Expand All @@ -59,7 +60,7 @@ impl Server {
const_router: Arc::new(ConstRouter::new()),
websocket_router: Arc::new(WebSocketRouter::new()),
middleware_router: Arc::new(MiddlewareRouter::new()),
headers: Arc::new(DashMap::new()),
global_headers: Arc::new(DashMap::new()),
directories: Arc::new(RwLock::new(Vec::new())),
startup_handler: None,
shutdown_handler: None,
Expand Down Expand Up @@ -91,7 +92,7 @@ impl Server {
let const_router = self.const_router.clone();
let middleware_router = self.middleware_router.clone();
let web_socket_router = self.websocket_router.clone();
let headers = self.headers.clone();
let global_headers = self.global_headers.clone();
let directories = self.directories.clone();
let workers = Arc::new(workers);

Expand Down Expand Up @@ -147,7 +148,7 @@ impl Server {
.app_data(web::Data::new(router.clone()))
.app_data(web::Data::new(const_router.clone()))
.app_data(web::Data::new(middleware_router.clone()))
.app_data(web::Data::new(headers.clone()));
.app_data(web::Data::new(global_headers.clone()));

let web_socket_map = web_socket_router.get_web_socket_map();
for (elem, value) in (web_socket_map.read().unwrap()).iter() {
Expand All @@ -158,7 +159,7 @@ impl Server {
&route.clone(),
web::get().to(
move |_router: web::Data<Arc<Router>>,
_headers: web::Data<Arc<Headers>>,
_global_headers: web::Data<Arc<Headers>>,
stream: web::Payload,
req: HttpRequest| {
start_web_socket(
Expand All @@ -176,7 +177,7 @@ impl Server {
move |router,
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
headers,
global_headers,
payload,
req| {
pyo3_asyncio::tokio::scope_local(
Expand All @@ -186,7 +187,7 @@ impl Server {
router,
const_router,
middleware_router,
headers,
global_headers,
payload,
req,
)
Expand Down Expand Up @@ -242,13 +243,14 @@ impl Server {
/// Adds a new header to our concurrent hashmap
/// this can be called after the server has started.
pub fn add_header(&self, key: &str, value: &str) {
self.headers.insert(key.to_string(), value.to_string());
self.global_headers
.insert(key.to_string(), value.to_string());
}

/// Removes a new header to our concurrent hashmap
/// this can be called after the server has started.
pub fn remove_header(&self, key: &str) {
self.headers.remove(key);
self.global_headers.remove(key);
}

/// Add a new route to the routing tables
Expand Down Expand Up @@ -345,13 +347,34 @@ impl Default for Server {
}
}

async fn merge_headers(
global_headers: &Arc<Headers>,
request_headers: &HeaderMap,
) -> HashMap<String, String> {
let mut headers = HashMap::new();

for elem in (global_headers).iter() {
headers.insert(elem.key().clone(), elem.value().clone());
}

for (key, value) in (request_headers).iter() {
headers.insert(
key.to_string().clone(),
// test if this crashes or not
value.to_str().unwrap().to_string().clone(),
);
}

headers
}

/// This is our service handler. It receives a Request, routes on it
/// path, and returns a Future of a Response.
async fn index(
router: web::Data<Arc<Router>>,
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
headers: web::Data<Arc<Headers>>,
global_headers: web::Data<Arc<Headers>>,
mut payload: web::Payload,
req: HttpRequest,
) -> impl Responder {
Expand All @@ -367,6 +390,9 @@ async fn index(
}
}

let headers = merge_headers(&global_headers, req.headers()).await;

// need a better name for this
let tuple_params = match middleware_router.get_route("BEFORE_REQUEST", req.uri().path()) {
Some(((handler_function, number_of_params), route_params)) => {
let x = handle_http_middleware_request(
Expand All @@ -387,11 +413,13 @@ async fn index(

debug!("These are the tuple params {:?}", tuple_params);

let mut headers_dup = HashMap::new();
let headers_dup = if !tuple_params.is_empty() {
tuple_params.get("headers").unwrap().clone()
} else {
headers
};

if !tuple_params.is_empty() {
headers_dup = tuple_params.get("headers").unwrap().clone();
}
debug!("These are the request headers {:?}", headers_dup);

let response = if const_router
.get_route(req.method().clone(), req.uri().path())
Expand Down Expand Up @@ -431,7 +459,7 @@ async fn index(
let x = handle_http_middleware_request(
handler_function,
number_of_params,
&headers,
&headers_dup,
&mut payload,
&req,
route_params,
Expand Down

0 comments on commit 870f3f6

Please sign in to comment.