Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Event handlers #154

Merged
merged 6 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,25 @@
websocket = WS(app, "/web_socket")
i = -1


@websocket.on("message")
async def connect():
global i
i+=1
if i==0:
i += 1
if i == 0:
return "Whaaat??"
elif i==1:
elif i == 1:
return "Whooo??"
elif i==2:
elif i == 2:
i = -1
return "*chika* *chika* Slim Shady."


@websocket.on("close")
def close():
return "GoodBye world, from ws"


@websocket.on("connect")
def message():
return "Hello world, from ws"
Expand All @@ -35,7 +38,7 @@ def message():
async def hello(request):
global callCount
callCount += 1
message = "Called " + str(callCount) + " times"
_message = "Called " + str(callCount) + " times"
return jsonify(request)


Expand All @@ -47,10 +50,12 @@ async def test(request):

return static_file(html_file)


@app.get("/jsonify")
async def json_get():
return jsonify({"hello": "world"})


@app.get("/query")
async def query_get(request):
query_data = request["queries"]
Expand All @@ -62,18 +67,22 @@ async def json(request):
print(request["params"]["id"])
return jsonify({"hello": "world"})


@app.post("/post")
async def post():
return "POST Request"


@app.post("/post_with_body")
async def postreq_with_body(request):
return bytearray(request["body"]).decode("utf-8")


@app.put("/put")
async def put(request):
return "PUT Request"


@app.put("/put_with_body")
async def putreq_with_body(request):
print(request)
Expand All @@ -84,6 +93,7 @@ async def putreq_with_body(request):
async def delete():
return "DELETE Request"


@app.delete("/delete_with_body")
async def deletereq_with_body(request):
return bytearray(request["body"]).decode("utf-8")
Expand All @@ -93,6 +103,7 @@ async def deletereq_with_body(request):
async def patch():
return "PATCH Request"


@app.patch("/patch_with_body")
async def patchreq_with_body(request):
return bytearray(request["body"]).decode("utf-8")
Expand All @@ -107,14 +118,29 @@ async def sleeper():
@app.get("/blocker")
def blocker():
import time

time.sleep(10)
return "blocker function"


async def startup_handler():
print("Starting up")


@app.shutdown_handler
def shutdown_handler():
print("Shutting down")


if __name__ == "__main__":
ROBYN_URL = os.getenv("ROBYN_URL", '0.0.0.0')
ROBYN_URL = os.getenv("ROBYN_URL", "0.0.0.0")
app.add_header("server", "robyn")
current_file_path = pathlib.Path(__file__).parent.resolve()
os.path.join(current_file_path, "build")
app.add_directory(route="/test_dir",directory_path=os.path.join(current_file_path, "build/"), index_file="index.html")
app.add_directory(
route="/test_dir",
directory_path=os.path.join(current_file_path, "build/"),
index_file="index.html",
)
app.startup_handler(startup_handler)
app.start(port=5000, url=ROBYN_URL)
77 changes: 55 additions & 22 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@


class Robyn:
"""This is the python wrapper for the Robyn binaries.
"""
"""This is the python wrapper for the Robyn binaries."""

def __init__(self, file_object):
directory_path = os.path.dirname(os.path.abspath(file_object))
self.file_path = file_object
self.directory_path = directory_path
self.server = Server(directory_path)
self.parser = ArgumentParser()
self.dev = self.parser.is_dev()
self.processes = self.parser.num_processes()
Expand All @@ -37,6 +36,7 @@ def __init__(self, file_object):
self.routes = []
self.directories = []
self.web_sockets = {}
self.event_handlers = {}

def add_route(self, route_type, endpoint, handler):
"""
Expand All @@ -51,25 +51,44 @@ def add_route(self, route_type, endpoint, handler):
"""
number_of_params = len(signature(handler).parameters)
self.routes.append(
(route_type,
endpoint,
handler,
asyncio.iscoroutinefunction(handler), number_of_params)
(
route_type,
endpoint,
handler,
asyncio.iscoroutinefunction(handler),
number_of_params,
)
)

def add_directory(self, route, directory_path, index_file=None, show_files_listing=False):
def add_directory(
self, route, directory_path, index_file=None, show_files_listing=False
):
self.directories.append((route, directory_path, index_file, show_files_listing))

def add_header(self, key, value):
self.headers.append((key, value))

def remove_header(self, key):
self.server.remove_header(key)

def add_web_socket(self, endpoint, ws):
self.web_sockets[endpoint] = ws

def start(self, url="127.0.0.1", port=5000):
def _add_event_handler(self, event_type: str, handler):
print(f"Add event {event_type} handler")
if event_type.lower() not in {"startup", "shutdown"}:
return

is_async = asyncio.iscoroutinefunction(handler)
if event_type.lower() == "startup":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor, but for clarity, I would consider to:

  • Extract the "startup"/"shutdown" string either in an Enum (like Class(str, Enum) or as constants
  • DIrectly use "startup"/"shutdown" as keys for the dictionary event_handlers

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Updated :D

self.event_handlers["startup_handler"] = (handler, is_async)
else:
self.event_handlers["shutdown_handler"] = (handler, is_async)

def startup_handler(self, handler):
self._add_event_handler("startup", handler)

def shutdown_handler(self, handler):
self._add_event_handler("shutdown", handler)

def start(self, url="128.0.0.1", port=5000):
"""
[Starts the server]

Expand All @@ -78,25 +97,31 @@ def start(self, url="127.0.0.1", port=5000):
if not self.dev:
workers = self.workers
socket = SocketHeld(url, port)
for process_number in range(self.processes):
copied = socket.try_clone()
for _ in range(self.processes):
copied_socket = socket.try_clone()
p = Process(
target=spawn_process,
args=(url, port, self.directories, self.headers,
self.routes, self.web_sockets, copied,
f"Process {process_number}", workers),
args=(
self.directories,
self.headers,
self.routes,
self.web_sockets,
self.event_handlers,
copied_socket,
workers,
),
)
p.start()

print("Press Ctrl + C to stop \n")
else:
event_handler = EventHandler(self.file_path)
event_handler.start_server_first_time()
print(f"{Colors.OKBLUE}Dev server initialised with the directory_path : {self.directory_path}{Colors.ENDC}")
print(
f"{Colors.OKBLUE}Dev server initialised with the directory_path : {self.directory_path}{Colors.ENDC}"
)
observer = Observer()
observer.schedule(event_handler,
path=self.directory_path,
recursive=True)
observer.schedule(event_handler, path=self.directory_path, recursive=True)
observer.start()
try:
while True:
Expand All @@ -111,6 +136,7 @@ def get(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("GET", endpoint, handler)

Expand All @@ -122,6 +148,7 @@ def post(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("POST", endpoint, handler)

Expand All @@ -133,6 +160,7 @@ def put(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("PUT", endpoint, handler)

Expand All @@ -144,6 +172,7 @@ def delete(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("DELETE", endpoint, handler)

Expand All @@ -155,6 +184,7 @@ def patch(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("PATCH", endpoint, handler)

Expand All @@ -166,6 +196,7 @@ def head(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("HEAD", endpoint, handler)

Expand All @@ -177,6 +208,7 @@ def options(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("OPTIONS", endpoint, handler)

Expand All @@ -188,6 +220,7 @@ def connect(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("CONNECT", endpoint, handler)

Expand All @@ -199,8 +232,8 @@ def trace(self, endpoint):

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
self.add_route("TRACE", endpoint, handler)

return inner

23 changes: 18 additions & 5 deletions robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import sys
import multiprocessing as mp
import asyncio

# import platform


mp.allow_connection_pickling()


def spawn_process(url, port, directories, headers, routes, web_sockets, socket, process_name, workers):
def spawn_process(
directories, headers, routes, web_sockets, event_handlers, socket, workers
):
"""
This function is called by the main process handler to create a server runtime.
This functions allows one runtime per process.
Expand All @@ -31,14 +34,13 @@ def spawn_process(url, port, directories, headers, routes, web_sockets, socket,
# uv loop doesn't support windows or arm machines at the moment
# but uv loop is much faster than native asyncio
import uvloop

uvloop.install()
loop = uvloop.new_event_loop()
asyncio.set_event_loop(loop)

server = Server()

print(directories)

for directory in directories:
route, directory_path, index_file, show_files_listing = directory
server.add_directory(route, directory_path, index_file, show_files_listing)
Expand All @@ -50,10 +52,21 @@ def spawn_process(url, port, directories, headers, routes, web_sockets, socket,
route_type, endpoint, handler, is_async, number_of_params = route
server.add_route(route_type, endpoint, handler, is_async, number_of_params)

if "startup_handler" in event_handlers:
server.add_startup_handler(event_handlers["startup_handler"][0], event_handlers["startup_handler"][1])

if "shutdown_handler" in event_handlers:
server.add_shutdown_handler(event_handlers["shutdown_handler"][0], event_handlers["shutdown_handler"][1])

for endpoint in web_sockets:
web_socket = web_sockets[endpoint]
print(web_socket.methods)
server.add_web_socket_route(endpoint, web_socket.methods["connect"], web_socket.methods["close"], web_socket.methods["message"])
server.add_web_socket_route(
endpoint,
web_socket.methods["connect"],
web_socket.methods["close"],
web_socket.methods["message"],
)

server.start(url, port, socket, process_name, workers)
server.start(socket, workers)
asyncio.get_event_loop().run_forever()
Loading