Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NAS-131223 / 25.04 / Add basic utilities for manipulating pam_tdb data #14367

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
@pytest.fixture(scope='module')
def groupmap_dir():
os.makedirs('/var/db/system/samba4', exist_ok=True)
try:
# pre-emptively delete in case we're running on a TrueNAS VM
os.unlink('/var/db/system/samba4/group_mapping.tdb')
except FileNotFoundError:
pass


@pytest.fixture(scope='module')
Expand Down
15 changes: 15 additions & 0 deletions src/middlewared/middlewared/utils/crypto.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from base64 import b64encode
from hashlib import pbkdf2_hmac
from secrets import choice, compare_digest, token_urlsafe, token_hex
from string import ascii_letters, digits, punctuation

Expand Down Expand Up @@ -63,3 +65,16 @@ def generate_nt_hash(passwd):
"""
md4_hash_bytes = md4_hash_blob(passwd.encode('utf-16le'))
return md4_hash_bytes.hex().upper()


def generate_pbkdf2_512(passwd):
"""
Generate a pbkdf2_sha512 hash for password. This is used for
verification of API keys.
"""
prefix = 'pbkdf2-sha512'
rounds = 500000
salt_length = 16
salt = generate_string(string_size=salt_length, extra_chars='./').encode()
hash = pbkdf2_hmac('sha512', passwd.encode(), salt, rounds)
return f'${prefix}${rounds}${b64encode(salt).decode()}${b64encode(hash).decode()}'
6 changes: 6 additions & 0 deletions src/middlewared/middlewared/utils/tdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class TDBHandle:
opath_fd = FD_CLOSED
keys_null_terminated = False

def __enter__(self):
return self

def __exit__(self, tp, val, traceback):
self.close()

def close(self):
""" Close the TDB handle and O_PATH open for the file """
if self.opath_fd == FD_CLOSED and self.hdl is None:
Expand Down
113 changes: 113 additions & 0 deletions src/middlewared/middlewared/utils/user_api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os

from base64 import b64encode
from dataclasses import dataclass
from struct import pack
from uuid import uuid4
from .tdb import (
TDBDataType,
TDBHandle,
TDBOptions,
TDBPathType,
)


PAM_TDB_DIR = '/var/run/pam_tdb'
PAM_TDB_FILE = os.path.join(PAM_TDB_DIR, 'pam_tdb.tdb')
PAM_TDB_DIR_MODE = 0o700
PAM_TDB_VERSION = 1
PAM_TDB_MAX_KEYS = 10 # Max number of keys per user. Also defined in pam_tdb.c

PAM_TDB_OPTIONS = TDBOptions(TDBPathType.CUSTOM, TDBDataType.BYTES)


@dataclass(frozen=True)
class UserApiKey:
expiry: int
dbid: int
userhash: str


@dataclass(frozen=True)
class PamTdbEntry:
keys: list[UserApiKey]
username: str


def _setup_pam_tdb_dir() -> None:
os.makedirs(PAM_TDB_DIR, mode=PAM_TDB_DIR_MODE, exist_ok=True)
os.chmod(PAM_TDB_DIR, PAM_TDB_DIR_MODE)


def _pack_user_api_key(api_key: UserApiKey) -> bytes:
"""
Convert UserApiKey object to bytes for TDB insertion.
This is packed struct with expiry converted into signed 64 bit
integer, the database id (32-bit unsigned), and the userhash (pascal string)
"""
if not isinstance(api_key, UserApiKey):
raise TypeError(f'{type(api_key)}: not a UserApiKey')

userhash = api_key.userhash.encode() + b'\x00'
return pack(f'<qI{len(userhash)}p', api_key.expiry, api_key.dbid, userhash)


def write_entry(hdl: TDBHandle, entry: PamTdbEntry) -> None:
"""
Convert PamTdbEntry object into a packed struct and insert
into tdb file.
key: username
value: uint32_t (version) + uint32_t (cnt of keys)
"""
if not isinstance(entry, PamTdbEntry):
raise TypeError(f'{type(entry)}: expected PamTdbEntry')

key_cnt = len(entry.keys)
if key_cnt > PAM_TDB_MAX_KEYS:
raise ValueError(f'{key_cnt}: count of entries exceeds maximum')

entry_bytes = pack('<II', PAM_TDB_VERSION, len(entry.keys))
parsed_cnt = 0
for key in entry.keys:
entry_bytes += _pack_user_api_key(key)
parsed_cnt += 1

# since we've already packed struct with array length
# we need to rigidly ensure we don't exceed it.
assert parsed_cnt == key_cnt
hdl.store(entry.username, b64encode(entry_bytes))


def flush_user_api_keys(pam_entries: list[PamTdbEntry]) -> None:
"""
Write a PamTdbEntry object to the pam_tdb file for user
authentication. This method first writes to temporary file
and then renames over pam_tdb file to ensure flush is atomic
and reduce risk of lock contention while under a transaction
lock.
raises:
TypeError - not PamTdbEntry
AssertionError - count of entries changed while generating
tdb payload
RuntimeError - TDB library error
"""
_setup_pam_tdb_dir()

if not isinstance(pam_entries, list):
raise TypeError('Expected list of PamTdbEntry objects')

tmp_path = os.path.join(PAM_TDB_DIR, f'tmp_{uuid4()}.tdb')

with TDBHandle(tmp_path, PAM_TDB_OPTIONS) as hdl:
hdl.keys_null_terminated = False

try:
for entry in pam_entries:
write_entry(hdl, entry)
except Exception:
os.remove(tmp_path)
raise

os.rename(tmp_path, PAM_TDB_FILE)
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
boto3
dnspython
junitparser
pytest
pytest-dependency
pytest-rerunfailures
Expand Down
162 changes: 162 additions & 0 deletions tests/run_unit_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# This script should be run locally from a TrueNAS VM. It runs all tests
# contained within the tests/unit directory as well as middleware specific unit
# tests contained within src/middlewared/middlewared/pytest/unit.
#
# NOTE: this requires `make install_tests` to have been run on the TrueNAS VM.

import argparse
import middlewared
import os
import pytest
import sys

from contextlib import contextmanager
from collections.abc import Generator
from dataclasses import dataclass
from junitparser import JUnitXml
from shutil import copytree, rmtree
from truenas_api_client import Client
from uuid import uuid4

DESCRIPTION = (
'Run unit tests from the specified middleware git repository on the '
'current TrueNAS server (version 25.04 or later). Exit code is one of '
'pytest exit codes with zero indicating success.'
)

UNIT_TESTS = 'tests/unit'
MIDDLEWARE_MODULE_PATH = os.path.dirname(os.path.abspath(middlewared.__file__))
MIDDLEWARE_PYTEST = 'src/middlewared/middlewared/pytest'
MIDDLEWARE_UNIT_TESTS = os.path.join(MIDDLEWARE_PYTEST, 'unit')
MIDDLEWARE_PYTEST_MODULE = os.path.join(MIDDLEWARE_MODULE_PATH, 'pytest')
RESULT_FILE = 'unit_tests_result.xml'
PYTEST_CONFTEST_FILE = 'tests/conftest.py'


@dataclass()
class UnitTestRun:
tests_dir: str
exit_code: pytest.ExitCode = pytest.ExitCode.NO_TESTS_COLLECTED
junit_file: str | None = None


def run_tests(data: UnitTestRun) -> UnitTestRun:
junit_file = f'unit_tests_result_{uuid4()}.xml'

data.exit_code = pytest.main([
'--disable-warnings', '-vv',
'-o', 'junit_family=xunit2',
'--junitxml', junit_file,
data.tests_dir
])

if data.exit_code is not pytest.ExitCode.OK:
print(
f'{data.tests_dir}: tests failed with code: {data.exit_code}',
file=sys.stderr
)

data.junit_file = junit_file
return data


def run_unit_tests(repo_dir: str) -> pytest.ExitCode:
"""
Iterate through our unit test sources and create a unified junit xml file
for the overall test results.
"""
xml_out = JUnitXml()
exit_code = pytest.ExitCode.NO_TESTS_COLLECTED
for test_dir in (
os.path.join(repo_dir, UNIT_TESTS),
os.path.join(repo_dir, MIDDLEWARE_UNIT_TESTS),
):
if not os.path.exists(test_dir):
raise FileNotFoundError(f'{test_dir}: unit test directory does not exist')

data = run_tests(UnitTestRun(tests_dir=test_dir))
xml_out += JUnitXml.fromfile(data.junit_file)
try:
os.remove(data.junit_file)
except Exception:
pass

match data.exit_code:
case pytest.ExitCode.NO_TESTS_COLLECTED:
# We'll treat this as a partial failure because we still want our
# test results from other runs, but don't want an overall misleading
# result.
print(
f'{test_dir}: not tests collected. Treating as partial failure.',
file=sys.stderr
)
if exit_code is pytest.ExitCode.OK:
exit_code = pytest.ExitCode.TESTS_FAILED

case pytest.ExitCode.OK:
# If this is our first OK test, set exit code
# otherwise preserve existing
if exit_code is pytest.ExitCode.NO_TESTS_COLLECTED:
exit_code = data.exit_code

case _:
# exit codes are an IntEnum. Preserve worst case
if exit_code < data.exit_code:
exit_code = data.exit_code

xml_out.write(RESULT_FILE)
return exit_code


@contextmanager
def disable_api_test_config(path: str) -> Generator[None, None, None]:
""" prevent API tests conftest from being applied """
os.rename(
os.path.join(path, PYTEST_CONFTEST_FILE),
os.path.join(path, f'{PYTEST_CONFTEST_FILE}.tmp')
)

try:
yield
finally:
os.rename(
os.path.join(path, f'{PYTEST_CONFTEST_FILE}.tmp'),
os.path.join(path, PYTEST_CONFTEST_FILE)
)


@contextmanager
def setup_middleware_tests(path: str) -> Generator[None, None, None]:
""" temporarily setup our pytest tests in the python dir """
try:
copytree(
os.path.join(path, MIDDLEWARE_PYTEST),
os.path.join(MIDDLEWARE_PYTEST_MODULE)
)
yield
finally:
rmtree(MIDDLEWARE_PYTEST_MODULE)


def main() -> None:
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument(
'-p', '--path',
help='Path to local copy of middleware git repository',
default='./middleware'
)

# lazy check to verify we're on a TrueNAS server
with Client() as c:
assert c.call('system.ready')

args = parser.parse_args()
with disable_api_test_config(args.path):
with setup_middleware_tests(args.path):
exit_code = run_unit_tests(args.path)

sys.exit(exit_code)


if __name__ == '__main__':
main()
Loading
Loading