Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NAS-130867 / 25.04 / Simplify {TCPIP/UnixSock}Origin to ConnectionOrigin #14372

Merged
merged 19 commits into from
Aug 30, 2024
Merged
4 changes: 2 additions & 2 deletions src/middlewared/middlewared/api/base/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import uuid

from middlewared.auth import SessionManagerCredentials
from middlewared.utils.origin import Origin
from middlewared.utils.origin import ConnectionOrigin

logger = logging.getLogger(__name__)


class App:
def __init__(self, origin: Origin):
def __init__(self, origin: ConnectionOrigin):
self.origin = origin
self.session_id = str(uuid.uuid4())
self.authenticated = False
Expand Down
40 changes: 11 additions & 29 deletions src/middlewared/middlewared/api/base/server/ws_handler/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import socket
import struct

from aiohttp.http_websocket import WSCloseCode
from aiohttp.web import Request, WebSocketResponse

from middlewared.auth import is_ha_connection
from middlewared.utils.nginx import get_remote_addr_port
from middlewared.utils.origin import Origin, UnixSocketOrigin, TCPIPOrigin
from middlewared.utils.origin import ConnectionOrigin
from middlewared.webui_auth import addr_in_allowlist


Expand Down Expand Up @@ -37,36 +32,23 @@ async def __call__(self, request: Request):
await self.process(origin, ws)
return ws

async def get_origin(self, request: Request) -> Origin | None:
try:
sock = request.transport.get_extra_info("socket")
except AttributeError:
# request.transport can be None by the time this is called on HA systems because remote node could have been
# rebooted
return

if sock.family == socket.AF_UNIX:
peercred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, struct.calcsize("3i"))
pid, uid, gid = struct.unpack("3i", peercred)
return UnixSocketOrigin(pid, uid, gid)

remote_addr, remote_port = await self.middleware.run_in_thread(get_remote_addr_port, request)
return TCPIPOrigin(remote_addr, remote_port)
async def get_origin(self, request: Request) -> ConnectionOrigin | None:
return await self.middleware.run_in_thread(ConnectionOrigin.create, request)

async def can_access(self, origin: Origin | None) -> bool:
if not isinstance(origin, TCPIPOrigin):
return True
async def can_access(self, origin: ConnectionOrigin | None) -> bool:
if origin is None:
return False

if not (ui_allowlist := await self.middleware.call("system.general.get_ui_allowlist")):
if origin.is_unix_family or origin.is_ha_connection:
return True

if is_ha_connection(origin.addr, origin.port):
ui_allowlist = await self.middleware.call("system.general.get_ui_allowlist")
if not ui_allowlist:
return True

if addr_in_allowlist(origin.addr, ui_allowlist):
elif addr_in_allowlist(origin.rem_addr, ui_allowlist):
return True

return False

async def process(self, origin: Origin, ws: WebSocketResponse):
async def process(self, origin: ConnectionOrigin, ws: WebSocketResponse):
raise NotImplementedError
8 changes: 4 additions & 4 deletions src/middlewared/middlewared/api/base/server/ws_handler/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from middlewared.utils.debug import get_frame_details
from middlewared.utils.limits import MsgSizeError, MsgSizeLimit, parse_message
from middlewared.utils.lock import SoftHardSemaphore, SoftHardSemaphoreLimit
from middlewared.utils.origin import Origin
from middlewared.utils.origin import ConnectionOrigin
from .base import BaseWebSocketHandler
from ..app import App
from ..method import Method
Expand All @@ -45,7 +45,7 @@ class RpcWebSocketAppEvent(enum.Enum):


class RpcWebSocketApp(App):
def __init__(self, middleware: "Middleware", origin: Origin, ws: WebSocketResponse):
def __init__(self, middleware: "Middleware", origin: ConnectionOrigin, ws: WebSocketResponse):
super().__init__(origin)

self.websocket = True
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(self, middleware: "Middleware", methods: dict[str, Method]):
super().__init__(middleware)
self.methods = methods

async def process(self, origin: Origin, ws: WebSocketResponse):
async def process(self, origin: ConnectionOrigin, ws: WebSocketResponse):
app = RpcWebSocketApp(self.middleware, origin, ws)

self.middleware.register_wsclient(app)
Expand Down Expand Up @@ -222,7 +222,7 @@ async def process(self, origin: Origin, ws: WebSocketResponse):
except MsgSizeError as err:
if err.limit is not MsgSizeLimit.UNAUTHENTICATED:
creds = app.authenticated_credentials.dump() if app.authenticated_credentials else None
origin = app.origin.repr() if app.origin else None
origin = app.origin.repr if app.origin else None

self.middleware.logger.error(
'Client using credentials [%s] at [%s] sent message with payload size [%d bytes] '
Expand Down
4 changes: 0 additions & 4 deletions src/middlewared/middlewared/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ def authorize(self, method, resource):
return True


def is_ha_connection(remote_addr, remote_port):
return remote_port <= 1024 and remote_addr in ('169.254.10.1', '169.254.10.2')


class FakeApplication:
authenticated_credentials = SessionManagerCredentials()

Expand Down
42 changes: 23 additions & 19 deletions src/middlewared/middlewared/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .utils.debug import get_frame_details, get_threads_stacks
from .utils.limits import MsgSizeError, MsgSizeLimit, parse_message
from .utils.lock import SoftHardSemaphore, SoftHardSemaphoreLimit
from .utils.origin import Origin, TCPIPOrigin
from .utils.origin import ConnectionOrigin
from .utils.os import close_fds
from .utils.plugins import LoadPluginsMixin
from .utils.privilege import credential_has_full_admin
Expand Down Expand Up @@ -107,20 +107,24 @@ def real_crud_method(method):


class Application(RpcWebSocketApp):
def __init__(self, middleware: 'Middleware', origin: Origin, loop: asyncio.AbstractEventLoop, request, response):
def __init__(
self,
middleware,
origin: ConnectionOrigin,
loop: asyncio.AbstractEventLoop,
request,
response
):
super().__init__(middleware, origin, response)
self.websocket = True

self.loop = loop
self.request = request
self.response = response
self.handshake = False
self.logger = logger.Logger('application').getLogger()

# Allow at most 10 concurrent calls and only queue up until 20
self._softhardsemaphore = SoftHardSemaphore(10, 20)
self._py_exceptions = False

self.__subscribed = {}

def _send(self, data: typing.Dict[str, typing.Any]):
Expand All @@ -129,13 +133,11 @@ def _send(self, data: typing.Dict[str, typing.Any]):

def _tb_error(self, exc_info: ExcInfoType) -> typing.Dict[str, typing.Union[str, list[dict]]]:
klass, exc, trace = exc_info

frames = []
cur_tb = trace
while cur_tb:
tb_frame = cur_tb.tb_frame
cur_tb = cur_tb.tb_next

cur_frame = get_frame_details(tb_frame, self.logger)
if cur_frame:
frames.append(cur_frame)
Expand Down Expand Up @@ -476,7 +478,7 @@ async def upload(self, request):
raise web.HTTPUnauthorized()
except web.HTTPException as e:
return web.Response(status=e.status_code, body=e.text)
app = create_application(request)
app = await create_application(request)
try:
authenticated_credentials = await authenticate(self.middleware, request, credentials, 'CALL',
data['method'])
Expand All @@ -489,7 +491,7 @@ async def upload(self, request):
'error': e.text,
}, False)
return web.Response(status=e.status_code, body=e.text)
app = create_application(request, authenticated_credentials)
app = await create_application(request, authenticated_credentials)
credentials['credentials_data'].pop('password', None)
await self.middleware.log_audit_message(app, 'AUTHENTICATION', {
'credentials': credentials,
Expand Down Expand Up @@ -1126,9 +1128,7 @@ def _console_write(self, text, fill_blank=True, append=False):
blank = ' ' * (maxlen - (len(prefix) + len(text)))
else:
blank = ''
writes = self.__console_io.write(
f'\r{prefix}{text}{blank}{newline}'
)
self.__console_io.write(f'\r{prefix}{text}{blank}{newline}')
self.__console_io.flush()
# be sure and reset error counter after we successfully log
# to the console
Expand Down Expand Up @@ -1312,7 +1312,7 @@ def pipe(self, buffered=False):

def _call_prepare(
self, name, serviceobj, methodobj, params, app=None, audit_callback=None, job_on_progress_cb=None, pipes=None,
in_event_loop: bool=True,
in_event_loop: bool = True,
):
"""
:param in_event_loop: Whether we are in the event loop thread.
Expand Down Expand Up @@ -1552,14 +1552,20 @@ async def log_audit_message_for_method(self, method, methodobj, params, app, aut
}, success)

async def log_audit_message(self, app, event, event_data, success):
remote_addr, origin = "127.0.0.1", None
if app is not None and app.origin is not None:
origin = app.origin.repr
if app.origin.is_tcp_ip_family:
remote_addr = origin

message = "@cee:" + json.dumps({
"TNAUDIT": {
"aid": str(uuid.uuid4()),
"vers": {
"major": 0,
"minor": 1
},
"addr": app.origin.repr() if isinstance(app.origin, TCPIPOrigin) else "127.0.0.1",
"addr": remote_addr,
"user": audit_username_from_session(app.authenticated_credentials),
"sess": app.session_id,
"time": utc_now().strftime('%Y-%m-%d %H:%M:%S.%f'),
Expand All @@ -1569,7 +1575,7 @@ async def log_audit_message(self, app, event, event_data, success):
"major": 0,
"minor": 1,
},
"origin": app.origin.repr() if app.origin else None,
"origin": origin,
"protocol": "WEBSOCKET" if app.websocket else "REST",
"credentials": {
"credentials": app.authenticated_credentials.class_name(),
Expand Down Expand Up @@ -1874,15 +1880,13 @@ async def ws_handler(self, request):
)
break

datalen = len(msg.data)

try:
message = parse_message(connection.authenticated, msg.data)
except MsgSizeError as err:
if err.limit is not MsgSizeLimit.UNAUTHENTICATED:
origin = connection.origin.repr() if connection.origin else None
origin = connection.origin.repr if connection.origin else None
if connection.authenticated_credentials:
creds = connection.authenticated_credentials.dump()
creds = connection.authenticated_credentials.dump()
else:
creds = None

Expand Down
15 changes: 9 additions & 6 deletions src/middlewared/middlewared/plugins/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from middlewared.service_exception import MatchNotFound
import middlewared.sqlalchemy as sa
from middlewared.utils.origin import UnixSocketOrigin
from middlewared.utils.crypto import generate_token
from middlewared.utils.time_utils import utc_now

Expand Down Expand Up @@ -160,9 +159,14 @@ def dump(self):
}


def is_internal_session(session):
if isinstance(session.app.origin, UnixSocketOrigin) and session.app.origin.uid == 0:
return True
def is_internal_session(session) -> bool:
try:
is_root_sock = session.app.origin.is_unix_family and session.app.origin.uid == 0
if is_root_sock:
return True
except AttributeError:
# session.app.origin can be NoneType
pass

if isinstance(session.app.authenticated_credentials, TrueNasNodeSessionManagerCredentials):
return True
Expand Down Expand Up @@ -601,7 +605,7 @@ async def check_permission(middleware, app):
if origin is None:
return

if isinstance(origin, UnixSocketOrigin):
if origin.is_unix_family:
if origin.uid == 0:
user = await middleware.call('auth.authenticate_root')
else:
Expand All @@ -622,7 +626,6 @@ async def check_permission(middleware, app):
return

await AuthService.session_manager.login(app, UnixSocketSessionManagerCredentials(user))
return


def setup(middleware):
Expand Down
14 changes: 5 additions & 9 deletions src/middlewared/middlewared/plugins/failover.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
from functools import partial

from middlewared.auth import is_ha_connection, TrueNasNodeSessionManagerCredentials
from middlewared.auth import TrueNasNodeSessionManagerCredentials
from middlewared.schema import accepts, Bool, Dict, Int, List, NOT_PROVIDED, Str, returns, Patch
from middlewared.service import (
job, no_auth_required, no_authz_required, pass_app, private, CallError, ConfigService,
Expand All @@ -27,7 +27,6 @@
from middlewared.plugins.update_.utils import DOWNLOAD_UPDATE_FILE, can_update
from middlewared.plugins.update_.utils_linux import mount_update
from middlewared.utils.contextlib import asyncnullcontext
from middlewared.utils.origin import TCPIPOrigin

ENCRYPTION_CACHE_LOCK = asyncio.Lock()

Expand Down Expand Up @@ -1119,14 +1118,11 @@ async def sync_keys_to_remote_node(self, lock=True):


async def ha_permission(middleware, app):
# Skip if session was already authenticated
if app is not None and app.authenticated is True:
return

# We only care for remote connections (IPv4), in the interlink
if isinstance(app.origin, TCPIPOrigin):
if is_ha_connection(app.origin.addr, app.origin.port):
try:
if not app.authenticated and app.origin.is_ha_connection:
await AuthService.session_manager.login(app, TrueNasNodeSessionManagerCredentials())
except AttributeError:
pass


async def interface_pre_sync_hook(middleware):
Expand Down
13 changes: 3 additions & 10 deletions src/middlewared/middlewared/plugins/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from middlewared.service import CallError, CRUDService, filterable, pass_app, private
from middlewared.utils import filter_list
from middlewared.schema import accepts, Bool, Dict, Int, IPAddr, List, Patch, returns, Str, ValidationErrors
from middlewared.utils.network_.procfs import read_proc_net
from middlewared.utils.origin import TCPIPOrigin
from middlewared.validators import Range
from .interface.netif import netif
from .interface.interface_types import InterfaceType
Expand Down Expand Up @@ -1453,15 +1451,10 @@ async def delete_network_interface(self, oid):
@pass_app()
async def websocket_local_ip(self, app):
"""Returns the local ip address for this websocket session."""
if app is None or isinstance(app.origin, TCPIPOrigin) is False:
return

try:
if info := await self.middleware.run_in_thread(read_proc_net, None, app.origin.port):
return info.local_ip
except Exception:
self.logger.error("Unexpected failure determining local websocket ip", exc_info=True)
return
return app.origin.loc_addr
except AttributeError:
pass

@accepts()
@returns(Str(null=True))
Expand Down
Loading
Loading