diff --git a/src/middlewared/middlewared/api/base/server/app.py b/src/middlewared/middlewared/api/base/server/app.py index 16472cd188270..d3367bcd7cf97 100644 --- a/src/middlewared/middlewared/api/base/server/app.py +++ b/src/middlewared/middlewared/api/base/server/app.py @@ -2,13 +2,13 @@ import uuid from middlewared.auth import SessionManagerCredentials -from middlewared.utils.origin import Origin +from middlewared.utils.origin import ConnectionOrigin logger = logging.getLogger(__name__) class App: - def __init__(self, origin: Origin): + def __init__(self, origin: ConnectionOrigin): self.origin = origin self.session_id = str(uuid.uuid4()) self.authenticated = False diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/base.py b/src/middlewared/middlewared/api/base/server/ws_handler/base.py index 975369abe0b7f..a6a83565fd3b2 100644 --- a/src/middlewared/middlewared/api/base/server/ws_handler/base.py +++ b/src/middlewared/middlewared/api/base/server/ws_handler/base.py @@ -1,12 +1,7 @@ -import socket -import struct - from aiohttp.http_websocket import WSCloseCode from aiohttp.web import Request, WebSocketResponse -from middlewared.auth import is_ha_connection -from middlewared.utils.nginx import get_remote_addr_port -from middlewared.utils.origin import Origin, UnixSocketOrigin, TCPIPOrigin +from middlewared.utils.origin import ConnectionOrigin from middlewared.webui_auth import addr_in_allowlist @@ -37,36 +32,23 @@ async def __call__(self, request: Request): await self.process(origin, ws) return ws - async def get_origin(self, request: Request) -> Origin | None: - try: - sock = request.transport.get_extra_info("socket") - except AttributeError: - # request.transport can be None by the time this is called on HA systems because remote node could have been - # rebooted - return - - if sock.family == socket.AF_UNIX: - peercred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, struct.calcsize("3i")) - pid, uid, gid = struct.unpack("3i", peercred) - return UnixSocketOrigin(pid, uid, gid) - - remote_addr, remote_port = await self.middleware.run_in_thread(get_remote_addr_port, request) - return TCPIPOrigin(remote_addr, remote_port) + async def get_origin(self, request: Request) -> ConnectionOrigin | None: + return await self.middleware.run_in_thread(ConnectionOrigin.create, request) - async def can_access(self, origin: Origin | None) -> bool: - if not isinstance(origin, TCPIPOrigin): - return True + async def can_access(self, origin: ConnectionOrigin | None) -> bool: + if origin is None: + return False - if not (ui_allowlist := await self.middleware.call("system.general.get_ui_allowlist")): + if origin.is_unix_family or origin.is_ha_connection: return True - if is_ha_connection(origin.addr, origin.port): + ui_allowlist = await self.middleware.call("system.general.get_ui_allowlist") + if not ui_allowlist: return True - - if addr_in_allowlist(origin.addr, ui_allowlist): + elif addr_in_allowlist(origin.rem_addr, ui_allowlist): return True return False - async def process(self, origin: Origin, ws: WebSocketResponse): + async def process(self, origin: ConnectionOrigin, ws: WebSocketResponse): raise NotImplementedError diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py b/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py index d5a63f842e31c..ebf2b939ddade 100644 --- a/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py +++ b/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py @@ -21,7 +21,7 @@ from middlewared.utils.debug import get_frame_details from middlewared.utils.limits import MsgSizeError, MsgSizeLimit, parse_message from middlewared.utils.lock import SoftHardSemaphore, SoftHardSemaphoreLimit -from middlewared.utils.origin import Origin +from middlewared.utils.origin import ConnectionOrigin from .base import BaseWebSocketHandler from ..app import App from ..method import Method @@ -45,7 +45,7 @@ class RpcWebSocketAppEvent(enum.Enum): class RpcWebSocketApp(App): - def __init__(self, middleware: "Middleware", origin: Origin, ws: WebSocketResponse): + def __init__(self, middleware: "Middleware", origin: ConnectionOrigin, ws: WebSocketResponse): super().__init__(origin) self.websocket = True @@ -194,7 +194,7 @@ def __init__(self, middleware: "Middleware", methods: dict[str, Method]): super().__init__(middleware) self.methods = methods - async def process(self, origin: Origin, ws: WebSocketResponse): + async def process(self, origin: ConnectionOrigin, ws: WebSocketResponse): app = RpcWebSocketApp(self.middleware, origin, ws) self.middleware.register_wsclient(app) @@ -222,7 +222,7 @@ async def process(self, origin: Origin, ws: WebSocketResponse): except MsgSizeError as err: if err.limit is not MsgSizeLimit.UNAUTHENTICATED: creds = app.authenticated_credentials.dump() if app.authenticated_credentials else None - origin = app.origin.repr() if app.origin else None + origin = app.origin.repr if app.origin else None self.middleware.logger.error( 'Client using credentials [%s] at [%s] sent message with payload size [%d bytes] ' diff --git a/src/middlewared/middlewared/auth.py b/src/middlewared/middlewared/auth.py index 89edce8e9d5fb..3999382f2e21b 100644 --- a/src/middlewared/middlewared/auth.py +++ b/src/middlewared/middlewared/auth.py @@ -121,10 +121,6 @@ def authorize(self, method, resource): return True -def is_ha_connection(remote_addr, remote_port): - return remote_port <= 1024 and remote_addr in ('169.254.10.1', '169.254.10.2') - - class FakeApplication: authenticated_credentials = SessionManagerCredentials() diff --git a/src/middlewared/middlewared/main.py b/src/middlewared/middlewared/main.py index 0db2b11deb85c..2a74ca587704e 100644 --- a/src/middlewared/middlewared/main.py +++ b/src/middlewared/middlewared/main.py @@ -22,7 +22,7 @@ from .utils.debug import get_frame_details, get_threads_stacks from .utils.limits import MsgSizeError, MsgSizeLimit, parse_message from .utils.lock import SoftHardSemaphore, SoftHardSemaphoreLimit -from .utils.origin import Origin, TCPIPOrigin +from .utils.origin import ConnectionOrigin from .utils.os import close_fds from .utils.plugins import LoadPluginsMixin from .utils.privilege import credential_has_full_admin @@ -107,20 +107,24 @@ def real_crud_method(method): class Application(RpcWebSocketApp): - def __init__(self, middleware: 'Middleware', origin: Origin, loop: asyncio.AbstractEventLoop, request, response): + def __init__( + self, + middleware, + origin: ConnectionOrigin, + loop: asyncio.AbstractEventLoop, + request, + response + ): super().__init__(middleware, origin, response) self.websocket = True - self.loop = loop self.request = request self.response = response self.handshake = False self.logger = logger.Logger('application').getLogger() - # Allow at most 10 concurrent calls and only queue up until 20 self._softhardsemaphore = SoftHardSemaphore(10, 20) self._py_exceptions = False - self.__subscribed = {} def _send(self, data: typing.Dict[str, typing.Any]): @@ -129,13 +133,11 @@ def _send(self, data: typing.Dict[str, typing.Any]): def _tb_error(self, exc_info: ExcInfoType) -> typing.Dict[str, typing.Union[str, list[dict]]]: klass, exc, trace = exc_info - frames = [] cur_tb = trace while cur_tb: tb_frame = cur_tb.tb_frame cur_tb = cur_tb.tb_next - cur_frame = get_frame_details(tb_frame, self.logger) if cur_frame: frames.append(cur_frame) @@ -476,7 +478,7 @@ async def upload(self, request): raise web.HTTPUnauthorized() except web.HTTPException as e: return web.Response(status=e.status_code, body=e.text) - app = create_application(request) + app = await create_application(request) try: authenticated_credentials = await authenticate(self.middleware, request, credentials, 'CALL', data['method']) @@ -489,7 +491,7 @@ async def upload(self, request): 'error': e.text, }, False) return web.Response(status=e.status_code, body=e.text) - app = create_application(request, authenticated_credentials) + app = await create_application(request, authenticated_credentials) credentials['credentials_data'].pop('password', None) await self.middleware.log_audit_message(app, 'AUTHENTICATION', { 'credentials': credentials, @@ -1126,9 +1128,7 @@ def _console_write(self, text, fill_blank=True, append=False): blank = ' ' * (maxlen - (len(prefix) + len(text))) else: blank = '' - writes = self.__console_io.write( - f'\r{prefix}{text}{blank}{newline}' - ) + self.__console_io.write(f'\r{prefix}{text}{blank}{newline}') self.__console_io.flush() # be sure and reset error counter after we successfully log # to the console @@ -1312,7 +1312,7 @@ def pipe(self, buffered=False): def _call_prepare( self, name, serviceobj, methodobj, params, app=None, audit_callback=None, job_on_progress_cb=None, pipes=None, - in_event_loop: bool=True, + in_event_loop: bool = True, ): """ :param in_event_loop: Whether we are in the event loop thread. @@ -1552,6 +1552,12 @@ async def log_audit_message_for_method(self, method, methodobj, params, app, aut }, success) async def log_audit_message(self, app, event, event_data, success): + remote_addr, origin = "127.0.0.1", None + if app is not None and app.origin is not None: + origin = app.origin.repr + if app.origin.is_tcp_ip_family: + remote_addr = origin + message = "@cee:" + json.dumps({ "TNAUDIT": { "aid": str(uuid.uuid4()), @@ -1559,7 +1565,7 @@ async def log_audit_message(self, app, event, event_data, success): "major": 0, "minor": 1 }, - "addr": app.origin.repr() if isinstance(app.origin, TCPIPOrigin) else "127.0.0.1", + "addr": remote_addr, "user": audit_username_from_session(app.authenticated_credentials), "sess": app.session_id, "time": utc_now().strftime('%Y-%m-%d %H:%M:%S.%f'), @@ -1569,7 +1575,7 @@ async def log_audit_message(self, app, event, event_data, success): "major": 0, "minor": 1, }, - "origin": app.origin.repr() if app.origin else None, + "origin": origin, "protocol": "WEBSOCKET" if app.websocket else "REST", "credentials": { "credentials": app.authenticated_credentials.class_name(), @@ -1874,15 +1880,13 @@ async def ws_handler(self, request): ) break - datalen = len(msg.data) - try: message = parse_message(connection.authenticated, msg.data) except MsgSizeError as err: if err.limit is not MsgSizeLimit.UNAUTHENTICATED: - origin = connection.origin.repr() if connection.origin else None + origin = connection.origin.repr if connection.origin else None if connection.authenticated_credentials: - creds = connection.authenticated_credentials.dump() + creds = connection.authenticated_credentials.dump() else: creds = None diff --git a/src/middlewared/middlewared/plugins/auth.py b/src/middlewared/middlewared/plugins/auth.py index a60c3903407fc..717541f5913ec 100644 --- a/src/middlewared/middlewared/plugins/auth.py +++ b/src/middlewared/middlewared/plugins/auth.py @@ -17,7 +17,6 @@ ) from middlewared.service_exception import MatchNotFound import middlewared.sqlalchemy as sa -from middlewared.utils.origin import UnixSocketOrigin from middlewared.utils.crypto import generate_token from middlewared.utils.time_utils import utc_now @@ -160,9 +159,14 @@ def dump(self): } -def is_internal_session(session): - if isinstance(session.app.origin, UnixSocketOrigin) and session.app.origin.uid == 0: - return True +def is_internal_session(session) -> bool: + try: + is_root_sock = session.app.origin.is_unix_family and session.app.origin.uid == 0 + if is_root_sock: + return True + except AttributeError: + # session.app.origin can be NoneType + pass if isinstance(session.app.authenticated_credentials, TrueNasNodeSessionManagerCredentials): return True @@ -601,7 +605,7 @@ async def check_permission(middleware, app): if origin is None: return - if isinstance(origin, UnixSocketOrigin): + if origin.is_unix_family: if origin.uid == 0: user = await middleware.call('auth.authenticate_root') else: @@ -622,7 +626,6 @@ async def check_permission(middleware, app): return await AuthService.session_manager.login(app, UnixSocketSessionManagerCredentials(user)) - return def setup(middleware): diff --git a/src/middlewared/middlewared/plugins/failover.py b/src/middlewared/middlewared/plugins/failover.py index da648fe97f336..f532508cd57e2 100644 --- a/src/middlewared/middlewared/plugins/failover.py +++ b/src/middlewared/middlewared/plugins/failover.py @@ -11,7 +11,7 @@ import time from functools import partial -from middlewared.auth import is_ha_connection, TrueNasNodeSessionManagerCredentials +from middlewared.auth import TrueNasNodeSessionManagerCredentials from middlewared.schema import accepts, Bool, Dict, Int, List, NOT_PROVIDED, Str, returns, Patch from middlewared.service import ( job, no_auth_required, no_authz_required, pass_app, private, CallError, ConfigService, @@ -27,7 +27,6 @@ from middlewared.plugins.update_.utils import DOWNLOAD_UPDATE_FILE, can_update from middlewared.plugins.update_.utils_linux import mount_update from middlewared.utils.contextlib import asyncnullcontext -from middlewared.utils.origin import TCPIPOrigin ENCRYPTION_CACHE_LOCK = asyncio.Lock() @@ -1119,14 +1118,11 @@ async def sync_keys_to_remote_node(self, lock=True): async def ha_permission(middleware, app): - # Skip if session was already authenticated - if app is not None and app.authenticated is True: - return - - # We only care for remote connections (IPv4), in the interlink - if isinstance(app.origin, TCPIPOrigin): - if is_ha_connection(app.origin.addr, app.origin.port): + try: + if not app.authenticated and app.origin.is_ha_connection: await AuthService.session_manager.login(app, TrueNasNodeSessionManagerCredentials()) + except AttributeError: + pass async def interface_pre_sync_hook(middleware): diff --git a/src/middlewared/middlewared/plugins/network.py b/src/middlewared/middlewared/plugins/network.py index f9327ceed6372..e00039f080f2d 100644 --- a/src/middlewared/middlewared/plugins/network.py +++ b/src/middlewared/middlewared/plugins/network.py @@ -9,8 +9,6 @@ from middlewared.service import CallError, CRUDService, filterable, pass_app, private from middlewared.utils import filter_list from middlewared.schema import accepts, Bool, Dict, Int, IPAddr, List, Patch, returns, Str, ValidationErrors -from middlewared.utils.network_.procfs import read_proc_net -from middlewared.utils.origin import TCPIPOrigin from middlewared.validators import Range from .interface.netif import netif from .interface.interface_types import InterfaceType @@ -1453,15 +1451,10 @@ async def delete_network_interface(self, oid): @pass_app() async def websocket_local_ip(self, app): """Returns the local ip address for this websocket session.""" - if app is None or isinstance(app.origin, TCPIPOrigin) is False: - return - try: - if info := await self.middleware.run_in_thread(read_proc_net, None, app.origin.port): - return info.local_ip - except Exception: - self.logger.error("Unexpected failure determining local websocket ip", exc_info=True) - return + return app.origin.loc_addr + except AttributeError: + pass @accepts() @returns(Str(null=True)) diff --git a/src/middlewared/middlewared/plugins/pool_/dataset.py b/src/middlewared/middlewared/plugins/pool_/dataset.py index 6d56a007818a2..5551995bd1276 100644 --- a/src/middlewared/middlewared/plugins/pool_/dataset.py +++ b/src/middlewared/middlewared/plugins/pool_/dataset.py @@ -10,10 +10,9 @@ accepts, Any, Attribute, EnumMixin, Bool, Dict, Int, List, NOT_PROVIDED, Patch, Ref, returns, Str ) from middlewared.service import ( - CallError, CRUDService, filterable, InstanceNotFound, item_method, job, pass_app, private, ValidationErrors + CallError, CRUDService, filterable, InstanceNotFound, item_method, job, private, ValidationErrors ) from middlewared.utils import filter_list -from middlewared.utils.origin import TCPIPOrigin from middlewared.validators import Exact, Match, Or, Range from .utils import ( @@ -489,8 +488,7 @@ async def __common_validation(self, verrors, schema, data, mode, parent=None, cu Bool('create_ancestors', default=False), register=True, ), audit='Pool dataset create', audit_extended=lambda data: data['name']) - @pass_app(rest=True) - async def do_create(self, app, data): + async def do_create(self, data): """ Creates a dataset/zvol. @@ -719,15 +717,6 @@ async def do_create(self, app, data): ) or encryption_dict verrors.check() - if app: - uri = None - if isinstance(app.origin, TCPIPOrigin): - uri = app.origin.addr - if uri and uri not in [ - '::1', '127.0.0.1', *[d['address'] for d in await self.middleware.call('interface.ip_in_use')] - ]: - data['managedby'] = uri if data['managedby'] == 'INHERIT' else f'{data["managedby"]}@{uri}' - props = {} for i, real_name, transform, inheritable in ( ('aclinherit', None, str.lower, True), diff --git a/src/middlewared/middlewared/restful.py b/src/middlewared/middlewared/restful.py index 4b8bd5f80f190..d6a8db87e7091 100644 --- a/src/middlewared/middlewared/restful.py +++ b/src/middlewared/middlewared/restful.py @@ -18,8 +18,7 @@ from .pipe import Pipes from .schema import Error as SchemaError from .service_exception import adapt_exception, CallError, MatchNotFound, ValidationError, ValidationErrors -from .utils.nginx import get_remote_addr_port -from .utils.origin import TCPIPOrigin +from .utils.origin import ConnectionOrigin def parse_credentials(request): @@ -72,7 +71,7 @@ def parse_credentials(request): async def authenticate(middleware, request, credentials, method, resource): if credentials['credentials'] == 'TOKEN': - origin = TCPIPOrigin(*await middleware.run_in_thread(get_remote_addr_port, request)) + origin = await middleware.run_in_thread(ConnectionOrigin.create, request) token = await middleware.call('auth.get_token_for_action', credentials['credentials_data']['token'], origin, method, resource) if token is None: @@ -101,13 +100,12 @@ async def authenticate(middleware, request, credentials, method, resource): raise web.HTTPUnauthorized() -def create_application(request, credentials=None): - try: - origin = TCPIPOrigin(request.headers['X-Real-Remote-Addr'], int(request.headers['X-Real-Remote-Port'])) - except (KeyError, ValueError): - origin = TCPIPOrigin(*request.transport.get_extra_info('peername')) +def create_application_impl(request, credentials=None): + return Application(ConnectionOrigin.create(request), credentials) + - return Application(origin, credentials) +async def create_application(request, credentials=None): + return await asyncio.to_thread(create_application_impl, request, credentials) def normalize_query_parameter(value): @@ -550,7 +548,7 @@ async def on_method(req, *args, **kwargs): else: resource = None - app = create_application(req) + app = await create_application(req) auth_required = not self.rest._methods[getattr(self, method)]['no_auth_required'] credentials = parse_credentials(req) if credentials is None: @@ -569,7 +567,7 @@ async def on_method(req, *args, **kwargs): 'error': e.text, }, False) raise - app = create_application(req, authenticated_credentials) + app = await create_application(req, authenticated_credentials) credentials['credentials_data'].pop('password', None) await self.middleware.log_audit_message(app, 'AUTHENTICATION', { 'credentials': credentials, diff --git a/src/middlewared/middlewared/utils/nginx.py b/src/middlewared/middlewared/utils/nginx.py deleted file mode 100644 index 7063428fd6c54..0000000000000 --- a/src/middlewared/middlewared/utils/nginx.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding=utf-8 -*- -import psutil -from psutil._common import addr - - -def get_remote_addr_port(request): - try: - remote_addr, remote_port = request.transport.get_extra_info("peername") - except Exception: - # request can be NoneType or request.transport could be NoneType as well - return "", "" - - if remote_addr in ["127.0.0.1", "::1"]: - try: - x_real_remote_addr = request.headers["X-Real-Remote-Addr"] - x_real_remote_port = int(request.headers["X-Real-Remote-Port"]) - except (KeyError, ValueError): - pass - else: - try: - with open("/var/run/nginx.pid") as f: - nginx_pid = int(f.read().strip()) - except Exception: - pass - else: - try: - process = psutil.Process(nginx_pid) - except psutil.ProcessNotFound: - pass - else: - if process.name() == "nginx": - for worker in process.children(): - for connection in worker.connections(kind="tcp"): - if connection.laddr == addr(remote_addr, remote_port): - remote_addr = x_real_remote_addr - remote_port = x_real_remote_port - - return remote_addr, remote_port diff --git a/src/middlewared/middlewared/utils/origin.py b/src/middlewared/middlewared/utils/origin.py index ee990841ebdc3..dbe039318d0a2 100644 --- a/src/middlewared/middlewared/utils/origin.py +++ b/src/middlewared/middlewared/utils/origin.py @@ -1,43 +1,133 @@ -class Origin: - def match(self, origin): - raise NotImplementedError +from dataclasses import dataclass +from socket import AF_INET, AF_INET6, AF_UNIX, SO_PEERCRED, SOL_SOCKET +from struct import calcsize, unpack - def repr(self): - raise NotImplementedError +from pyroute2 import DiagSocket - def __str__(self): - raise NotImplementedError +__all__ = ('ConnectionOrigin',) +HA_HEARTBEAT_IPS = ('169.254.10.1', '169.254.10.2') +UIDS_TO_CHECK = (33, 0) -class UnixSocketOrigin(Origin): - def __init__(self, pid, uid, gid): - self.pid = pid - self.uid = uid - self.gid = gid - def match(self, origin): - return self.uid == origin.uid and self.gid == origin.gid +@dataclass(slots=True, frozen=True, kw_only=True) +class ConnectionOrigin: + family: AF_INET | AF_INET6 | AF_UNIX + """The address family associated to the API connection""" + loc_addr: str | None = None + """If `family` is not of type AF_UNIX, this represents + the local IP address associated to the TCP/IP connection""" + loc_port: int | None = None + """If `family` is not of type AF_UNIX, this represents + the local port associated to the TCP/IP connection""" + rem_addr: str | None = None + """If `family` is not of type AF_UNIX, this represents + the remote IP address associated to the TCP/IP connection""" + rem_port: int | None = None + """If `family` is not of type AF_UNIX, this represents + the remote port associated to the TCP/IP connection""" + pid: int | None = None + """If `family` is of type AF_UNIX, this represents + the process id associated to the unix datagram connection""" + uid: int | None = None + """If `family` is of type AF_UNIX, this represents + the user id associated to the unix datagram connection""" + gid: int | None = None + """If `family` is of type AF_UNIX, this represents + the group id associated to the unix datagram connection""" - def repr(self): - return f"pid:{self.pid}" + @classmethod + def create(cls, request): + try: + sock = request.transport.get_extra_info("socket") + if sock.family == AF_UNIX: + pid, uid, gid = unpack("3i", sock.getsockopt(SOL_SOCKET, SO_PEERCRED, calcsize("3i"))) + return cls( + family=sock.family, + pid=pid, + uid=uid, + gid=gid + ) + elif sock.family in (AF_INET, AF_INET6): + la, lp, ra, rp = get_tcp_ip_info(sock, request) + return cls( + family=sock.family, + loc_addr=la, + loc_port=lp, + rem_addr=ra, + rem_port=rp, + ) + except AttributeError: + # request.transport can be None by the time this is + # called on HA systems because remote node could + # have been rebooted + return - def __str__(self): - return f"UNIX socket (pid={self.pid} uid={self.uid} gid={self.gid})" + def __str__(self) -> str: + if self.is_unix_family: + return f"UNIX socket (pid={self.pid} uid={self.uid} gid={self.gid})" + elif self.family == AF_INET: + return f"{self.rem_addr}:{self.rem_port}" + elif self.family == AF_INET6: + return f"[{self.rem_addr}]:{self.rem_port}" + def match(self, origin) -> bool: + if self.is_unix_family: + return self.uid == origin.uid and self.gid == origin.gid + else: + return self.rem_addr == origin.rem_addr -class TCPIPOrigin(Origin): - def __init__(self, addr, port): - self.addr = addr - self.port = port + @property + def repr(self) -> str: + return f"pid:{self.pid}" if self.is_unix_family else self.rem_addr - def match(self, origin): - return self.addr == origin.addr + @property + def is_tcp_ip_family(self) -> bool: + return self.family in (AF_INET, AF_INET6) - def repr(self): - return self.addr + @property + def is_unix_family(self) -> bool: + return self.family == AF_UNIX - def __str__(self): - if ":" in self.addr: - return f"[{self.addr}]:{self.port}" - else: - return f"{self.addr}:{self.port}" + @property + def is_ha_connection(self) -> bool: + return ( + self.family in (AF_INET, AF_INET6) and + self.rem_port and self.rem_port <= 1024 and + self.rem_addr and self.rem_addr in HA_HEARTBEAT_IPS + ) + + +def get_tcp_ip_info(sock, request) -> tuple: + # All API connections are terminated by nginx reverse + # proxy so the remote address is always 127.0.0.1. The + # only exceptions to this are: + # 1. Someone connects directly to 127.0.0.1 via a local + # shell session + # 2. Someone connects directly to heartbeat IP port 6000 + # via a local shell session on a TrueNAS HA system + # 3. We connect directly to the other controller on an HA + # machine via heartbeat IP for intra-node communication. + # (this is done by us) + try: + # These headers are set by nginx or a user trying to do + # (potentially) nefarious things. If these are set then + # we need to check if the UID of the socket is owned by + # 0 (root) or 33 (www-data (nginx forks workers)) + ra = request.headers["X-Real-Remote-Addr"] + rp = int(request.headers["X-Real-Remote-Port"]) + check_uids = True + except (KeyError, ValueError): + ra, rp = sock.getpeername() + check_uids = False + + with DiagSocket() as ds: + ds.bind() + for i in ds.get_sock_stats(family=sock.family): + if i['idiag_dst'] == ra and i['idiag_dport'] == rp: + if check_uids: + if i['idiag_uid'] in UIDS_TO_CHECK: + return i['idiag_src'], i['idiag_sport'], i['idiag_dst'], i['idiag_dport'] + else: + return i['idiag_src'], i['idiag_sport'], i['idiag_dst'], i['idiag_dport'] + return (None, None, None, None) diff --git a/src/middlewared/middlewared/utils/rate_limit/cache.py b/src/middlewared/middlewared/utils/rate_limit/cache.py index 0593af1a9fb8c..b04ecc1ec1343 100644 --- a/src/middlewared/middlewared/utils/rate_limit/cache.py +++ b/src/middlewared/middlewared/utils/rate_limit/cache.py @@ -4,8 +4,7 @@ from time import monotonic from typing import TypedDict -from middlewared.auth import is_ha_connection -from middlewared.utils.origin import TCPIPOrigin +from middlewared.utils.origin import ConnectionOrigin __all__ = ['RateLimitCache'] @@ -65,25 +64,23 @@ def rate_limit_exceeded(self, method_name: str, ip: str) -> bool: return False - async def add(self, method_name: str, origin: TCPIPOrigin) -> str | None: + async def add(self, method_name: str, origin: ConnectionOrigin) -> str | None: """Add an entry to the cache. Returns the IP address of origin of the request if it has been cached, returns None otherwise""" - if not isinstance(origin, TCPIPOrigin): - return None - - ip, port = origin.addr, origin.port - if any((ip is None, port is None)) or is_ha_connection(ip, port): - # Short-circuit if: - # 1. if the IP address is None - # 2. OR the port is None - # 3. OR the origin of the request is from our HA P2P heartbeat - # connection + try: + if ( + origin.is_ha_connection or origin.is_unix_family or + origin.rem_addr is None or origin.rem_port is None + ): + return None + else: + key = self.cache_key(method_name, origin.rem_addr) + if key not in RL_CACHE: + RL_CACHE[key] = RateLimitObject(num_times_called=0, last_reset=monotonic()) + return origin.rem_addr + except AttributeError: + # origin is NoneType return None - else: - key = self.cache_key(method_name, ip) - if key not in RL_CACHE: - RL_CACHE[key] = RateLimitObject(num_times_called=0, last_reset=monotonic()) - return ip async def cache_pop(self, method_name: str, ip: str) -> None: """Pop (remove) an entry from the cache."""