diff --git a/docs/performance.rst b/docs/performance.rst index 1ec7130f1..2f2e933fa 100644 --- a/docs/performance.rst +++ b/docs/performance.rst @@ -14,19 +14,24 @@ Quick Tips ---------- * Use Conda-based Python, with ``tbb`` installed. -* Set the ``MKL_THREADING_LAYER`` environment variable to ``tbb``, so both MKL and LensKit - will use TBB and can coordinate their thread pools. +* When using MKL, set the ``MKL_THREADING_LAYER`` environment variable to ``tbb``, so both + MKL and LensKit will use TBB and can coordinate their thread pools. * Use ``LK_NUM_PROCS`` if you want to control LensKit's batch prediction and recommendation parallelism, and ``NUMBA_NUM_THREADS`` to control its model training parallelism. -We generally find the best performance using MKL with TBB throughout the stack. If both -LensKit's Numba-accelerated code and MKL are using TBB, they will coordinate their -thread pools to coordinate threading levels. +We generally find the best performance using MKL with TBB throughout the stack on Intel +processors. If both LensKit's Numba-accelerated code and MKL are using TBB, they will +coordinate their thread pools to coordinate threading levels. -If you are **not** using MKL with TBB, we recommend setting ``MKL_NUM_THREADS=1`` and/or -``OPENBLAS_NUM_THREADS=1`` (depending on your BLAS implementation) to turn off -BLAS threading. When LensKit starts (usually at model training time), it will -check your runtime environment and log warning messages if it detects problems. +If you are **not** using MKL (Apple Silicon, maybe also AMD processors), we recommend +controlling your BLAS parallelism. For OpenBLAS, how you control this depends on how +OpenBLAS was built, whether Numba is using OpenMP or TBB, and whether you are training +or evaluating the model. + +When LensKit starts (usually at model training time), it will check your runtime environment +and log warning messages if it detects problems. During evaluation, it also makes a +best-effort attempt, through `threadpoolctl`_, to disable nested parallelism when running +a parallel evaluation. Controlling Parallelism ----------------------- diff --git a/lenskit/util/debug.py b/lenskit/util/debug.py index dd9694559..cb7d5c82d 100644 --- a/lenskit/util/debug.py +++ b/lenskit/util/debug.py @@ -2,7 +2,6 @@ Debugging utility code. Also runnable as a Python command. Usage: - lenskit.util.debug [options] --libraries lenskit.util.debug [options] --blas-info lenskit.util.debug [options] --numba-info lenskit.util.debug [options] --check-env @@ -12,15 +11,12 @@ Turn on verbose logging """ -from pathlib import Path import sys import logging -import ctypes from typing import Optional from dataclasses import dataclass import numba -import psutil -import numpy as np +import threadpoolctl from .parallel import is_worker @@ -52,157 +48,21 @@ class NumbaInfo: threads: int -def _get_shlibs(): - proc = psutil.Process() - if hasattr(proc, 'memory_maps'): - return [mm.path for mm in proc.memory_maps()] - else: - return [] - - -def _guess_layer(): - layer = None - - _log.debug('scanning process memory maps for MKL threading layers') - for mm in _get_shlibs(): - if 'mkl_intel_thread' in mm: - _log.debug('found library %s linked', mm) - if layer: - _log.warn('multiple threading layers detected') - layer = 'intel' - elif 'mkl_tbb_thread' in mm: - _log.debug('found library %s linked', mm) - if layer: - _log.warn('multiple threading layers detected') - layer = 'tbb' - elif 'mkl_gnu_thread' in mm: - _log.debug('found library %s linked', mm) - if layer: - _log.warn('multiple threading layers detected') - layer = 'gnu' - - return layer - - -def guess_blas_unix(): - _log.info('opening self DLL') - dll = ctypes.CDLL(None) - - _log.debug('checking for MKL') - try: - mkl_vstr = dll.mkl_get_version_string - mkl_vbuf = ctypes.create_string_buffer(256) - mkl_vstr(mkl_vbuf, 256) - version = mkl_vbuf.value.decode().strip() - _log.debug('version %s', version) - - mkl_mth = dll.mkl_get_max_threads - mkl_mth.restype = ctypes.c_int - threads = mkl_mth() - - layer = _guess_layer() - - return BlasInfo('mkl', layer, threads, version) - except AttributeError as e: - _log.debug('MKL attribute error: %s', e) - pass # no MKL - - _log.debug('checking BLAS for OpenBLAS') - np_dll = ctypes.CDLL(np.core._multiarray_umath.__file__) - try: - openblas_vstr = np_dll.openblas_get_config - openblas_vstr.restype = ctypes.c_char_p - version = openblas_vstr().decode() - _log.debug('version %s', version) - - openblas_th = np_dll.openblas_get_num_threads - openblas_th.restype = ctypes.c_int - threads = openblas_th() - _log.debug('threads %d', threads) - - return BlasInfo('openblas', None, threads, version) - except AttributeError as e: - _log.info('OpenBLAS error: %s', e) - - return BlasInfo(None, None, None, 'unknown') - - -def _find_win_blas_path(): - for lib in _get_shlibs(): - path = Path(lib) - name = path.name - if not name.startswith('libopenblas'): +def blas_info(): + pools = threadpoolctl.threadpool_info() + blas = None + for pool in pools: + if pool['user_api'] != 'blas': continue - if path.parent.parent.name == 'numpy': - _log.debug('found BLAS at %s', lib) - return lib - elif path.parent.name == 'numpy.libs': - _log.debug('found BLAS at %s', lib) - return lib - - -def _find_win_blas(): - try: - blas_dll = ctypes.cdll.libblas - _log.debug('loaded BLAS dll %s', blas_dll) - return blas_dll - except (FileNotFoundError, OSError) as e: - _log.debug('no LIBBLAS, searching') - path = _find_win_blas_path() - if path is not None: - return ctypes.CDLL(path) - else: - _log.error('could not load LIBBLAS: %s', e) - return BlasInfo(None, None, None, 'unknown') - - -def guess_blas_windows(): - blas_dll = _find_win_blas() - - _log.debug('checking BLAS for MKL') - try: - mkl_vstr = blas_dll.mkl_get_version_string - mkl_vbuf = ctypes.create_string_buffer(256) - mkl_vstr(mkl_vbuf, 256) - version = mkl_vbuf.value.decode().strip() - _log.debug('version %s', version) - - mkl_mth = blas_dll.mkl_get_max_threads - mkl_mth.restype = ctypes.c_int - threads = mkl_mth() - - layer = _guess_layer() - - return BlasInfo('mkl', layer, threads, version) - except AttributeError as e: - _log.debug('MKL attribute error: %s', e) - pass # no MKL - - _log.debug('checking BLAS for OpenBLAS') - try: - openblas_vstr = blas_dll.openblas_get_config - openblas_vstr.restype = ctypes.c_char_p - version = openblas_vstr().decode() - - openblas_th = blas_dll.openblas_get_num_threads - openblas_th.restype = ctypes.c_int - threads = openblas_th() - _log.debug('threads %d', threads) - - return BlasInfo('openblas', None, threads, version) - except AttributeError as e: - _log.info('OpenBLAS error: %s', e) - - return BlasInfo(None, None, None, 'unknown') - + if blas is not None: + _log.warning("found multiple BLAS layers, using first") + _log.info("later layer is: %s", pool) + continue -def blas_info(): - if sys.platform == 'win32': - return guess_blas_windows() - else: - return guess_blas_unix() + blas = BlasInfo(pool['internal_api'], pool.get('threading_layer', None), pool.get('num_threads', None), pool['version']) + return blas def numba_info(): x = _par_test(100) @@ -240,17 +100,28 @@ def check_env(): _already_checked = True return + if blas is None: + _log.warning('threadpoolctl could not find your BLAS') + _already_checked = True + return + _log.info('Using BLAS %s', blas.impl) if numba.threading != 'tbb': - _log.warning('Numba is using threading layer %s - consider TBB', numba.threading) - _log.info('Non-TBB threading is often slower and can cause crashes') - problems += 1 + _log.info('Numba is using threading layer %s - consider TBB', numba.threading) if numba.threading == 'tbb' and blas.threading == 'tbb': _log.info('Numba and BLAS both using TBB - good') - elif blas.threads and blas.threads > 1 and numba.threads > 1: + + if numba.threading == 'tbb' and blas.impl == 'mkl' and blas.threading != 'tbb': + _log.warning('Numba using TBB but MKL is using %s', blas.threading) + _log.info('Set MKL_THREADING_LAYER=tbb for improved performance') + problems += 1 + + if blas.threads and blas.threads > 1 and numba.threads > 1: + # TODO make this be fine in OpenMP configurations _log.warning('BLAS using multiple threads - can cause oversubscription') + _log.info('See https://mde.one/lkpy-blas for information on tuning BLAS for LensKit') problems += 1 if problems: @@ -261,14 +132,6 @@ def check_env(): return problems -def print_libraries(): - p = psutil.Process() - - _log.info('printing process libraries') - for map in p.memory_maps(): - print(map.path) - - def print_blas_info(): blas = blas_info() print(blas) @@ -284,9 +147,8 @@ def main(): opts = docopt(__doc__) level = logging.DEBUG if opts['--verbose'] else logging.INFO logging.basicConfig(level=level, stream=sys.stderr, format='%(levelname)s %(name)s %(message)s') + logging.getLogger('numba').setLevel(logging.INFO) - if opts['--libraries']: - print_libraries() if opts['--blas-info']: print_blas_info() if opts['--numba-info']: diff --git a/lenskit/util/parallel.py b/lenskit/util/parallel.py index 27b00a4df..25cc7fcd4 100644 --- a/lenskit/util/parallel.py +++ b/lenskit/util/parallel.py @@ -12,6 +12,7 @@ from concurrent.futures import ProcessPoolExecutor from abc import ABC, abstractmethod import pickle +from threadpoolctl import threadpool_limits from lenskit.sharing import persist, PersistedModel from lenskit.util.log import log_queue @@ -103,17 +104,8 @@ def _initialize_mp_worker(mkey, func, threads, log_queue, seed): _initialize_worker(log_queue, seed) global __work_model, __work_func - nnt_env = os.environ.get('NUMBA_NUM_THREADS', None) - if nnt_env is None or int(nnt_env) > threads: - _log.debug('configuring Numba thread count') - import numba - numba.config.NUMBA_NUM_THREADS = threads - try: - import mkl - _log.debug('configuring MKL thread count') - mkl.set_num_threads(threads) - except ImportError: - pass + # disable BLAS threading + threadpool_limits(limits=1, user_api="blas") __work_model = mkey # deferred function unpickling to minimize imports before initialization diff --git a/pyproject.toml b/pyproject.toml index a9927b01e..79bcebd2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "scipy >= 1.8.0", "numba >= 0.56, < 0.59", "cffi >= 1.15.0", - "psutil >= 5", + "threadpoolctl >=3.0", "binpickle >= 0.3.2", "seedbank >= 0.1.0", "csr >= 0.5",