diff --git a/src/middlewared/middlewared/plugins/account.py b/src/middlewared/middlewared/plugins/account.py index e38185e3511c7..0c5270044fb28 100644 --- a/src/middlewared/middlewared/plugins/account.py +++ b/src/middlewared/middlewared/plugins/account.py @@ -1049,6 +1049,7 @@ def shell_choices(self, group_ids): Int('pw_gid'), List('grouplist'), Dict('sid_info'), + Bool('local'), register=True, )) async def get_user_obj(self, data): diff --git a/src/middlewared/middlewared/plugins/activedirectory_/cache.py b/src/middlewared/middlewared/plugins/activedirectory_/cache.py index be2c63c372fe4..de2ec26051557 100644 --- a/src/middlewared/middlewared/plugins/activedirectory_/cache.py +++ b/src/middlewared/middlewared/plugins/activedirectory_/cache.py @@ -1,6 +1,3 @@ -import grp -import pwd - from middlewared.plugins.idmap_.utils import ( IDType, SID_LOCAL_USER_PREFIX, @@ -9,6 +6,8 @@ ) from middlewared.service import Service, private, job from middlewared.service_exception import CallError +from middlewared.utils.nss import pwd, grp +from middlewared.utils.nss.nss_common import NssModule from time import sleep @@ -31,27 +30,13 @@ def get_entries(self, data): dom_by_sid = {x['domain_info']['sid']: x for x in domain_info} if entry_type == 'USER': - entries = WBClient().users() + entries = pwd.getpwall(module=NssModule.WINBIND.name)[NssModule.WINBIND.name] + for i in entries: + ret.append({"id": i.pw_uid, "sid": None, "nss": i, "id_type": entry_type}) else: - entries = WBClient().groups() - - for i in entries: - entry = {"id": -1, "sid": None, "nss": None, "id_type": entry_type} - if entry_type == 'USER': - try: - entry["nss"] = pwd.getpwnam(i) - except KeyError: - continue - entry["id"] = entry["nss"].pw_uid - - else: - try: - entry["nss"] = grp.getgrnam(i) - except KeyError: - continue - entry["id"] = entry["nss"].gr_gid - - ret.append(entry) + entries = grp.getgrall(module=NssModule.WINBIND.name)[NssModule.WINBIND.name] + for i in entries: + ret.append({"id": i.gr_gid, "sid": None, "nss": i, "id_type": entry_type}) idmaps = self.middleware.call_sync('idmap.convert_unixids', ret) to_remove = [] diff --git a/src/middlewared/middlewared/plugins/cache.py b/src/middlewared/middlewared/plugins/cache.py index 3a29f1cfb30fc..ac70d609fcbf9 100644 --- a/src/middlewared/middlewared/plugins/cache.py +++ b/src/middlewared/middlewared/plugins/cache.py @@ -1,6 +1,7 @@ from middlewared.schema import Any, Str, Ref, Int, Dict, Bool, accepts from middlewared.service import Service, private, job, filterable from middlewared.utils import filter_list +from middlewared.utils.nss import pwd, grp from middlewared.service_exception import CallError, MatchNotFound from middlewared.plugins.idmap_.utils import SID_LOCAL_USER_PREFIX, SID_LOCAL_GROUP_PREFIX @@ -8,8 +9,6 @@ import errno import os import time -import pwd -import grp class CacheService(Service): @@ -234,42 +233,37 @@ def get_uncached_user(self, username=None, uid=None, getgroups=False, sid_info=F for user validation. """ if username: - u = pwd.getpwnam(username) + user_obj = pwd.getpwnam(username, module='ALL', as_dict=True) elif uid is not None: - u = pwd.getpwuid(uid) + user_obj = pwd.getpwuid(uid, module='ALL', as_dict=True) else: return {} - user_obj = { - 'pw_name': u.pw_name, - 'pw_uid': u.pw_uid, - 'pw_gid': u.pw_gid, - 'pw_gecos': u.pw_gecos, - 'pw_dir': u.pw_dir, - 'pw_shell': u.pw_shell, - } + source = user_obj.pop('source') + user_obj['local'] = source == 'FILES' + if getgroups: - user_obj['grouplist'] = os.getgrouplist(u.pw_name, u.pw_gid) + user_obj['grouplist'] = os.getgrouplist(user_obj['pw_name'], user_obj['pw_gid']) if sid_info: try: if (idmap := self.middleware.call_sync('idmap.convert_unixids', [{ 'id_type': 'USER', - 'id': u.pw_uid, + 'id': user_obj['pw_uid'], }])['mapped']): - sid = idmap[f'UID:{u.pw_uid}']['sid'] + sid = idmap[f'UID:{user_obj["pw_uid"]}']['sid'] else: - sid = SID_LOCAL_USER_PREFIX + str(u.pw_uid) + sid = SID_LOCAL_USER_PREFIX + str(user_obj['pw_uid']) except CallError as e: # ENOENT means no winbindd entry for user # ENOTCONN means winbindd is stopped / can't be started # EAGAIN means the system dataset is hosed and needs to be fixed, # but we need to let it through so that it's very clear in logs if e.errno not in (errno.ENOENT, errno.ENOTCONN): - self.logger.error('Failed to retrieve SID for uid: %d', u.pw_uid, exc_info=True) + self.logger.error('Failed to retrieve SID for uid: %d', user_obj['pw_uid'], exc_info=True) sid = None except Exception: - self.logger.error('Failed to retrieve SID for uid: %d', u.pw_uid, exc_info=True) + self.logger.error('Failed to retrieve SID for uid: %d', user_obj['pw_uid'], exc_info=True) sid = None if sid: @@ -290,33 +284,30 @@ def get_uncached_group(self, groupname=None, gid=None, sid_info=False): for group validation. """ if groupname: - g = grp.getgrnam(groupname) + grp_obj = grp.getgrnam(groupname, module='ALL', as_dict=True) elif gid is not None: - g = grp.getgrgid(gid) + grp_obj = grp.getgrgid(gid, module='ALL', as_dict=True) else: return {} - grp_obj = { - 'gr_name': g.gr_name, - 'gr_gid': g.gr_gid, - 'gr_mem': g.gr_mem - } + source = grp_obj.pop('source') + grp_obj['local'] = source == 'FILES' if sid_info: try: if (idmap := self.middleware.call_sync('idmap.convert_unixids', [{ 'id_type': 'GROUP', - 'id': g.gr_gid, + 'id': grp_obj['gr_gid'], }])['mapped']): - sid = idmap[f'GID:{g.gr_gid}']['sid'] + sid = idmap[f'GID:{grp_obj["gr_gid"]}']['sid'] else: - sid = SID_LOCAL_GROUP_PREFIX + str(g.gr_gid) + sid = SID_LOCAL_GROUP_PREFIX + str(grp_obj['gr_gid']) except CallError as e: if e.errno not in (errno.ENOENT, errno.ENOTCONN): - self.logger.error('Failed to retrieve SID for gid: %d', grp.gr_gid, exc_info=True) + self.logger.error('Failed to retrieve SID for gid: %d', grp_obj['gr_gid'], exc_info=True) sid = None except Exception: - self.logger.error('Failed to retrieve SID for gid: %d', grp.gr_gid, exc_info=True) + self.logger.error('Failed to retrieve SID for gid: %d', grp['gr_gid'], exc_info=True) sid = None if sid: diff --git a/src/middlewared/middlewared/plugins/filesystem.py b/src/middlewared/middlewared/plugins/filesystem.py index e4263f72f7cf0..1ad948a62cd65 100644 --- a/src/middlewared/middlewared/plugins/filesystem.py +++ b/src/middlewared/middlewared/plugins/filesystem.py @@ -1,10 +1,8 @@ import binascii import errno import functools -import grp import os import pathlib -import pwd import shutil import stat as statlib import time @@ -19,6 +17,7 @@ from middlewared.service import private, CallError, filterable_returns, filterable, Service, job from middlewared.utils import filter_list from middlewared.utils.mount import getmntinfo +from middlewared.utils.nss import pwd, grp from middlewared.utils.path import FSLocation, path_location, strip_location_prefix, is_child_realpath from middlewared.plugins.filesystem_.utils import ACLType from middlewared.plugins.zfs_.utils import ZFSCTL diff --git a/src/middlewared/middlewared/plugins/filesystem_/perm_check.py b/src/middlewared/middlewared/plugins/filesystem_/perm_check.py index b2607703c853b..c85bee2d6482a 100644 --- a/src/middlewared/middlewared/plugins/filesystem_/perm_check.py +++ b/src/middlewared/middlewared/plugins/filesystem_/perm_check.py @@ -1,12 +1,11 @@ import errno -import grp import os import pathlib -import pwd from middlewared.schema import accepts, Bool, Dict, returns, Str from middlewared.service import CallError, Service, private +from middlewared.utils.nss import pwd, grp from middlewared.utils.user_context import run_with_user_context, set_user_context # This should be a sufficiently high UID to never be used explicitly diff --git a/src/middlewared/middlewared/plugins/ldap.py b/src/middlewared/middlewared/plugins/ldap.py index 989711e8e27a3..74c1619bef227 100644 --- a/src/middlewared/middlewared/plugins/ldap.py +++ b/src/middlewared/middlewared/plugins/ldap.py @@ -14,6 +14,8 @@ from middlewared.plugins.idmap import DSType from middlewared.plugins.ldap_.ldap_client import LdapClient from middlewared.plugins.ldap_ import constants +from middlewared.utils.nss import pwd, grp +from middlewared.utils.nss.nss_common import NssModule from middlewared.validators import Range LDAP_SMBCONF_PARAMS = { @@ -1047,15 +1049,12 @@ async def __stop(self, job): @job(lock='fill_ldap_cache') def fill_cache(self, job, force=False): user_next_index = group_next_index = 100000000 - if self.middleware.call_sync('cache.has_key', 'LDAP_cache') and not force: - raise CallError('LDAP cache already exists. Refusing to generate cache.') - if (self.middleware.call_sync('ldap.config'))['disable_freenas_cache']: self.logger.debug('LDAP cache is disabled. Bypassing cache fill.') return - pwd_list = [] - grp_list = [] + pwd_list = pwd.getpwall(module=NssModule.SSS.name, as_dict=True)[NssModule.SSS.name] + grp_list = grp.getgrall(module=NssModule.SSS.name, as_dict=True)[NssModule.SSS.name] for u in pwd_list: entry = { diff --git a/src/middlewared/middlewared/utils/nss/__init__.py b/src/middlewared/middlewared/utils/nss/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/middlewared/middlewared/utils/nss/grp.py b/src/middlewared/middlewared/utils/nss/grp.py new file mode 100644 index 0000000000000..7c36901bf6b44 --- /dev/null +++ b/src/middlewared/middlewared/utils/nss/grp.py @@ -0,0 +1,303 @@ +import ctypes +import errno + +from collections import namedtuple +from .nss_common import get_nss_func, NssError, NssModule, NssOperation, NssReturnCode + +GROUP_INIT_BUFLEN = 1024 + + +class Group(ctypes.Structure): + _fields_ = [ + ("gr_name", ctypes.c_char_p), + ("gr_passwd", ctypes.c_char_p), + ("gr_gid", ctypes.c_int), + ("gr_mem", ctypes.POINTER(ctypes.c_char_p)) + ] + + +group_struct = namedtuple('struct_group', ['gr_name', 'gr_gid', 'gr_mem', 'source']) + + +def __parse_nss_result(result, as_dict, module_name): + name = result.gr_name.decode() + members = list() + + i = 0 + while result.gr_mem[i]: + members.append(result.gr_mem[i].decode()) + i += 1 + + if as_dict: + return { + 'gr_name': name, + 'gr_gid': result.gr_gid, + 'gr_mem': members, + 'source': module_name + } + + return group_struct(name, result.gr_gid, members, module_name) + + +def __getgrnam_r(name, result_p, buffer_p, buflen, nss_module): + """ + enum nss_status _nss_#module#_getgrnam_r(const char *name, + struct group *result, + char *buffer, + size_t buflen, + int *error) + """ + func = get_nss_func(NssOperation.GETGRNAM, nss_module) + func.restype = ctypes.c_int + func.argtypes = [ + ctypes.c_char_p, + ctypes.POINTER(Group), + ctypes.c_char_p, + ctypes.c_ulong, + ctypes.POINTER(ctypes.c_int) + ] + + err = ctypes.c_int() + name = name.encode('utf-8') + res = func(ctypes.c_char_p(name), result_p, buffer_p, buflen, ctypes.byref(err)) + + return (int(res), err.value, result_p) + + +def __getgrgid_r(gid, result_p, buffer_p, buflen, nss_module): + """ + enum nss_status _nss_#module#_getgrgid_r(gid_t gid, + struct group *result, + char *buffer, + size_t buflen, + int *error) + """ + func = get_nss_func(NssOperation.GETGRGID, nss_module) + func.restype = ctypes.c_int + func.argtypes = [ + ctypes.c_ulong, + ctypes.POINTER(Group), + ctypes.c_char_p, + ctypes.c_ulong, + ctypes.POINTER(ctypes.c_int) + ] + err = ctypes.c_int() + res = func(gid, result_p, buffer_p, buflen, ctypes.byref(err)) + + return (int(res), err.value, result_p) + + +def __getgrent_r(result_p, buffer_p, buflen, nss_module): + """ + enum nss_status _nss_#module#_getgrent_r(struct group *result, + char *buffer, + size_t buflen, + int *error) + """ + func = get_nss_func(NssOperation.GETGRENT, nss_module) + func.restype = ctypes.c_int + func.argtypes = [ + ctypes.POINTER(Group), + ctypes.c_char_p, + ctypes.c_ulong, + ctypes.POINTER(ctypes.c_int) + ] + + err = ctypes.c_int() + res = func(result_p, buffer_p, buflen, ctypes.byref(err)) + + return (int(res), err.value, result_p) + + +def __setgrent(nss_module): + """ + enum nss_status _nss_#module#_setgrent(void) + """ + func = get_nss_func(NssOperation.SETGRENT, nss_module) + func.argtypes = [] + + res = func() + + if res != NssReturnCode.SUCCESS: + raise NssError(ctypes.get_errno(), NssOperation.SETGRENT, res, nss_module) + + +def __endgrent(nss_module): + """ + enum nss_status _nss_#module#_endgrent(void) + """ + func = get_nss_func(NssOperation.ENDGRENT, nss_module) + func.argtypes = [] + + res = func() + + if res != NssReturnCode.SUCCESS: + raise NssError(ctypes.get_errno(), NssOperation.ENDGRENT, res, nss_module) + + +def __getgrent_impl(mod, as_dict, buffer_len=GROUP_INIT_BUFLEN): + result = Group() + + buf = ctypes.create_string_buffer(buffer_len) + + res, error, result_p = __getgrent_r(ctypes.byref(result), buf, + buffer_len, mod) + + match error: + case 0: + pass + case errno.ERANGE: + # Our buffer was too small, increment + return __getgrent_impl(mod, as_dict, buffer_len * 2) + case _: + raise NssError(error, NssOperation.GETGRENT, res, mod) + + if res != NssReturnCode.SUCCESS: + return None + + return __parse_nss_result(result, as_dict, mod.name) + + +def __getgrall_impl(module, as_dict): + mod = NssModule[module] + __setgrent(mod) + group_list = [] + + group = __getgrent_impl(mod, as_dict) + while group is not None: + if (group := __getgrent_impl(mod, as_dict)): + group_list.append(group) + + __endgrent(mod) + return group_list + + +def __getgrnam_impl(name, module, as_dict, buffer_len=GROUP_INIT_BUFLEN): + mod = NssModule[module] + result = Group() + + buf = ctypes.create_string_buffer(buffer_len) + + res, error, result_p = __getgrnam_r(name, ctypes.byref(result), + buf, buffer_len, mod) + match error: + case 0: + pass + case errno.ERANGE: + # Our buffer was too small, increment + return __getgrnam_impl(name, module, as_dict, buffer_len * 2) + case _: + raise NssError(error, NssOperation.GETGRNAM, res, mod) + + if res == NssReturnCode.NOTFOUND: + return None + + return __parse_nss_result(result, as_dict, mod.name) + + +def __getgrgid_impl(gid, module, as_dict, buffer_len=GROUP_INIT_BUFLEN): + mod = NssModule[module] + result = Group() + buf = ctypes.create_string_buffer(buffer_len) + + res, error, result_p = __getgrgid_r(gid, ctypes.byref(result), + buf, buffer_len, mod) + match error: + case 0: + pass + case errno.ERANGE: + # Our buffer was too small, increment + return __getgrgid_impl(gid, module, as_dict, buffer_len * 2) + case _: + raise NssError(error, NssOperation.GETGRGID, res, mod) + + if res == NssReturnCode.NOTFOUND: + return None + + return __parse_nss_result(result, as_dict, mod.name) + + +def getgrgid(gid, module=NssModule.ALL.name, as_dict=False): + """ + Return the group database entry for the given group by gid. + + `module` - NSS module from which to retrieve the group + `as_dict` - return output as a dictionary rather than `struct_group`. + """ + if module != NssModule.ALL.name: + if (result := __getgrgid_impl(gid, module, as_dict)): + return result + + raise KeyError(f"getgrgid(): gid not found: '{gid}'") + + # We're querying all modules + for mod in NssModule: + if mod == NssModule.ALL: + continue + + try: + if (result := __getgrgid_impl(gid, mod.name, as_dict)): + return result + except NssError as e: + if e.return_code != NssReturnCode.UNAVAIL: + raise e from None + + raise KeyError(f"getgrgid(): gid not found: '{gid}'") + + +def getgrnam(name, module=NssModule.ALL.name, as_dict=False): + """ + Return the group database entry for the given group by name. + + `module` - NSS module from which to retrieve the group + `as_dict` - return output as a dictionary rather than `struct_group`. + """ + if module != NssModule.ALL.name: + if (result := __getgrnam_impl(name, module, as_dict)): + return result + + raise KeyError(f"getgrnam(): name not found: '{name}'") + + # We're querying all modules + for mod in NssModule: + if mod == NssModule.ALL: + continue + + try: + if (result := __getgrnam_impl(name, mod.name, as_dict)): + return result + except NssError as e: + if e.return_code != NssReturnCode.UNAVAIL: + raise e from None + + raise KeyError(f"getgrnam(): name not found: '{name}'") + + +def getgrall(module=NssModule.ALL.name, as_dict=False): + """ + Returns all group entries on server (similar to grp.getgrall()). + + `module` - NSS module from which to retrieve the entries + `as_dict` - return password database entries as dictionaries + + This module returns a dictionary keyed by NSS module, e.g. + {'FILES': [, ], 'WINBIND': [], 'SSS': []} + """ + if module != NssModule.ALL.name: + return {module: __getgrall_impl(module, as_dict)} + + results = {} + for mod in NssModule: + if mod == NssModule.ALL: + continue + + entries = [] + try: + entries = __getgrall_impl(mod.name, as_dict) + except NssError as e: + if e.return_code != NssReturnCode.UNAVAIL: + raise e from None + + results[mod.name] = entries + + return results diff --git a/src/middlewared/middlewared/utils/nss/nss_common.py b/src/middlewared/middlewared/utils/nss/nss_common.py new file mode 100644 index 0000000000000..102ac85d3a95b --- /dev/null +++ b/src/middlewared/middlewared/utils/nss/nss_common.py @@ -0,0 +1,62 @@ +import enum +import ctypes +import os + +NSS_MODULES_DIR = '/usr/lib/x86_64-linux-gnu' +FILES_NSS_PATH = os.path.join(NSS_MODULES_DIR, 'libnss_files.so.2') +SSS_NSS_PATH = os.path.join(NSS_MODULES_DIR, 'libnss_sss.so.2') +WINBIND_NSS_PATH = os.path.join(NSS_MODULES_DIR, 'libnss_winbind.so.2') + + +class NssReturnCode(enum.IntEnum): + """ Possible NSS return codes, see /usr/include/nss.h """ + TRYAGAIN = -2 + UNAVAIL = -1 + NOTFOUND = 0 + SUCCESS = 1 + RETURN = 2 + + +class NssModule(enum.Enum): + """ Currently supported NSS modules """ + ALL = enum.auto() + FILES = FILES_NSS_PATH + SSS = SSS_NSS_PATH + WINBIND = WINBIND_NSS_PATH + + +class NssOperation(enum.Enum): + """ Currently supported NSS operations """ + GETGRNAM = 'getgrnam_r' + GETGRGID = 'getgrgid_r' + SETGRENT = 'setgrent' + ENDGRENT = 'endgrent' + GETGRENT = 'getgrent_r' + GETPWNAM = 'getpwnam_r' + GETPWUID = 'getpwuid_r' + GETPWENT = 'getpwent_r' + SETPWENT = 'setpwent' + ENDPWENT = 'endpwent' + + +class NssError(Exception): + def __init__(self, errno, nssop, return_code, module): + self.errno = errno + self.nssop = nssop.value + self.return_code = return_code + self.mod_name = module.name + + def __str__(self): + errmsg = f'NSS operation {self.nssop} failed with errno {self.errno}: {self.return_code}' + if self.mod_name != 'ALL': + errmsg += f' on module [{self.mod_name.lower()}].' + + return errmsg + + +def get_nss_func(nss_op, nss_module): + if nss_module == NssModule.ALL: + raise ValueError('ALL module may not be explicitly used') + + lib = ctypes.CDLL(nss_module.value, use_errno=True) + return getattr(lib, f'_nss_{nss_module.name.lower()}_{nss_op.value}') diff --git a/src/middlewared/middlewared/utils/nss/pwd.py b/src/middlewared/middlewared/utils/nss/pwd.py new file mode 100644 index 0000000000000..9cb937257ff77 --- /dev/null +++ b/src/middlewared/middlewared/utils/nss/pwd.py @@ -0,0 +1,305 @@ +import ctypes +import errno + +from collections import namedtuple +from .nss_common import get_nss_func, NssError, NssModule, NssOperation, NssReturnCode + +PASSWD_INIT_BUFLEN = 1024 + + +class Passwd(ctypes.Structure): + _fields_ = [ + ("pw_name", ctypes.c_char_p), + ("pw_passwd", ctypes.c_char_p), + ("pw_uid", ctypes.c_int), + ("pw_gid", ctypes.c_int), + ("pw_gecos", ctypes.c_char_p), + ("pw_dir", ctypes.c_char_p), + ("pw_shell", ctypes.c_char_p) + ] + + +pwd_struct = namedtuple('struct_passwd', [ + 'pw_name', 'pw_uid', 'pw_gid', 'pw_gecos', 'pw_dir', 'pw_shell', 'source' +]) + + +def __parse_nss_result(result, as_dict, module_name): + name = result.pw_name.decode() + gecos = result.pw_gecos.decode() + homedir = result.pw_dir.decode() + shell = result.pw_shell.decode() + + if as_dict: + return { + 'pw_name': name, + 'pw_uid': result.pw_uid, + 'pw_gid': result.pw_gid, + 'pw_gecos': gecos, + 'pw_dir': homedir, + 'pw_shell': shell, + 'source': module_name + } + + return pwd_struct(name, result.pw_uid, result.pw_gid, homedir, gecos, shell, module_name) + + +def __getpwnam_r(name, result_p, buffer_p, buflen, nss_module): + """ + enum nss_status _nss_#module#_getpwnam_r(const char *name, + struct passwd *result, + char *buffer, + size_t buflen, + int *errnop) + """ + func = get_nss_func(NssOperation.GETPWNAM, nss_module) + func.restype = ctypes.c_int + func.argtypes = [ + ctypes.c_char_p, + ctypes.POINTER(Passwd), + ctypes.c_char_p, + ctypes.c_ulong, + ctypes.POINTER(ctypes.c_int) + ] + + err = ctypes.c_int() + name = name.encode('utf-8') + res = func(ctypes.c_char_p(name), result_p, buffer_p, buflen, ctypes.byref(err)) + + return (int(res), err.value, result_p) + + +def __getpwuid_r(uid, result_p, buffer_p, buflen, nss_module): + """ + enum nss_status _nss_#module#_getpwuid_r(uid_t uid, + struct passwd *result, + char *buffer, + size_t buflen, + int *errnop) + """ + func = get_nss_func(NssOperation.GETPWUID, nss_module) + func.restype = ctypes.c_int + func.argtypes = [ + ctypes.c_ulong, + ctypes.POINTER(Passwd), + ctypes.c_char_p, + ctypes.c_ulong, + ctypes.POINTER(ctypes.c_int) + ] + err = ctypes.c_int() + res = func(uid, result_p, buffer_p, buflen, ctypes.byref(err)) + + return (int(res), err.value, result_p) + + +def __getpwent_r(result_p, buffer_p, buflen, nss_module): + """ + enum nss_status _nss_#module#_getpwent_r(struct passwd *result, + char *buffer, size_t buflen, + int *errnop) + """ + func = get_nss_func(NssOperation.GETPWENT, nss_module) + func.restype = ctypes.c_int + func.argtypes = [ + ctypes.POINTER(Passwd), + ctypes.c_char_p, + ctypes.c_ulong, + ctypes.POINTER(ctypes.c_int) + ] + + err = ctypes.c_int() + res = func(result_p, buffer_p, buflen, ctypes.byref(err)) + + return (int(res), err.value, result_p) + + +def __setpwent(nss_module): + """ + enum nss_status _nss_#module#_setpwent(void) + """ + func = get_nss_func(NssOperation.SETPWENT, nss_module) + func.argtypes = [] + + res = func() + + if res != NssReturnCode.SUCCESS: + raise NssError(ctypes.get_errno(), NssOperation.SETPWENT, res, nss_module) + + +def __endpwent(nss_module): + """ + enum nss_status _nss_#module#_endpwent(void) + """ + func = get_nss_func(NssOperation.ENDPWENT, nss_module) + func.argtypes = [] + + res = func() + + if res != NssReturnCode.SUCCESS: + raise NssError(ctypes.get_errno(), NssOperation.ENDPWENT, res, nss_module) + + +def __getpwent_impl(mod, as_dict, buffer_len=PASSWD_INIT_BUFLEN): + result = Passwd() + buf = ctypes.create_string_buffer(buffer_len) + + res, error, result_p = __getpwent_r(ctypes.byref(result), buf, + buffer_len, mod) + match error: + case 0: + pass + case errno.ERANGE: + # Our buffer was too small, increment + return __getpwent_impl(mod, as_dict, buffer_len * 2) + case _: + raise NssError(error, NssOperation.GETPWENT, res, mod) + + if res != NssReturnCode.SUCCESS: + return None + + return __parse_nss_result(result, as_dict, mod.name) + + +def __getpwall_impl(module, as_dict): + mod = NssModule[module] + __setpwent(mod) + pwd_list = [] + + user = __getpwent_impl(mod, as_dict) + while user is not None: + if (user := __getpwent_impl(mod, as_dict)): + pwd_list.append(user) + + __endpwent(mod) + return pwd_list + + +def __getpwnam_impl(name, module, as_dict, buffer_len=PASSWD_INIT_BUFLEN): + mod = NssModule[module] + result = Passwd() + buf = ctypes.create_string_buffer(buffer_len) + + res, error, result_p = __getpwnam_r(name, ctypes.byref(result), + buf, buffer_len, mod) + match error: + case 0: + pass + case errno.ERANGE: + # Our buffer was too small, increment + return __getpwnam_impl(name, module, as_dict, buffer_len * 2) + case _: + raise NssError(error, NssOperation.GETPWNAM, res, mod) + + if res == NssReturnCode.NOTFOUND: + return None + + return __parse_nss_result(result, as_dict, mod.name) + + +def __getpwuid_impl(uid, module, as_dict, buffer_len=PASSWD_INIT_BUFLEN): + mod = NssModule[module] + result = Passwd() + buf = ctypes.create_string_buffer(buffer_len) + + res, error, result_p = __getpwuid_r(uid, ctypes.byref(result), + buf, buffer_len, mod) + match error: + case 0: + pass + case errno.ERANGE: + # Our buffer was too small, increment + return __getpwuid_impl(uid, module, as_dict, buffer_len * 2) + case _: + raise NssError(error, NssOperation.GETPWUID, res, mod) + + if res == NssReturnCode.NOTFOUND: + return None + + return __parse_nss_result(result, as_dict, mod.name) + + +def getpwuid(uid, module=NssModule.ALL.name, as_dict=False): + """ + Return the password database entry for the given user by uid. + + `module` - NSS module from which to retrieve the user + `as_dict` - return output as a dictionary rather than `struct_passwd`. + """ + if module != NssModule.ALL.name: + if (result := __getpwuid_impl(uid, module, as_dict)): + return result + + raise KeyError(f"getpwuid(): uid not found: '{uid}'") + + # We're querying all modules + for mod in NssModule: + if mod == NssModule.ALL: + continue + + try: + if (result := __getpwuid_impl(uid, mod.name, as_dict)): + return result + except NssError as e: + if e.return_code != NssReturnCode.UNAVAIL: + raise e from None + + raise KeyError(f"getpwuid(): uid not found: '{uid}'") + + +def getpwnam(name, module=NssModule.ALL.name, as_dict=False): + """ + Return the password database entry for the given user by name. + + `module` - NSS module from which to retrieve the user + `as_dict` - return output as a dictionary rather than `struct_passwd`. + """ + if module != NssModule.ALL.name: + if (result := __getpwnam_impl(name, module, as_dict)): + return result + + raise KeyError(f"getpwnam(): name not found: '{name}'") + + # We're querying all modules + for mod in NssModule: + if mod == NssModule.ALL: + continue + + try: + if (result := __getpwnam_impl(name, mod.name, as_dict)): + return result + except NssError as e: + if e.return_code != NssReturnCode.UNAVAIL: + raise e from None + + + raise KeyError(f"getpwnam(): name not found: '{name}'") + + +def getpwall(module=NssModule.ALL.name, as_dict=False): + """ + Returns all password entries on server (similar to pwd.getpwall()). + + `module` - NSS module from which to retrieve the entries + `as_dict` - return password database entries as dictionaries + + This module returns a dictionary keyed by NSS module, e.g. + {'FILES': [, ], 'WINBIND': [], 'SSS': []} + """ + if module != NssModule.ALL.name: + return {module: __getpwall_impl(module, as_dict)} + + results = {} + for mod in NssModule: + if mod == NssModule.ALL: + continue + + entries = [] + try: + entries = __getpwall_impl(mod.name, as_dict) + except NssError as e: + if e.return_code != NssReturnCode.UNAVAIL: + raise e from None + + results[mod.name] = entries + + return results