From 90ad0022d44cea7eb52cbb8dbe5f3e49d1760590 Mon Sep 17 00:00:00 2001 From: Andrew Walker Date: Tue, 27 Aug 2024 13:50:59 -0600 Subject: [PATCH] Add utilities for manipulating pam_tdb data This commit adds required middleware utils for writing pam_tdb contents for user-linked API keys along with a local unit test framework to be used from a jenkins pipeline. --- .../pytest/unit/utils/test_groupmap.py | 5 + src/middlewared/middlewared/utils/crypto.py | 15 + src/middlewared/middlewared/utils/tdb.py | 6 + .../middlewared/utils/user_api_key.py | 113 +++++++ tests/requirements.txt | 1 + tests/run_unit_tests.py | 162 +++++++++ tests/unit/test_pam_tdb.py | 315 ++++++++++++++++++ 7 files changed, 617 insertions(+) create mode 100644 src/middlewared/middlewared/utils/user_api_key.py create mode 100644 tests/run_unit_tests.py create mode 100644 tests/unit/test_pam_tdb.py diff --git a/src/middlewared/middlewared/pytest/unit/utils/test_groupmap.py b/src/middlewared/middlewared/pytest/unit/utils/test_groupmap.py index 9bc6ea1c94b47..0d8a144ecda5f 100644 --- a/src/middlewared/middlewared/pytest/unit/utils/test_groupmap.py +++ b/src/middlewared/middlewared/pytest/unit/utils/test_groupmap.py @@ -22,6 +22,11 @@ @pytest.fixture(scope='module') def groupmap_dir(): os.makedirs('/var/db/system/samba4', exist_ok=True) + try: + # pre-emptively delete in case we're running on a TrueNAS VM + os.unlink('/var/db/system/samba4/group_mapping.tdb') + except FileNotFoundError: + pass @pytest.fixture(scope='module') diff --git a/src/middlewared/middlewared/utils/crypto.py b/src/middlewared/middlewared/utils/crypto.py index 25880fad0ce34..5f5d826f29991 100644 --- a/src/middlewared/middlewared/utils/crypto.py +++ b/src/middlewared/middlewared/utils/crypto.py @@ -1,3 +1,5 @@ +from base64 import b64encode +from hashlib import pbkdf2_hmac from secrets import choice, compare_digest, token_urlsafe, token_hex from string import ascii_letters, digits, punctuation @@ -63,3 +65,16 @@ def generate_nt_hash(passwd): """ md4_hash_bytes = md4_hash_blob(passwd.encode('utf-16le')) return md4_hash_bytes.hex().upper() + + +def generate_pbkdf2_512(passwd): + """ + Generate a pbkdf2_sha512 hash for password. This is used for + verification of API keys. + """ + prefix = 'pbkdf2-sha512' + rounds = 500000 + salt_length = 16 + salt = generate_string(string_size=salt_length, extra_chars='./').encode() + hash = pbkdf2_hmac('sha512', passwd.encode(), salt, rounds) + return f'${prefix}${rounds}${b64encode(salt).decode()}${b64encode(hash).decode()}' diff --git a/src/middlewared/middlewared/utils/tdb.py b/src/middlewared/middlewared/utils/tdb.py index d67ef99317006..974990d8b6169 100644 --- a/src/middlewared/middlewared/utils/tdb.py +++ b/src/middlewared/middlewared/utils/tdb.py @@ -100,6 +100,12 @@ class TDBHandle: opath_fd = FD_CLOSED keys_null_terminated = False + def __enter__(self): + return self + + def __exit__(self, tp, val, traceback): + self.close() + def close(self): """ Close the TDB handle and O_PATH open for the file """ if self.opath_fd == FD_CLOSED and self.hdl is None: diff --git a/src/middlewared/middlewared/utils/user_api_key.py b/src/middlewared/middlewared/utils/user_api_key.py new file mode 100644 index 0000000000000..40eb94207f2dd --- /dev/null +++ b/src/middlewared/middlewared/utils/user_api_key.py @@ -0,0 +1,113 @@ +import os + +from base64 import b64encode +from dataclasses import dataclass +from struct import pack +from uuid import uuid4 +from .tdb import ( + TDBDataType, + TDBHandle, + TDBOptions, + TDBPathType, +) + + +PAM_TDB_DIR = '/var/run/pam_tdb' +PAM_TDB_FILE = os.path.join(PAM_TDB_DIR, 'pam_tdb.tdb') +PAM_TDB_DIR_MODE = 0o700 +PAM_TDB_VERSION = 1 +PAM_TDB_MAX_KEYS = 10 # Max number of keys per user. Also defined in pam_tdb.c + +PAM_TDB_OPTIONS = TDBOptions(TDBPathType.CUSTOM, TDBDataType.BYTES) + + +@dataclass(frozen=True) +class UserApiKey: + expiry: int + dbid: int + userhash: str + + +@dataclass(frozen=True) +class PamTdbEntry: + keys: list[UserApiKey] + username: str + + +def _setup_pam_tdb_dir() -> None: + os.makedirs(PAM_TDB_DIR, mode=PAM_TDB_DIR_MODE, exist_ok=True) + os.chmod(PAM_TDB_DIR, PAM_TDB_DIR_MODE) + + +def _pack_user_api_key(api_key: UserApiKey) -> bytes: + """ + Convert UserApiKey object to bytes for TDB insertion. + This is packed struct with expiry converted into signed 64 bit + integer, the database id (32-bit unsigned), and the userhash (pascal string) + """ + if not isinstance(api_key, UserApiKey): + raise TypeError(f'{type(api_key)}: not a UserApiKey') + + userhash = api_key.userhash.encode() + b'\x00' + return pack(f' None: + """ + Convert PamTdbEntry object into a packed struct and insert + into tdb file. + + key: username + value: uint32_t (version) + uint32_t (cnt of keys) + """ + if not isinstance(entry, PamTdbEntry): + raise TypeError(f'{type(entry)}: expected PamTdbEntry') + + key_cnt = len(entry.keys) + if key_cnt > PAM_TDB_MAX_KEYS: + raise ValueError(f'{key_cnt}: count of entries exceeds maximum') + + entry_bytes = pack(' None: + """ + Write a PamTdbEntry object to the pam_tdb file for user + authentication. This method first writes to temporary file + and then renames over pam_tdb file to ensure flush is atomic + and reduce risk of lock contention while under a transaction + lock. + + raises: + TypeError - not PamTdbEntry + AssertionError - count of entries changed while generating + tdb payload + RuntimeError - TDB library error + """ + _setup_pam_tdb_dir() + + if not isinstance(pam_entries, list): + raise TypeError('Expected list of PamTdbEntry objects') + + tmp_path = os.path.join(PAM_TDB_DIR, f'tmp_{uuid4()}.tdb') + + with TDBHandle(tmp_path, PAM_TDB_OPTIONS) as hdl: + hdl.keys_null_terminated = False + + try: + for entry in pam_entries: + write_entry(hdl, entry) + except Exception: + os.remove(tmp_path) + raise + + os.rename(tmp_path, PAM_TDB_FILE) diff --git a/tests/requirements.txt b/tests/requirements.txt index ed4bfea7700a6..df6fa412bc82a 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,5 +1,6 @@ boto3 dnspython +junitparser pytest pytest-dependency pytest-rerunfailures diff --git a/tests/run_unit_tests.py b/tests/run_unit_tests.py new file mode 100644 index 0000000000000..ef46186928387 --- /dev/null +++ b/tests/run_unit_tests.py @@ -0,0 +1,162 @@ +# This script should be run locally from a TrueNAS VM. It runs all tests +# contained within the tests/unit directory as well as middleware specific unit +# tests contained within src/middlewared/middlewared/pytest/unit. +# +# NOTE: this requires `make install_tests` to have been run on the TrueNAS VM. + +import argparse +import middlewared +import os +import pytest +import sys + +from contextlib import contextmanager +from collections.abc import Generator +from dataclasses import dataclass +from junitparser import JUnitXml +from shutil import copytree, rmtree +from truenas_api_client import Client +from uuid import uuid4 + +DESCRIPTION = ( + 'Run unit tests from the specified middleware git repository on the ' + 'current TrueNAS server (version 25.04 or later). Exit code is one of ' + 'pytest exit codes with zero indicating success.' +) + +UNIT_TESTS = 'tests/unit' +MIDDLEWARE_MODULE_PATH = os.path.dirname(os.path.abspath(middlewared.__file__)) +MIDDLEWARE_PYTEST = 'src/middlewared/middlewared/pytest' +MIDDLEWARE_UNIT_TESTS = os.path.join(MIDDLEWARE_PYTEST, 'unit') +MIDDLEWARE_PYTEST_MODULE = os.path.join(MIDDLEWARE_MODULE_PATH, 'pytest') +RESULT_FILE = 'unit_tests_result.xml' +PYTEST_CONFTEST_FILE = 'tests/conftest.py' + + +@dataclass() +class UnitTestRun: + tests_dir: str + exit_code: pytest.ExitCode = pytest.ExitCode.NO_TESTS_COLLECTED + junit_file: str | None = None + + +def run_tests(data: UnitTestRun) -> UnitTestRun: + junit_file = f'unit_tests_result_{uuid4()}.xml' + + data.exit_code = pytest.main([ + '--disable-warnings', '-vv', + '-o', 'junit_family=xunit2', + '--junitxml', junit_file, + data.tests_dir + ]) + + if data.exit_code is not pytest.ExitCode.OK: + print( + f'{data.tests_dir}: tests failed with code: {data.exit_code}', + file=sys.stderr + ) + + data.junit_file = junit_file + return data + + +def run_unit_tests(repo_dir: str) -> pytest.ExitCode: + """ + Iterate through our unit test sources and create a unified junit xml file + for the overall test results. + """ + xml_out = JUnitXml() + exit_code = pytest.ExitCode.NO_TESTS_COLLECTED + for test_dir in ( + os.path.join(repo_dir, UNIT_TESTS), + os.path.join(repo_dir, MIDDLEWARE_UNIT_TESTS), + ): + if not os.path.exists(test_dir): + raise FileNotFoundError(f'{test_dir}: unit test directory does not exist') + + data = run_tests(UnitTestRun(tests_dir=test_dir)) + xml_out += JUnitXml.fromfile(data.junit_file) + try: + os.remove(data.junit_file) + except Exception: + pass + + match data.exit_code: + case pytest.ExitCode.NO_TESTS_COLLECTED: + # We'll treat this as a partial failure because we still want our + # test results from other runs, but don't want an overall misleading + # result. + print( + f'{test_dir}: not tests collected. Treating as partial failure.', + file=sys.stderr + ) + if exit_code is pytest.ExitCode.OK: + exit_code = pytest.ExitCode.TESTS_FAILED + + case pytest.ExitCode.OK: + # If this is our first OK test, set exit code + # otherwise preserve existing + if exit_code is pytest.ExitCode.NO_TESTS_COLLECTED: + exit_code = data.exit_code + + case _: + # exit codes are an IntEnum. Preserve worst case + if exit_code < data.exit_code: + exit_code = data.exit_code + + xml_out.write(RESULT_FILE) + return exit_code + + +@contextmanager +def disable_api_test_config(path: str) -> Generator[None, None, None]: + """ prevent API tests conftest from being applied """ + os.rename( + os.path.join(path, PYTEST_CONFTEST_FILE), + os.path.join(path, f'{PYTEST_CONFTEST_FILE}.tmp') + ) + + try: + yield + finally: + os.rename( + os.path.join(path, f'{PYTEST_CONFTEST_FILE}.tmp'), + os.path.join(path, PYTEST_CONFTEST_FILE) + ) + + +@contextmanager +def setup_middleware_tests(path: str) -> Generator[None, None, None]: + """ temporarily setup our pytest tests in the python dir """ + try: + copytree( + os.path.join(path, MIDDLEWARE_PYTEST), + os.path.join(MIDDLEWARE_PYTEST_MODULE) + ) + yield + finally: + rmtree(MIDDLEWARE_PYTEST_MODULE) + + +def main() -> None: + parser = argparse.ArgumentParser(description=DESCRIPTION) + parser.add_argument( + '-p', '--path', + help='Path to local copy of middleware git repository', + default='./middleware' + ) + + # lazy check to verify we're on a TrueNAS server + with Client() as c: + assert c.call('system.ready') + + args = parser.parse_args() + with disable_api_test_config(args.path): + with setup_middleware_tests(args.path): + exit_code = run_unit_tests(args.path) + + sys.exit(exit_code) + + +if __name__ == '__main__': + main() diff --git a/tests/unit/test_pam_tdb.py b/tests/unit/test_pam_tdb.py new file mode 100644 index 0000000000000..ac58233ba67e5 --- /dev/null +++ b/tests/unit/test_pam_tdb.py @@ -0,0 +1,315 @@ +import os +import pam +import pwd +import pytest +import tdb + +from collections.abc import Generator +from contextlib import contextmanager +from middlewared.utils import crypto +from middlewared.utils import user_api_key +from time import monotonic + +EXPIRED_TS = 1 +BASE_ID = 1325 +LEGACY_ENTRY_KEY = 'rtpz6u16l42XJJGy5KMJOVfkiQH7CyitaoplXy7TqFTmY7zHqaPXuA1ob07B9bcB' +LEGACY_ENTRY_HASH = '$pbkdf2-sha256$29000$CyGktHYOwXgvBYDQOqc05g$nK1MMvVuPGHMvUENyR01qNsaZjgGmlt3k08CRuC4aTI' +INVALID_HASH_TYPE = '$pbkdf2-canary$29000$CyGktHYOwXgvBYDQOqc05g$nK1MMvVuPGHMvUENyR01qNsaZjgGmlt3k08CRuC4aTI' +INVALID_SALT = '$pbkdf2-sha256$29000$CyGktHYOwXgvBYDQOqc0*g$nK1MMvVuPGHMvUENyR01qNsaZjgGmlt3k08CRuC4aTI' +INVALID_HASH = '$pbkdf2-sha256$29000$CyGktHYOwXgvBYDQOqc05g$nK1MMvVuPGHMvUENyR01qNsaZjgGmlt3k08CRuC4a*I' +MISSING_SALT = '$pbkdf2-sha256$29000$$nK1MMvVuPGHMvUENyR01qNsaZjgGmlt3k08CRuC4aTI' +MISSING_HASH = '$pbkdf2-sha256$29000$CyGktHYOwXgvBYDQOqc05g$' +PAM_DIR = '/etc/pam.d' +PAM_FILE = 'middleware-api-key' +PAM_AUTH_LINE = 'auth [success=1 default=die] pam_tdb.so debug ' +PAM_FAIL_DELAY = 2 # This is minimum that pam_tdb will delay failed auth attempts + +PAM_FILE_REMAINING_CONTENTS = """ +@include common-auth-unix +@include common-account +password required pam_deny.so +session required pam_deny.so +""" + + +def write_tdb_file( + username: str, + hashlist: list[str], + expired: bool = False +) -> int: + """ + Generate a tdb file based on the specified parameters + The resulting TDB will have one entry for `username` and + a varying amount of hashes. + + Although each hash supports a separate expiry, we are only + concerned in these tests in that works overall. + """ + + keys = [] + idx = 0 + + for idx, thehash in enumerate(hashlist): + keys.append(user_api_key.UserApiKey( + userhash=thehash, + dbid=BASE_ID + idx, + expiry=EXPIRED_TS if expired else 0 + )) + + entry = user_api_key.PamTdbEntry(username=username, keys=keys) + + user_api_key.flush_user_api_keys([entry]) + + return BASE_ID + idx + + +def truncate_tdb_file(username: str) -> None: + """ + Truncate tdb entry to make pascal string point off end of buffer + If this sets PAM_AUTH_ERR then we need to look closely to make + sure we don't have parser issues in pam_tdb.c + """ + hdl = tdb.open(user_api_key.PAM_TDB_FILE) + try: + hdl.get(username.encode()) + hdl.store(username.encode(), data[0:len(data) - 5]) + finally: + hdl.close() + + +def make_tdb_garbage(username: str) -> None: + """ fill entry with non-api-key data """ + hdl = tdb.open(user_api_key.PAM_TDB_FILE) + try: + hdl.get(username.encode()) + hdl.store(username.encode(), b'meow') + finally: + hdl.close() + + +def make_null_tdb_entry(username: str) -> None: + """ throw some nulls into the mix for fun """ + hdl = tdb.open(user_api_key.PAM_TDB_FILE) + try: + hdl.get(username.encode()) + hdl.store(username.encode(), b'\x00' * 128) + finally: + hdl.close() + + +@contextmanager +def pam_service( + file_name: str = PAM_FILE, + admin_user: str | None = None, +) -> Generator[str, None, None]: + """ Create a pam service file for pam_tdb.so """ + auth_entry = PAM_AUTH_LINE + if admin_user: + auth_entry += f'truenas_admin={admin_user}' + + pam_service_path = os.path.join(PAM_DIR, file_name) + + with open(pam_service_path, 'w') as f: + f.write(auth_entry) + f.write(PAM_FILE_REMAINING_CONTENTS) + f.flush() + + try: + yield file_name + finally: + os.remove(pam_service_path) + + +@contextmanager +def fail_delay() -> Generator[None, None, None]: + """ assert if failure case finishes faster than our expected fail delay """ + now = monotonic() + yield + elapsed = monotonic() - now + assert elapsed > PAM_FAIL_DELAY + + +@pytest.fixture(scope='module') +def current_username(): + """ for simplicity sake we'll test against current user """ + return pwd.getpwuid(os.geteuid()).pw_name + + +def test_unknown_user(current_username): + """ + A user without an entry in the file should fail with appropriate error + and generate a fail delay + """ + db_id = write_tdb_file(current_username, [LEGACY_ENTRY_HASH]) + with pam_service(admin_user=current_username) as svc: + p = pam.pam() + with fail_delay(): + authd = p.authenticate('canary', f'{db_id}-{LEGACY_ENTRY_KEY}', service=svc) + assert authd is False + assert p.code == pam.PAM_USER_UNKNOWN + + +def test_legacy_auth_admin(current_username): + """ This should succeed for specified admin user """ + db_id = write_tdb_file(current_username, [LEGACY_ENTRY_HASH]) + with pam_service(admin_user=current_username) as svc: + p = pam.pam() + authd = p.authenticate(current_username, f'{db_id}-{LEGACY_ENTRY_KEY}', service=svc) + assert authd is True + assert p.code == pam.PAM_SUCCESS + + with fail_delay(): + # attempt to authenticate with invalid key should trigger a fail delay + authd = p.authenticate(current_username, f'{db_id}-{LEGACY_ENTRY_KEY[0:-1]}', service=svc) + assert authd is False + assert p.code == pam.PAM_AUTH_ERR + + +def test_legacy_auth_admin_expired_key(current_username): + """ Verify that an expired key results in PAM_AUTH_ERR """ + db_id = write_tdb_file(current_username, [LEGACY_ENTRY_HASH], True) + with pam_service(admin_user=current_username) as svc: + p = pam.pam() + authd = p.authenticate(current_username, f'{db_id}-{LEGACY_ENTRY_KEY}', service=svc) + assert authd is False + assert p.code == pam.PAM_AUTH_ERR + + +def test_legacy_auth_non_admin(current_username): + """ Test that legacy hash doesn't work for non-admin user + We really want to deprecate these legacy keys. + """ + write_tdb_file(current_username, [LEGACY_ENTRY_HASH]) + with pam_service() as svc: + with fail_delay(): + p = pam.pam() + authd = p.authenticate(current_username, LEGACY_ENTRY_KEY, service=svc) + assert authd is False + assert p.code == pam.PAM_AUTH_ERR + + +def test_legacy_auth_multiple_entries(current_username): + """ verify last entry in hash list can be used to auth + We allow multiple keys per user. Ensure that we can use more than the + first key. + """ + hashes = [crypto.generate_pbkdf2_512('canary') for i in range(0, 5)] + hashes.append(LEGACY_ENTRY_HASH) + + db_id = write_tdb_file(current_username, hashes) + with pam_service(admin_user=current_username) as svc: + p = pam.pam() + authd = p.authenticate(current_username, f'{db_id}-{LEGACY_ENTRY_KEY}', service=svc) + assert authd is True + assert p.code == pam.PAM_SUCCESS + + +def test_new_auth(current_username): + """ verify that that new hash works as expected """ + key = crypto.generate_string(string_size=64) + db_id = write_tdb_file(current_username, [crypto.generate_pbkdf2_512(key)]) + + with pam_service() as svc: + p = pam.pam() + # verify that using correct key succeeds + authd = p.authenticate(current_username, f'{db_id}-{key}', service=svc) + assert authd is True + assert p.code == pam.PAM_SUCCESS + + # verify that using incorrect key fails + with fail_delay(): + authd = p.authenticate(current_username, f'{db_id}-{key[0:-1]}', service=svc) + assert authd is False + assert p.code == pam.PAM_AUTH_ERR + +def test_new_auth_truncated_password(current_username): + """ Verify that truncated password generates auth error """ + key = crypto.generate_string(string_size=64) + db_id = write_tdb_file(current_username, [crypto.generate_pbkdf2_512(key)]) + + with pam_service() as svc: + p = pam.pam() + with fail_delay(): + authd = p.authenticate(current_username, f'{db_id}-', service=svc) + assert authd is False + assert p.code == pam.PAM_AUTH_ERR + + +def test_new_auth_multi(current_username): + """ verify that second key works with newer hash """ + key = crypto.generate_string(string_size=64) + db_id = write_tdb_file(current_username, [ + LEGACY_ENTRY_HASH, + crypto.generate_pbkdf2_512(key) + ]) + with pam_service() as svc: + p = pam.pam() + # verify that using correct key succeeds + authd = p.authenticate(current_username, f'{db_id}-{key}', service=svc) + assert authd is True + assert p.code == pam.PAM_SUCCESS + + # verify that using incorrect key fails + with fail_delay(): + authd = p.authenticate(current_username, f'{db_id}-{key[0:-1]}', service=svc) + assert authd is False + assert p.code == pam.PAM_AUTH_ERR + + +def test_new_auth_timeout(current_username): + """ verify that valid but expired key denies auth with expected error code """ + key = crypto.generate_string(string_size=64) + db_id = write_tdb_file(current_username, [crypto.generate_pbkdf2_512(key)], True) + with pam_service() as svc: + p = pam.pam() + with fail_delay(): + authd = p.authenticate(current_username, f'{db_id}-{key}', service=svc) + assert authd is False + assert p.code == pam.PAM_AUTH_ERR + + +def test_unsupported_service_file_name(current_username): + """ pam_tdb has strict check that it can't be used for other services """ + key = crypto.generate_string(string_size=64) + db_id = write_tdb_file(current_username, [crypto.generate_pbkdf2_512(key)]) + with pam_service(file_name='canary') as svc: + p = pam.pam() + # verify that using correct key succeeds + authd = p.authenticate(current_username, f'{db_id}-{key}', service=svc) + assert authd is False + assert p.code == pam.PAM_SYSTEM_ERR + + +@pytest.mark.parametrize('thehash', [ + INVALID_HASH_TYPE, + INVALID_SALT, + INVALID_HASH, + MISSING_SALT, + MISSING_HASH, +]) +def test_invalid_hash(current_username, thehash): + """ Check that variations of broken hash entries generate PAM_AUTH_ERR """ + db_id = write_tdb_file(current_username, [thehash]) + with pam_service(admin_user=current_username) as svc: + p = pam.pam() + # verify that using correct key succeeds + authd = p.authenticate(current_username, f'{db_id}-{LEGACY_ENTRY_KEY}', service=svc) + assert authd is False + assert p.code == pam.PAM_AUTH_ERR + + +@pytest.mark.parametrize('fuzz_fn', [ + truncate_tdb_file, + make_tdb_garbage, + make_null_tdb_entry, +]) +def test_invalid_tdb_data(current_username, fuzz_fn): + """ verify we detect garbage tdb entry and flag for reinit""" + key = crypto.generate_string(string_size=64) + db_id = write_tdb_file(current_username, [crypto.generate_pbkdf2_512(key)], True) + fuzz_fn(current_username) + with pam_service() as svc: + p = pam.pam() + authd = p.authenticate(current_username, f'{db_id}-{key}', service=svc) + assert authd is False + assert p.code == pam.PAM_AUTHINFO_UNAVAIL