Skip to content

Commit

Permalink
Added typing to examples (#5256)
Browse files Browse the repository at this point in the history
  • Loading branch information
WisdomPill authored Nov 18, 2020
1 parent 5357858 commit 6f73339
Show file tree
Hide file tree
Showing 16 changed files with 82 additions and 72 deletions.
17 changes: 9 additions & 8 deletions examples/background_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
"""Example of aiohttp.web.Application.on_startup signal handler"""
import asyncio

import aioredis
import aioredis # type: ignore

from aiohttp import web


async def websocket_handler(request):
async def websocket_handler(request: web.Request) -> web.StreamResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
request.app["websockets"].append(ws)
Expand All @@ -20,14 +20,15 @@ async def websocket_handler(request):
return ws


async def on_shutdown(app):
async def on_shutdown(app: web.Application) -> None:
for ws in app["websockets"]:
await ws.close(code=999, message="Server shutdown")


async def listen_to_redis(app):
async def listen_to_redis(app: web.Application) -> None:
try:
sub = await aioredis.create_redis(("localhost", 6379), loop=app.loop)
loop = asyncio.get_event_loop()
sub = await aioredis.create_redis(("localhost", 6379), loop=loop)
ch, *_ = await sub.subscribe("news")
async for msg in ch.iter(encoding="utf-8"):
# Forward message to all connected websockets:
Expand All @@ -43,17 +44,17 @@ async def listen_to_redis(app):
print("Redis connection closed.")


async def start_background_tasks(app):
async def start_background_tasks(app: web.Application) -> None:
app["redis_listener"] = app.loop.create_task(listen_to_redis(app))


async def cleanup_background_tasks(app):
async def cleanup_background_tasks(app: web.Application) -> None:
print("cleanup background tasks...")
app["redis_listener"].cancel()
await app["redis_listener"]


def init():
def init() -> web.Application:
app = web.Application()
app["websockets"] = []
app.router.add_get("/news", websocket_handler)
Expand Down
5 changes: 3 additions & 2 deletions examples/cli_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
"""

from argparse import ArgumentParser
from typing import Optional, Sequence

from aiohttp import web


def display_message(req):
async def display_message(req: web.Request) -> web.StreamResponse:
args = req.app["args"]
text = "\n".join([args.message] * args.repeat)
return web.Response(text=text)


def init(argv):
def init(argv: Optional[Sequence[str]]) -> web.Application:
arg_parser = ArgumentParser(
prog="aiohttp.web ...", description="Application CLI", add_help=False
)
Expand Down
4 changes: 2 additions & 2 deletions examples/client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import aiohttp


async def fetch(session):
async def fetch(session: aiohttp.ClientSession) -> None:
print("Query http://httpbin.org/basic-auth/andrew/password")
async with session.get("http://httpbin.org/basic-auth/andrew/password") as resp:
print(resp.status)
body = await resp.text()
print(body)


async def go():
async def go() -> None:
async with aiohttp.ClientSession(
auth=aiohttp.BasicAuth("andrew", "password")
) as session:
Expand Down
4 changes: 2 additions & 2 deletions examples/client_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import aiohttp


async def fetch(session):
async def fetch(session: aiohttp.ClientSession):
print("Query http://httpbin.org/get")
async with session.get("http://httpbin.org/get") as resp:
print(resp.status)
data = await resp.json()
print(data)


async def go():
async def go() -> None:
async with aiohttp.ClientSession() as session:
await fetch(session)

Expand Down
8 changes: 4 additions & 4 deletions examples/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import aiohttp


async def start_client(loop, url):
async def start_client(loop: asyncio.AbstractEventLoop, url: str) -> None:
name = input("Please enter your name: ")

# input reader
def stdin_callback():
def stdin_callback() -> None:
line = sys.stdin.buffer.readline().decode("utf-8")
if not line:
loop.stop()
Expand All @@ -21,7 +21,7 @@ def stdin_callback():

loop.add_reader(sys.stdin.fileno(), stdin_callback)

async def dispatch():
async def dispatch() -> None:
while True:
msg = await ws.receive()

Expand All @@ -30,7 +30,7 @@ async def dispatch():
elif msg.type == aiohttp.WSMsgType.BINARY:
print("Binary: ", msg.data)
elif msg.type == aiohttp.WSMsgType.PING:
ws.pong()
await ws.pong()
elif msg.type == aiohttp.WSMsgType.PONG:
print("Pong received")
else:
Expand Down
2 changes: 1 addition & 1 deletion examples/curl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import aiohttp


async def curl(url):
async def curl(url: str) -> None:
async with aiohttp.ClientSession() as session:
async with session.request("GET", url) as response:
print(repr(response))
Expand Down
41 changes: 23 additions & 18 deletions examples/fake_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@
import pathlib
import socket
import ssl
from typing import Any, Dict, List, Union

import aiohttp
from aiohttp import web
from aiohttp.resolver import DefaultResolver
from aiohttp.test_utils import unused_port
from aiohttp import ClientSession, TCPConnector, resolver, test_utils, web


class FakeResolver:
_LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1", socket.AF_INET6: "::1"}

def __init__(self, fakes):
def __init__(self, fakes: Dict[str, int]) -> None:
"""fakes -- dns -> port dict"""
self._fakes = fakes
self._resolver = DefaultResolver()

async def resolve(self, host, port=0, family=socket.AF_INET):
self._resolver = resolver.DefaultResolver()

async def resolve(
self,
host: str,
port: int = 0,
family: Union[socket.AddressFamily, int] = socket.AF_INET,
) -> List[Dict[str, Any]]:
fake_port = self._fakes.get(host)
if fake_port is not None:
return [
Expand All @@ -34,9 +37,12 @@ async def resolve(self, host, port=0, family=socket.AF_INET):
else:
return await self._resolver.resolve(host, port, family)

async def close(self) -> None:
self._resolver.close()


class FakeFacebook:
def __init__(self):
def __init__(self) -> None:
self.app = web.Application()
self.app.router.add_routes(
[
Expand All @@ -51,21 +57,20 @@ def __init__(self):
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
self.ssl_context.load_cert_chain(str(ssl_cert), str(ssl_key))

async def start(self):
port = unused_port()
self.runner = web.AppRunner(self.app)
async def start(self) -> Dict[str, int]:
port = test_utils.unused_port()
await self.runner.setup()
site = web.TCPSite(self.runner, "127.0.0.1", port, ssl_context=self.ssl_context)
await site.start()
return {"graph.facebook.com": port}

async def stop(self):
async def stop(self) -> None:
await self.runner.cleanup()

async def on_me(self, request):
async def on_me(self, request: web.Request) -> web.StreamResponse:
return web.json_response({"name": "John Doe", "id": "12345678901234567"})

async def on_my_friends(self, request):
async def on_my_friends(self, request: web.Request) -> web.StreamResponse:
return web.json_response(
{
"data": [
Expand All @@ -88,15 +93,15 @@ async def on_my_friends(self, request):
)


async def main():
async def main() -> None:
token = "ER34gsSGGS34XCBKd7u"

fake_facebook = FakeFacebook()
info = await fake_facebook.start()
resolver = FakeResolver(info)
connector = aiohttp.TCPConnector(resolver=resolver, ssl=False)
connector = TCPConnector(resolver=resolver, ssl=False)

async with aiohttp.ClientSession(connector=connector) as session:
async with ClientSession(connector=connector) as session:
async with session.get(
"https://graph.facebook.com/v2.7/me", params={"access_token": token}
) as resp:
Expand Down
6 changes: 3 additions & 3 deletions examples/lowlevel_srv.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio

from aiohttp import web
from aiohttp import web, web_request


async def handler(request):
async def handler(request: web_request.BaseRequest) -> web.StreamResponse:
return web.Response(text="OK")


async def main(loop):
async def main(loop: asyncio.AbstractEventLoop) -> None:
server = web.Server(handler)
await loop.create_server(server, "127.0.0.1", 8080)
print("======= Serving on http://127.0.0.1:8080/ ======")
Expand Down
4 changes: 2 additions & 2 deletions examples/server_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from aiohttp import web


async def handle(request):
async def handle(request: web.Request) -> web.StreamResponse:
name = request.match_info.get("name", "Anonymous")
text = "Hello, " + name
return web.Response(text=text)


async def wshandle(request):
async def wshandle(request: web.Request) -> web.StreamResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)

Expand Down
9 changes: 4 additions & 5 deletions examples/web_classview.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
"""Example for aiohttp.web class based views
"""


import functools
import json

from aiohttp import web


class MyView(web.View):
async def get(self):
async def get(self) -> web.StreamResponse:
return web.json_response(
{
"method": self.request.method,
Expand All @@ -20,7 +19,7 @@ async def get(self):
dumps=functools.partial(json.dumps, indent=4),
)

async def post(self):
async def post(self) -> web.StreamResponse:
data = await self.request.post()
return web.json_response(
{
Expand All @@ -32,7 +31,7 @@ async def post(self):
)


async def index(request):
async def index(request: web.Request) -> web.StreamResponse:
txt = """
<html>
<head>
Expand All @@ -51,7 +50,7 @@ async def index(request):
return web.Response(text=txt, content_type="text/html")


def init():
def init() -> web.Application:
app = web.Application()
app.router.add_get("/", index)
app.router.add_get("/get", MyView)
Expand Down
8 changes: 4 additions & 4 deletions examples/web_cookies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,25 @@
</html>"""


async def root(request):
async def root(request: web.Request) -> web.StreamResponse:
resp = web.Response(content_type="text/html")
resp.text = tmpl.format(pformat(request.cookies))
return resp


async def login(request):
async def login(request: web.Request) -> None:
exc = web.HTTPFound(location="/")
exc.set_cookie("AUTH", "secret")
raise exc


async def logout(request):
async def logout(request: web.Request) -> None:
exc = web.HTTPFound(location="/")
exc.del_cookie("AUTH")
raise exc


def init():
def init() -> web.Application:
app = web.Application()
app.router.add_get("/", root)
app.router.add_get("/login", login)
Expand Down
9 changes: 6 additions & 3 deletions examples/web_rewrite_headers_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
"""
Example for rewriting response headers by middleware.
"""
from typing import Awaitable, Callable

from aiohttp import web

_WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]]

async def handler(request):

async def handler(request: web.Request) -> web.StreamResponse:
return web.Response(text="Everything is fine")


async def middleware(request, handler):
async def middleware(request: web.Request, handler: _WebHandler) -> web.StreamResponse:
try:
response = await handler(request)
except web.HTTPException as exc:
Expand All @@ -20,7 +23,7 @@ async def middleware(request, handler):
return response


def init():
def init() -> web.Application:
app = web.Application(middlewares=[middleware])
app.router.add_get("/", handler)
return app
Expand Down
Loading

0 comments on commit 6f73339

Please sign in to comment.