Skip to content

Commit

Permalink
Respond to PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
khk-globus committed Nov 21, 2024
1 parent 3a45057 commit 3f83405
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import sys
import threading
import time
import types
import typing as t
import uuid
from concurrent.futures import Future
Expand Down Expand Up @@ -75,11 +76,11 @@ def _import_pyprctl():
return pyprctl


def _import_pamhandle():
def _import_pam() -> types.ModuleType:
# Enable conditional import, and create a hook-point for testing to mock
from globus_compute_endpoint.pam import PamHandle
from globus_compute_endpoint import pam

return PamHandle
return pam


class UserEndpointRecord(BaseModel):
Expand Down Expand Up @@ -117,6 +118,14 @@ def __init__(

self.conf_dir = conf_dir
self._config = config

# UX - test conditional imports *now*, rather than when a request comes in;
# this gives immediate feedback to an implementing admin if something is awry
if config.pam.enable:
_import_pam()
else:
_import_pyprctl()

self._reload_requested = False
self._time_to_stop = False
self._kill_event = threading.Event()
Expand Down Expand Up @@ -775,7 +784,7 @@ def send_failure_notice(
sys.exit()

@contextmanager
def do_host_auth(self, username):
def do_host_auth(self, username: str):
if not self._config.pam.enable:
pyprctl = _import_pyprctl()
yield
Expand All @@ -795,9 +804,9 @@ def do_host_auth(self, username):
sname = self._config.pam.service_name
log.debug("PAM: Creating handle (%s, %s)", sname, username)
try:
PamHandle = _import_pamhandle()
pam = _import_pam()

with PamHandle(sname, username=username) as pamh:
with pam.PamHandle(sname, username=username) as pamh:
log.debug("PAM: Invoking account stage")
pamh.pam_acct_mgmt()
log.debug("PAM: Creating credentials")
Expand All @@ -816,12 +825,19 @@ def do_host_auth(self, username):
pamh.credentials_delete()

log.debug("PAM: Closing handle")
except Exception as e:
log.error(str(e)) # Share (very likely) pamlib error with admin ...

except pam.PamError as e:
log.error(str(e)) # Share pamlib error with admin ...

# ... but be opaque with user.
raise PermissionError("see your system administrator") from None

except Exception:
log.exception(f"Unhandled error during PAM session for {username}")

# Regardless, be opaque with user.
raise PermissionError("see your system administrator") from None

def cmd_start_endpoint(
self,
user_record: pwd.struct_passwd,
Expand Down
149 changes: 94 additions & 55 deletions compute_endpoint/tests/unit/test_endpointmanager_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ResultPublisher,
)
from globus_compute_endpoint.endpoint.utils import _redact_url_creds
from globus_compute_endpoint.pam import PamHandle
from globus_sdk import GlobusAPIError, NetworkError

try:
Expand Down Expand Up @@ -69,6 +68,11 @@
)


class MockPamError(Exception):
def __init__(self, *a, **k):
pass


def mock_ensure_compute_dir():
return pathlib.Path(_mock_localuser_rec.pw_dir) / ".globus_compute"

Expand Down Expand Up @@ -333,6 +337,38 @@ def create_response(
return create_response


@pytest.fixture
def mock_pamh():
try:
# attempt to play nice with systems that do not have PAM installed, and
# rely on those that do to test with spec=PamHandle
from globus_compute_endpoint.pam import PamHandle

m = mock.MagicMock(spec=PamHandle)
except ImportError:
m = mock.MagicMock()

m.return_value = m
m.__enter__.return_value = m
yield m


@pytest.fixture
def mock_pam(mock_pamh):
with mock.patch(f"{_MOCK_BASE}_import_pam") as m:
m.return_value = m
m.PamHandle = mock_pamh
m.PamError = MockPamError
yield m


@pytest.fixture
def mock_ctl():
with mock.patch(f"{_MOCK_BASE}_import_pyprctl") as m:
m.return_value = m
yield m


@pytest.mark.parametrize("env", [None, "blar", "local", "production"])
def test_sets_process_title(
randomstring, conf_dir, mock_conf, mock_client, mock_setproctitle, env
Expand Down Expand Up @@ -2058,24 +2094,35 @@ def test_port_is_respected(mocker, mock_client, mock_conf, conf_dir, port):
assert mock_update_url_port.call_args[0][1] == port


def test_pam_disabled(conf_dir, mock_conf, mock_ep_uuid, mock_reg_info):
em = EndpointManager(conf_dir, mock_ep_uuid, mock_conf, mock_reg_info)
@pytest.mark.parametrize(
"fn_name,pam_enable",
(
("_import_pam", True),
("_import_pyprctl", False),
),
)
def test_conditional_imports_verified_at_init_for_ux(
conf_dir, mock_conf, ep_uuid, mock_reg_info, mock_ctl, fn_name, pam_enable
):
mock_conf.pam.enable = pam_enable
with mock.patch(f"{_MOCK_BASE}{fn_name}") as m:
m.side_effect = MemoryError("test induced")
with pytest.raises(MemoryError):
EndpointManager(conf_dir, ep_uuid, mock_conf, mock_reg_info)


def test_pam_disabled(conf_dir, mock_conf, ep_uuid, mock_reg_info, mock_ctl, mock_pam):
em = EndpointManager(conf_dir, ep_uuid, mock_conf, mock_reg_info)

mock_conf.pam.enable = False
with mock.patch(f"{_MOCK_BASE}_import_pyprctl") as mock_ctl:
mock_ctl.return_value = mock_ctl
with mock.patch(f"{_MOCK_BASE}_import_pamhandle") as mock_pam:
mock_pam.return_value = mock_pam
with em.do_host_auth("some user name"):
pass
with em.do_host_auth("some user name"):
pass
assert not mock_pam.called, "PAM was disable; should *not* attempt PAM"
assert mock_ctl.CapState.called, "No PAM? No privileges."
assert mock_ctl.set_no_new_privs.called, "No PAM? No privileges."


def test_pam_enabled(conf_dir, mock_conf, mock_ep_uuid, mock_reg_info):
em = EndpointManager(conf_dir, mock_ep_uuid, mock_conf, mock_reg_info)

def test_pam_enabled(conf_dir, mock_conf, ep_uuid, mock_reg_info, mock_ctl, mock_pam):
def install_next_pamf():
# ensure PAM functions called in appropriate order
fns = [ # reversed because we pop() to get each fn
Expand All @@ -2094,21 +2141,18 @@ def _install_next_test_func():
return _install_next_test_func

mock_conf.pam.enable = True
pamh = mock.Mock(spec=PamHandle)
with mock.patch(f"{_MOCK_BASE}_import_pyprctl") as mock_ctl:
mock_ctl.return_value = mock_ctl
with mock.patch(f"{_MOCK_BASE}_import_pamhandle") as mock_pam:
mock_pam.return_value = mock_pam
mock_pam.return_value.__enter__.return_value = pamh
pamh.pam_acct_mgmt.side_effect = install_next_pamf()
pamh.credentials_establish.side_effect = AssertionError("Out of order")
pamh.pam_open_session.side_effect = AssertionError("Out of order")
pamh.pam_close_session.side_effect = AssertionError("Out of order")
pamh.credentials_delete.side_effect = AssertionError("Out of order")
with em.do_host_auth("some user name"):
assert pamh.pam_open_session.called, "Complete authentication"
assert not pamh.credentials_delete.called, "PAM session *not* over yet"
assert pamh.credentials_delete.called, "PAM session completes"
pamh = mock_pam.PamHandle
pamh.pam_acct_mgmt.side_effect = install_next_pamf()
pamh.credentials_establish.side_effect = AssertionError("Out of order")
pamh.pam_open_session.side_effect = AssertionError("Out of order")
pamh.pam_close_session.side_effect = AssertionError("Out of order")
pamh.credentials_delete.side_effect = AssertionError("Out of order")

em = EndpointManager(conf_dir, ep_uuid, mock_conf, mock_reg_info)
with em.do_host_auth("some user name"):
assert pamh.pam_open_session.called, "Complete authentication"
assert not pamh.credentials_delete.called, "PAM session *not* over yet"
assert pamh.credentials_delete.called, "PAM session completes"

assert not mock_ctl.CapState.called, "Using PAM; admin manages privs"
assert not mock_ctl.set_no_new_privs.called, "Using PAM; admin manages privs"
Expand All @@ -2124,35 +2168,33 @@ def _install_next_test_func():
"credentials_delete",
),
)
@pytest.mark.parametrize("exc", (MockPamError("test err"), MemoryError("test err")))
def test_pam_error(
mock_log, conf_dir, mock_conf, mock_ep_uuid, mock_reg_info, fn_name, randomstring
mock_log, conf_dir, mock_conf, ep_uuid, mock_reg_info, fn_name, mock_pam, exc
):
em = EndpointManager(conf_dir, mock_ep_uuid, mock_conf, mock_reg_info)

exc_text = randomstring()
exc = MemoryError(exc_text)
em = EndpointManager(conf_dir, ep_uuid, mock_conf, mock_reg_info)

mock_conf.pam.enable = True
pamh = mock.Mock(spec=PamHandle)
with mock.patch(f"{_MOCK_BASE}_import_pamhandle") as mock_pam:
mock_pam.return_value = mock_pam
mock_pam.return_value.__enter__.return_value = pamh
getattr(pamh, fn_name).side_effect = exc
with pytest.raises(PermissionError) as pyt_e:
with em.do_host_auth("some user name"):
pass
pamh = mock_pam.PamHandle
username = "some username"
getattr(pamh, fn_name).side_effect = exc
with pytest.raises(PermissionError) as pyt_e:
with em.do_host_auth(username):
pass

e_str = str(pyt_e.value)
assert "PAM" not in e_str, "User-visible exception should be opaque"
assert "see your system administrator" in e_str
assert "see your system administrator" in e_str, "User-visible should have action"

a, _k = mock_log.error.call_args
if not isinstance(exc, MockPamError):
assert mock_log.exception.called, "Admin log should contain entire exception"
a, _k = mock_log.exception.call_args

assert exc_text in a[0], "Admin logs should specific error msg"
assert username in a[0], "Admin log should contain related username"


def test_do_auth_change_uid_then_close(
mock_conf_root, successful_exec_from_mocked_root
mock_conf_root, successful_exec_from_mocked_root, mock_pam
):
mock_os, *_, em = successful_exec_from_mocked_root

Expand Down Expand Up @@ -2181,20 +2223,17 @@ def _called(fn_name):
return _called

mock_conf_root.pam.enable = True
pamh = mock.Mock(spec=PamHandle)
fn_opener = set_called()

with mock.patch(f"{_MOCK_BASE}_import_pamhandle") as mock_pam:
mock_pam.return_value = mock_pam
mock_pam.return_value.__enter__.return_value = pamh
pamh.pam_open_session.side_effect = this_func(fn_opener, "pam_open_session")
pamh.pam_close_session.side_effect = AssertionError("Out of order")
mock_os.setresuid.side_effect = AssertionError("Out of order")
mock_os.setresgid.side_effect = AssertionError("Out of order")
mock_os.initgroups.side_effect = AssertionError("Out of order")
pamh = mock_pam.PamHandle
pamh.pam_open_session.side_effect = this_func(fn_opener, "pam_open_session")
pamh.pam_close_session.side_effect = AssertionError("Out of order")
mock_os.setresuid.side_effect = AssertionError("Out of order")
mock_os.setresgid.side_effect = AssertionError("Out of order")
mock_os.initgroups.side_effect = AssertionError("Out of order")

with pytest.raises(SystemExit) as pyexc:
em._event_loop()
with pytest.raises(SystemExit) as pyexc:
em._event_loop()

assert pyexc.value.code == _GOOD_EC, "Q&D: verify we exec'ed, based on '+= 1'"
assert pamh.pam_close_session.called

0 comments on commit 3f83405

Please sign in to comment.