diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 873f18c7d..e06f6ca15 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -2,27 +2,34 @@ import asyncio import os import pathlib +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) app = Robyn(__file__) websocket = WS(app, "/web_socket") i = -1 + @websocket.on("message") async def connect(): global i - i+=1 - if i==0: + i += 1 + if i == 0: return "Whaaat??" - elif i==1: + elif i == 1: return "Whooo??" - elif i==2: + elif i == 2: i = -1 return "*chika* *chika* Slim Shady." + @websocket.on("close") def close(): return "GoodBye world, from ws" + @websocket.on("connect") def message(): return "Hello world, from ws" @@ -35,7 +42,7 @@ def message(): async def hello(request): global callCount callCount += 1 - message = "Called " + str(callCount) + " times" + _message = "Called " + str(callCount) + " times" return jsonify(request) @@ -47,10 +54,12 @@ async def test(request): return static_file(html_file) + @app.get("/jsonify") async def json_get(): return jsonify({"hello": "world"}) + @app.get("/query") async def query_get(request): query_data = request["queries"] @@ -62,18 +71,22 @@ async def json(request): print(request["params"]["id"]) return jsonify({"hello": "world"}) + @app.post("/post") async def post(): return "POST Request" + @app.post("/post_with_body") async def postreq_with_body(request): return bytearray(request["body"]).decode("utf-8") + @app.put("/put") async def put(request): return "PUT Request" + @app.put("/put_with_body") async def putreq_with_body(request): print(request) @@ -84,6 +97,7 @@ async def putreq_with_body(request): async def delete(): return "DELETE Request" + @app.delete("/delete_with_body") async def deletereq_with_body(request): return bytearray(request["body"]).decode("utf-8") @@ -93,6 +107,7 @@ async def deletereq_with_body(request): async def patch(): return "PATCH Request" + @app.patch("/patch_with_body") async def patchreq_with_body(request): return bytearray(request["body"]).decode("utf-8") @@ -107,14 +122,29 @@ async def sleeper(): @app.get("/blocker") def blocker(): import time + time.sleep(10) return "blocker function" +async def startup_handler(): + logger.log(logging.INFO, "Starting up") + + +@app.shutdown_handler +def shutdown_handler(): + logger.log(logging.INFO, "Shutting down") + + if __name__ == "__main__": - ROBYN_URL = os.getenv("ROBYN_URL", '0.0.0.0') + ROBYN_URL = os.getenv("ROBYN_URL", "0.0.0.0") app.add_header("server", "robyn") current_file_path = pathlib.Path(__file__).parent.resolve() os.path.join(current_file_path, "build") - app.add_directory(route="/test_dir",directory_path=os.path.join(current_file_path, "build/"), index_file="index.html") + app.add_directory( + route="/test_dir", + directory_path=os.path.join(current_file_path, "build/"), + index_file="index.html", + ) + app.startup_handler(startup_handler) app.start(port=5000, url=ROBYN_URL) diff --git a/integration_tests/test_get_requests.py b/integration_tests/test_get_requests.py index 629d25d45..016704532 100644 --- a/integration_tests/test_get_requests.py +++ b/integration_tests/test_get_requests.py @@ -2,22 +2,26 @@ BASE_URL = "http://127.0.0.1:5000" + def test_index_request(session): res = requests.get(f"{BASE_URL}") - assert(res.status_code == 200) + assert res.status_code == 200 + def test_jsonify(session): r = requests.get(f"{BASE_URL}/jsonify") - assert r.json()=={"hello":"world"} - assert r.status_code==200 + assert r.json() == {"hello": "world"} + assert r.status_code == 200 + def test_html(session): r = requests.get(f"{BASE_URL}/test/123") assert "Hello world. How are you?" in r.text + def test_queries(session): r = requests.get(f"{BASE_URL}/query?hello=robyn") - assert r.json()=={"hello":"robyn"} + assert r.json() == {"hello": "robyn"} r = requests.get(f"{BASE_URL}/query") - assert r.json()=={} + assert r.json() == {} diff --git a/robyn/__init__.py b/robyn/__init__.py index 3cf7bad98..450633f8d 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -3,6 +3,7 @@ import asyncio from inspect import signature import multiprocessing as mp +from robyn.events import Events # custom imports and exports from .robyn import Server, SocketHeld @@ -21,13 +22,12 @@ class Robyn: - """This is the python wrapper for the Robyn binaries. - """ + """This is the python wrapper for the Robyn binaries.""" + def __init__(self, file_object): directory_path = os.path.dirname(os.path.abspath(file_object)) self.file_path = file_object self.directory_path = directory_path - self.server = Server(directory_path) self.parser = ArgumentParser() self.dev = self.parser.is_dev() self.processes = self.parser.num_processes() @@ -37,6 +37,7 @@ def __init__(self, file_object): self.routes = [] self.directories = [] self.web_sockets = {} + self.event_handlers = {} def add_route(self, route_type, endpoint, handler): """ @@ -51,25 +52,41 @@ def add_route(self, route_type, endpoint, handler): """ number_of_params = len(signature(handler).parameters) self.routes.append( - (route_type, - endpoint, - handler, - asyncio.iscoroutinefunction(handler), number_of_params) + ( + route_type, + endpoint, + handler, + asyncio.iscoroutinefunction(handler), + number_of_params, + ) ) - def add_directory(self, route, directory_path, index_file=None, show_files_listing=False): + def add_directory( + self, route, directory_path, index_file=None, show_files_listing=False + ): self.directories.append((route, directory_path, index_file, show_files_listing)) def add_header(self, key, value): self.headers.append((key, value)) - def remove_header(self, key): - self.server.remove_header(key) - def add_web_socket(self, endpoint, ws): self.web_sockets[endpoint] = ws - def start(self, url="127.0.0.1", port=5000): + def _add_event_handler(self, event_type: str, handler): + print(f"Add event {event_type} handler") + if event_type not in {Events.STARTUP, Events.SHUTDOWN}: + return + + is_async = asyncio.iscoroutinefunction(handler) + self.event_handlers[event_type] = (handler, is_async) + + def startup_handler(self, handler): + self._add_event_handler(Events.STARTUP, handler) + + def shutdown_handler(self, handler): + self._add_event_handler(Events.SHUTDOWN, handler) + + def start(self, url="128.0.0.1", port=5000): """ [Starts the server] @@ -78,13 +95,19 @@ def start(self, url="127.0.0.1", port=5000): if not self.dev: workers = self.workers socket = SocketHeld(url, port) - for process_number in range(self.processes): - copied = socket.try_clone() + for _ in range(self.processes): + copied_socket = socket.try_clone() p = Process( target=spawn_process, - args=(url, port, self.directories, self.headers, - self.routes, self.web_sockets, copied, - f"Process {process_number}", workers), + args=( + self.directories, + self.headers, + self.routes, + self.web_sockets, + self.event_handlers, + copied_socket, + workers, + ), ) p.start() @@ -92,11 +115,11 @@ def start(self, url="127.0.0.1", port=5000): else: event_handler = EventHandler(self.file_path) event_handler.start_server_first_time() - print(f"{Colors.OKBLUE}Dev server initialised with the directory_path : {self.directory_path}{Colors.ENDC}") + print( + f"{Colors.OKBLUE}Dev server initialised with the directory_path : {self.directory_path}{Colors.ENDC}" + ) observer = Observer() - observer.schedule(event_handler, - path=self.directory_path, - recursive=True) + observer.schedule(event_handler, path=self.directory_path, recursive=True) observer.start() try: while True: @@ -111,6 +134,7 @@ def get(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("GET", endpoint, handler) @@ -122,6 +146,7 @@ def post(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("POST", endpoint, handler) @@ -133,6 +158,7 @@ def put(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("PUT", endpoint, handler) @@ -144,6 +170,7 @@ def delete(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("DELETE", endpoint, handler) @@ -155,6 +182,7 @@ def patch(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("PATCH", endpoint, handler) @@ -166,6 +194,7 @@ def head(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("HEAD", endpoint, handler) @@ -177,6 +206,7 @@ def options(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("OPTIONS", endpoint, handler) @@ -188,6 +218,7 @@ def connect(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("CONNECT", endpoint, handler) @@ -199,8 +230,8 @@ def trace(self, endpoint): :param endpoint [str]: [endpoint to server the route] """ + def inner(handler): self.add_route("TRACE", endpoint, handler) return inner - diff --git a/robyn/events.py b/robyn/events.py new file mode 100644 index 000000000..522bcb702 --- /dev/null +++ b/robyn/events.py @@ -0,0 +1,3 @@ +class Events: + STARTUP = "startup" + SHUTDOWN = "shutdown" diff --git a/robyn/processpool.py b/robyn/processpool.py index bb3cf2211..fec6d82eb 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -1,24 +1,28 @@ from .robyn import Server +from .events import Events import sys import multiprocessing as mp import asyncio + # import platform mp.allow_connection_pickling() -def spawn_process(url, port, directories, headers, routes, web_sockets, socket, process_name, workers): +def spawn_process( + directories, headers, routes, web_sockets, event_handlers, socket, workers +): """ This function is called by the main process handler to create a server runtime. This functions allows one runtime per process. - :param url string: the base url at which the server will listen - :param port string: the port at which the url will listen to :param directories tuple: the list of all the directories and related data in a tuple :param headers tuple: All the global headers in a tuple :param routes tuple: The routes touple, containing the description about every route. + :param web_sockets list: This is a list of all the web socket routes + :param event_handlers Dict: This is an event dict that contains the event handlers :param socket Socket: This is the main tcp socket, which is being shared across multiple processes. :param process_name string: This is the name given to the process to identify the process :param workers number: This is the name given to the process to identify the process @@ -31,14 +35,13 @@ def spawn_process(url, port, directories, headers, routes, web_sockets, socket, # uv loop doesn't support windows or arm machines at the moment # but uv loop is much faster than native asyncio import uvloop + uvloop.install() loop = uvloop.new_event_loop() asyncio.set_event_loop(loop) server = Server() - print(directories) - for directory in directories: route, directory_path, index_file, show_files_listing = directory server.add_directory(route, directory_path, index_file, show_files_listing) @@ -50,10 +53,21 @@ def spawn_process(url, port, directories, headers, routes, web_sockets, socket, route_type, endpoint, handler, is_async, number_of_params = route server.add_route(route_type, endpoint, handler, is_async, number_of_params) + if "startup" in event_handlers: + server.add_startup_handler(event_handlers[Events.STARTUP][0], event_handlers[Events.STARTUP][1]) + + if "shutdown" in event_handlers: + server.add_shutdown_handler(event_handlers[Events.SHUTDOWN][0], event_handlers[Events.SHUTDOWN][1]) + for endpoint in web_sockets: web_socket = web_sockets[endpoint] print(web_socket.methods) - server.add_web_socket_route(endpoint, web_socket.methods["connect"], web_socket.methods["close"], web_socket.methods["message"]) + server.add_web_socket_route( + endpoint, + web_socket.methods["connect"], + web_socket.methods["close"], + web_socket.methods["message"], + ) - server.start(url, port, socket, process_name, workers) + server.start(socket, workers) asyncio.get_event_loop().run_forever() diff --git a/src/processor.rs b/src/processor.rs index 640416999..52f47e54b 100644 --- a/src/processor.rs +++ b/src/processor.rs @@ -210,3 +210,26 @@ async fn execute_http_function( } } } + +pub async fn execute_event_handler(event_handler: Option, event_loop: Py) { + match event_handler { + Some(handler) => match handler { + PyFunction::SyncFunction(function) => { + println!("Startup event handler"); + Python::with_gil(|py| { + function.call0(py).unwrap(); + }); + } + PyFunction::CoRoutine(function) => { + let future = Python::with_gil(|py| { + println!("Startup event handler async"); + + let coroutine = function.as_ref(py).call0().unwrap(); + pyo3_asyncio::into_future_with_loop(event_loop.as_ref(py), coroutine).unwrap() + }); + future.await.unwrap(); + } + }, + None => {} + } +} diff --git a/src/server.rs b/src/server.rs index e252d343c..d538c79bd 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,15 +1,15 @@ -use crate::processor::{apply_headers, handle_request}; +use crate::processor::{apply_headers, execute_event_handler, handle_request}; use crate::router::Router; use crate::shared_socket::SocketHeld; -use crate::types::Headers; +use crate::types::{Headers, PyFunction}; use crate::web_socket_connection::start_web_socket; +use std::collections::HashMap; use std::convert::TryInto; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; use std::sync::{Arc, RwLock}; use std::thread; -use std::collections::HashMap; use actix_files::Files; use actix_http::KeepAlive; @@ -34,6 +34,8 @@ pub struct Server { router: Arc, headers: Arc>, directories: Arc>>, + startup_handler: Option, + shutdown_handler: Option, } #[pymethods] @@ -44,16 +46,15 @@ impl Server { router: Arc::new(Router::new()), headers: Arc::new(DashMap::new()), directories: Arc::new(RwLock::new(Vec::new())), + startup_handler: None, + shutdown_handler: None, } } pub fn start( &mut self, py: Python, - _url: String, - _port: u16, socket: &PyCell, - _name: String, workers: usize, ) -> PyResult<()> { if STARTED @@ -81,15 +82,20 @@ impl Server { .call_method1("set_event_loop", (event_loop,)) .unwrap(); let event_loop_hdl = PyObject::from(event_loop); + let event_loop_cleanup = PyObject::from(event_loop); + let startup_handler = self.startup_handler.clone(); + let shutdown_handler = self.shutdown_handler.clone(); thread::spawn(move || { //init_current_thread_once(); + let copied_event_loop = event_loop_hdl.clone(); actix_web::rt::System::new().block_on(async move { println!("The number of workers are {}", workers.clone()); + execute_event_handler(startup_handler, copied_event_loop.clone()).await; HttpServer::new(move || { let mut app = App::new(); - let event_loop_hdl = event_loop_hdl.clone(); + let event_loop_hdl = copied_event_loop.clone(); let directories = directories.read().unwrap(); let router_copy = router.clone(); @@ -160,7 +166,18 @@ impl Server { }); }); - event_loop.call_method0("run_forever").unwrap(); + let event_loop = event_loop.call_method0("run_forever"); + if event_loop.is_err() { + println!("Ctrl c handler"); + Python::with_gil(|py| { + let event_loop_hdl = event_loop_cleanup.clone(); + pyo3_asyncio::tokio::run(py, async move { + execute_event_handler(shutdown_handler, event_loop_hdl.clone()).await; + Ok(()) + }) + .unwrap(); + }) + } Ok(()) } @@ -219,6 +236,27 @@ impl Server { self.router .add_websocket_route(route, connect_route, close_route, message_route); } + + /// Add a new startup handler + pub fn add_startup_handler(&mut self, handler: Py, is_async: bool) { + println!("Adding startup handler"); + match is_async { + true => self.startup_handler = Some(PyFunction::CoRoutine(handler)), + false => self.startup_handler = Some(PyFunction::SyncFunction(handler)), + }; + println!("{:?}", self.startup_handler); + } + + /// Add a new shutdown handler + pub fn add_shutdown_handler(&mut self, handler: Py, is_async: bool) { + println!("Adding shutdown handler"); + match is_async { + true => self.shutdown_handler = Some(PyFunction::CoRoutine(handler)), + false => self.shutdown_handler = Some(PyFunction::SyncFunction(handler)), + }; + println!("{:?}", self.startup_handler); + println!("{:?}", self.shutdown_handler); + } } impl Default for Server { @@ -236,7 +274,7 @@ async fn index( req: HttpRequest, ) -> impl Responder { let mut queries = HashMap::new(); - + if req.query_string().len() > 0 { let split = req.query_string().split("&"); for s in split {