diff --git a/README.md b/README.md index 921d088..016c7af 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ While micro is covered by the [LGPL](https://www.gnu.org/licenses/lgpl.html), th are released into the public domain: * [jsonredis](https://github.com/noyainrain/micro/blob/master/micro/jsonredis.py) +* [ratelimit](https://github.com/noyainrain/micro/blob/master/micro/ratelimit.py) * [webapi](https://github.com/noyainrain/micro/blob/master/micro/webapi.py) * [bind.js](https://github.com/noyainrain/micro/blob/master/client/bind.js) * [keyboard.js](https://github.com/noyainrain/micro/blob/master/client/keyboard.js) diff --git a/boilerplate/client/package.json b/boilerplate/client/package.json index 89677e4..4bd293e 100644 --- a/boilerplate/client/package.json +++ b/boilerplate/client/package.json @@ -5,7 +5,7 @@ "clean": "rm -rf node_modules" }, "dependencies": { - "@noyainrain/micro": "^0.52" + "@noyainrain/micro": "^0.53" }, "devDependencies": { "eslint": "~6.8", diff --git a/boilerplate/requirements.txt b/boilerplate/requirements.txt index 90f0fc1..2b462aa 100644 --- a/boilerplate/requirements.txt +++ b/boilerplate/requirements.txt @@ -1 +1 @@ -noyainrain.micro ~= 0.52.0 +noyainrain.micro ~= 0.53.0 diff --git a/client/index.js b/client/index.js index c19d87d..b827b8c 100644 --- a/client/index.js +++ b/client/index.js @@ -402,14 +402,10 @@ micro.UI = class extends HTMLBodyElement { } /** - * Handle a common call error *e* with a default reaction: + * Handle a common call error *e* with a default reaction. * - * - `NetworkError`: Notify the user that they seem to be offline - * - `NotFoundError`: Notify the user that the current page has been deleted - * - `PermissionError`: Notify the user that their permissions for the current page have been - * revoked - * - * Other errors are not handled and re-thrown. + * :class:`NetworkError`, ``NotFoundError``, ``PermissionError`` and `RateLimitError` are + * handled. Other errors are re-thrown. */ handleCallError(e) { if (e instanceof micro.NetworkError) { @@ -419,6 +415,8 @@ micro.UI = class extends HTMLBodyElement { this.notify("Oops, someone has just deleted this page!"); } else if (e instanceof micro.APIError && e.error.__type__ === "PermissionError") { this.notify("Oops, someone has just revoked your permissions for this page!"); + } else if (e instanceof micro.APIError && e.error.__type__ === "RateLimitError") { + this.notify("Oops, you are a bit too fast! Please try again later."); } else { throw e; } diff --git a/client/package.json b/client/package.json index 6996554..ac5f1e8 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@noyainrain/micro", - "version": "0.52.0", + "version": "0.53.0", "description": "Toolkit for social micro web apps.", "repository": "noyainrain/micro", "license": "LGPL-3.0", diff --git a/micro/core.py b/micro/core.py index 208d7c7..b0de86a 100644 --- a/micro/core.py +++ b/micro/core.py @@ -19,6 +19,27 @@ Function of the form ``rewrite(url)`` which rewrites the given *url*. """ -from typing import Callable +from contextvars import ContextVar +import typing +from typing import Callable, Optional + +if typing.TYPE_CHECKING: + from micro import User RewriteFunc = Callable[[str], str] + +class context: + """Application context. + + .. attribute:: user + + Current user. Defaults to ``None``, meaning anonymous access. + + .. attribute:: client + + Identifier of the current client, e.g. a network address. Defaults to ``local``. + """ + # pylint: disable=invalid-name; namespace + + user: ContextVar[Optional['User']] = ContextVar('user', default=None) + client = ContextVar('client', default='local') diff --git a/micro/doc/general.inc b/micro/doc/general.inc index dea6cf1..3043667 100644 --- a/micro/doc/general.inc +++ b/micro/doc/general.inc @@ -7,7 +7,8 @@ Arguments are passed to an endpoint simply as JSON object and the result is retu *Objects* contain a ``__type__`` attribute that holds the name of the object type. If a requested endpoint doesn't exist, a :ref:`NotFoundError` is returned. For any endpoint, an -:ref:`InputError` is returned if the input contains invalid arguments. +:ref:`InputError` is returned if the input contains invalid arguments. A :ref:`RateLimitError` is +returned if the current client exceeds the allowed rate limit for an operation. The URL that uniquely identifies an object is referred to as *object-url*, e.g. ``users/abc`` for a :ref:`User` with the *id* ``abc``. @@ -374,6 +375,13 @@ PermissionError Returned if the current user is not allowed to perform an action. +.. _RateLimitError: + +RateLimitError +^^^^^^^^^^^^^^ + +Returned if the current client exceeds the allowed rate limit for an operation. + .. _CommunicationError: CommunicationError diff --git a/micro/micro.py b/micro/micro.py index 3a932d4..04b115d 100644 --- a/micro/micro.py +++ b/micro/micro.py @@ -40,10 +40,11 @@ from tornado.ioloop import IOLoop from typing_extensions import Protocol -from .core import RewriteFunc +from .core import RewriteFunc, context from .error import CommunicationError, ValueError from .jsonredis import (ExpectFunc, JSONRedis, JSONRedisSequence, JSONRedisMapping, RedisList, RedisSequence, bzpoptimed) +from .ratelimit import RateLimit, RateLimiter from .resource import ( # pylint: disable=unused-import; typing Analyzer, Files, HandleResourceFunc, Image, Resource, Video, handle_image, handle_webpage, handle_youtube) @@ -133,6 +134,10 @@ class Application: .. attribute:: analyzer Web resource analyzer. + + .. attribute:: rate_limiter + + Subclass API: Mechanism to limit the rate of operations per client. """ def __init__( @@ -184,6 +189,7 @@ def __init__( if 'youtube' in self.video_service_keys: handlers.insert(0, handle_youtube(self.video_service_keys['youtube'])) self.analyzer = Analyzer(handlers=handlers, files=self.files) + self.rate_limiter = RateLimiter() @property def settings(self) -> 'Settings': @@ -906,14 +912,14 @@ def __init__(self, *, app: Application, **data: Dict[str, object]) -> None: self.device_notification_status = cast(str, data['device_notification_status']) self.push_subscription = cast(Optional[str], data['push_subscription']) - def store_email(self, email): + def store_email(self, email: str) -> None: """Update the user's *email* address. If *email* is already associated with another user, a :exc:`ValueError` (``email_duplicate``) is raised. """ check_email(email) - id = self.app.r.hget('user_email_map', email) + id = self.app.r.r.hget('user_email_map', email.encode()) if id and id.decode() != self.id: raise ValueError('email_duplicate') @@ -938,14 +944,12 @@ def set_email(self, email): self._send_email(email, self.app.render_email_auth_message(email, auth_request, code)) return auth_request - def finish_set_email(self, auth_request, auth): + def finish_set_email(self, auth_request: 'AuthRequest', auth: str) -> None: """See :http:post:`/api/users/(id)/finish-set-email`.""" # pylint: disable=protected-access; auth_request is a friend if self.app.user != self: raise PermissionError() - if auth != auth_request._code: - raise ValueError('auth_invalid') - + auth_request.verify(auth) self.app.r.delete(auth_request.id) self.store_email(auth_request._email) @@ -1430,6 +1434,13 @@ def __init__(self, id: str, app: Application, email: str, code: str) -> None: self._email = email self._code = code + def verify(self, code: str) -> None: + """Verify the secret *code*.""" + self.app.rate_limiter.count(RateLimit(f'{self.id}.verify', 10, timedelta(minutes=10)), + context.client.get()) + if code != self._code: + raise ValueError('Invalid code') + def json(self, restricted: bool = False, include: bool = False, *, rewrite: RewriteFunc = None) -> Dict[str, object]: return { diff --git a/micro/ratelimit.py b/micro/ratelimit.py new file mode 100644 index 0000000..a9234e9 --- /dev/null +++ b/micro/ratelimit.py @@ -0,0 +1,62 @@ +# ratelimit +# Released into the public domain +# https://github.com/noyainrain/micro/blob/master/micro/ratelimit.py + +"""Mechanism to limit the rate of operations per client.""" + +from asyncio import get_event_loop +from dataclasses import dataclass +from datetime import timedelta +from typing import Dict, Tuple + +@dataclass(frozen=True) +class RateLimit: + """Rate limit rule for an operation. + + .. attribute:: id + + Unique ID of the rule. + + .. attribute:: n + + Maximum number of operations. + + .. attribute:: time_frame + + Reference time frame. + """ + id: str + n: int + time_frame: timedelta + + def __post_init__(self) -> None: + if not self.id: + raise ValueError('Empty id') + if self.n <= 0: + raise ValueError('Out-of-range n') + if self.time_frame <= timedelta(): + raise ValueError('Out-of-range time_frame') + +class RateLimiter: + """Mechanism to limit the rate of operations per client.""" + + def __init__(self) -> None: + self._counters: Dict[Tuple[RateLimit, str], int] = {} + + def count(self, limit: RateLimit, client: str) -> None: + """Count an operation by *client*. + + The operation is defined by *limit*. *client* is an identifier, e.g. a network address. If + the client exceeds the allowed rate limit, a :exc:`RateLimitError` is raised. + """ + key = (limit, client) + if key not in self._counters: + self._counters[key] = 0 + get_event_loop().call_later(limit.time_frame.total_seconds(), + lambda: self._counters.pop(key)) + self._counters[key] += 1 + if self._counters[key] > limit.n: + raise RateLimitError(client) + +class RateLimitError(Exception): + """Raised if a client exceeds the allowed rate limit for an operation.""" diff --git a/micro/server.py b/micro/server.py index 88b4fc8..da8e7b7 100644 --- a/micro/server.py +++ b/micro/server.py @@ -43,9 +43,11 @@ from tornado.web import Application, HTTPError, RequestHandler, StaticFileHandler from . import micro, templates, error +from .core import context from .micro import ( # pylint: disable=unused-import; typing Activity, AuthRequest, Collection, JSONifiable, Object, User, InputError, AuthenticationError, CommunicationError, PermissionError, Trashable) +from .ratelimit import RateLimitError from .resource import NoResourceError, ForbiddenResourceError, BrokenResourceError from .util import (Expect, ExpectFunc, cancel, look_up_files, str_or_none, parse_slice, check_polyglot) @@ -282,7 +284,7 @@ def get_activity(*args: str) -> Activity: template_path=self.client_path, debug=self.debug, server=self) # Install static file handler manually to allow pre-processing cast(_ApplicationSettings, application.settings).update({'static_path': self.client_path}) - self._server = HTTPServer(application) + self._server = HTTPServer(application, xheaders=True) self._garbage_collect_files_task = None # type: Optional[Task[None]] self._empty_trash_task = None # type: Optional[Task[None]] @@ -368,11 +370,14 @@ def initialize(self, **args: object) -> None: self.app = self.server.app self.args = {} # type: Dict[str, object] - def prepare(self): + def prepare(self) -> None: + context.client.set(self.request.remote_ip) # type: ignore + self.app.user = None auth_secret = self.get_cookie('auth_secret') if auth_secret: self.current_user = self.app.authenticate(auth_secret) + context.user.set(self.current_user) if self.request.body: try: @@ -406,6 +411,10 @@ def write_error(self, status_code: int, **kwargs: object) -> None: elif isinstance(e, PermissionError): self.set_status(http.client.FORBIDDEN) self.write({'__type__': type(e).__name__}) # type: ignore + elif isinstance(e, RateLimitError): + self.set_status(http.client.TOO_MANY_REQUESTS) + data = {'__type__': type(e).__name__, 'message': str(e)} + self.write(data) elif isinstance(e, InputError): self.set_status(http.client.BAD_REQUEST) self.write({ # type: ignore @@ -432,7 +441,8 @@ def log_exception(self, typ, value, tb): # These errors are handled specially and there is no need to log them as exceptions if issubclass( typ, - (KeyError, AuthenticationError, PermissionError, CommunicationError, error.Error)): + (KeyError, AuthenticationError, PermissionError, RateLimitError, CommunicationError, + error.Error)): return super().log_exception(typ, value, tb) diff --git a/micro/tests/ext_test_email.py b/micro/tests/ext_test_email.py index 084a8b0..6d3dd46 100644 --- a/micro/tests/ext_test_email.py +++ b/micro/tests/ext_test_email.py @@ -82,7 +82,7 @@ def _render_email_auth_message(email, auth_request, auth): def test_user_set_email_auth_invalid(self): auth_request = self.user.set_email('happy@example.org') - with self.assertRaisesRegex(ValueError, 'auth_invalid'): + with self.assertRaisesRegex(ValueError, 'code'): self.user.finish_set_email(auth_request, 'foo') def test_user_remove_email(self): diff --git a/micro/tests/test_ratelimit.py b/micro/tests/test_ratelimit.py new file mode 100644 index 0000000..5246c02 --- /dev/null +++ b/micro/tests/test_ratelimit.py @@ -0,0 +1,31 @@ +# ratelimit +# Released into the public domain +# https://github.com/noyainrain/micro/blob/master/micro/ratelimit.py + +# pylint: disable=missing-docstring; test module + +from asyncio import sleep +from datetime import timedelta +from tornado.testing import AsyncTestCase, gen_test + +from micro.ratelimit import RateLimit, RateLimiter, RateLimitError + +class RateLimiterTest(AsyncTestCase): + LIMIT = RateLimit('meow', 2, timedelta(seconds=0.1)) + + def setUp(self) -> None: + super().setUp() + self.rate_limiter = RateLimiter() + + def test_count(self) -> None: + self.rate_limiter.count(self.LIMIT, 'local') + self.rate_limiter.count(self.LIMIT, 'local') + with self.assertRaises(RateLimitError): + self.rate_limiter.count(self.LIMIT, 'local') + + @gen_test # type: ignore + async def test_count_after_time_frame(self) -> None: + self.rate_limiter.count(self.LIMIT, 'local') + self.rate_limiter.count(self.LIMIT, 'local') + await sleep(0.2) + self.rate_limiter.count(self.LIMIT, 'local') diff --git a/setup.py b/setup.py index 496a855..8185a80 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name='noyainrain.micro', - version='0.52.0', + version='0.53.0', url='https://github.com/noyainrain/micro', maintainer='Sven James', maintainer_email='sven@inrain.org',