diff --git a/.flake8 b/.flake8 index 39142f074..98f0dc50e 100644 --- a/.flake8 +++ b/.flake8 @@ -1,27 +1,4 @@ -[flake8] -# TODO: remove all of these ignores other than W503,W504,B008 -# `black` will handle enforcement of styling, and we will have no opinionated -# ignore rules -# any cases in which we actually need to ignore a rule (e.g. E402) we will mark -# the relevant segment with noqa comments as necessary -# -# D203: 1 blank line required before class docstring -# E124: closing bracket does not match visual indentation -# E126: continuation line over-indented for hanging indent -# This one is bad. Sometimes ordering matters, conditional imports -# setting env vars necessary etc. -# E402: module level import not at top of file -# E129: Visual indent to not match indent as next line, counter eg here: -# https://github.com/PyCQA/pycodestyle/issues/386 -# -# E203,W503,W504: conflict with black formatting sometimes -# B008: a flake8-bugbear rule which fails on idiomatic typer usage (consider -# re-enabling this once everything else is fixed and updating usage) -ignore = D203, E124, E126, E402, E129, W605, W503, W504, E203, F401, B008 +[flake8] # black-compatible +ignore = W503, W504, E203, B008 # TODO: reduce this to 88 once `black` is applied to all code -max-line-length = 160 -exclude = parsl/executors/serialize/, test_import_fail.py -# F632 is comparing constant literals with == instead of "is" -per-file-ignores = funcx_sdk/funcx/sdk/client.py:F632, - funcx_endpoint/funcx/endpoint/auth.py:F821, - funcx_endpoint/funcx/serialize/base.py:F821 +max-line-length = 88 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 391985466..b2327551d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -9,73 +9,68 @@ on: pull_request: jobs: - test: - strategy: - matrix: - python-version: [3.7] + lint: runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + - name: install pre-commit + run: | + python -m pip install -U pip setuptools wheel + python -m pip install pre-commit + - name: run pre-commit + run: pre-commit run -a + test-sdk: + runs-on: ubuntu-latest steps: - - uses: actions/checkout@master - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - name: Get latest pip version - run: | - python -m pip install --upgrade pip setuptools wheel - - name: Lint - run: | - pip install pre-commit - pre-commit run -a - - name: Install dependencies for funcx-sdk - run: | - python -m pip install -r funcx_sdk/requirements.txt - python -m pip install -r funcx_sdk/test-requirements.txt - pip list - - name: Check for vulnerabilities in libraries - run: | - pip install safety - pip freeze | safety check - - name: Test sdk by just importing - run: | - cd funcx_sdk - pip install . - python -c "from funcx.sdk.client import FuncXClient" - cd .. -# - name: Test with pytest -# run: | -# pytest - - name: Install dependencies for funcx-endpoint - run: | - python -m pip install -r funcx_endpoint/requirements.txt - python -m pip install -r funcx_endpoint/test-requirements.txt - pip list - - name: Check for vulnerabilities in libraries - run: | - pip install safety - pip freeze | safety check - - name: Test funcx-endpoint by just importing - run: | - cd funcx_endpoint - pip install . - python -c "from funcx_endpoint.version import VERSION" - funcx-endpoint -v - cd .. - - name: Lint with Flake8 - run: | - flake8 funcx_endpoint - - name: Test with pytest - run: | - PYTHONPATH=funcx_endpoint python -m coverage run -m pytest funcx_endpoint/tests/funcx_endpoint - - name: Report coverage with Codecov - run: | - codecov --token=${{ secrets.CODECOV_TOKEN }} + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: install requirements + run: | + python -m pip install -U pip setuptools wheel + python -m pip install './funcx_sdk[test]' + pip install safety + - name: run safety check + run: safety check + + # TODO: remove this test + # This is the weakest test which does anything, checking that the client can + # be imported. As soon as pytest is running again, remove this. + - name: check importable + run: python -c "from funcx.sdk.client import FuncXClient" + # - name: run pytest + # run: | + # cd funcx_sdk + # pytest + + test-endpoint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: install requirements + run: | + python -m pip install -U pip setuptools wheel + python -m pip install './funcx_endpoint[test]' + pip install safety + - name: run safety check + run: safety check + - name: run pytest + run: | + PYTHONPATH=funcx_endpoint python -m coverage run -m pytest funcx_endpoint/tests/funcx_endpoint publish: # only trigger on pushes to the main repo (not forks, and not PRs) if: ${{ github.repository == 'funcx-faas/funcX' && github.event_name == 'push' }} - needs: test + needs: + - lint + - test-sdk + - test-endpoint runs-on: ubuntu-latest strategy: matrix: diff --git a/.github/workflows/daily.yaml b/.github/workflows/daily.yaml new file mode 100644 index 000000000..de4e05c8e --- /dev/null +++ b/.github/workflows/daily.yaml @@ -0,0 +1,57 @@ +name: daily +on: + # build every day at 4:00 AM UTC + schedule: + - cron: '0 4 * * *' + workflow_dispatch: + +jobs: + safety-check-sdk: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + ref: main + - uses: actions/setup-python@v1 + - name: install requirements + run: | + python -m pip install --upgrade pip setuptools wheel + python -m pip install './funcx_sdk' + python -m pip install safety + - name: run safety check + run: safety check + + safety-check-endpoint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + ref: main + - uses: actions/setup-python@v1 + - name: install requirements + run: | + python -m pip install --upgrade pip setuptools wheel + python -m pip install './funcx_endpoint' + python -m pip install safety + - name: run safety check + run: safety check + + notify: + runs-on: ubuntu-latest + needs: + - safety-check-sdk + - safety-check-endpoint + if: failure() + steps: + # FIXME: make this send to a listhost or Slack + - name: Send mail + uses: dawidd6/action-send-mail@v3 + with: + server_address: smtp.gmail.com + server_port: 465 + username: ${{secrets.MAIL_USERNAME}} + password: ${{secrets.MAIL_PASSWORD}} + subject: ${{ github.repository }} - Daily Check ${{ job.status }} + to: ryan.chard@gmail.com,rchard@anl.gov,chard@uchicago.edu,yadudoc1729@gmail.com,josh@globus.org,bengal1@illinois.edu,benc@hawaga.org.uk,sirosen@globus.org,uriel@globus.org + from: funcX Tests # + body: The daily ${{ github.repository }} workflow failed! diff --git a/.github/workflows/hourly.yaml b/.github/workflows/hourly.yaml index 4abfb9d7c..9cbc6a087 100644 --- a/.github/workflows/hourly.yaml +++ b/.github/workflows/hourly.yaml @@ -10,7 +10,7 @@ on: description: "manual test" jobs: - tutorial_test: + smoke-test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -22,25 +22,28 @@ jobs: - name: Install dependencies for funcx-sdk and test requirements run: | python -m pip install --upgrade pip setuptools wheel - python -m pip install ./funcx_sdk - python -m pip install -r funcx_sdk/test-requirements.txt - - name: Check for vulnerabilities in libraries - run: | - pip install safety - safety check + python -m pip install './funcx_sdk[test]' + # fixme: Remove this next install line once issue #640 is fixed. + python -m pip install './funcx_endpoint' + python -m pip install safety - name: Run smoke tests to check liveness of hosted services run: | pytest -v funcx_endpoint/tests/smoke_tests --api-client-id ${{ secrets.API_CLIENT_ID }} --api-client-secret ${{ secrets.API_CLIENT_SECRET }} - # FIXME: make this send to a listhost or Slack - - name: Send mail - if: ${{ failure() }} - uses: dawidd6/action-send-mail@v3 - with: - server_address: smtp.gmail.com - server_port: 465 - username: ${{secrets.MAIL_USERNAME}} - password: ${{secrets.MAIL_PASSWORD}} - subject: ${{ github.repository }} - Tutorial test ${{ job.status }} - to: ryan.chard@gmail.com,rchard@anl.gov,chard@uchicago.edu,yadudoc1729@gmail.com,josh@globus.org,bengal1@illinois.edu,benc@hawaga.org.uk,sirosen@globus.org,uriel@globus.org - from: funcX Tests # - body: The ${{ github.repository }} test ${{ github.workflow }} exited with status - ${{ job.status }}! + + notify: + runs-on: ubuntu-latest + needs: [smoke-test] + if: failure() + steps: + # FIXME: make this send to a listhost or Slack + - name: Send mail + uses: dawidd6/action-send-mail@v3 + with: + server_address: smtp.gmail.com + server_port: 465 + username: ${{secrets.MAIL_USERNAME}} + password: ${{secrets.MAIL_PASSWORD}} + subject: ${{ github.repository }} - Hourly smoke test failed + to: ryan.chard@gmail.com,rchard@anl.gov,chard@uchicago.edu,yadudoc1729@gmail.com,josh@globus.org,bengal1@illinois.edu,benc@hawaga.org.uk,sirosen@globus.org,uriel@globus.org + from: funcX Tests # + body: The hourly ${{ github.repository }} workflow failed! diff --git a/.github/workflows/smoke_test.yaml b/.github/workflows/smoke_test.yaml index 6aa569e9a..4c12c2737 100644 --- a/.github/workflows/smoke_test.yaml +++ b/.github/workflows/smoke_test.yaml @@ -21,14 +21,12 @@ jobs: - uses: actions/setup-python@v1 with: python-version: 3.7 - - name: Install dependencies for funcx-sdk and test requirements + - name: install requirements run: | python -m pip install --upgrade pip setuptools wheel - python -m pip install ./funcx_sdk - python -m pip install -r funcx_sdk/test-requirements.txt - - name: Test sdk by just importing - run: | - python -c "from funcx.sdk.client import FuncXClient" + python -m pip install './funcx_sdk[test]' + # fixme: Remove this next install line once issue #640 is fixed. + python -m pip install './funcx_endpoint' - name: Run smoke tests to check liveness of hosted services run: | pytest -v funcx_endpoint/tests/smoke_tests --api-client-id ${{ secrets.API_CLIENT_ID }} --api-client-secret ${{ secrets.API_CLIENT_SECRET }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cabc51d2a..54ad96044 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,9 +8,6 @@ repos: hooks: - id: check-merge-conflict - id: trailing-whitespace - # FIXME: temporary exclude to reduce conflicts, remove this - exclude: ^funcx_endpoint/ - # end FIXME - repo: https://github.com/sirosen/check-jsonschema rev: 0.3.1 hooks: @@ -19,9 +16,6 @@ repos: rev: 21.5b1 hooks: - id: black - # FIXME: temporary exclude to reduce conflicts, remove this - exclude: ^funcx_endpoint/ - # end FIXME - repo: https://github.com/timothycrosley/isort rev: 5.8.0 hooks: @@ -29,17 +23,11 @@ repos: # explicitly pass settings file so that isort does not try to deduce # which settings to use based on a file's directory args: ["--settings-path", ".isort.cfg"] - # FIXME: temporary exclude to reduce conflicts, remove this - exclude: ^funcx_endpoint/ - # end FIXME - repo: https://github.com/asottile/pyupgrade rev: v2.17.0 hooks: - id: pyupgrade args: ["--py36-plus"] - # FIXME: temporary exclude to reduce conflicts, remove this - exclude: ^funcx_endpoint/ - # end FIXME - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.2 hooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 01773fa79..c701642ee 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,3 +30,18 @@ After installing `pre-commit`, run in the repo to configure hooks. > NOTE: If necessary, you can always skip hooks with `git commit --no-verify` + +## Installing Testing Requirements + +Testing requirements for each of the two packages in this repository +(funcx-sdk and funcx-endpoint) are specified as installable extras. + +To install the funcx-sdk test requirements + + cd funcx_sdk + pip install '.[test]' + +To install the funcx-endpoint test requirements + + cd funcx_endpoint + pip install '.[test]' diff --git a/docs/conf.py b/docs/conf.py index e8eb9f11e..9dde0e7f3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,10 +16,8 @@ import os import sys -import requests - sys.path.insert(0, os.path.abspath("../funcx_sdk/")) -import funcx +import funcx # noqa:E402 # -- Project information ----------------------------------------------------- diff --git a/docs/configs/bluewaters.py b/docs/configs/bluewaters.py index bb1530836..eddd6588c 100644 --- a/docs/configs/bluewaters.py +++ b/docs/configs/bluewaters.py @@ -10,7 +10,7 @@ # PLEASE UPDATE user_opts BEFORE USE user_opts = { 'bluewaters': { - 'worker_init': 'module load bwpy;source anaconda3/etc/profile.d/conda.sh;conda activate funcx_testing_py3.7', + 'worker_init': 'module load bwpy;source anaconda3/etc/profile.d/conda.sh;conda activate funcx_testing_py3.7', # noqa: E501 'scheduler_options': '', } } diff --git a/docs/configs/polaris.py b/docs/configs/polaris.py new file mode 100644 index 000000000..869e7b2c4 --- /dev/null +++ b/docs/configs/polaris.py @@ -0,0 +1,44 @@ +from parsl.launchers import SingleNodeLauncher +from parsl.providers import PBSProProvider + +from funcx_endpoint.endpoint.utils.config import Config +from funcx_endpoint.executors import HighThroughputExecutor +from funcx_endpoint.strategies import SimpleStrategy + +# fmt: off + +# PLEASE UPDATE user_opts BEFORE USE +user_opts = { + 'polaris': { + # Node setup: activate necessary conda environment and such. + 'worker_init': '', + # PBS directives (header lines): for array jobs pass '-J' option + # Set ncpus=32, otherwise it defaults to 1 on Polaris + 'scheduler_options': '#PBS -l select=32:ncpus=32', + } +} + +config = Config( + executors=[ + HighThroughputExecutor( + max_workers_per_node=1, + strategy=SimpleStrategy(max_idletime=300), + address='10.230.2.72', + provider=PBSProProvider( + launcher=SingleNodeLauncher(), + queue='workq', + scheduler_options=user_opts['polaris']['scheduler_options'], + # Command to be run before starting a worker, such as: + # 'module load Anaconda; source activate parsl_env'. + worker_init=user_opts['polaris']['worker_init'], + walltime='01:00:00', + nodes_per_block=1, + init_blocks=0, + min_blocks=0, + max_blocks=1, + ), + ) + ], +) + +# fmt: on diff --git a/docs/configs/uchicago_ai_cluster.py b/docs/configs/uchicago_ai_cluster.py index c4b69af0e..ba9a06b23 100644 --- a/docs/configs/uchicago_ai_cluster.py +++ b/docs/configs/uchicago_ai_cluster.py @@ -1,6 +1,6 @@ from parsl.addresses import address_by_hostname from parsl.launchers import SrunLauncher -from parsl.providers import LocalProvider, SlurmProvider +from parsl.providers import SlurmProvider from funcx_endpoint.endpoint.utils.config import Config from funcx_endpoint.executors import HighThroughputExecutor @@ -28,12 +28,17 @@ partition='general', # Launch 4 managers per node, each bound to 1 GPU - # This is a hack. We use hostname ; to terminate the srun command, and start our own + # This is a hack. We use hostname ; to terminate the srun command, and + # start our own + # # DO NOT MODIFY unless you know what you are doing. - launcher=SrunLauncher(overrides=(f'hostname; srun --ntasks={TOTAL_WORKERS} ' - f'--ntasks-per-node={WORKERS_PER_NODE} ' - f'--gpus-per-task=rtx2080ti:{GPUS_PER_WORKER} ' - f'--gpu-bind=map_gpu:{GPU_MAP}') + launcher=SrunLauncher( + overrides=( + f'hostname; srun --ntasks={TOTAL_WORKERS} ' + f'--ntasks-per-node={WORKERS_PER_NODE} ' + f'--gpus-per-task=rtx2080ti:{GPUS_PER_WORKER} ' + f'--gpu-bind=map_gpu:{GPU_MAP}' + ) ), # Scale between 0-1 blocks with 2 nodes per block diff --git a/docs/configuring.rst b/docs/configuring.rst index cbefb2c0a..d2f84855b 100644 --- a/docs/configuring.rst +++ b/docs/configuring.rst @@ -93,6 +93,18 @@ using the `CobaltProvider`. This configuration assumes that the script is being .. literalinclude:: configs/cooley.py +Polaris (ALCF) +^^^^^^^^^^^^ + +.. image:: images/ALCF_Polaris.jpeg + +The following snippet shows an example configuration for executing on Argonne Leadership Computing Facility's +**Polaris** cluster. This example uses the `HighThroughputExecutor` and connects to Polaris's PBS scheduler +using the `PBSProProvider`. This configuration assumes that the script is being executed on the login node of Polaris (edtb-02). + +.. literalinclude:: configs/polaris.py + + Cori (NERSC) ^^^^^^^^^^^^ diff --git a/docs/doc-requirements.txt b/docs/doc-requirements.txt index ba67ae0ad..528f1a9a6 100644 --- a/docs/doc-requirements.txt +++ b/docs/doc-requirements.txt @@ -1,4 +1,3 @@ --r ../funcx_sdk/requirements.txt --r ../funcx_sdk/test-requirements.txt +../funcx_sdk nbsphinx sphinx_rtd_theme diff --git a/docs/endpoints.rst b/docs/endpoints.rst index d0d2fe1c0..545692276 100644 --- a/docs/endpoints.rst +++ b/docs/endpoints.rst @@ -63,7 +63,7 @@ targeting before you start the endpoint. funcX is configured using a :class:`~funcx_endpoint.endpoint.utils.config.Config` object. funcX uses `Parsl `_ to manage resources. For more information, see the :class:`~funcx_endpoint.endpoint.utils.config.Config` class documentation and the -`Parsl documentation `_ . +`Parsl documentation `_ . .. note:: If the ENDPOINT_NAME is not specified, a default endpoint named "default" is configured. diff --git a/docs/images/ALCF_Polaris.jpeg b/docs/images/ALCF_Polaris.jpeg new file mode 100644 index 000000000..3847f77b8 Binary files /dev/null and b/docs/images/ALCF_Polaris.jpeg differ diff --git a/funcx_endpoint/funcx_endpoint/endpoint/config.py b/funcx_endpoint/funcx_endpoint/endpoint/config.py index 432eb7bbd..22394a72c 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/config.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/config.py @@ -1,17 +1,16 @@ -import globus_sdk -import parsl import os -from parsl.config import Config +import globus_sdk +from parsl.addresses import address_by_route from parsl.channels import LocalChannel -from parsl.providers import LocalProvider, KubernetesProvider +from parsl.config import Config from parsl.executors import HighThroughputExecutor -from parsl.addresses import address_by_route +from parsl.providers import KubernetesProvider, LocalProvider # GlobusAuth-related secrets -SECRET_KEY = os.environ.get('secret_key') -GLOBUS_KEY = os.environ.get('globus_key') -GLOBUS_CLIENT = os.environ.get('globus_client') +SECRET_KEY = os.environ.get("secret_key") +GLOBUS_KEY = os.environ.get("globus_key") +GLOBUS_CLIENT = os.environ.get("globus_client") FUNCX_URL = "https://funcx.org/" FUNCX_HUB_URL = "3.88.81.131" @@ -31,10 +30,9 @@ def _load_auth_client(): _prod = True if _prod: - app = globus_sdk.ConfidentialAppAuthClient(GLOBUS_CLIENT, - GLOBUS_KEY) + app = globus_sdk.ConfidentialAppAuthClient(GLOBUS_CLIENT, GLOBUS_KEY) else: - app = globus_sdk.ConfidentialAppAuthClient('', '') + app = globus_sdk.ConfidentialAppAuthClient("", "") return app @@ -62,7 +60,7 @@ def _get_parsl_config(): ), ) ], - strategy=None + strategy=None, ) return config @@ -77,31 +75,35 @@ def _get_executor(container): """ executor = HighThroughputExecutor( - label=container['container_uuid'], - cores_per_worker=1, - max_workers=1, - poll_period=10, - # launch_cmd="ls; sleep 3600", - worker_logdir_root='runinfo', - # worker_debug=True, - address=address_by_route(), - provider=KubernetesProvider( - namespace="dlhub-privileged", - image=container['location'], - nodes_per_block=1, - init_blocks=1, - max_blocks=1, - parallelism=1, - worker_init="""pip install git+https://github.com/Parsl/parsl; + label=container["container_uuid"], + cores_per_worker=1, + max_workers=1, + poll_period=10, + # launch_cmd="ls; sleep 3600", + worker_logdir_root="runinfo", + # worker_debug=True, + address=address_by_route(), + provider=KubernetesProvider( + namespace="dlhub-privileged", + image=container["location"], + nodes_per_block=1, + init_blocks=1, + max_blocks=1, + parallelism=1, + worker_init="""pip install git+https://github.com/Parsl/parsl; pip install git+https://github.com/funcx-faas/funcX; export PYTHONPATH=$PYTHONPATH:/home/ubuntu:/app""", - # security=None, - secret="ryan-kube-secret", - pod_name=container['name'].replace('.', '-').replace("_", '-').replace('/', '-').lower(), - # secret="minikube-aws-ecr", - # user_id=32781, - # group_id=10253, - # run_as_non_root=True - ), - ) + # security=None, + secret="ryan-kube-secret", + pod_name=container["name"] + .replace(".", "-") + .replace("_", "-") + .replace("/", "-") + .lower(), + # secret="minikube-aws-ecr", + # user_id=32781, + # group_id=10253, + # run_as_non_root=True + ), + ) return [executor] diff --git a/funcx_endpoint/funcx_endpoint/endpoint/default_config.py b/funcx_endpoint/funcx_endpoint/endpoint/default_config.py index cd451ad5c..10e5553e3 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/default_config.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/default_config.py @@ -1,16 +1,19 @@ +from parsl.providers import LocalProvider + from funcx_endpoint.endpoint.utils.config import Config from funcx_endpoint.executors import HighThroughputExecutor -from parsl.providers import LocalProvider config = Config( - executors=[HighThroughputExecutor( - provider=LocalProvider( - init_blocks=1, - min_blocks=0, - max_blocks=1, - ), - )], - funcx_service_address='https://api2.funcx.org/v2' + executors=[ + HighThroughputExecutor( + provider=LocalProvider( + init_blocks=1, + min_blocks=0, + max_blocks=1, + ), + ) + ], + funcx_service_address="https://api2.funcx.org/v2", ) # For now, visible_to must be a list of URNs for globus auth users or groups, e.g.: @@ -22,5 +25,5 @@ "organization": "", "department": "", "public": False, - "visible_to": [] + "visible_to": [], } diff --git a/funcx_endpoint/funcx_endpoint/endpoint/endpoint.py b/funcx_endpoint/funcx_endpoint/endpoint/endpoint.py index a3fa83fb0..83724f66a 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/endpoint.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/endpoint.py @@ -1,57 +1,43 @@ import glob -from importlib.machinery import SourceFileLoader -import json import logging import os import pathlib -import random -import shutil -import signal -import sys -import time -import uuid -from string import Template - -import daemon -import daemon.pidfile -import psutil -import requests -import typer -from retry import retry +from importlib.machinery import SourceFileLoader -import funcx -import zmq +import typer -from funcx_endpoint.endpoint import default_config as endpoint_default_config -from funcx_endpoint.executors.high_throughput import global_config as funcx_default_config -from funcx_endpoint.endpoint.interchange import EndpointInterchange from funcx_endpoint.endpoint.endpoint_manager import EndpointManager -from funcx.sdk.client import FuncXClient +from funcx_endpoint.logging_config import setup_logging app = typer.Typer() -logger = None +log = logging.getLogger(__name__) def version_callback(value): if value: import funcx_endpoint - typer.echo("FuncX endpoint version: {}".format(funcx_endpoint.__version__)) + + typer.echo(f"FuncX endpoint version: {funcx_endpoint.__version__}") raise typer.Exit() def complete_endpoint_name(): # Manager context is not initialized at this point, so we assume the default # the funcx_dir path of ~/.funcx - funcx_dir = os.path.join(pathlib.Path.home(), '.funcx') - config_files = glob.glob(os.path.join(funcx_dir, '*', 'config.py')) + funcx_dir = os.path.join(pathlib.Path.home(), ".funcx") + config_files = glob.glob(os.path.join(funcx_dir, "*", "config.py")) for config_file in config_files: yield os.path.basename(os.path.dirname(config_file)) @app.command(name="configure", help="Configure an endpoint") def configure_endpoint( - name: str = typer.Argument("default", help="endpoint name", autocompletion=complete_endpoint_name), - endpoint_config: str = typer.Option(None, "--endpoint-config", help="endpoint config file") + name: str = typer.Argument( + "default", help="endpoint name", autocompletion=complete_endpoint_name + ), + endpoint_config: str = typer.Option( + None, "--endpoint-config", help="endpoint config file" + ), ): """Configure an endpoint @@ -63,8 +49,10 @@ def configure_endpoint( @app.command(name="start", help="Start an endpoint by name") def start_endpoint( - name: str = typer.Argument("default", autocompletion=complete_endpoint_name), - endpoint_uuid: str = typer.Option(None, help="The UUID for the endpoint to register with") + name: str = typer.Argument("default", autocompletion=complete_endpoint_name), + endpoint_uuid: str = typer.Option( + None, help="The UUID for the endpoint to register with" + ), ): """Start an endpoint @@ -94,37 +82,45 @@ def start_endpoint( endpoint_dir = os.path.join(manager.funcx_dir, name) if not os.path.exists(endpoint_dir): - msg = (f'\nEndpoint {name} is not configured!\n' - '1. Please create a configuration template with:\n' - f'\tfuncx-endpoint configure {name}\n' - '2. Update the configuration\n' - '3. Start the endpoint\n') + msg = ( + f"\nEndpoint {name} is not configured!\n" + "1. Please create a configuration template with:\n" + f"\tfuncx-endpoint configure {name}\n" + "2. Update the configuration\n" + "3. Start the endpoint\n" + ) print(msg) return try: - endpoint_config = SourceFileLoader('config', - os.path.join(endpoint_dir, manager.funcx_config_file_name)).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(endpoint_dir, manager.funcx_config_file_name) + ).load_module() except Exception: - manager.logger.exception('funcX v0.2.0 made several non-backwards compatible changes to the config. ' - 'Your config might be out of date. ' - 'Refer to https://funcx.readthedocs.io/en/latest/endpoints.html#configuring-funcx') + log.exception( + "funcX v0.2.0 made several non-backwards compatible changes to the config. " + "Your config might be out of date. " + "Refer to " + "https://funcx.readthedocs.io/en/latest/endpoints.html#configuring-funcx" + ) raise manager.start_endpoint(name, endpoint_uuid, endpoint_config) @app.command(name="stop") -def stop_endpoint(name: str = typer.Argument("default", autocompletion=complete_endpoint_name)): - """ Stops an endpoint using the pidfile - - """ +def stop_endpoint( + name: str = typer.Argument("default", autocompletion=complete_endpoint_name) +): + """Stops an endpoint using the pidfile""" manager.stop_endpoint(name) @app.command(name="restart") -def restart_endpoint(name: str = typer.Argument("default", autocompletion=complete_endpoint_name)): +def restart_endpoint( + name: str = typer.Argument("default", autocompletion=complete_endpoint_name) +): """Restarts an endpoint""" stop_endpoint(name) start_endpoint(name) @@ -132,52 +128,68 @@ def restart_endpoint(name: str = typer.Argument("default", autocompletion=comple @app.command(name="list") def list_endpoints(): - """ List all available endpoints - """ + """List all available endpoints""" manager.list_endpoints() @app.command(name="delete") def delete_endpoint( - name: str = typer.Argument(..., autocompletion=complete_endpoint_name), - autoconfirm: bool = typer.Option(False, "-y", help="Do not ask for confirmation to delete.") + name: str = typer.Argument(..., autocompletion=complete_endpoint_name), + autoconfirm: bool = typer.Option( + False, "-y", help="Do not ask for confirmation to delete." + ), ): """Deletes an endpoint and its config.""" if not autoconfirm: - typer.confirm(f"Are you sure you want to delete the endpoint <{name}>?", abort=True) + typer.confirm( + f"Are you sure you want to delete the endpoint <{name}>?", abort=True + ) manager.delete_endpoint(name) @app.callback() def main( - ctx: typer.Context, - _: bool = typer.Option(None, "--version", "-v", callback=version_callback, is_eager=True), - debug: bool = typer.Option(False, "--debug", "-d"), - config_dir: str = typer.Option(os.path.join(pathlib.Path.home(), '.funcx'), "--config_dir", "-c", help="override default config dir") + ctx: typer.Context, + _: bool = typer.Option( + None, "--version", "-v", callback=version_callback, is_eager=True + ), + debug: bool = typer.Option(False, "--debug", "-d"), + config_dir: str = typer.Option( + os.path.join(pathlib.Path.home(), ".funcx"), + "--config_dir", + "-c", + help="override default config dir", + ), ): - # Note: no docstring here; the docstring for @app.callback is used as a help message for overall app. - # Sets up global variables in the State wrapper (debug flag, config dir, default config file). - # For commands other than `init`, we ensure the existence of the config directory and file. - - global logger - funcx.set_stream_logger(name='endpoint', - level=logging.DEBUG if debug else logging.INFO) - logger = logging.getLogger('endpoint') - logger.debug("Command: {}".format(ctx.invoked_subcommand)) + # Note: no docstring here; the docstring for @app.callback is used as a help + # message for overall app. + # + # Sets up global variables in the State wrapper (debug flag, config dir, default + # config file). + # + # For commands other than `init`, we ensure the existence of the config directory + # and file. + + setup_logging(debug=debug) + log.debug("Command: %s", ctx.invoked_subcommand) global manager - manager = EndpointManager(funcx_dir=config_dir, - debug=debug) + manager = EndpointManager(funcx_dir=config_dir, debug=debug) # Otherwise, we ensure that configs exist if not os.path.exists(manager.funcx_config_file): - logger.info(f"No existing configuration found at {manager.funcx_config_file}. Initializing...") + log.info( + "No existing configuration found at %s. Initializing...", + manager.funcx_config_file, + ) manager.init_endpoint() - logger.debug("Loading config files from {}".format(manager.funcx_dir)) + log.debug(f"Loading config files from {manager.funcx_dir}") - funcx_config = SourceFileLoader('global_config', manager.funcx_config_file).load_module() + funcx_config = SourceFileLoader( + "global_config", manager.funcx_config_file + ).load_module() manager.funcx_config = funcx_config.global_options @@ -186,5 +198,5 @@ def cli_run(): app() -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/funcx_endpoint/funcx_endpoint/endpoint/endpoint_manager.py b/funcx_endpoint/funcx_endpoint/endpoint/endpoint_manager.py index e4f196064..8529a1812 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/endpoint_manager.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/endpoint_manager.py @@ -6,7 +6,6 @@ import shutil import signal import sys -import time import uuid from string import Template @@ -15,30 +14,33 @@ import psutil import texttable import typer - -import funcx import zmq from globus_sdk import GlobusAPIError, NetworkError +from funcx.sdk.client import FuncXClient from funcx.utils.response_errors import FuncxResponseError from funcx_endpoint.endpoint import default_config as endpoint_default_config -from funcx_endpoint.executors.high_throughput import global_config as funcx_default_config from funcx_endpoint.endpoint.interchange import EndpointInterchange from funcx_endpoint.endpoint.register_endpoint import register_endpoint from funcx_endpoint.endpoint.results_ack import ResultsAckHandler -from funcx.sdk.client import FuncXClient +from funcx_endpoint.executors.high_throughput import ( + global_config as funcx_default_config, +) +from funcx_endpoint.logging_config import setup_logging -logger = logging.getLogger("endpoint.endpoint_manager") +log = logging.getLogger(__name__) class EndpointManager: - """ EndpointManager is primarily responsible for configuring, launching and stopping the Endpoint. + """ + EndpointManager is primarily responsible for configuring, launching and stopping + the Endpoint. """ - def __init__(self, - funcx_dir=os.path.join(pathlib.Path.home(), '.funcx'), - debug=False): - """ Initialize the EndpointManager + def __init__( + self, funcx_dir=os.path.join(pathlib.Path.home(), ".funcx"), debug=False + ): + """Initialize the EndpointManager Parameters ---------- @@ -49,18 +51,18 @@ def __init__(self, debug: Bool Enable debug logging. Default: False """ - self.funcx_config_file_name = 'config.py' + self.funcx_config_file_name = "config.py" self.debug = debug self.funcx_dir = funcx_dir - self.funcx_config_file = os.path.join(self.funcx_dir, self.funcx_config_file_name) + self.funcx_config_file = os.path.join( + self.funcx_dir, self.funcx_config_file_name + ) self.funcx_default_config_template = funcx_default_config.__file__ self.funcx_config = {} - self.name = 'default' - global logger - self.logger = logger + self.name = "default" def init_endpoint_dir(self, endpoint_config=None): - """ Initialize a clean endpoint dir + """Initialize a clean endpoint dir Returns if an endpoint_dir already exists Parameters @@ -70,10 +72,12 @@ def init_endpoint_dir(self, endpoint_config=None): """ endpoint_dir = os.path.join(self.funcx_dir, self.name) - self.logger.debug(f"Creating endpoint dir {endpoint_dir}") + log.debug(f"Creating endpoint dir {endpoint_dir}") os.makedirs(endpoint_dir, exist_ok=True) - endpoint_config_target_file = os.path.join(endpoint_dir, self.funcx_config_file_name) + endpoint_config_target_file = os.path.join( + endpoint_dir, self.funcx_config_file_name + ) if endpoint_config: shutil.copyfile(endpoint_config, endpoint_config_target_file) return endpoint_dir @@ -95,54 +99,60 @@ def configure_endpoint(self, name, endpoint_config): if not os.path.exists(endpoint_dir): self.init_endpoint_dir(endpoint_config=endpoint_config) - print(f'A default profile has been create for <{self.name}> at {new_config_file}') - print('Configure this file and try restarting with:') - print(f' $ funcx-endpoint start {self.name}') + print( + f"A default profile has been create for <{self.name}> " + f"at {new_config_file}" + ) + print("Configure this file and try restarting with:") + print(f" $ funcx-endpoint start {self.name}") else: - print(f'config dir <{self.name}> already exsits') - raise Exception('ConfigExists') + print(f"config dir <{self.name}> already exsits") + raise Exception("ConfigExists") def init_endpoint(self): """Setup funcx dirs and default endpoint config files TODO : Every mechanism that will update the config file, must be using a - locking mechanism, ideally something like fcntl https://docs.python.org/3/library/fcntl.html - to ensure that multiple endpoint invocations do not mangle the funcx config files - or the lockfile module. + locking mechanism, ideally something like fcntl [1] + to ensure that multiple endpoint invocations do not mangle the funcx config + files or the lockfile module. + + [1] https://docs.python.org/3/library/fcntl.html """ _ = FuncXClient() if os.path.exists(self.funcx_config_file): typer.confirm( "Are you sure you want to initialize this directory? " - f"This will erase everything in {self.funcx_dir}", abort=True + f"This will erase everything in {self.funcx_dir}", + abort=True, ) - self.logger.info("Wiping all current configs in {}".format(self.funcx_dir)) + log.info(f"Wiping all current configs in {self.funcx_dir}") backup_dir = self.funcx_dir + ".bak" try: - self.logger.debug(f"Removing old backups in {backup_dir}") + log.debug(f"Removing old backups in {backup_dir}") shutil.rmtree(backup_dir) except OSError: pass os.renames(self.funcx_dir, backup_dir) if os.path.exists(self.funcx_config_file): - self.logger.debug("Config file exists at {}".format(self.funcx_config_file)) + log.debug(f"Config file exists at {self.funcx_config_file}") return try: os.makedirs(self.funcx_dir, exist_ok=True) except Exception as e: - print("[FuncX] Caught exception during registration {}".format(e)) + print(f"[FuncX] Caught exception during registration {e}") shutil.copyfile(self.funcx_default_config_template, self.funcx_config_file) def check_endpoint_json(self, endpoint_json, endpoint_uuid): if os.path.exists(endpoint_json): - with open(endpoint_json, 'r') as fp: - self.logger.debug("Connection info loaded from prior registration record") + with open(endpoint_json) as fp: + log.debug("Connection info loaded from prior registration record") reg_info = json.load(fp) - endpoint_uuid = reg_info['endpoint_id'] + endpoint_uuid = reg_info["endpoint_id"] elif not endpoint_uuid: endpoint_uuid = str(uuid.uuid4()) return endpoint_uuid @@ -151,18 +161,23 @@ def start_endpoint(self, name, endpoint_uuid, endpoint_config): self.name = name endpoint_dir = os.path.join(self.funcx_dir, self.name) - endpoint_json = os.path.join(endpoint_dir, 'endpoint.json') + endpoint_json = os.path.join(endpoint_dir, "endpoint.json") # These certs need to be recreated for every registration - keys_dir = os.path.join(endpoint_dir, 'certificates') + keys_dir = os.path.join(endpoint_dir, "certificates") os.makedirs(keys_dir, exist_ok=True) - client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "endpoint") + client_public_file, client_secret_file = zmq.auth.create_certificates( + keys_dir, "endpoint" + ) client_public_key, _ = zmq.auth.load_certificate(client_public_file) - client_public_key = client_public_key.decode('utf-8') + client_public_key = client_public_key.decode("utf-8") # This is to ensure that at least 1 executor is defined if not endpoint_config.config.executors: - raise Exception(f"Endpoint config file at {endpoint_dir} is missing executor definitions") + raise Exception( + f"Endpoint config file at {endpoint_dir} is missing " + "executor definitions" + ) funcx_client_options = { "funcx_service_address": endpoint_config.config.funcx_service_address, @@ -172,19 +187,22 @@ def start_endpoint(self, name, endpoint_uuid, endpoint_config): endpoint_uuid = self.check_endpoint_json(endpoint_json, endpoint_uuid) - self.logger.info(f"Starting endpoint with uuid: {endpoint_uuid}") + log.info(f"Starting endpoint with uuid: {endpoint_uuid}") - pid_file = os.path.join(endpoint_dir, 'daemon.pid') + pid_file = os.path.join(endpoint_dir, "daemon.pid") pid_check = self.check_pidfile(pid_file) # if the pidfile exists, we should return early because we don't # want to attempt to create a new daemon when one is already # potentially running with the existing pidfile - if pid_check['exists']: - if pid_check['active']: - self.logger.info("Endpoint is already active") + if pid_check["exists"]: + if pid_check["active"]: + log.info("Endpoint is already active") sys.exit(-1) else: - self.logger.info("A prior Endpoint instance appears to have been terminated without proper cleanup. Cleaning up now.") + log.info( + "A prior Endpoint instance appears to have been terminated without " + "proper cleanup. Cleaning up now." + ) self.pidfile_cleanup(pid_file) results_ack_handler = ResultsAckHandler(endpoint_dir=endpoint_dir) @@ -193,100 +211,152 @@ def start_endpoint(self, name, endpoint_uuid, endpoint_config): results_ack_handler.load() results_ack_handler.persist() except Exception: - self.logger.exception("Caught exception while attempting load and persist of outstanding results") + log.exception( + "Caught exception while attempting load and persist of outstanding " + "results" + ) sys.exit(-1) # Create a daemon context # If we are running a full detached daemon then we will send the output to # log files, otherwise we can piggy back on our stdout if endpoint_config.config.detach_endpoint: - stdout = open(os.path.join(endpoint_dir, endpoint_config.config.stdout), 'a+') - stderr = open(os.path.join(endpoint_dir, endpoint_config.config.stderr), 'a+') + stdout = open( + os.path.join(endpoint_dir, endpoint_config.config.stdout), "a+" + ) + stderr = open( + os.path.join(endpoint_dir, endpoint_config.config.stderr), "a+" + ) else: stdout = sys.stdout stderr = sys.stderr try: - context = daemon.DaemonContext(working_directory=endpoint_dir, - umask=0o002, - pidfile=daemon.pidfile.PIDLockFile(pid_file), - stdout=stdout, - stderr=stderr, - detach_process=endpoint_config.config.detach_endpoint) + context = daemon.DaemonContext( + working_directory=endpoint_dir, + umask=0o002, + pidfile=daemon.pidfile.PIDLockFile(pid_file), + stdout=stdout, + stderr=stderr, + detach_process=endpoint_config.config.detach_endpoint, + ) except Exception: - self.logger.exception("Caught exception while trying to setup endpoint context dirs") + log.exception( + "Caught exception while trying to setup endpoint context dirs" + ) sys.exit(-1) # place registration after everything else so that the endpoint will # only be registered if everything else has been set up successfully reg_info = None try: - reg_info = register_endpoint(funcx_client, endpoint_uuid, endpoint_dir, self.name, logger=self.logger) + reg_info = register_endpoint( + funcx_client, endpoint_uuid, endpoint_dir, self.name + ) # if the service sends back an error response, it will be a FuncxResponseError except FuncxResponseError as e: # an example of an error that could conceivably occur here would be # if the service could not register this endpoint with the forwarder # because the forwarder was unreachable if e.http_status_code >= 500: - self.logger.exception("Caught exception while attempting endpoint registration") - self.logger.critical("Endpoint registration will be retried in the new endpoint daemon " - "process. The endpoint will not work until it is successfully registered.") + log.exception("Caught exception while attempting endpoint registration") + log.critical( + "Endpoint registration will be retried in the new endpoint daemon " + "process. The endpoint will not work until it is successfully " + "registered." + ) else: raise e # if the service has an unexpected internal error and is unable to send # back a FuncxResponseError except GlobusAPIError as e: if e.http_status >= 500: - self.logger.exception("Caught exception while attempting endpoint registration") - self.logger.critical("Endpoint registration will be retried in the new endpoint daemon " - "process. The endpoint will not work until it is successfully registered.") + log.exception("Caught exception while attempting endpoint registration") + log.critical( + "Endpoint registration will be retried in the new endpoint daemon " + "process. The endpoint will not work until it is successfully " + "registered." + ) else: raise e # if the service is unreachable due to a timeout or connection error except NetworkError as e: # the output of a NetworkError exception is huge and unhelpful, so # it seems better to just stringify it here and get a concise error - self.logger.exception(f"Caught exception while attempting endpoint registration: {e}") - self.logger.critical("funcx-endpoint is unable to reach the funcX service due to a NetworkError \n" - "Please make sure that the funcX service address you provided is reachable \n" - "and then attempt restarting the endpoint") + log.exception( + f"Caught exception while attempting endpoint registration: {e}" + ) + log.critical( + "funcx-endpoint is unable to reach the funcX service due to a " + "NetworkError \n" + "Please make sure that the funcX service address you provided is " + "reachable \n" + "and then attempt restarting the endpoint" + ) exit(-1) except Exception: raise if reg_info: - self.logger.info("Launching endpoint daemon process") + log.info("Launching endpoint daemon process") else: - self.logger.critical("Launching endpoint daemon process with errors noted above") + log.critical("Launching endpoint daemon process with errors noted above") + + # NOTE + # It's important that this log is emitted before we enter the daemon context + # because daemonization closes down everything, a log message inside the + # context won't write the currently configured loggers + logfile = os.path.join(endpoint_dir, "endpoint.log") + log.info( + "Logging will be reconfigured for the daemon. logfile=%s , debug=%s", + logfile, + self.debug, + ) with context: - self.daemon_launch(endpoint_uuid, endpoint_dir, keys_dir, endpoint_config, reg_info, funcx_client_options, results_ack_handler) + setup_logging(logfile=logfile, debug=self.debug, console_enabled=False) + self.daemon_launch( + endpoint_uuid, + endpoint_dir, + keys_dir, + endpoint_config, + reg_info, + funcx_client_options, + results_ack_handler, + ) - def daemon_launch(self, endpoint_uuid, endpoint_dir, keys_dir, endpoint_config, reg_info, funcx_client_options, results_ack_handler): + def daemon_launch( + self, + endpoint_uuid, + endpoint_dir, + keys_dir, + endpoint_config, + reg_info, + funcx_client_options, + results_ack_handler, + ): # Configure the parameters for the interchange optionals = {} - if 'endpoint_address' in self.funcx_config: - optionals['interchange_address'] = self.funcx_config['endpoint_address'] - - optionals['logdir'] = endpoint_dir - - if self.debug: - optionals['logging_level'] = logging.DEBUG - - ic = EndpointInterchange(endpoint_config.config, - endpoint_id=endpoint_uuid, - keys_dir=keys_dir, - endpoint_dir=endpoint_dir, - endpoint_name=self.name, - reg_info=reg_info, - funcx_client_options=funcx_client_options, - results_ack_handler=results_ack_handler, - **optionals) + if "endpoint_address" in self.funcx_config: + optionals["interchange_address"] = self.funcx_config["endpoint_address"] + + ic = EndpointInterchange( + endpoint_config.config, + endpoint_id=endpoint_uuid, + keys_dir=keys_dir, + endpoint_dir=endpoint_dir, + endpoint_name=self.name, + reg_info=reg_info, + funcx_client_options=funcx_client_options, + results_ack_handler=results_ack_handler, + logdir=endpoint_dir, + **optionals, + ) ic.start() - self.logger.critical("Interchange terminated.") + log.critical("Interchange terminated.") def stop_endpoint(self, name): self.name = name @@ -294,15 +364,16 @@ def stop_endpoint(self, name): pid_file = os.path.join(endpoint_dir, "daemon.pid") pid_check = self.check_pidfile(pid_file) - # The process is active if the PID file exists and the process it points to is a funcx-endpoint - if pid_check['active']: - self.logger.debug(f"{self.name} has a daemon.pid file") + # The process is active if the PID file exists and the process it points to is + # a funcx-endpoint + if pid_check["active"]: + log.debug(f"{self.name} has a daemon.pid file") pid = None - with open(pid_file, 'r') as f: + with open(pid_file) as f: pid = int(f.read()) # Attempt terminating try: - self.logger.debug("Signalling process: {}".format(pid)) + log.debug(f"Signalling process: {pid}") # For all the processes, including the deamon and its child process tree # Send SIGTERM to the processes # Wait for 200ms @@ -322,25 +393,25 @@ def stop_endpoint(self, name): pass # Wait to confirm that the pid file disappears if not os.path.exists(pid_file): - self.logger.info("Endpoint <{}> is now stopped".format(self.name)) + log.info(f"Endpoint <{self.name}> is now stopped") except OSError: - self.logger.warning("Endpoint <{}> could not be terminated".format(self.name)) - self.logger.warning("Attempting Endpoint <{}> cleanup".format(self.name)) + log.warning(f"Endpoint <{self.name}> could not be terminated") + log.warning(f"Attempting Endpoint <{self.name}> cleanup") os.remove(pid_file) sys.exit(-1) # The process is not active, but the PID file exists and needs to be deleted - elif pid_check['exists']: + elif pid_check["exists"]: self.pidfile_cleanup(pid_file) else: - self.logger.info("Endpoint <{}> is not active.".format(self.name)) + log.info(f"Endpoint <{self.name}> is not active.") def delete_endpoint(self, name): self.name = name endpoint_dir = os.path.join(self.funcx_dir, self.name) if not os.path.exists(endpoint_dir): - self.logger.warning("Endpoint <{}> does not exist".format(self.name)) + log.warning(f"Endpoint <{self.name}> does not exist") sys.exit(-1) # stopping the endpoint should handle all of the process cleanup before @@ -348,10 +419,10 @@ def delete_endpoint(self, name): self.stop_endpoint(self.name) shutil.rmtree(endpoint_dir) - self.logger.info("Endpoint <{}> has been deleted.".format(self.name)) + log.info(f"Endpoint <{self.name}> has been deleted.") def check_pidfile(self, filepath): - """ Helper function to identify possible dead endpoints + """Helper function to identify possible dead endpoints Returns a record with 'exists' and 'active' fields indicating whether the pidfile exists, and whether the process is active if it does exist @@ -363,12 +434,9 @@ def check_pidfile(self, filepath): Path to the pidfile """ if not os.path.exists(filepath): - return { - "exists": False, - "active": False - } + return {"exists": False, "active": False} - pid = int(open(filepath, 'r').read().strip()) + pid = int(open(filepath).read().strip()) active = False try: @@ -380,40 +448,37 @@ def check_pidfile(self, filepath): # it means the endpoint has been terminated without proper cleanup active = True - return { - "exists": True, - "active": active - } + return {"exists": True, "active": active} def pidfile_cleanup(self, filepath): os.remove(filepath) - self.logger.info("Endpoint <{}> has been cleaned up.".format(self.name)) + log.info(f"Endpoint <{self.name}> has been cleaned up.") def list_endpoints(self): table = texttable.Texttable() - headings = ['Endpoint Name', 'Status', 'Endpoint ID'] + headings = ["Endpoint Name", "Status", "Endpoint ID"] table.header(headings) - config_files = glob.glob(os.path.join(self.funcx_dir, '*', 'config.py')) + config_files = glob.glob(os.path.join(self.funcx_dir, "*", "config.py")) for config_file in config_files: endpoint_dir = os.path.dirname(config_file) endpoint_name = os.path.basename(endpoint_dir) - status = 'Initialized' + status = "Initialized" endpoint_id = None - endpoint_json = os.path.join(endpoint_dir, 'endpoint.json') + endpoint_json = os.path.join(endpoint_dir, "endpoint.json") if os.path.exists(endpoint_json): - with open(endpoint_json, 'r') as f: + with open(endpoint_json) as f: endpoint_info = json.load(f) - endpoint_id = endpoint_info['endpoint_id'] - pid_check = self.check_pidfile(os.path.join(endpoint_dir, 'daemon.pid')) - if pid_check['active']: - status = 'Running' - elif pid_check['exists']: - status = 'Disconnected' + endpoint_id = endpoint_info["endpoint_id"] + pid_check = self.check_pidfile(os.path.join(endpoint_dir, "daemon.pid")) + if pid_check["active"]: + status = "Running" + elif pid_check["exists"]: + status = "Disconnected" else: - status = 'Stopped' + status = "Stopped" table.add_row([endpoint_name, status, endpoint_id]) diff --git a/funcx_endpoint/funcx_endpoint/endpoint/interchange.py b/funcx_endpoint/funcx_endpoint/endpoint/interchange.py index dbb24cb51..7b0e79562 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/interchange.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/interchange.py @@ -1,38 +1,41 @@ #!/usr/bin/env python import argparse -from typing import Tuple, Dict - -import zmq +import logging import os -import sys -import platform -import random -import time import pickle -import logging +import platform import queue -import threading -import json -import daemon -import collections -from retry.api import retry_call import signal -from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue +import sys +import threading +import time +from queue import Queue +from typing import Tuple +import zmq from parsl.executors.errors import ScalingFailed from parsl.version import VERSION as PARSL_VERSION +from retry.api import retry_call -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task -from funcx_endpoint.executors.high_throughput.messages import EPStatusReport, Heartbeat, TaskStatusCode, ResultsAck -from funcx.sdk.client import FuncXClient -from funcx import set_file_logger from funcx import __version__ as funcx_sdk_version -from funcx_endpoint import __version__ as funcx_endpoint_version -from funcx_endpoint.executors.high_throughput.interchange_task_dispatch import naive_interchange_task_dispatch +from funcx.sdk.client import FuncXClient from funcx.serialize import FuncXSerializer -from funcx_endpoint.endpoint.taskqueue import TaskQueue +from funcx_endpoint import __version__ as funcx_endpoint_version from funcx_endpoint.endpoint.register_endpoint import register_endpoint -from queue import Queue +from funcx_endpoint.endpoint.taskqueue import TaskQueue +from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue +from funcx_endpoint.executors.high_throughput.messages import ( + COMMAND_TYPES, + Heartbeat, + Message, + MessageType, + ResultsAck, + Task, + TaskStatusCode, +) +from funcx_endpoint.logging_config import setup_logging + +log = logging.getLogger(__name__) LOOP_SLOWDOWN = 0.0 # in seconds HEARTBEAT_CODE = (2 ** 32) - 1 @@ -40,18 +43,17 @@ class ShutdownRequest(Exception): - """ Exception raised when any async component receives a ShutdownRequest - """ + """Exception raised when any async component receives a ShutdownRequest""" def __init__(self): self.tstamp = time.time() def __repr__(self): - return "Shutdown request received at {}".format(self.tstamp) + return f"Shutdown request received at {self.tstamp}" class ManagerLost(Exception): - """ Task lost due to worker loss. Worker is considered lost when multiple heartbeats + """Task lost due to worker loss. Worker is considered lost when multiple heartbeats have been missed. """ @@ -60,25 +62,11 @@ def __init__(self, worker_id): self.tstamp = time.time() def __repr__(self): - return "Task failure due to loss of worker {}".format(self.worker_id) + return f"Task failure due to loss of worker {self.worker_id}" -class BadRegistration(Exception): - ''' A new Manager tried to join the executor with a BadRegistration message - ''' - - def __init__(self, worker_id, critical=False): - self.worker_id = worker_id - self.tstamp = time.time() - self.handled = "critical" if critical else "suppressed" - - def __repr__(self): - return "Manager:{} caused a {} failure".format(self.worker_id, - self.handled) - - -class EndpointInterchange(object): - """ Interchange is a task orchestrator for distributed systems. +class EndpointInterchange: + """Interchange is a task orchestrator for distributed systems. 1. Asynchronously queue large volume of tasks (>100K) 2. Allow for workers to join and leave the union @@ -90,23 +78,23 @@ class EndpointInterchange(object): TODO: We most likely need a PUB channel to send out global commandzs, like shutdown """ - def __init__(self, - config, - client_address="127.0.0.1", - interchange_address="127.0.0.1", - client_ports: Tuple[int, int, int] = (50055, 50056, 50057), - launch_cmd=None, - logdir=".", - logging_level=logging.INFO, - endpoint_id=None, - keys_dir=".curve", - suppress_failure=True, - endpoint_dir=".", - endpoint_name="default", - reg_info=None, - funcx_client_options=None, - results_ack_handler=None, - ): + def __init__( + self, + config, + client_address="127.0.0.1", + interchange_address="127.0.0.1", + client_ports: Tuple[int, int, int] = (50055, 50056, 50057), + launch_cmd=None, + logdir=".", + endpoint_id=None, + keys_dir=".curve", + suppress_failure=True, + endpoint_dir=".", + endpoint_name="default", + reg_info=None, + funcx_client_options=None, + results_ack_handler=None, + ): """ Parameters ---------- @@ -114,10 +102,12 @@ def __init__(self, Funcx config object that describes how compute should be provisioned client_address : str - The ip address at which the parsl client can be reached. Default: "127.0.0.1" + The ip address at which the parsl client can be reached. + Default: "127.0.0.1" interchange_address : str - The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" + The ip address at which the workers will be able to reach the Interchange. + Default: "127.0.0.1" client_ports : Tuple[int, int, int] The ports at which the client can be reached @@ -128,18 +118,16 @@ def __init__(self, logdir : str Parsl log directory paths. Logs and temp files go here. Default: '.' - logging_level : int - Logging level as defined in the logging module. Default: logging.INFO (20) - keys_dir : str - Directory from where keys used for communicating with the funcX service (forwarders) - are stored + Directory from where keys used for communicating with the funcX + service (forwarders) are stored endpoint_id : str Identity string that identifies the endpoint to the broker suppress_failure : Bool - When set to True, the interchange will attempt to suppress failures. Default: False + When set to True, the interchange will attempt to suppress failures. + Default: False endpoint_dir : str Endpoint directory path to store registration info in @@ -148,23 +136,20 @@ def __init__(self, Name of endpoint reg_info : Dict - Registration info from initial registration on endpoint start, if it succeeded + Registration info from initial registration on endpoint start, if it + succeeded funcx_client_options : Dict FuncXClient initialization options """ self.logdir = logdir - try: - os.makedirs(self.logdir) - except FileExistsError: - pass - - global logger - - logger = set_file_logger(os.path.join(self.logdir, "endpoint.log"), name="funcx_endpoint", level=logging_level) - logger.info("Initializing EndpointInterchange process with Endpoint ID: {}".format(endpoint_id)) + log.info( + "Initializing EndpointInterchange process with Endpoint ID: {}".format( + endpoint_id + ) + ) self.config = config - logger.info("Got config : {}".format(config)) + log.info(f"Got config: {config}") self.client_address = client_address self.interchange_address = interchange_address @@ -199,28 +184,30 @@ def __init__(self, self.results_ack_handler = results_ack_handler - logger.info("Interchange address is {}".format(self.interchange_address)) + log.info(f"Interchange address is {self.interchange_address}") self.endpoint_id = endpoint_id - self.current_platform = {'parsl_v': PARSL_VERSION, - 'python_v': "{}.{}.{}".format(sys.version_info.major, - sys.version_info.minor, - sys.version_info.micro), - 'libzmq_v': zmq.zmq_version(), - 'pyzmq_v': zmq.pyzmq_version(), - 'os': platform.system(), - 'hname': platform.node(), - 'funcx_sdk_version': funcx_sdk_version, - 'funcx_endpoint_version': funcx_endpoint_version, - 'registration': self.endpoint_id, - 'dir': os.getcwd()} - - logger.info("Platform info: {}".format(self.current_platform)) + self.current_platform = { + "parsl_v": PARSL_VERSION, + "python_v": "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro + ), + "libzmq_v": zmq.zmq_version(), + "pyzmq_v": zmq.pyzmq_version(), + "os": platform.system(), + "hname": platform.node(), + "funcx_sdk_version": funcx_sdk_version, + "funcx_endpoint_version": funcx_endpoint_version, + "registration": self.endpoint_id, + "dir": os.getcwd(), + } + + log.info(f"Platform info: {self.current_platform}") try: self.load_config() except Exception: - logger.exception("Caught exception") + log.exception("Caught exception") raise self.tasks = set() @@ -229,36 +216,41 @@ def __init__(self, self._test_start = False def load_config(self): - """ Load the config - """ - logger.info("Loading endpoint local config") + """Load the config""" + log.info("Loading endpoint local config") self.results_passthrough = mpQueue() self.executors = {} for executor in self.config.executors: - logger.info(f"Initializing executor: {executor.label}") + log.info(f"Initializing executor: {executor.label}") executor.funcx_service_address = self.config.funcx_service_address if not executor.endpoint_id: executor.endpoint_id = self.endpoint_id else: if not executor.endpoint_id == self.endpoint_id: - raise Exception('InconsistentEndpointId') + raise Exception("InconsistentEndpointId") self.executors[executor.label] = executor if executor.run_dir is None: executor.run_dir = self.logdir def start_executors(self): - logger.info("Starting Executors") + log.info("Starting Executors") for executor in self.config.executors: - if hasattr(executor, 'passthrough') and executor.passthrough is True: + if hasattr(executor, "passthrough") and executor.passthrough is True: executor.start(results_passthrough=self.results_passthrough) def apply_reg_info(self, reg_info): - self.client_address = reg_info['public_ip'] - self.client_ports = reg_info['tasks_port'], reg_info['results_port'], reg_info['commands_port'], + self.client_address = reg_info["public_ip"] + self.client_ports = ( + reg_info["tasks_port"], + reg_info["results_port"], + reg_info["commands_port"], + ) def register_endpoint(self): - reg_info = register_endpoint(self.funcx_client, self.endpoint_id, self.endpoint_dir, self.endpoint_name) + reg_info = register_endpoint( + self.funcx_client, self.endpoint_id, self.endpoint_dir, self.endpoint_name + ) self.apply_reg_info(reg_info) return reg_info @@ -271,32 +263,34 @@ def migrate_tasks_to_internal(self, quiesce_event): quiesce_event : threading.Event Event to let the thread know when it is time to die. """ - logger.info("[TASK_PULL_THREAD] Starting") + log.info("[TASK_PULL_THREAD] Starting") try: self._task_puller_loop(quiesce_event) except Exception: - logger.exception("[TASK_PULL_THREAD] Unhandled exception") + log.exception("[TASK_PULL_THREAD] Unhandled exception") finally: quiesce_event.set() self.task_incoming.close() - logger.info("[TASK_PULL_THREAD] Thread loop exiting") + log.info("[TASK_PULL_THREAD] Thread loop exiting") def _task_puller_loop(self, quiesce_event): task_counter = 0 # Create the incoming queue in the thread to keep # zmq.context in the same thread. zmq.context is not thread-safe - self.task_incoming = TaskQueue(self.client_address, - port=self.client_ports[0], - identity=self.endpoint_id, - mode='client', - set_hwm=True, - keys_dir=self.keys_dir, - RCVTIMEO=1000, - linger=0) - - self.task_incoming.put('forwarder', pickle.dumps(self.current_platform)) - logger.info(f"Task incoming on tcp://{self.client_address}:{self.client_ports[0]}") + self.task_incoming = TaskQueue( + self.client_address, + port=self.client_ports[0], + identity=self.endpoint_id, + mode="client", + set_hwm=True, + keys_dir=self.keys_dir, + RCVTIMEO=1000, + linger=0, + ) + + self.task_incoming.put("forwarder", pickle.dumps(self.current_platform)) + log.info(f"Task incoming on tcp://{self.client_address}:{self.client_ports[0]}") self.last_heartbeat = time.time() @@ -304,7 +298,10 @@ def _task_puller_loop(self, quiesce_event): try: if int(time.time() - self.last_heartbeat) > self.heartbeat_threshold: - logger.critical("[TASK_PULL_THREAD] Missed too many heartbeats. Setting quiesce event.") + log.critical( + "[TASK_PULL_THREAD] Missed too many heartbeats. " + "Setting quiesce event." + ) quiesce_event.set() break @@ -314,160 +311,184 @@ def _task_puller_loop(self, quiesce_event): self.last_heartbeat = time.time() except zmq.Again: # We just timed out while attempting to receive - logger.debug("[TASK_PULL_THREAD] {} tasks in internal queue".format(self.total_pending_task_count)) + log.debug( + "[TASK_PULL_THREAD] {} tasks in internal queue".format( + self.total_pending_task_count + ) + ) continue except Exception: - logger.exception("[TASK_PULL_THREAD] Unknown exception while waiting for tasks") + log.exception( + "[TASK_PULL_THREAD] Unknown exception while waiting for tasks" + ) # YADU: TODO We need to do the routing here try: msg = Message.unpack(raw_msg) except Exception: - logger.exception("[TASK_PULL_THREAD] Failed to unpack message from forwarder") + log.exception( + "[TASK_PULL_THREAD] Failed to unpack message from forwarder" + ) pass - if msg == 'STOP': + if msg == "STOP": self._kill_event.set() quiesce_event.set() break elif isinstance(msg, Heartbeat): - logger.info("[TASK_PULL_THREAD] Got heartbeat from funcx-forwarder") + log.info("[TASK_PULL_THREAD] Got heartbeat from funcx-forwarder") elif isinstance(msg, Task): - logger.info(f"[TASK_PULL_THREAD] Received task:{msg.task_id}") + log.info(f"[TASK_PULL_THREAD] Received task:{msg.task_id}") self.pending_task_queue.put(msg) self.total_pending_task_count += 1 - self.task_status_deltas[msg.task_id] = TaskStatusCode.WAITING_FOR_NODES + self.task_status_deltas[ + msg.task_id + ] = TaskStatusCode.WAITING_FOR_NODES task_counter += 1 - logger.debug(f"[TASK_PULL_THREAD] Task counter:{task_counter} Pending Tasks: {self.total_pending_task_count}") + log.debug( + "[TASK_PULL_THREAD] Task counter:%s Pending Tasks: %s", + task_counter, + self.total_pending_task_count, + ) elif isinstance(msg, ResultsAck): self.results_ack_handler.ack(msg.task_id) else: - logger.warning(f"[TASK_PULL_THREAD] Unknown message type received: {msg}") + log.warning( + f"[TASK_PULL_THREAD] Unknown message type received: {msg}" + ) except Exception: - logger.exception("[TASK_PULL_THREAD] Something really bad happened") + log.exception("[TASK_PULL_THREAD] Something really bad happened") continue def get_container(self, container_uuid): - """ Get the container image location if it is not known to the interchange""" + """Get the container image location if it is not known to the interchange""" if container_uuid not in self.containers: - if container_uuid == 'RAW' or not container_uuid: - self.containers[container_uuid] = 'RAW' + if container_uuid == "RAW" or not container_uuid: + self.containers[container_uuid] = "RAW" else: try: - container = self.funcx_client.get_container(container_uuid, self.config.container_type) + container = self.funcx_client.get_container( + container_uuid, self.config.container_type + ) except Exception: - logger.exception("[FETCH_CONTAINER] Unable to resolve container location") - self.containers[container_uuid] = 'RAW' + log.exception( + "[FETCH_CONTAINER] Unable to resolve container location" + ) + self.containers[container_uuid] = "RAW" else: - logger.info("[FETCH_CONTAINER] Got container info: {}".format(container)) - self.containers[container_uuid] = container.get('location', 'RAW') + log.info(f"[FETCH_CONTAINER] Got container info: {container}") + self.containers[container_uuid] = container.get("location", "RAW") return self.containers[container_uuid] def _command_server(self, quiesce_event): - """ Command server to run async command to the interchange + """Command server to run async command to the interchange - We want to be able to receive the following not yet implemented/updated commands: + We want to be able to receive the following not yet implemented/updated + commands: - OutstandingCount - ListManagers (get outstanding broken down by manager) - HoldWorker - Shutdown """ - logger.debug("[COMMAND] Command Server Starting") + log.debug("[COMMAND] Command Server Starting") try: self._command_server_loop(quiesce_event) except Exception: - logger.exception("[COMMAND] Unhandled exception") + log.exception("[COMMAND] Unhandled exception") finally: quiesce_event.set() self.command_channel.close() - logger.info("[COMMAND] Thread loop exiting") + log.info("[COMMAND] Thread loop exiting") def _command_server_loop(self, quiesce_event): - self.command_channel = TaskQueue(self.client_address, - port=self.client_ports[2], - identity=self.endpoint_id, - mode='client', - RCVTIMEO=1000, # in milliseconds - keys_dir=self.keys_dir, - set_hwm=True, - linger=0) + self.command_channel = TaskQueue( + self.client_address, + port=self.client_ports[2], + identity=self.endpoint_id, + mode="client", + RCVTIMEO=1000, # in milliseconds + keys_dir=self.keys_dir, + set_hwm=True, + linger=0, + ) # TODO :Register all channels with the authentication string. - self.command_channel.put('forwarder', pickle.dumps({"registration": self.endpoint_id})) + self.command_channel.put( + "forwarder", pickle.dumps({"registration": self.endpoint_id}) + ) while not quiesce_event.is_set(): try: # Wait for 1000 ms buffer = self.command_channel.get(timeout=1000) - logger.debug(f"[COMMAND] Received command request {buffer}") + log.debug(f"[COMMAND] Received command request {buffer}") command = Message.unpack(buffer) if command.type not in COMMAND_TYPES: - logger.error("Received incorrect message type on command channel") + log.error("Received incorrect message type on command channel") self.command_channel.put(bytes()) continue if command.type is MessageType.HEARTBEAT_REQ: - logger.info("[COMMAND] Received synchonous HEARTBEAT_REQ from hub") - logger.info(f"[COMMAND] Replying with Heartbeat({self.endpoint_id})") + log.info("[COMMAND] Received synchonous HEARTBEAT_REQ from hub") + log.info(f"[COMMAND] Replying with Heartbeat({self.endpoint_id})") reply = Heartbeat(self.endpoint_id) - logger.debug("[COMMAND] Reply: {}".format(reply)) + log.debug(f"[COMMAND] Reply: {reply}") self.command_channel.put(reply.pack()) except zmq.Again: - # logger.debug("[COMMAND] is alive") + # log.debug("[COMMAND] is alive") continue def quiesce(self): """Temporarily stop everything on the interchange in order to reach a consistent state before attempting to start again. This must be called on the main thread """ - logger.info("Interchange Quiesce in progress (stopping and joining all threads)") + log.info("Interchange Quiesce in progress (stopping and joining all threads)") self._quiesce_event.set() self._task_puller_thread.join() self._command_thread.join() - logger.info("Saving unacked results to disk") + log.info("Saving unacked results to disk") try: self.results_ack_handler.persist() except Exception: - logger.exception("Caught exception while saving unacked results") - logger.warning("Interchange will continue without saving unacked results") + log.exception("Caught exception while saving unacked results") + log.warning("Interchange will continue without saving unacked results") # this must be called last to ensure the next interchange run will occur self._quiesce_event.clear() def stop(self): """Prepare the interchange for shutdown""" - logger.info("Shutting down EndpointInterchange") + log.info("Shutting down EndpointInterchange") # TODO: shut down executors gracefully - # kill_event must be set before quiesce_event because we need to guarantee that once - # the quiesce is complete, the interchange will not try to start again + # kill_event must be set before quiesce_event because we need to guarantee that + # once the quiesce is complete, the interchange will not try to start again self._kill_event.set() self._quiesce_event.set() def handle_sigterm(self, sig_num, curr_stack_frame): - logger.warning("Received SIGTERM, attempting to save unacked results to disk") + log.warning("Received SIGTERM, attempting to save unacked results to disk") try: self.results_ack_handler.persist() except Exception: - logger.exception("Caught exception while saving unacked results") + log.exception("Caught exception while saving unacked results") else: - logger.info("Unacked results successfully saved to disk") + log.info("Unacked results successfully saved to disk") sys.exit(1) def start(self): - """ Start the Interchange - """ - logger.info("Starting EndpointInterchange") + """Start the Interchange""" + log.info("Starting EndpointInterchange") signal.signal(signal.SIGTERM, self.handle_sigterm) @@ -486,72 +507,94 @@ def start(self): if self._test_start: break - logger.info("EndpointInterchange shutdown complete.") + log.info("EndpointInterchange shutdown complete.") def _start_threads_and_main(self): # re-register on every loop start if not self.initial_registration_complete: # Register the endpoint - logger.info("Running endpoint registration retry loop") - reg_info = retry_call(self.register_endpoint, delay=10, max_delay=300, backoff=1.2) - logger.info("Endpoint registered with UUID: {}".format(reg_info['endpoint_id'])) + log.info("Running endpoint registration retry loop") + reg_info = retry_call( + self.register_endpoint, delay=10, max_delay=300, backoff=1.2 + ) + log.info( + "Endpoint registered with UUID: {}".format(reg_info["endpoint_id"]) + ) self.initial_registration_complete = False - logger.info("Attempting connection to client at {} on ports: {},{},{}".format( - self.client_address, self.client_ports[0], self.client_ports[1], self.client_ports[2])) + log.info( + "Attempting connection to client at {} on ports: {},{},{}".format( + self.client_address, + self.client_ports[0], + self.client_ports[1], + self.client_ports[2], + ) + ) - self._task_puller_thread = threading.Thread(target=self.migrate_tasks_to_internal, - args=(self._quiesce_event, )) + self._task_puller_thread = threading.Thread( + target=self.migrate_tasks_to_internal, args=(self._quiesce_event,) + ) self._task_puller_thread.start() - self._command_thread = threading.Thread(target=self._command_server, - args=(self._quiesce_event, )) + self._command_thread = threading.Thread( + target=self._command_server, args=(self._quiesce_event,) + ) self._command_thread.start() try: self._main_loop() except Exception: - logger.exception("[MAIN] Unhandled exception") + log.exception("[MAIN] Unhandled exception") finally: self.results_outgoing.close() - logger.info("[MAIN] Thread loop exiting") + log.info("[MAIN] Thread loop exiting") def _main_loop(self): - self.results_outgoing = TaskQueue(self.client_address, - port=self.client_ports[1], - identity=self.endpoint_id, - mode='client', - keys_dir=self.keys_dir, - # Fail immediately if results cannot be sent back - SNDTIMEO=0, - set_hwm=True, - linger=0) - self.results_outgoing.put('forwarder', pickle.dumps({"registration": self.endpoint_id})) + self.results_outgoing = TaskQueue( + self.client_address, + port=self.client_ports[1], + identity=self.endpoint_id, + mode="client", + keys_dir=self.keys_dir, + # Fail immediately if results cannot be sent back + SNDTIMEO=0, + set_hwm=True, + linger=0, + ) + self.results_outgoing.put( + "forwarder", pickle.dumps({"registration": self.endpoint_id}) + ) # TODO: this resend must happen after any endpoint re-registration to # ensure there are not unacked results left resend_results_messages = self.results_ack_handler.get_unacked_results_list() if len(resend_results_messages) > 0: - logger.info(f"[MAIN] Resending {len(resend_results_messages)} previously unacked results") + log.info( + "[MAIN] Resending %s previously unacked results", + len(resend_results_messages), + ) # TODO: this should be a multipart send rather than a loop for results in resend_results_messages: - self.results_outgoing.put('forwarder', results) + self.results_outgoing.put("forwarder", results) executor = list(self.executors.values())[0] last = time.time() while not self._quiesce_event.is_set(): if last + self.heartbeat_threshold < time.time(): - logger.debug("[MAIN] alive") + log.debug("[MAIN] alive") last = time.time() try: # Adding results heartbeat to essentially force a TCP keepalive # without meddling with OS TCP keepalive defaults - self.results_outgoing.put('forwarder', b'HEARTBEAT') + self.results_outgoing.put("forwarder", b"HEARTBEAT") except Exception: - logger.exception("[MAIN] Sending heartbeat to the forwarder over the results channel has failed") + log.exception( + "[MAIN] Sending heartbeat to the forwarder over the results " + "channel has failed" + ) raise self.results_ack_handler.check_ack_counts() @@ -562,8 +605,7 @@ def _main_loop(self): except queue.Empty: pass except Exception: - logger.exception("[MAIN] Unhandled issue while waiting for pending tasks") - pass + log.exception("[MAIN] Unhandled issue while waiting for pending tasks") try: results = self.results_passthrough.get(False, 0.01) @@ -571,21 +613,24 @@ def _main_loop(self): task_id = results["task_id"] if task_id: self.results_ack_handler.put(task_id, results["message"]) - logger.info(f"Passing result to forwarder for task {task_id}") + log.info(f"Passing result to forwarder for task {task_id}") - # results will be a pickled dict with task_id, container_id, and results/exception - self.results_outgoing.put('forwarder', results["message"]) + # results will be a pickled dict with task_id, container_id, + # and results/exception + self.results_outgoing.put("forwarder", results["message"]) except queue.Empty: pass except Exception: - logger.exception("[MAIN] Something broke while forwarding results from executor to forwarder queues") + log.exception( + "[MAIN] Something broke while forwarding results from executor " + "to forwarder queues" + ) continue def get_status_report(self): - """ Get utilization numbers - """ + """Get utilization numbers""" total_cores = 0 total_mem = 0 core_hrs = 0 @@ -597,36 +642,43 @@ def get_status_report(self): live_workers = self.get_total_live_workers() for manager in self._ready_manager_queue: - total_cores += self._ready_manager_queue[manager]['cores'] - total_mem += self._ready_manager_queue[manager]['mem'] - active_dur = abs(time.time() - self._ready_manager_queue[manager]['reg_time']) + total_cores += self._ready_manager_queue[manager]["cores"] + total_mem += self._ready_manager_queue[manager]["mem"] + active_dur = abs( + time.time() - self._ready_manager_queue[manager]["reg_time"] + ) core_hrs += (active_dur * total_cores) / 3600 - if self._ready_manager_queue[manager]['active']: + if self._ready_manager_queue[manager]["active"]: active_managers += 1 - free_capacity += self._ready_manager_queue[manager]['free_capacity']['total_workers'] - - result_package = {'task_id': -2, - 'info': {'total_cores': total_cores, - 'total_mem': total_mem, - 'new_core_hrs': core_hrs - self.last_core_hr_counter, - 'total_core_hrs': round(core_hrs, 2), - 'managers': num_managers, - 'active_managers': active_managers, - 'total_workers': live_workers, - 'idle_workers': free_capacity, - 'pending_tasks': pending_tasks, - 'outstanding_tasks': outstanding_tasks, - 'worker_mode': self.config.worker_mode, - 'scheduler_mode': self.config.scheduler_mode, - 'scaling_enabled': self.config.scaling_enabled, - 'mem_per_worker': self.config.mem_per_worker, - 'cores_per_worker': self.config.cores_per_worker, - 'prefetch_capacity': self.config.prefetch_capacity, - 'max_blocks': self.config.provider.max_blocks, - 'min_blocks': self.config.provider.min_blocks, - 'max_workers_per_node': self.config.max_workers_per_node, - 'nodes_per_block': self.config.provider.nodes_per_block - }} + free_capacity += self._ready_manager_queue[manager]["free_capacity"][ + "total_workers" + ] + + result_package = { + "task_id": -2, + "info": { + "total_cores": total_cores, + "total_mem": total_mem, + "new_core_hrs": core_hrs - self.last_core_hr_counter, + "total_core_hrs": round(core_hrs, 2), + "managers": num_managers, + "active_managers": active_managers, + "total_workers": live_workers, + "idle_workers": free_capacity, + "pending_tasks": pending_tasks, + "outstanding_tasks": outstanding_tasks, + "worker_mode": self.config.worker_mode, + "scheduler_mode": self.config.scheduler_mode, + "scaling_enabled": self.config.scaling_enabled, + "mem_per_worker": self.config.mem_per_worker, + "cores_per_worker": self.config.cores_per_worker, + "prefetch_capacity": self.config.prefetch_capacity, + "max_blocks": self.config.provider.max_blocks, + "min_blocks": self.config.provider.min_blocks, + "max_workers_per_node": self.config.max_workers_per_node, + "nodes_per_block": self.config.provider.nodes_per_block, + }, + } self.last_core_hr_counter = core_hrs return result_package @@ -642,22 +694,32 @@ def scale_out(self, blocks=1, task_type=None): if self.config.provider: self._block_counter += 1 external_block_id = str(self._block_counter) - if not task_type and self.config.scheduler_mode == 'hard': - launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type='RAW') + if not task_type and self.config.scheduler_mode == "hard": + launch_cmd = self.launch_cmd.format( + block_id=external_block_id, worker_type="RAW" + ) else: - launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type=task_type) + launch_cmd = self.launch_cmd.format( + block_id=external_block_id, worker_type=task_type + ) if not task_type: internal_block = self.config.provider.submit(launch_cmd, 1) else: - internal_block = self.config.provider.submit(launch_cmd, 1, task_type) - logger.debug("Launched block {}->{}".format(external_block_id, internal_block)) + internal_block = self.config.provider.submit( + launch_cmd, 1, task_type + ) + log.debug(f"Launched block {external_block_id}->{internal_block}") if not internal_block: - raise(ScalingFailed(self.provider.label, - "Attempts to provision nodes via provider has failed")) + raise ( + ScalingFailed( + self.provider.label, + "Attempts to provision nodes via provider has failed", + ) + ) self.blocks[external_block_id] = internal_block self.block_id_map[internal_block] = external_block_id else: - logger.error("No execution provider available") + log.error("No execution provider available") r = None return r @@ -675,14 +737,22 @@ def scale_in(self, blocks=None, block_ids=None, task_type=None): if block_ids is None: block_ids = [] if task_type: - logger.info("Scaling in blocks of specific task type {}. Let the provider decide which to kill".format(task_type)) + log.info( + "Scaling in blocks of specific task type %s. " + "Let the provider decide which to kill", + task_type, + ) if self.config.scaling_enabled and self.config.provider: to_kill, r = self.config.provider.cancel(blocks, task_type) - logger.info("Get the killed blocks: {}, and status: {}".format(to_kill, r)) + log.info(f"Get the killed blocks: {to_kill}, and status: {r}") for job in to_kill: - logger.info("[scale_in] Getting the block_id map {} for job {}".format(self.block_id_map, job)) + log.info( + "[scale_in] Getting the block_id map {} for job {}".format( + self.block_id_map, job + ) + ) block_id = self.block_id_map[job] - logger.info("[scale_in] Holding block {}".format(block_id)) + log.info(f"[scale_in] Holding block {block_id}") self._hold_block(block_id) self.blocks.pop(block_id) return r @@ -706,13 +776,16 @@ def scale_in(self, blocks=None, block_ids=None, task_type=None): return r def provider_status(self): - """ Get status of all blocks from the provider - """ + """Get status of all blocks from the provider""" status = [] if self.config.provider: - logger.debug("[MAIN] Getting the status of {} blocks.".format(list(self.blocks.values()))) + log.debug( + "[MAIN] Getting the status of {} blocks.".format( + list(self.blocks.values()) + ) + ) status = self.config.provider.status(list(self.blocks.values())) - logger.debug("[MAIN] The status is {}".format(status)) + log.debug(f"[MAIN] The status is {status}") return status @@ -720,7 +793,8 @@ def provider_status(self): def starter(comm_q, *args, **kwargs): """Start the interchange process - The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__ + The executor is expected to call this function. + The args, kwargs match that of the Interchange.__init__ """ # logger = multiprocessing.get_logger() ic = EndpointInterchange(*args, **kwargs) @@ -732,42 +806,65 @@ def starter(comm_q, *args, **kwargs): def cli_run(): parser = argparse.ArgumentParser() - parser.add_argument("-c", "--client_address", required=True, - help="Client address") - parser.add_argument("--client_ports", required=True, - help="client ports as a triple of outgoing,incoming,command") - parser.add_argument("--worker_port_range", - help="Worker port range as a tuple") - parser.add_argument("-l", "--logdir", default="./parsl_worker_logs", - help="Parsl worker log directory") - parser.add_argument("--worker_ports", default=None, - help="OPTIONAL, pair of workers ports to listen on, eg --worker_ports=50001,50005") - parser.add_argument("--suppress_failure", action='store_true', - help="Enables suppression of failures") - parser.add_argument("--endpoint_id", required=True, - help="Endpoint ID, used to identify the endpoint to the remote broker") - parser.add_argument("--hb_threshold", - help="Heartbeat threshold in seconds") - parser.add_argument("--config", default=None, - help="Configuration object that describes provisioning") - parser.add_argument("-d", "--debug", action='store_true', - help="Enables debug logging") + parser.add_argument("-c", "--client_address", required=True, help="Client address") + parser.add_argument( + "--client_ports", + required=True, + help="client ports as a triple of outgoing,incoming,command", + ) + parser.add_argument("--worker_port_range", help="Worker port range as a tuple") + parser.add_argument( + "-l", + "--logdir", + default="./parsl_worker_logs", + help="Parsl worker log directory", + ) + parser.add_argument( + "--worker_ports", + default=None, + help="OPTIONAL, pair of workers ports to listen on, " + "e.g. --worker_ports=50001,50005", + ) + parser.add_argument( + "--suppress_failure", + action="store_true", + help="Enables suppression of failures", + ) + parser.add_argument( + "--endpoint_id", + required=True, + help="Endpoint ID, used to identify the endpoint to the remote broker", + ) + parser.add_argument("--hb_threshold", help="Heartbeat threshold in seconds") + parser.add_argument( + "--config", + default=None, + help="Configuration object that describes provisioning", + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enables debug logging" + ) print("Starting HTEX Intechange") args = parser.parse_args() + logdir = os.path.abspath(args.logdir) + os.makedirs(logdir, exist_ok=True) + setup_logging(logfile=os.path.join(logdir, "endpoint.log"), debug=args.debug) + optionals = {} - optionals['suppress_failure'] = args.suppress_failure - optionals['logdir'] = os.path.abspath(args.logdir) - optionals['client_address'] = args.client_address - optionals['client_ports'] = [int(i) for i in args.client_ports.split(',')] - optionals['endpoint_id'] = args.endpoint_id + optionals["suppress_failure"] = args.suppress_failure + optionals["logdir"] = os.path.abspath(args.logdir) + optionals["client_address"] = args.client_address + optionals["client_ports"] = [int(i) for i in args.client_ports.split(",")] + optionals["endpoint_id"] = args.endpoint_id # DEBUG ONLY : TODO: FIX if args.config is None: - from funcx_endpoint.endpoint.utils.config import Config from parsl.providers import LocalProvider + from funcx_endpoint.endpoint.utils.config import Config + config = Config( worker_debug=True, scaling_enabled=True, @@ -777,18 +874,18 @@ def cli_run(): max_blocks=1, ), max_workers_per_node=2, - funcx_service_address='http://127.0.0.1:8080' + funcx_service_address="http://127.0.0.1:8080", ) - optionals['config'] = config + optionals["config"] = config else: - optionals['config'] = args.config + optionals["config"] = args.config - if args.debug: - optionals['logging_level'] = logging.DEBUG if args.worker_ports: - optionals['worker_ports'] = [int(i) for i in args.worker_ports.split(',')] + optionals["worker_ports"] = [int(i) for i in args.worker_ports.split(",")] if args.worker_port_range: - optionals['worker_port_range'] = [int(i) for i in args.worker_port_range.split(',')] + optionals["worker_port_range"] = [ + int(i) for i in args.worker_port_range.split(",") + ] ic = EndpointInterchange(**optionals) ic.start() diff --git a/funcx_endpoint/funcx_endpoint/endpoint/register_endpoint.py b/funcx_endpoint/funcx_endpoint/endpoint/register_endpoint.py index 1b36f8d34..add62ee7b 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/register_endpoint.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/register_endpoint.py @@ -1,17 +1,16 @@ -import os import json import logging +import os import funcx_endpoint -namespace_logger = logging.getLogger(__name__) +log = logging.getLogger(__name__) -def register_endpoint(funcx_client, endpoint_uuid, endpoint_dir, endpoint_name, logger=None): +def register_endpoint(funcx_client, endpoint_uuid, endpoint_dir, endpoint_name): """Register the endpoint and return the registration info. This function needs - to be isolated (including the logger which is passed in) so that the function - can both be called from the endpoint start process as well as the daemon process - that it spawns. + to be isolated so that the function can both be called from the endpoint start + process as well as the daemon process that it spawns. Parameters ---------- @@ -30,38 +29,41 @@ def register_endpoint(funcx_client, endpoint_uuid, endpoint_dir, endpoint_name, logger : Logger Logger to use """ - if logger is None: - logger = namespace_logger - - logger.debug("Attempting registration") - logger.debug(f"Trying with eid : {endpoint_uuid}") - reg_info = funcx_client.register_endpoint(endpoint_name, - endpoint_uuid, - endpoint_version=funcx_endpoint.__version__) + log.debug("Attempting registration") + log.debug(f"Trying with eid : {endpoint_uuid}") + reg_info = funcx_client.register_endpoint( + endpoint_name, endpoint_uuid, endpoint_version=funcx_endpoint.__version__ + ) # this is a backup error handler in case an endpoint ID is not sent back # from the service or a bad ID is sent back - if 'endpoint_id' not in reg_info: - raise Exception("Endpoint ID was not included in the service's registration response.") - elif not isinstance(reg_info['endpoint_id'], str): + if "endpoint_id" not in reg_info: + raise Exception( + "Endpoint ID was not included in the service's registration response." + ) + elif not isinstance(reg_info["endpoint_id"], str): raise Exception("Endpoint ID sent by the service was not a string.") # NOTE: While all registration info is saved to endpoint.json, only the # endpoint UUID is reused from this file. The latest forwarder URI is used # every time we fetch registration info and register - with open(os.path.join(endpoint_dir, 'endpoint.json'), 'w+') as fp: + with open(os.path.join(endpoint_dir, "endpoint.json"), "w+") as fp: json.dump(reg_info, fp) - logger.debug("Registration info written to {}".format(os.path.join(endpoint_dir, 'endpoint.json'))) + log.debug( + "Registration info written to {}".format( + os.path.join(endpoint_dir, "endpoint.json") + ) + ) - certs_dir = os.path.join(endpoint_dir, 'certificates') + certs_dir = os.path.join(endpoint_dir, "certificates") os.makedirs(certs_dir, exist_ok=True) - server_keyfile = os.path.join(certs_dir, 'server.key') - logger.debug(f"Writing server key to {server_keyfile}") + server_keyfile = os.path.join(certs_dir, "server.key") + log.debug(f"Writing server key to {server_keyfile}") try: - with open(server_keyfile, 'w') as f: - f.write(reg_info['forwarder_pubkey']) + with open(server_keyfile, "w") as f: + f.write(reg_info["forwarder_pubkey"]) os.chmod(server_keyfile, 0o600) except Exception: - logger.exception("Failed to write server certificate") + log.exception("Failed to write server certificate") return reg_info diff --git a/funcx_endpoint/funcx_endpoint/endpoint/results_ack.py b/funcx_endpoint/funcx_endpoint/endpoint/results_ack.py index e222445e0..c066fcbb8 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/results_ack.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/results_ack.py @@ -1,23 +1,20 @@ import logging -import time import os import pickle +import time -# The logger path needs to start with endpoint. while the current path -# start with funcx_endpoint.endpoint. -logger = logging.getLogger("endpoint.results_ack") +log = logging.getLogger(__name__) -class ResultsAckHandler(): +class ResultsAckHandler: """ Tracks task results by task ID, discarding results after they have been ack'ed """ def __init__(self, endpoint_dir): - """ Initialize results storage and timing for log updates - """ + """Initialize results storage and timing for log updates""" self.endpoint_dir = endpoint_dir - self.data_path = os.path.join(self.endpoint_dir, 'unacked_results.p') + self.data_path = os.path.join(self.endpoint_dir, "unacked_results.p") self.unacked_results = {} # how frequently to log info about acked and unacked results @@ -26,7 +23,7 @@ def __init__(self, endpoint_dir): self.acked_count = 0 def put(self, task_id, message): - """ Put sent task result into Unacked Dict + """Put sent task result into Unacked Dict Parameters ---------- @@ -39,7 +36,7 @@ def put(self, task_id, message): self.unacked_results[task_id] = message def ack(self, task_id): - """ Ack a task result that was sent. Nothing happens if the task ID is not + """Ack a task result that was sent. Nothing happens if the task ID is not present in the Unacked Dict Parameters @@ -51,21 +48,25 @@ def ack(self, task_id): if acked_task: self.acked_count += 1 unacked_count = len(self.unacked_results) - logger.debug(f"Acked task {task_id}, Unacked count: {unacked_count}") + log.debug(f"Acked task {task_id}, Unacked count: {unacked_count}") def check_ack_counts(self): - """ Log the number of currently Unacked tasks and the tasks Acked since + """Log the number of currently Unacked tasks and the tasks Acked since the last check """ now = time.time() if now - self.last_log_timestamp > self.log_period: unacked_count = len(self.unacked_results) - logger.info(f"Unacked count: {unacked_count}, Acked results since last check {self.acked_count}") + log.info( + "Unacked count: %s, Acked results since last check %s", + unacked_count, + self.acked_count, + ) self.acked_count = 0 self.last_log_timestamp = now def get_unacked_results_list(self): - """ Get a list of unacked results messages that can be used for resending + """Get a list of unacked results messages that can be used for resending Returns ------- @@ -75,17 +76,19 @@ def get_unacked_results_list(self): return list(self.unacked_results.values()) def persist(self): - """ Save unacked results to disk - """ - with open(self.data_path, 'wb') as fp: + """Save unacked results to disk""" + with open(self.data_path, "wb") as fp: pickle.dump(self.unacked_results, fp) def load(self): - """ Load unacked results from disk - """ + """Load unacked results from disk""" try: if os.path.exists(self.data_path): - with open(self.data_path, 'rb') as fp: + with open(self.data_path, "rb") as fp: self.unacked_results = pickle.load(fp) except pickle.UnpicklingError: - logger.warning(f"Cached results {self.data_path} appear to be corrupt. Proceeding without loading cached results") + log.warning( + "Cached results %s appear to be corrupt. " + "Proceeding without loading cached results", + self.data_path, + ) diff --git a/funcx_endpoint/funcx_endpoint/endpoint/taskqueue.py b/funcx_endpoint/funcx_endpoint/endpoint/taskqueue.py index ab1ab69c2..f7a93d62a 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/taskqueue.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/taskqueue.py @@ -1,30 +1,31 @@ +import logging import os +import uuid + import zmq import zmq.auth from zmq.auth.thread import ThreadAuthenticator -import uuid -import logging -import time - -logger = logging.getLogger(__name__) - - -class TaskQueue(object): - """ Outgoing task queue from the executor to the Interchange - """ - - def __init__(self, - address: str, - port: int = 55001, - identity: str = str(uuid.uuid4()), - zmq_context=None, - set_hwm=False, - RCVTIMEO=None, - SNDTIMEO=None, - linger=None, - ironhouse: bool = False, - keys_dir: str = os.path.abspath('.curve'), - mode: str = 'client'): + +log = logging.getLogger(__name__) + + +class TaskQueue: + """Outgoing task queue from the executor to the Interchange""" + + def __init__( + self, + address: str, + port: int = 55001, + identity: str = str(uuid.uuid4()), + zmq_context=None, + set_hwm=False, + RCVTIMEO=None, + SNDTIMEO=None, + linger=None, + ironhouse: bool = False, + keys_dir: str = os.path.abspath(".curve"), + mode: str = "client", + ): """ Parameters ---------- @@ -37,7 +38,8 @@ def __init__(self, identity : str Applies only to clients, where the identity must match the endpoint uuid. - This will be utf-8 encoded on the wire. A random uuid4 string is set by default. + This will be utf-8 encoded on the wire. A random uuid4 string is set by + default. mode: string Either 'client' or 'server' @@ -59,21 +61,26 @@ def __init__(self, self.ironhouse = ironhouse self.keys_dir = keys_dir - assert self.mode in ['client', 'server'], "Only two modes are supported: client, server" + assert self.mode in [ + "client", + "server", + ], "Only two modes are supported: client, server" - if self.mode == 'server': + if self.mode == "server": print("Configuring server") self.zmq_socket = self.context.socket(zmq.ROUTER) self.zmq_socket.set(zmq.ROUTER_MANDATORY, 1) self.zmq_socket.set(zmq.ROUTER_HANDOVER, 1) print("Setting up auth-server") self.setup_server_auth() - elif self.mode == 'client': + elif self.mode == "client": self.zmq_socket = self.context.socket(zmq.DEALER) self.setup_client_auth() - self.zmq_socket.setsockopt(zmq.IDENTITY, identity.encode('utf-8')) + self.zmq_socket.setsockopt(zmq.IDENTITY, identity.encode("utf-8")) else: - raise ValueError("TaskQueue must be initialized with mode set to 'server' or 'client'") + raise ValueError( + "TaskQueue must be initialized with mode set to 'server' or 'client'" + ) if set_hwm: self.zmq_socket.set_hwm(0) @@ -85,41 +92,41 @@ def __init__(self, self.zmq_socket.setsockopt(zmq.LINGER, linger) # all zmq setsockopt calls must be done before bind/connect is called - if self.mode == 'server': - self.zmq_socket.bind("tcp://*:{}".format(port)) - elif self.mode == 'client': - self.zmq_socket.connect("tcp://{}:{}".format(address, port)) + if self.mode == "server": + self.zmq_socket.bind(f"tcp://*:{port}") + elif self.mode == "client": + self.zmq_socket.connect(f"tcp://{address}:{port}") self.poller = zmq.Poller() self.poller.register(self.zmq_socket) os.makedirs(self.keys_dir, exist_ok=True) - logger.debug(f"Initializing Taskqueue:{self.mode} on port:{self.port}") + log.debug(f"Initializing Taskqueue:{self.mode} on port:{self.port}") def zmq_context(self): return self.context def add_client_key(self, endpoint_id, client_key): - logger.info("Adding client key") + log.info("Adding client key") if self.ironhouse: # Use the ironhouse ZMQ pattern: http://hintjens.com/blog:49#toc6 - with open(os.path.join(self.keys_dir, f'{endpoint_id}.key'), 'w') as f: + with open(os.path.join(self.keys_dir, f"{endpoint_id}.key"), "w") as f: f.write(client_key) try: - self.auth.configure_curve(domain='*', location=self.keys_dir) + self.auth.configure_curve(domain="*", location=self.keys_dir) except Exception: - logger.exception("Failed to load keys from {self.keys_dir}") + log.exception("Failed to load keys from {self.keys_dir}") return def setup_server_auth(self): # Start an authenticator for this context. self.auth = ThreadAuthenticator(self.context) self.auth.start() - self.auth.allow('127.0.0.1') + self.auth.allow("127.0.0.1") # Tell the authenticator how to handle CURVE requests if not self.ironhouse: # Use the stonehouse ZMQ pattern: http://hintjens.com/blog:49#toc5 - self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) + self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) server_secret_file = os.path.join(self.keys_dir, "server.key_secret") server_public, server_secret = zmq.auth.load_certificate(server_secret_file) @@ -166,7 +173,7 @@ def register_client(self, message): return self.zmq_socket.send_multipart([message]) def put(self, dest, message, max_timeout=1000): - """ This function needs to be fast at the same time aware of the possibility of + """This function needs to be fast at the same time aware of the possibility of ZMQ pipes overflowing. The timeout increases slowly if contention is detected on ZMQ pipes. @@ -183,7 +190,8 @@ def put(self, dest, message, max_timeout=1000): Python object to send max_timeout : int - Max timeout in milliseconds that we will wait for before raising an exception + Max timeout in milliseconds that we will wait for before raising an + exception Raises ------ @@ -192,7 +200,7 @@ def put(self, dest, message, max_timeout=1000): zmq.error.ZMQError: Host unreachable (if client disconnects?) """ - if self.mode == 'client': + if self.mode == "client": return self.zmq_socket.send_multipart([message]) else: return self.zmq_socket.send_multipart([dest, message]) diff --git a/funcx_endpoint/funcx_endpoint/endpoint/utils/config.py b/funcx_endpoint/funcx_endpoint/endpoint/utils/config.py index 00588ba62..0f24000ac 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/utils/config.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/utils/config.py @@ -1,7 +1,7 @@ -from funcx_endpoint.executors import HighThroughputExecutor -from funcx_endpoint.strategies.simple import SimpleStrategy from parsl.utils import RepresentationMixin +from funcx_endpoint.executors import HighThroughputExecutor + _DEFAULT_EXECUTORS = [HighThroughputExecutor()] @@ -12,8 +12,8 @@ class Config(RepresentationMixin): ---------- executors : list of Executors - A list of executors which serve as the backend for function execution. As of 0.2.2, - this list should contain only one executor. + A list of executors which serve as the backend for function execution. + As of 0.2.2, this list should contain only one executor. Default: [HighThroughtputExecutor()] funcx_service_address: str @@ -21,12 +21,13 @@ class Config(RepresentationMixin): Default: 'https://api2.funcx.org/v2' heartbeat_period: int (seconds) - The interval at which heartbeat messages are sent from the endpoint to the funcx-web-service + The interval at which heartbeat messages are sent from the endpoint to the + funcx-web-service Default: 30s heartbeat_threshold: int (seconds) - Seconds since the last hearbeat message from the funcx-web-service after which the connection - is assumed to be disconnected. + Seconds since the last hearbeat message from the funcx-web-service after which + the connection is assumed to be disconnected. Default: 120s stdout : str diff --git a/funcx_endpoint/funcx_endpoint/executors/__init__.py b/funcx_endpoint/funcx_endpoint/executors/__init__.py index 6fbfebe5a..61d8fa594 100644 --- a/funcx_endpoint/funcx_endpoint/executors/__init__.py +++ b/funcx_endpoint/funcx_endpoint/executors/__init__.py @@ -1,3 +1,3 @@ from funcx_endpoint.executors.high_throughput.executor import HighThroughputExecutor -__all__ = ['HighThroughputExecutor'] +__all__ = ["HighThroughputExecutor"] diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/container_sched.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/container_sched.py index 591bca56f..8fe0216c6 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/container_sched.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/container_sched.py @@ -1,18 +1,23 @@ - +import logging import math import random +log = logging.getLogger(__name__) -def naive_scheduler(task_qs, outstanding_task_count, max_workers, old_worker_map, to_die_list, logger): - """ Return two items (as one tuple) dict kill_list :: KILL [(worker_type, num_kill), ...] - dict create_list :: CREATE [(worker_type, num_create), ...] - In this scheduler model, there is minimum 1 instance of each nonempty task queue. +def naive_scheduler( + task_qs, outstanding_task_count, max_workers, old_worker_map, to_die_list +): + """ + Return two items (as one tuple) + dict kill_list :: KILL [(worker_type, num_kill), ...] + dict create_list :: CREATE [(worker_type, num_create), ...] + In this scheduler model, there is minimum 1 instance of each nonempty task queue. """ - logger.debug("Entering scheduler...") - logger.debug("old_worker_map: {}".format(old_worker_map)) + log.debug("Entering scheduler...") + log.debug(f"old_worker_map: {old_worker_map}") q_sizes = {} q_types = [] new_worker_map = {} @@ -26,23 +31,25 @@ def naive_scheduler(task_qs, outstanding_task_count, max_workers, old_worker_map q_sizes[q_type] = q_size if sum_q_size > 0: - logger.info("[SCHEDULER] Total number of tasks is {}".format(sum_q_size)) + log.info(f"[SCHEDULER] Total number of tasks is {sum_q_size}") # Set proportions of workers equal to the proportion of queue size. for q_type in q_sizes: ratio = q_sizes[q_type] / sum_q_size - new_worker_map[q_type] = min(int(math.floor(ratio * max_workers)), q_sizes[q_type]) + new_worker_map[q_type] = min( + int(math.floor(ratio * max_workers)), q_sizes[q_type] + ) # CLEANUP: Assign the difference here to any random worker. Should be small. - # logger.debug("Temporary new worker map: {}".format(new_worker_map)) + # log.debug("Temporary new worker map: {}".format(new_worker_map)) # Check the difference tmp_sum_q_size = sum(new_worker_map.values()) difference = 0 if sum_q_size > tmp_sum_q_size: difference = min(max_workers - tmp_sum_q_size, sum_q_size - tmp_sum_q_size) - logger.debug("[SCHEDULER] Offset difference: {}".format(difference)) - logger.debug("[SCHEDULER] Queue Types: {}".format(q_types)) + log.debug(f"[SCHEDULER] Offset difference: {difference}") + log.debug(f"[SCHEDULER] Queue Types: {q_types}") if len(q_types) > 0: while difference > 0: diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/executor.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/executor.py index 89a54a3f5..8e756daa1 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/executor.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/executor.py @@ -1,49 +1,45 @@ -"""HighThroughputExecutor builds on the Swift/T EMEWS architecture to use MPI for fast task distribution +"""HighThroughputExecutor builds on the Swift/T EMEWS architecture to use MPI for fast +task distribution There's a slow but sure deviation from Parsl's Executor interface here, that needs to be addressed. """ import concurrent.futures -from concurrent.futures import Future -import os -import time import logging -import threading -import queue +import os import pickle -import daemon -import uuid +import queue +import threading +import time from multiprocessing import Process -from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue - -from funcx_endpoint.executors.high_throughput.messages import HeartbeatReq, EPStatusReport, Heartbeat -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, Task, TaskCancel -from funcx_endpoint.executors.high_throughput.messages import BadCommand -from funcx.serialize import FuncXSerializer -from funcx_endpoint.strategies.simple import SimpleStrategy -fx_serializer = FuncXSerializer() - -# from parsl.executors.high_throughput import interchange -from funcx_endpoint.executors.high_throughput import interchange +import daemon +from parsl.dataflow.error import ConfigurationError from parsl.executors.errors import BadMessage, ScalingFailed -# from parsl.executors.base import ParslExecutor from parsl.executors.status_handling import StatusHandlingExecutor -from parsl.dataflow.error import ConfigurationError - -from parsl.utils import RepresentationMixin from parsl.providers import LocalProvider +from parsl.utils import RepresentationMixin +from funcx.serialize import FuncXSerializer +from funcx_endpoint.executors.high_throughput import interchange, zmq_pipes +from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue +from funcx_endpoint.executors.high_throughput.messages import ( + EPStatusReport, + Heartbeat, + HeartbeatReq, + Task, + TaskCancel, +) +from funcx_endpoint.logging_config import setup_logging +from funcx_endpoint.strategies.simple import SimpleStrategy -from funcx_endpoint.executors.high_throughput import zmq_pipes -from funcx import set_file_logger +fx_serializer = FuncXSerializer() -# TODO: YADU There's a bug here which causes some of the log messages to write out to stderr +# TODO: YADU There's a bug here which causes some of the log messages to write out to +# stderr # "logging" python3 self.stream.flush() OSError: [Errno 9] Bad file descriptor -logger = logging.getLogger(__name__) -# if not logger.hasHandlers(): -# logger = set_file_logger("executor.log", name=__name__) +log = logging.getLogger(__name__) BUFFER_THRESHOLD = 1024 * 1024 @@ -55,10 +51,12 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): The HighThroughputExecutor system has the following components: 1. The HighThroughputExecutor instance which is run as part of the Parsl script. - 2. The Interchange which is acts as a load-balancing proxy between workers and Parsl - 3. The multiprocessing based worker pool which coordinates task execution over several - cores on a node. - 4. ZeroMQ pipes connect the HighThroughputExecutor, Interchange and the process_worker_pool + 2. The Interchange which is acts as a load-balancing proxy between workers and + Parsl + 3. The multiprocessing based worker pool which coordinates task execution over + several cores on a node. + 4. ZeroMQ pipes connect the HighThroughputExecutor, Interchange and the + process_worker_pool Here is a diagram @@ -83,7 +81,8 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): ---------- provider : :class:`~parsl.providers.provider_base.ExecutionProvider` - Provider to access computation resources. Can be one of :class:`~parsl.providers.aws.aws.EC2Provider`, + Provider to access computation resources. Can be one of + :class:`~parsl.providers.aws.aws.EC2Provider`, :class:`~parsl.providers.cobalt.cobalt.Cobalt`, :class:`~parsl.providers.condor.condor.Condor`, :class:`~parsl.providers.googlecloud.googlecloud.GoogleCloud`, @@ -98,21 +97,33 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): Label for this executor instance. launch_cmd : str - Command line string to launch the process_worker_pool from the provider. The command line string - will be formatted with appropriate values for the following values (debug, task_url, result_url, - cores_per_worker, nodes_per_block, heartbeat_period ,heartbeat_threshold, logdir). For eg: - launch_cmd="process_worker_pool.py {debug} -c {cores_per_worker} --task_url={task_url} --result_url={result_url}" + Command line string to launch the process_worker_pool from the provider. The + command line string will be formatted with appropriate values for the following + values: ( + debug, + task_url, + result_url, + cores_per_worker, + nodes_per_block, + heartbeat_period, + heartbeat_threshold, + logdir, + ). + For example: + launch_cmd="process_worker_pool.py {debug} -c {cores_per_worker} \ + --task_url={task_url} --result_url={result_url}" address : string - An address of the host on which the executor runs, which is reachable from the network in which - workers will be running. This can be either a hostname as returned by `hostname` or an - IP address. Most login nodes on clusters have several network interfaces available, only - some of which can be reached from the compute nodes. Some trial and error might be - necessary to indentify what addresses are reachable from compute nodes. + An address of the host on which the executor runs, which is reachable from the + network in which workers will be running. This can be either a hostname as + returned by `hostname` or an IP address. Most login nodes on clusters have + several network interfaces available, only some of which can be reached + from the compute nodes. Some trial and error might be necessary to + indentify what addresses are reachable from compute nodes. worker_ports : (int, int) - Specify the ports to be used by workers to connect to Parsl. If this option is specified, - worker_port_range will not be honored. + Specify the ports to be used by workers to connect to Parsl. If this + option is specified, worker_port_range will not be honored. worker_port_range : (int, int) Worker ports will be chosen between the two integers provided. @@ -140,19 +151,22 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): Caps the number of workers launched by the manager. Default: infinity suppress_failure : Bool - If set, the interchange will suppress failures rather than terminate early. Default: True + If set, the interchange will suppress failures rather than terminate early. + Default: False heartbeat_threshold : int Seconds since the last message from the counterpart in the communication pair: - (interchange, manager) after which the counterpart is assumed to be un-available. Default:120s + (interchange, manager) after which the counterpart is assumed to be unavailable. + Default:120s heartbeat_period : int - Number of seconds after which a heartbeat message indicating liveness is sent to the endpoint + Number of seconds after which a heartbeat message indicating liveness is sent to + the endpoint counterpart (interchange, manager). Default:30s poll_period : int - Timeout period to be used by the executor components in milliseconds. Increasing poll_periods - trades performance for cpu efficiency. Default: 10ms + Timeout period to be used by the executor components in milliseconds. + Increasing poll_periods trades performance for cpu efficiency. Default: 10ms container_image : str Path or identfier to the container image to be used by the workers @@ -163,7 +177,8 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): 'soft' -> managers can replace unused worker's containers based on demand worker_mode : str - Select the mode of operation from no_container, singularity_reuse, singularity_single_use + Select the mode of operation from no_container, singularity_reuse, + singularity_single_use Default: singularity_reuse container_cmd_options: str @@ -177,68 +192,66 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): Specify the scaling strategy to use for this executor. launch_cmd: str - Specify the launch command as using f-string format that will be used to specify command to - launch managers. Default: None + Specify the launch command as using f-string format that will be used to specify + command to launch managers. Default: None prefetch_capacity: int - Number of tasks that can be fetched by managers in excess of available workers is a - prefetching optimization. This option can cause poor load-balancing for long running functions. + Number of tasks that can be fetched by managers in excess of available + workers is a prefetching optimization. This option can cause poor + load-balancing for long running functions. Default: 10 provider: Provider object - Provider determines how managers can be provisioned, say LocalProvider offers forked processes, - and SlurmProvider interfaces to request resources from the Slurm batch scheduler. + Provider determines how managers can be provisioned, say LocalProvider + offers forked processes, and SlurmProvider interfaces to request + resources from the Slurm batch scheduler. Default: LocalProvider funcx_service_address: str - Override funcx_service_address used by the FuncXClient. If no address is specified, - the FuncXClient's default funcx_service_address is used. + Override funcx_service_address used by the FuncXClient. If no address + is specified, the FuncXClient's default funcx_service_address is used. Default: None """ - def __init__(self, - label='HighThroughputExecutor', - - - # NEW - strategy=SimpleStrategy(), - max_workers_per_node=float('inf'), - mem_per_worker=None, - launch_cmd=None, - - # Container specific - worker_mode='no_container', - scheduler_mode='hard', - container_type=None, - container_cmd_options='', - cold_routing_interval=10.0, - - # Tuning info - prefetch_capacity=10, - - provider=LocalProvider(), - address="127.0.0.1", - worker_ports=None, - worker_port_range=(54000, 55000), - interchange_port_range=(55000, 56000), - storage_access=None, - working_dir=None, - worker_debug=False, - cores_per_worker=1.0, - heartbeat_threshold=120, - heartbeat_period=30, - poll_period=10, - container_image=None, - suppress_failure=True, - run_dir=None, - endpoint_id=None, - managed=True, - interchange_local=True, - passthrough=True, - funcx_service_address=None, - task_status_queue=None): - - logger.debug("Initializing HighThroughputExecutor") + def __init__( + self, + label="HighThroughputExecutor", + # NEW + strategy=SimpleStrategy(), + max_workers_per_node=float("inf"), + mem_per_worker=None, + launch_cmd=None, + # Container specific + worker_mode="no_container", + scheduler_mode="hard", + container_type=None, + container_cmd_options="", + cold_routing_interval=10.0, + # Tuning info + prefetch_capacity=10, + provider=LocalProvider(), + address="127.0.0.1", + worker_ports=None, + worker_port_range=(54000, 55000), + interchange_port_range=(55000, 56000), + storage_access=None, + working_dir=None, + worker_debug=False, + cores_per_worker=1.0, + heartbeat_threshold=120, + heartbeat_period=30, + poll_period=10, + container_image=None, + suppress_failure=True, + run_dir=None, + endpoint_id=None, + managed=True, + interchange_local=True, + passthrough=True, + funcx_service_address=None, + task_status_queue=None, + ): + log.debug("Initializing HighThroughputExecutor") StatusHandlingExecutor.__init__(self, provider) self.label = label @@ -262,7 +275,9 @@ def __init__(self, self.storage_access = storage_access if storage_access is not None else [] if len(self.storage_access) > 1: - raise ConfigurationError('Multiple storage access schemes are not supported') + raise ConfigurationError( + "Multiple storage access schemes are not supported" + ) self.working_dir = working_dir self.managed = managed self.blocks = [] @@ -290,72 +305,88 @@ def __init__(self, self.last_response_time = time.time() if not launch_cmd: - self.launch_cmd = ("process_worker_pool.py {debug} {max_workers} " - "-c {cores_per_worker} " - "--poll {poll_period} " - "--task_url={task_url} " - "--result_url={result_url} " - "--logdir={logdir} " - "--hb_period={heartbeat_period} " - "--hb_threshold={heartbeat_threshold} " - "--mode={worker_mode} " - "--container_image={container_image} ") - - self.ix_launch_cmd = ("funcx-interchange {debug} -c={client_address} " - "--client_ports={client_ports} " - "--worker_port_range={worker_port_range} " - "--logdir={logdir} " - "{suppress_failure} " - ) + self.launch_cmd = ( + "process_worker_pool.py {debug} {max_workers} " + "-c {cores_per_worker} " + "--poll {poll_period} " + "--task_url={task_url} " + "--result_url={result_url} " + "--logdir={logdir} " + "--hb_period={heartbeat_period} " + "--hb_threshold={heartbeat_threshold} " + "--mode={worker_mode} " + "--container_image={container_image} " + ) + + self.ix_launch_cmd = ( + "funcx-interchange {debug} -c={client_address} " + "--client_ports={client_ports} " + "--worker_port_range={worker_port_range} " + "--logdir={logdir} " + "{suppress_failure} " + ) def initialize_scaling(self): - """ Compose the launch command and call the scale_out + """Compose the launch command and call the scale_out This should be implemented in the child classes to take care of executor specific oddities. """ debug_opts = "--debug" if self.worker_debug else "" - max_workers = "" if self.max_workers == float('inf') else "--max_workers={}".format(self.max_workers) - - l_cmd = self.launch_cmd.format(debug=debug_opts, - task_url=self.worker_task_url, - result_url=self.worker_result_url, - cores_per_worker=self.cores_per_worker, - max_workers=max_workers, - nodes_per_block=self.provider.nodes_per_block, - heartbeat_period=self.heartbeat_period, - heartbeat_threshold=self.heartbeat_threshold, - poll_period=self.poll_period, - logdir=os.path.join(self.run_dir, self.label), - worker_mode=self.worker_mode, - container_image=self.container_image) + max_workers = ( + "" + if self.max_workers == float("inf") + else f"--max_workers={self.max_workers}" + ) + + l_cmd = self.launch_cmd.format( + debug=debug_opts, + task_url=self.worker_task_url, + result_url=self.worker_result_url, + cores_per_worker=self.cores_per_worker, + max_workers=max_workers, + nodes_per_block=self.provider.nodes_per_block, + heartbeat_period=self.heartbeat_period, + heartbeat_threshold=self.heartbeat_threshold, + poll_period=self.poll_period, + logdir=os.path.join(self.run_dir, self.label), + worker_mode=self.worker_mode, + container_image=self.container_image, + ) self.launch_cmd = l_cmd - logger.debug("Launch command: {}".format(self.launch_cmd)) + log.debug(f"Launch command: {self.launch_cmd}") self._scaling_enabled = self.provider.scaling_enabled - logger.debug("Starting HighThroughputExecutor with provider:\n%s", self.provider) - if hasattr(self.provider, 'init_blocks'): + log.debug("Starting HighThroughputExecutor with provider:\n%s", self.provider) + if hasattr(self.provider, "init_blocks"): try: self.scale_out(blocks=self.provider.init_blocks) except Exception as e: - logger.error("Scaling out failed: {}".format(e)) + log.error(f"Scaling out failed: {e}") raise e def start(self, results_passthrough=None): - """Create the Interchange process and connect to it. - """ - self.outgoing_q = zmq_pipes.TasksOutgoing("0.0.0.0", self.interchange_port_range) - self.incoming_q = zmq_pipes.ResultsIncoming("0.0.0.0", self.interchange_port_range) - self.command_client = zmq_pipes.CommandClient("0.0.0.0", self.interchange_port_range) + """Create the Interchange process and connect to it.""" + self.outgoing_q = zmq_pipes.TasksOutgoing( + "0.0.0.0", self.interchange_port_range + ) + self.incoming_q = zmq_pipes.ResultsIncoming( + "0.0.0.0", self.interchange_port_range + ) + self.command_client = zmq_pipes.CommandClient( + "0.0.0.0", self.interchange_port_range + ) self.is_alive = True if self.passthrough is True: if results_passthrough is None: - raise Exception("Executors configured in passthrough mode, must be started with" - "a multiprocessing queue for results_passthrough") + raise Exception( + "Executors configured in passthrough mode, must be started with" + "a multiprocessing queue for results_passthrough" + ) self.results_passthrough = results_passthrough - logger.debug(f"Executor:{self.label} starting in results_passthrough mode") + log.debug(f"Executor:{self.label} starting in results_passthrough mode") self._executor_bad_state = threading.Event() self._executor_exception = None @@ -363,99 +394,120 @@ def start(self, results_passthrough=None): self._start_queue_management_thread() if self.interchange_local is True: - logger.info("Attempting local interchange start") + log.info("Attempting local interchange start") self._start_local_interchange_process() - logger.info(f"Started local interchange with ports: {self.worker_task_port}. {self.worker_result_port}") + log.info( + "Started local interchange with ports: %s. %s", + self.worker_task_port, + self.worker_result_port, + ) - logger.debug("Created management thread: {}".format(self._queue_management_thread)) + log.debug(f"Created management thread: {self._queue_management_thread}") if self.provider: # self.initialize_scaling() pass else: self._scaling_enabled = False - logger.debug("Starting HighThroughputExecutor with no provider") + log.debug("Starting HighThroughputExecutor with no provider") return (self.outgoing_q.port, self.incoming_q.port, self.command_client.port) def _start_local_interchange_process(self): - """ Starts the interchange process locally + """Starts the interchange process locally Starts the interchange process locally and uses an internal command queue to get the worker task and result ports that the interchange has bound to. """ comm_q = mpQueue(maxsize=10) print(f"Starting local interchange with endpoint id: {self.endpoint_id}") - self.queue_proc = Process(target=interchange.starter, - args=(comm_q,), - kwargs={"client_address": self.address, - "client_ports": (self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port), - "provider": self.provider, - "strategy": self.strategy, - "poll_period": self.poll_period, - "heartbeat_period": self.heartbeat_period, - "heartbeat_threshold": self.heartbeat_threshold, - "working_dir": self.working_dir, - "worker_debug": self.worker_debug, - "max_workers_per_node": self.max_workers_per_node, - "mem_per_worker": self.mem_per_worker, - "cores_per_worker": self.cores_per_worker, - "prefetch_capacity": self.prefetch_capacity, - # "log_max_bytes": self.log_max_bytes, - # "log_backup_count": self.log_backup_count, - "scheduler_mode": self.scheduler_mode, - "worker_mode": self.worker_mode, - "container_type": self.container_type, - "container_cmd_options": self.container_cmd_options, - "cold_routing_interval": self.cold_routing_interval, - "funcx_service_address": self.funcx_service_address, - "interchange_address": self.address, - "worker_ports": self.worker_ports, - "worker_port_range": self.worker_port_range, - "logdir": os.path.join(self.run_dir, self.label), - "suppress_failure": self.suppress_failure, - "endpoint_id": self.endpoint_id, - "logging_level": logging.DEBUG if self.worker_debug else logging.INFO - }, + self.queue_proc = Process( + target=interchange.starter, + args=(comm_q,), + kwargs={ + "client_address": self.address, + "client_ports": ( + self.outgoing_q.port, + self.incoming_q.port, + self.command_client.port, + ), + "provider": self.provider, + "strategy": self.strategy, + "poll_period": self.poll_period, + "heartbeat_period": self.heartbeat_period, + "heartbeat_threshold": self.heartbeat_threshold, + "working_dir": self.working_dir, + "worker_debug": self.worker_debug, + "max_workers_per_node": self.max_workers_per_node, + "mem_per_worker": self.mem_per_worker, + "cores_per_worker": self.cores_per_worker, + "prefetch_capacity": self.prefetch_capacity, + "scheduler_mode": self.scheduler_mode, + "worker_mode": self.worker_mode, + "container_type": self.container_type, + "container_cmd_options": self.container_cmd_options, + "cold_routing_interval": self.cold_routing_interval, + "funcx_service_address": self.funcx_service_address, + "interchange_address": self.address, + "worker_ports": self.worker_ports, + "worker_port_range": self.worker_port_range, + "logdir": os.path.join(self.run_dir, self.label), + "suppress_failure": self.suppress_failure, + "endpoint_id": self.endpoint_id, + }, ) self.queue_proc.start() try: - (self.worker_task_port, self.worker_result_port) = comm_q.get(block=True, timeout=120) + (self.worker_task_port, self.worker_result_port) = comm_q.get( + block=True, timeout=120 + ) except queue.Empty: - logger.error("Interchange has not completed initialization in 120s. Aborting") + log.error("Interchange has not completed initialization in 120s. Aborting") raise Exception("Interchange failed to start") - self.worker_task_url = "tcp://{}:{}".format(self.address, self.worker_task_port) - self.worker_result_url = "tcp://{}:{}".format(self.address, self.worker_result_port) + self.worker_task_url = f"tcp://{self.address}:{self.worker_task_port}" + self.worker_result_url = "tcp://{}:{}".format( + self.address, self.worker_result_port + ) def _start_remote_interchange_process(self): - """ Starts the interchange process locally + """Starts the interchange process locally - Starts the interchange process remotely via the provider.channel and uses the command channel - to request worker urls that the interchange has selected. + Starts the interchange process remotely via the provider.channel and + uses the command channel to request worker urls that the interchange + has selected. """ - logger.debug("Attempting Interchange deployment via channel: {}".format(self.provider.channel)) + log.debug( + "Attempting Interchange deployment via channel: {}".format( + self.provider.channel + ) + ) debug_opts = "--debug" if self.worker_debug else "" suppress_failure = "--suppress_failure" if self.suppress_failure else "" - logger.debug("Before : \n{}\n".format(self.ix_launch_cmd)) - launch_command = self.ix_launch_cmd.format(debug=debug_opts, - client_address=self.address, - client_ports="{},{},{}".format(self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port), - worker_port_range="{},{}".format(self.worker_port_range[0], - self.worker_port_range[1]), - logdir=os.path.join(self.provider.channel.script_dir, 'runinfo', - os.path.basename(self.run_dir), self.label), - suppress_failure=suppress_failure) + log.debug(f"Before : \n{self.ix_launch_cmd}\n") + launch_command = self.ix_launch_cmd.format( + debug=debug_opts, + client_address=self.address, + client_ports="{},{},{}".format( + self.outgoing_q.port, self.incoming_q.port, self.command_client.port + ), + worker_port_range="{},{}".format( + self.worker_port_range[0], self.worker_port_range[1] + ), + logdir=os.path.join( + self.provider.channel.script_dir, + "runinfo", + os.path.basename(self.run_dir), + self.label, + ), + suppress_failure=suppress_failure, + ) if self.provider.worker_init: - launch_command = self.provider.worker_init + '\n' + launch_command + launch_command = self.provider.worker_init + "\n" + launch_command - logger.debug("Launch command : \n{}\n".format(launch_command)) + log.debug(f"Launch command : \n{launch_command}\n") return def _queue_management_worker(self): @@ -491,7 +543,7 @@ def _queue_management_worker(self): The `None` message is a die request. """ - logger.debug("[MTHREAD] queue management worker starting") + log.debug("[MTHREAD] queue management worker starting") while not self._executor_bad_state.is_set(): try: @@ -499,105 +551,145 @@ def _queue_management_worker(self): self.last_response_time = time.time() except queue.Empty: - logger.debug("[MTHREAD] queue empty") + log.debug("[MTHREAD] queue empty") # Timed out. pass - except IOError as e: - logger.exception("[MTHREAD] Caught broken queue with exception code {}: {}".format(e.errno, e)) + except OSError as e: + log.exception( + "[MTHREAD] Caught broken queue with exception code {}: {}".format( + e.errno, e + ) + ) return except Exception as e: - logger.exception("[MTHREAD] Caught unknown exception: {}".format(e)) + log.exception(f"[MTHREAD] Caught unknown exception: {e}") return else: if msgs is None: - logger.debug("[MTHREAD] Got None, exiting") + log.debug("[MTHREAD] Got None, exiting") return elif isinstance(msgs, EPStatusReport): - logger.debug("[MTHREAD] Received EPStatusReport {}".format(msgs)) + log.debug(f"[MTHREAD] Received EPStatusReport {msgs}") if self.passthrough: - self.results_passthrough.put({ - "task_id": None, - "message": pickle.dumps(msgs) - }) + self.results_passthrough.put( + {"task_id": None, "message": pickle.dumps(msgs)} + ) else: - logger.debug("[MTHREAD] Unpacking results") + log.debug("[MTHREAD] Unpacking results") for serialized_msg in msgs: try: msg = pickle.loads(serialized_msg) - tid = msg['task_id'] + tid = msg["task_id"] except pickle.UnpicklingError: raise BadMessage("Message received could not be unpickled") except Exception: - raise BadMessage("Message received does not contain 'task_id' field") - - if tid == -2 and 'info' in msg: - logger.warning("[MTHREAD[ Received info response : {}".format(msg['info'])) - - if tid == -1 and 'exception' in msg: - # TODO: This could be handled better we are essentially shutting down the - # client with little indication to the user. - logger.warning("[MTHREAD] Executor shutting down due to version mismatch in interchange") - self._executor_exception = fx_serializer.deserialize(msg['exception']) - logger.exception("[MTHREAD] Exception: {}".format(self._executor_exception)) + raise BadMessage( + "Message received does not contain 'task_id' field" + ) + + if tid == -2 and "info" in msg: + log.warning( + "[MTHREAD[ Received info response : {}".format( + msg["info"] + ) + ) + + if tid == -1 and "exception" in msg: + # TODO: This could be handled better we are + # essentially shutting down the client with little + # indication to the user. + log.warning( + "[MTHREAD] Executor shutting down due to fatal " + "exception from interchange" + ) + self._executor_exception = fx_serializer.deserialize( + msg["exception"] + ) + log.exception( + "[MTHREAD] Exception: {}".format( + self._executor_exception + ) + ) # Set bad state to prevent new tasks from being submitted self._executor_bad_state.set() - # We set all current tasks to this exception to make sure that - # this is raised in the main context. + # We set all current tasks to this exception to make sure + # that this is raised in the main context. for task_id in self.tasks: try: - self.tasks[task_id].set_exception(self._executor_exception) + self.tasks[task_id].set_exception( + self._executor_exception + ) except concurrent.futures.InvalidStateError: - # Task was already cancelled, the exception can be ignored - logger.debug(f"Task:{task_id} result couldn't be set. Already in terminal state") + # Task was already cancelled, the exception can be + # ignored + log.debug( + f"Task:{task_id} result couldn't be set. " + "Already in terminal state" + ) break if self.passthrough is True: - logger.debug(f"[MTHREAD] Pushing results for task:{tid}") - # we are only interested in actual task ids here, not identifiers - # for other message types + log.debug(f"[MTHREAD] Pushing results for task:{tid}") + # we are only interested in actual task ids here, not + # identifiers for other message types sent_task_id = tid if isinstance(tid, str) else None - x = self.results_passthrough.put({ - "task_id": sent_task_id, - "message": serialized_msg - }) - logger.debug(f"[MTHREAD] task:{tid} ret value: {x}") - logger.debug(f"[MTHREAD] task:{tid} items in queue: {self.results_passthrough.qsize()}") + x = self.results_passthrough.put( + {"task_id": sent_task_id, "message": serialized_msg} + ) + log.debug(f"[MTHREAD] task:{tid} ret value: {x}") + log.debug( + "[MTHREAD] task:%s items in queue: %s", + tid, + self.results_passthrough.qsize(), + ) continue try: task_fut = self.tasks.pop(tid) except KeyError: - # This is triggered when the result of a cancelled task is returned + # This is triggered when the result of a cancelled task is + # returned # We should log, and proceed. - logger.warning(f"[MTHREAD] Task:{tid} not found in tasks table\n" - "Task likely was cancelled and removed.") + log.warning( + f"[MTHREAD] Task:{tid} not found in tasks table\n" + "Task likely was cancelled and removed." + ) continue - if 'result' in msg: - result = fx_serializer.deserialize(msg['result']) + if "result" in msg: + result = fx_serializer.deserialize(msg["result"]) try: task_fut.set_result(result) except concurrent.futures.InvalidStateError: - logger.debug(f"Task:{tid} result couldn't be set. Already in terminal state") - elif 'exception' in msg: - exception = fx_serializer.deserialize(msg['exception']) + log.debug( + f"Task:{tid} result couldn't be set. " + "Already in terminal state" + ) + elif "exception" in msg: + exception = fx_serializer.deserialize(msg["exception"]) try: task_fut.set_result(exception) except concurrent.futures.InvalidStateError: - logger.debug(f"Task:{tid} result couldn't be set. Already in terminal state") + log.debug( + f"Task:{tid} result couldn't be set. " + "Already in terminal state" + ) else: - raise BadMessage("[MTHREAD] Message received is neither result or exception") + raise BadMessage( + "[MTHREAD] Message received is neither result or " + "exception" + ) if not self.is_alive: break - logger.info("[MTHREAD] queue management worker finished") + log.info("[MTHREAD] queue management worker finished") # When the executor gets lost, the weakref callback will wake up # the queue management thread. @@ -612,14 +704,16 @@ def _start_queue_management_thread(self): Could be used later as a restart if the management thread dies. """ if self._queue_management_thread is None: - logger.debug("Starting queue management thread") - self._queue_management_thread = threading.Thread(target=self._queue_management_worker) + log.debug("Starting queue management thread") + self._queue_management_thread = threading.Thread( + target=self._queue_management_worker + ) self._queue_management_thread.daemon = True self._queue_management_thread.start() - logger.debug("Started queue management thread") + log.debug("Started queue management thread") else: - logger.debug("Management thread already exists, returning") + log.debug("Management thread already exists, returning") def hold_worker(self, worker_id): """Puts a worker on hold, preventing scheduling of additional tasks to it. @@ -633,35 +727,36 @@ def hold_worker(self, worker_id): worker_id : str Worker id to be put on hold """ - c = self.command_client.run("HOLD_WORKER;{}".format(worker_id)) - logger.debug("Sent hold request to worker: {}".format(worker_id)) + c = self.command_client.run(f"HOLD_WORKER;{worker_id}") + log.debug(f"Sent hold request to worker: {worker_id}") return c def send_heartbeat(self): - logger.warning("Sending heartbeat to interchange") + log.warning("Sending heartbeat to interchange") msg = Heartbeat(endpoint_id="") self.outgoing_q.put(msg.pack()) def wait_for_endpoint(self): heartbeat = self.command_client.run(HeartbeatReq()) - logger.debug("Attempting heartbeat to interchange") + log.debug("Attempting heartbeat to interchange") return heartbeat @property def outstanding(self): outstanding_c = self.command_client.run("OUTSTANDING_C") - logger.debug("Got outstanding count: {}".format(outstanding_c)) + log.debug(f"Got outstanding count: {outstanding_c}") return outstanding_c @property def connected_workers(self): workers = self.command_client.run("MANAGERS") - logger.debug("Got managers: {}".format(workers)) + log.debug(f"Got managers: {workers}") return workers - def submit(self, func, *args, container_id: str = 'RAW', task_id: str = None, **kwargs): - """ Submits the function and it's params for execution. - """ + def submit( + self, func, *args, container_id: str = "RAW", task_id: str = None, **kwargs + ): + """Submits the function and it's params for execution.""" self._task_counter += 1 if task_id is None: task_id = self._task_counter @@ -669,11 +764,10 @@ def submit(self, func, *args, container_id: str = 'RAW', task_id: str = None, ** fn_code = fx_serializer.serialize(func) ser_code = fx_serializer.pack_buffers([fn_code]) - ser_params = fx_serializer.pack_buffers([fx_serializer.serialize(args), - fx_serializer.serialize(kwargs)]) - payload = Task(task_id, - container_id, - ser_code + ser_params) + ser_params = fx_serializer.pack_buffers( + [fx_serializer.serialize(args), fx_serializer.serialize(kwargs)] + ) + payload = Task(task_id, container_id, ser_code + ser_params) self.submit_raw(payload.pack()) self.tasks[task_id] = HTEXFuture(self) @@ -686,16 +780,18 @@ def submit_raw(self, packed_task): The outgoing_q is an external process listens on this queue for new work. This method behaves like a - submit call as described here `Python docs: `_ + submit call as described in the `Python docs \ + `_ Parameters ---------- - Packed Task (messages.Task) - A packed Task object which contains task_id, container_id, and serialized fn, args, kwargs packages. + Packed Task (messages.Task) - A packed Task object which contains task_id, + container_id, and serialized fn, args, kwargs packages. Returns: Submit status """ - logger.debug(f"Submitting raw task : {packed_task}") + log.debug(f"Submitting raw task : {packed_task}") if self._executor_bad_state.is_set(): raise self._executor_exception @@ -714,17 +810,18 @@ def _get_block_and_job_ids(self): @property def connection_info(self): - """ All connection info necessary for the endpoint to connect back + """All connection info necessary for the endpoint to connect back Returns: Dict with connection info """ - return {'address': self.address, - # A memorial to the ungodly amount of time and effort spent, - # troubleshooting the order of these ports. - 'client_ports': '{},{},{}'.format(self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port) + return { + "address": self.address, + # A memorial to the ungodly amount of time and effort spent, + # troubleshooting the order of these ports. + "client_ports": "{},{},{}".format( + self.outgoing_q.port, self.incoming_q.port, self.command_client.port + ), } @property @@ -741,13 +838,17 @@ def scale_out(self, blocks=1): for i in range(blocks): if self.provider: block = self.provider.submit(self.launch_cmd, 1, 1) - logger.debug("Launched block {}:{}".format(i, block)) + log.debug(f"Launched block {i}:{block}") if not block: - raise(ScalingFailed(self.provider.label, - "Attempts to provision nodes via provider has failed")) + raise ( + ScalingFailed( + self.provider.label, + "Attempts to provision nodes via provider has failed", + ) + ) self.blocks.extend([block]) else: - logger.error("No execution provider available") + log.error("No execution provider available") r = None return r @@ -778,7 +879,7 @@ def status(self): return status - def shutdown(self, hub=True, targets='all', block=False): + def shutdown(self, hub=True, targets="all", block=False): """Shutdown the executor, including all workers and controllers. This is not implemented. @@ -792,16 +893,16 @@ def shutdown(self, hub=True, targets='all', block=False): NotImplementedError """ - logger.info("Attempting HighThroughputExecutor shutdown") + log.info("Attempting HighThroughputExecutor shutdown") # self.outgoing_q.close() # self.incoming_q.close() if self.queue_proc: self.queue_proc.terminate() - logger.info("Finished HighThroughputExecutor shutdown attempt") + log.info("Finished HighThroughputExecutor shutdown attempt") return True def _cancel(self, future): - """ Attempt cancelling a task tracked by the future by requesting + """Attempt cancelling a task tracked by the future by requesting cancellation from the interchange. Task cancellation is attempted only if the future is cancellable i.e not already in a terminal state. This relies on the executor not setting the task to a running @@ -817,30 +918,32 @@ def _cancel(self, future): """ ret_value = future._cancel() - logger.debug("Sending cancel of task_id:{future.task_id} to interchange") + log.debug("Sending cancel of task_id:{future.task_id} to interchange") if ret_value is True: self.command_client.run(TaskCancel(future.task_id)) - logger.debug("Sent TaskCancel to interchange") + log.debug("Sent TaskCancel to interchange") return ret_value -CANCELLED = 'CANCELLED' -CANCELLED_AND_NOTIFIED = 'CANCELLED_AND_NOTIFIED' -FINISHED = 'FINISHED' - +CANCELLED = "CANCELLED" +CANCELLED_AND_NOTIFIED = "CANCELLED_AND_NOTIFIED" +FINISHED = "FINISHED" -class HTEXFuture(Future): +class HTEXFuture(concurrent.futures.Future): def __init__(self, executor): super().__init__() self.executor = executor def cancel(self): - raise NotImplementedError(f"{self.__class__} does not implement cancel() try using best_effort_cancel()") + raise NotImplementedError( + f"{self.__class__} does not implement cancel() " + "try using best_effort_cancel()" + ) def _cancel(self): - """ Should be invoked only by the executor + """Should be invoked only by the executor Returns ------- Bool @@ -848,12 +951,16 @@ def _cancel(self): return super().cancel() def best_effort_cancel(self): - """ Attempt to cancel the function. If the function has finished running, the task cannot be cancelled - and the method will return False. If the function is yet to start or is running, cancellation will be + """Attempt to cancel the function. + + If the function has finished running, the task cannot be cancelled + and the method will return False. + If the function is yet to start or is running, cancellation will be attempted without guarantees, and the method will return True. - Please note that a return value of True does not guarantee that your function will not - execute at all, but it does guarantee that the future will be in a cancelled state. + Please note that a return value of True does not guarantee that your + function will not execute at all, but it does guarantee that the + future will be in a cancelled state. Returns ------- @@ -862,17 +969,17 @@ def best_effort_cancel(self): return self.executor._cancel(self) -def executor_starter(htex, logdir, endpoint_id, logging_level=logging.DEBUG): - - stdout = open(os.path.join(logdir, "executor.{}.stdout".format(endpoint_id)), 'w') - stderr = open(os.path.join(logdir, "executor.{}.stderr".format(endpoint_id)), 'w') +def executor_starter(htex, logdir, endpoint_id): + stdout = open(os.path.join(logdir, f"executor.{endpoint_id}.stdout"), "w") + stderr = open(os.path.join(logdir, f"executor.{endpoint_id}.stderr"), "w") logdir = os.path.abspath(logdir) with daemon.DaemonContext(stdout=stdout, stderr=stderr): - global logger print("cwd: ", os.getcwd()) - logger = set_file_logger(os.path.join(logdir, "executor.{}.log".format(endpoint_id)), - level=logging_level) + setup_logging( + logfile=os.path.join(logdir, f"executor.{endpoint_id}.log"), + console_enabled=False, + ) htex.start() stdout.close() diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_manager.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_manager.py index e974a5b40..1bd15800c 100755 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_manager.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_manager.py @@ -1,51 +1,47 @@ #!/usr/bin/env python3 import argparse +import json import logging +import math +import multiprocessing import os -import sys +import pickle import platform +import queue +import subprocess +import sys import threading -import pickle import time -import queue import uuid -import zmq -import math -import json -import multiprocessing + import psutil -import subprocess +import zmq +from parsl.app.errors import RemoteExceptionWrapper +from parsl.version import VERSION as PARSL_VERSION +from funcx.serialize import FuncXSerializer from funcx_endpoint.executors.high_throughput.container_sched import naive_scheduler +from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue from funcx_endpoint.executors.high_throughput.messages import ( - EPStatusReport, - Heartbeat, ManagerStatusReport, - TaskStatusCode + Message, + Task, + TaskStatusCode, ) from funcx_endpoint.executors.high_throughput.worker_map import WorkerMap -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task -from funcx.serialize import FuncXSerializer -from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue - -from parsl.version import VERSION as PARSL_VERSION -from parsl.app.errors import RemoteExceptionWrapper - -from funcx import set_file_logger - +from funcx_endpoint.logging_config import setup_logging RESULT_TAG = 10 TASK_REQUEST_TAG = 11 HEARTBEAT_CODE = (2 ** 32) - 1 -logger = None +log = logging.getLogger(__name__) class TaskCancelled(Exception): - """ Task is cancelled by user request. - """ + """Task is cancelled by user request.""" def __init__(self, worker_id, manager_id): self.worker_id = worker_id @@ -53,11 +49,14 @@ def __init__(self, worker_id, manager_id): self.tstamp = time.time() def __str__(self): - return f"Task cancelled based on user request on manager:{self.manager_id}, worker:{self.worker_id}" + return ( + "Task cancelled based on user request on manager: " + f"{self.manager_id}, worker: {self.worker_id}" + ) -class Manager(object): - """ Manager manages task execution by the workers +class Manager: + """Manager manages task execution by the workers | 0mq | Manager | Worker Processes | | | @@ -73,26 +72,28 @@ class Manager(object): """ - def __init__(self, - task_q_url="tcp://127.0.0.1:50097", - result_q_url="tcp://127.0.0.1:50098", - max_queue_size=10, - cores_per_worker=1, - max_workers=float('inf'), - uid=None, - heartbeat_threshold=120, - heartbeat_period=30, - logdir=None, - debug=False, - block_id=None, - internal_worker_port_range=(50000, 60000), - worker_mode="singularity_reuse", - container_cmd_options="", - scheduler_mode="hard", - worker_type=None, - worker_max_idletime=60, - # TODO : This should be 10ms - poll_period=100): + def __init__( + self, + task_q_url="tcp://127.0.0.1:50097", + result_q_url="tcp://127.0.0.1:50098", + max_queue_size=10, + cores_per_worker=1, + max_workers=float("inf"), + uid=None, + heartbeat_threshold=120, + heartbeat_period=30, + logdir=None, + debug=False, + block_id=None, + internal_worker_port_range=(50000, 60000), + worker_mode="singularity_reuse", + container_cmd_options="", + scheduler_mode="hard", + worker_type=None, + worker_max_idletime=60, + # TODO : This should be 10ms + poll_period=100, + ): """ Parameters ---------- @@ -112,23 +113,29 @@ def __init__(self, heartbeat_threshold : int Seconds since the last message from the interchange after which the - interchange is assumed to be un-available, and the manager initiates shutdown. Default:120s + interchange is assumed to be un-available, and the manager initiates + shutdown. Default:120s - Number of seconds since the last message from the interchange after which the worker - assumes that the interchange is lost and the manager shuts down. Default:120 + Number of seconds since the last message from the interchange after which + the worker assumes that the interchange is lost and the manager shuts down. + Default:120 heartbeat_period : int - Number of seconds after which a heartbeat message is sent to the interchange + Number of seconds after which a heartbeat message is sent to the + interchange internal_worker_port_range : tuple(int, int) - Port range from which the port(s) for the workers to connect to the manager is picked. + Port range from which the port(s) for the workers to connect to the manager + is picked. Default: (50000,60000) worker_mode : str Pick between 3 supported modes for the worker: 1. no_container : Worker launched without containers - 2. singularity_reuse : Worker launched inside a singularity container that will be reused - 3. singularity_single_use : Each worker and task runs inside a new container instance. + 2. singularity_reuse : Worker launched inside a singularity container that + will be reused + 3. singularity_single_use : Each worker and task runs inside a new + container instance. container_cmd_options: str Container command strings to be added to associated container command. @@ -145,19 +152,11 @@ def __init__(self, poll_period : int Timeout period used by the manager in milliseconds. Default: 10ms """ - - global logger - # This is expected to be used only in unit test - if logger is None: - logger = set_file_logger(os.path.join(logdir, uid, 'manager.log'), - name='funcx_manager', - level=logging.DEBUG) - - logger.info("Manager started") + log.info("Manager started") self.context = zmq.Context() self.task_incoming = self.context.socket(zmq.DEALER) - self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) + self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode("utf-8")) # Linger is set to 0, so that the manager can exit even when there might be # messages in the pipe self.task_incoming.setsockopt(zmq.LINGER, 0) @@ -167,11 +166,11 @@ def __init__(self, self.debug = debug self.block_id = block_id self.result_outgoing = self.context.socket(zmq.DEALER) - self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) + self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode("utf-8")) self.result_outgoing.setsockopt(zmq.LINGER, 0) self.result_outgoing.connect(result_q_url) - logger.info("Manager connected") + log.info("Manager connected") self.uid = uid @@ -183,22 +182,30 @@ def __init__(self, self.cores_on_node = multiprocessing.cpu_count() self.max_workers = max_workers self.cores_per_workers = cores_per_worker - self.available_mem_on_node = round(psutil.virtual_memory().available / (2**30), 1) - self.max_worker_count = min(max_workers, - math.floor(self.cores_on_node / cores_per_worker)) + self.available_mem_on_node = round( + psutil.virtual_memory().available / (2 ** 30), 1 + ) + self.max_worker_count = min( + max_workers, math.floor(self.cores_on_node / cores_per_worker) + ) self.worker_map = WorkerMap(self.max_worker_count) self.internal_worker_port_range = internal_worker_port_range self.funcx_task_socket = self.context.socket(zmq.ROUTER) self.funcx_task_socket.set_hwm(0) - self.address = '127.0.0.1' + self.address = "127.0.0.1" self.worker_port = self.funcx_task_socket.bind_to_random_port( "tcp://*", min_port=self.internal_worker_port_range[0], - max_port=self.internal_worker_port_range[1]) + max_port=self.internal_worker_port_range[1], + ) - logger.info("Manager listening on {} port for incoming worker connections".format(self.worker_port)) + log.info( + "Manager listening on {} port for incoming worker connections".format( + self.worker_port + ) + ) self.task_queues = {} if worker_type: @@ -222,12 +229,10 @@ def __init__(self, self._kill_event = threading.Event() self._result_pusher_thread = threading.Thread( - target=self.push_results, - args=(self._kill_event,) + target=self.push_results, args=(self._kill_event,) ) self._status_report_thread = threading.Thread( - target=self._status_report_loop, - args=(self._kill_event,) + target=self._status_report_loop, args=(self._kill_event,) ) self.container_switch_count = 0 @@ -240,26 +245,26 @@ def __init__(self, self.task_cancel_lock = threading.Lock() def create_reg_message(self): - """ Creates a registration message to identify the worker to the interchange - """ - msg = {'parsl_v': PARSL_VERSION, - 'python_v': "{}.{}.{}".format(sys.version_info.major, - sys.version_info.minor, - sys.version_info.micro), - 'max_worker_count': self.max_worker_count, - 'cores': self.cores_on_node, - 'mem': self.available_mem_on_node, - 'block_id': self.block_id, - 'worker_type': self.worker_type, - 'os': platform.system(), - 'hname': platform.node(), - 'dir': os.getcwd(), + """Creates a registration message to identify the worker to the interchange""" + msg = { + "parsl_v": PARSL_VERSION, + "python_v": "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro + ), + "max_worker_count": self.max_worker_count, + "cores": self.cores_on_node, + "mem": self.available_mem_on_node, + "block_id": self.block_id, + "worker_type": self.worker_type, + "os": platform.system(), + "hname": platform.node(), + "dir": os.getcwd(), } - b_msg = json.dumps(msg).encode('utf-8') + b_msg = json.dumps(msg).encode("utf-8") return b_msg def pull_tasks(self, kill_event): - """ Pull tasks from the incoming tasks 0mq pipe onto the internal + """Pull tasks from the incoming tasks 0mq pipe onto the internal pending task queue @@ -277,11 +282,11 @@ def pull_tasks(self, kill_event): kill_event : threading.Event Event to let the thread know when it is time to die. """ - logger.info("[TASK PULL THREAD] starting") + log.info("[TASK PULL THREAD] starting") # Send a registration message msg = self.create_reg_message() - logger.debug("Sending registration message: {}".format(msg)) + log.debug(f"Sending registration message: {msg}") self.task_incoming.send(msg) last_interchange_contact = time.time() task_recv_counter = 0 @@ -291,76 +296,96 @@ def pull_tasks(self, kill_event): new_worker_map = None while not kill_event.is_set(): # Disabling the check on ready_worker_queue disables batching - logger.debug("[TASK_PULL_THREAD] Loop start") + log.debug("[TASK_PULL_THREAD] Loop start") pending_task_count = task_recv_counter - self.task_done_counter ready_worker_count = self.worker_map.ready_worker_count() - logger.debug("[TASK_PULL_THREAD pending_task_count: {}, Ready_worker_count: {}".format( - pending_task_count, ready_worker_count)) + log.debug( + "[TASK_PULL_THREAD pending_task_count: %s, Ready_worker_count: %s", + pending_task_count, + ready_worker_count, + ) if pending_task_count < self.max_queue_size and ready_worker_count > 0: ads = self.worker_map.advertisement() - logger.debug("[TASK_PULL_THREAD] Requesting tasks: {}".format(ads)) + log.debug(f"[TASK_PULL_THREAD] Requesting tasks: {ads}") msg = pickle.dumps(ads) self.task_incoming.send(msg) # Receive results from the workers, if any socks = dict(self.poller.poll(timeout=poll_timer)) - if self.funcx_task_socket in socks and socks[self.funcx_task_socket] == zmq.POLLIN: + if ( + self.funcx_task_socket in socks + and socks[self.funcx_task_socket] == zmq.POLLIN + ): self.poll_funcx_task_socket() # Receive task batches from Interchange and forward to workers if self.task_incoming in socks and socks[self.task_incoming] == zmq.POLLIN: - # If we want to wrap the task_incoming polling into a separate function, we need to - # self.poll_task_incoming(poll_timer, last_interchange_contact, kill_event, task_revc_counter) + # If we want to wrap the task_incoming polling into a separate function, + # we need to + # self.poll_task_incoming( + # poll_timer, + # last_interchange_contact, + # kill_event, + # task_revc_counter + # ) poll_timer = 0 _, pkl_msg = self.task_incoming.recv_multipart() message = pickle.loads(pkl_msg) last_interchange_contact = time.time() - if message == 'STOP': - logger.critical("[TASK_PULL_THREAD] Received stop request") + if message == "STOP": + log.critical("[TASK_PULL_THREAD] Received stop request") kill_event.set() break - elif type(message) == tuple and message[0] == 'TASK_CANCEL': + elif type(message) == tuple and message[0] == "TASK_CANCEL": self.task_cancel_lock.acquire() task_id = message[1] - logger.info(f"Received TASK_CANCEL request for task: {task_id}") + log.info(f"Received TASK_CANCEL request for task: {task_id}") if task_id not in self.task_worker_map: - logger.warning(f"Task:{task_id} is not in task_worker_map.") - logger.warning("Possible duplicate cancel or race-condition") + log.warning(f"Task:{task_id} is not in task_worker_map.") + log.warning("Possible duplicate cancel or race-condition") continue # Cancel task by killing the worker it is on - worker_id_raw = self.task_worker_map[task_id]['worker_id'] - worker_to_kill = self.task_worker_map[task_id]['worker_id'].decode('utf-8') - worker_type = self.task_worker_map[task_id]['task_type'] - logger.debug(f"Cancelling task running on worker:{self.task_worker_map[task_id]}") + worker_id_raw = self.task_worker_map[task_id]["worker_id"] + worker_to_kill = self.task_worker_map[task_id]["worker_id"].decode( + "utf-8" + ) + worker_type = self.task_worker_map[task_id]["task_type"] + log.debug( + "Cancelling task running on worker: %s", + self.task_worker_map[task_id], + ) try: - logger.info(f"Removing worker:{worker_id_raw} from map") + log.info(f"Removing worker:{worker_id_raw} from map") self.worker_map.start_remove_worker(worker_type) self.worker_map.remove_worker(worker_id_raw) - logger.info(f"Popping worker:{worker_to_kill} from worker_procs") + log.info(f"Popping worker:{worker_to_kill} from worker_procs") proc = self.worker_procs.pop(worker_to_kill) - logger.warning(f"Sending process:{proc.pid} terminate signal") + log.warning(f"Sending process:{proc.pid} terminate signal") proc.terminate() try: proc.wait(1) # Wait 1 second before attempting SIGKILL except subprocess.TimeoutExpired: - logger.exception("Process did not terminate in 1 second") - logger.warning(f"Sending process:{proc.pid} kill signal") + log.exception("Process did not terminate in 1 second") + log.warning(f"Sending process:{proc.pid} kill signal") proc.kill() else: - logger.debug(f"Worker process exited with : {proc.returncode}") + log.debug(f"Worker process exited with : {proc.returncode}") raise TaskCancelled(worker_to_kill, self.uid) except Exception as e: - logger.exception(f"Raise exception, handling: {e}") - result_package = {'task_id': task_id, - 'container_id': worker_type, - 'exception': self.serializer.serialize( - RemoteExceptionWrapper(*sys.exc_info()))} + log.exception(f"Raise exception, handling: {e}") + result_package = { + "task_id": task_id, + "container_id": worker_type, + "exception": self.serializer.serialize( + RemoteExceptionWrapper(*sys.exc_info()) + ), + } self.pending_result_queue.put(pickle.dumps(result_package)) worker_proc = self.worker_map.add_worker( @@ -371,24 +396,31 @@ def pull_tasks(self, kill_event): debug=self.debug, uid=self.uid, logdir=self.logdir, - worker_port=self.worker_port) + worker_port=self.worker_port, + ) self.worker_procs.update(worker_proc) self.task_worker_map.pop(task_id) self.remove_task(task_id) self.task_cancel_lock.release() elif message == HEARTBEAT_CODE: - logger.debug("Got heartbeat from interchange") + log.debug("Got heartbeat from interchange") else: - tasks = [(rt['local_container'], Message.unpack(rt['raw_buffer'])) for rt in message] + tasks = [ + (rt["local_container"], Message.unpack(rt["raw_buffer"])) + for rt in message + ] task_recv_counter += len(tasks) - logger.debug("[TASK_PULL_THREAD] Got tasks: {} of {}".format([t[1].task_id for t in tasks], - task_recv_counter)) + log.debug( + "[TASK_PULL_THREAD] Got tasks: {} of {}".format( + [t[1].task_id for t in tasks], task_recv_counter + ) + ) for task_type, task in tasks: - logger.debug("[TASK DEBUG] Task is of type: {}".format(task_type)) + log.debug(f"[TASK DEBUG] Task is of type: {task_type}") if task_type not in self.task_queues: self.task_queues[task_type] = queue.Queue() @@ -397,11 +429,15 @@ def pull_tasks(self, kill_event): self.task_queues[task_type].put(task) self.outstanding_task_count[task_type] += 1 self.task_type_mapping[task.task_id] = task_type - logger.debug("Got task: Outstanding task counts: {}".format(self.outstanding_task_count)) - logger.debug("Task {} pushed to a task queue {}".format(task, task_type)) + log.debug( + "Got task: Outstanding task counts: {}".format( + self.outstanding_task_count + ) + ) + log.debug(f"Task {task} pushed to a task queue {task_type}") else: - logger.debug("[TASK_PULL_THREAD] No incoming tasks") + log.debug("[TASK_PULL_THREAD] No incoming tasks") # Limit poll duration to heartbeat_period # heartbeat_period is in s vs poll_timer in ms if not poll_timer: @@ -410,114 +446,160 @@ def pull_tasks(self, kill_event): # Only check if no messages were received. if time.time() > last_interchange_contact + self.heartbeat_threshold: - logger.critical("[TASK_PULL_THREAD] Missing contact with interchange beyond heartbeat_threshold") + log.critical( + "[TASK_PULL_THREAD] Missing contact with interchange beyond " + "heartbeat_threshold" + ) kill_event.set() - logger.critical("Killing all workers") + log.critical("Killing all workers") for proc in self.worker_procs.values(): proc.kill() - logger.critical("[TASK_PULL_THREAD] Exiting") + log.critical("[TASK_PULL_THREAD] Exiting") break - logger.debug("To-Die Counts: {}".format(self.worker_map.to_die_count)) - logger.debug("Alive worker counts: {}".format(self.worker_map.total_worker_type_counts)) + log.debug(f"To-Die Counts: {self.worker_map.to_die_count}") + log.debug( + "Alive worker counts: {}".format( + self.worker_map.total_worker_type_counts + ) + ) - new_worker_map = naive_scheduler(self.task_queues, - self.outstanding_task_count, - self.max_worker_count, - new_worker_map, - self.worker_map.to_die_count, - logger=logger) - logger.debug("[SCHEDULER] New worker map: {}".format(new_worker_map)) + new_worker_map = naive_scheduler( + self.task_queues, + self.outstanding_task_count, + self.max_worker_count, + new_worker_map, + self.worker_map.to_die_count, + ) + log.debug(f"[SCHEDULER] New worker map: {new_worker_map}") - # NOTE: Wipes the queue -- previous scheduling loops don't affect what's needed now. - self.next_worker_q, need_more = self.worker_map.get_next_worker_q(new_worker_map) + # NOTE: Wipes the queue -- previous scheduling loops don't affect what's + # needed now. + self.next_worker_q, need_more = self.worker_map.get_next_worker_q( + new_worker_map + ) # Spin up any new workers according to the worker queue. # Returns the total number of containers that have spun up. - self.worker_procs.update(self.worker_map.spin_up_workers(self.next_worker_q, - mode=self.worker_mode, - debug=self.debug, - container_cmd_options=self.container_cmd_options, - address=self.address, - uid=self.uid, - logdir=self.logdir, - worker_port=self.worker_port)) - logger.debug(f"[SPIN UP] Worker processes: {self.worker_procs}") + self.worker_procs.update( + self.worker_map.spin_up_workers( + self.next_worker_q, + mode=self.worker_mode, + debug=self.debug, + container_cmd_options=self.container_cmd_options, + address=self.address, + uid=self.uid, + logdir=self.logdir, + worker_port=self.worker_port, + ) + ) + log.debug(f"[SPIN UP] Worker processes: {self.worker_procs}") # Count the workers of each type that need to be removed - spin_downs, container_switch_count = self.worker_map.spin_down_workers(new_worker_map, - worker_max_idletime=self.worker_max_idletime, - need_more=need_more, - scheduler_mode=self.scheduler_mode) + spin_downs, container_switch_count = self.worker_map.spin_down_workers( + new_worker_map, + worker_max_idletime=self.worker_max_idletime, + need_more=need_more, + scheduler_mode=self.scheduler_mode, + ) self.container_switch_count += container_switch_count - logger.debug("Container switch count: total {}, cur {}".format(self.container_switch_count, container_switch_count)) + log.debug( + "Container switch count: total {}, cur {}".format( + self.container_switch_count, container_switch_count + ) + ) for w_type in spin_downs: self.remove_worker_init(w_type) current_worker_map = self.worker_map.get_worker_counts() for task_type in current_worker_map: - if task_type == 'unused': + if task_type == "unused": continue # *** Match tasks to workers *** # else: available_workers = current_worker_map[task_type] - logger.debug("Available workers of type {}: {}".format(task_type, - available_workers)) + log.debug( + "Available workers of type {}: {}".format( + task_type, available_workers + ) + ) for _i in range(available_workers): - if task_type in self.task_queues and not self.task_queues[task_type].qsize() == 0 \ - and not self.worker_map.worker_queues[task_type].qsize() == 0: - - logger.debug("Task type {} has task queue size {}" - .format(task_type, self.task_queues[task_type].qsize())) - logger.debug("... and available workers: {}" - .format(self.worker_map.worker_queues[task_type].qsize())) + if ( + task_type in self.task_queues + and not self.task_queues[task_type].qsize() == 0 + and not self.worker_map.worker_queues[task_type].qsize() + == 0 + ): + + log.debug( + "Task type {} has task queue size {}".format( + task_type, self.task_queues[task_type].qsize() + ) + ) + log.debug( + "... and available workers: {}".format( + self.worker_map.worker_queues[task_type].qsize() + ) + ) self.send_task_to_worker(task_type) def poll_funcx_task_socket(self, test=False): try: w_id, m_type, message = self.funcx_task_socket.recv_multipart() - if m_type == b'REGISTER': + if m_type == b"REGISTER": reg_info = pickle.loads(message) - logger.debug("Registration received from worker:{} {}".format(w_id, reg_info)) - self.worker_map.register_worker(w_id, reg_info['worker_type']) + log.debug(f"Registration received from worker:{w_id} {reg_info}") + self.worker_map.register_worker(w_id, reg_info["worker_type"]) - elif m_type == b'TASK_RET': + elif m_type == b"TASK_RET": # We lock because the following steps are also shared by task_cancel self.task_cancel_lock.acquire() - logger.debug("Result received from worker: {}".format(w_id)) - task_id = pickle.loads(message)['task_id'] + log.debug(f"Result received from worker: {w_id}") + task_id = pickle.loads(message)["task_id"] try: self.remove_task(task_id) except KeyError: - logger.exception(f"Task:{task_id} missing in task structure") + log.exception(f"Task:{task_id} missing in task structure") else: self.pending_result_queue.put(message) self.worker_map.put_worker(w_id) self.task_cancel_lock.release() - elif m_type == b'WRKR_DIE': - logger.debug("[WORKER_REMOVE] Removing worker {} from worker_map...".format(w_id)) - logger.debug("Ready worker counts: {}".format(self.worker_map.ready_worker_type_counts)) - logger.debug("Total worker counts: {}".format(self.worker_map.total_worker_type_counts)) + elif m_type == b"WRKR_DIE": + log.debug(f"[WORKER_REMOVE] Removing worker {w_id} from worker_map...") + log.debug( + "Ready worker counts: {}".format( + self.worker_map.ready_worker_type_counts + ) + ) + log.debug( + "Total worker counts: {}".format( + self.worker_map.total_worker_type_counts + ) + ) self.worker_map.remove_worker(w_id) proc = self.worker_procs.pop(w_id.decode()) if not proc.poll(): try: proc.wait(timeout=1) except subprocess.TimeoutExpired: - logger.warning(f"[WORKER_REMOVE] Timeout waiting for worker {w_id} process to terminate") - logger.debug(f"[WORKER_REMOVE] Removing worker {w_id} process object") - logger.debug(f"[WORKER_REMOVE] Worker processes: {self.worker_procs}") + log.warning( + "[WORKER_REMOVE] Timeout waiting for worker %s process to " + "terminate", + w_id, + ) + log.debug(f"[WORKER_REMOVE] Removing worker {w_id} process object") + log.debug(f"[WORKER_REMOVE] Worker processes: {self.worker_procs}") if test: return pickle.loads(message) except Exception: - logger.exception("Unhandled exception while processing worker messages") + log.exception("Unhandled exception while processing worker messages") def remove_task(self, task_id: str): task_type = self.task_type_mapping.pop(task_id) @@ -529,34 +611,43 @@ def send_task_to_worker(self, task_type): task = self.task_queues[task_type].get() worker_id = self.worker_map.get_worker(task_type) - logger.debug("Sending task {} to {}".format(task.task_id, worker_id)) + log.debug(f"Sending task {task.task_id} to {worker_id}") # TODO: Some duplication of work could be avoided here - to_send = [worker_id, pickle.dumps(task.task_id), pickle.dumps(task.container_id), task.pack()] + to_send = [ + worker_id, + pickle.dumps(task.task_id), + pickle.dumps(task.container_id), + task.pack(), + ] self.funcx_task_socket.send_multipart(to_send) self.worker_map.update_worker_idle(task_type) if task.task_id != "KILL": - logger.debug(f"Set task {task.task_id} to RUNNING") + log.debug(f"Set task {task.task_id} to RUNNING") self.task_status_deltas[task.task_id] = TaskStatusCode.RUNNING - self.task_worker_map[task.task_id] = {'worker_id': worker_id, - 'task_type': task_type} - logger.debug("Sending complete!") + self.task_worker_map[task.task_id] = { + "worker_id": worker_id, + "task_type": task_type, + } + log.debug("Sending complete!") def _status_report_loop(self, kill_event): - logger.debug("[STATUS] Manager status reporting loop starting") + log.debug("[STATUS] Manager status reporting loop starting") while not kill_event.is_set(): msg = ManagerStatusReport( self.task_status_deltas, self.container_switch_count, ) - logger.info(f"[STATUS] Sending status report to interchange: {msg.task_statuses}") + log.info( + f"[STATUS] Sending status report to interchange: {msg.task_statuses}" + ) self.pending_result_queue.put(msg) - logger.info("[STATUS] Clearing task deltas") + log.info("[STATUS] Clearing task deltas") self.task_status_deltas.clear() time.sleep(self.heartbeat_period) def push_results(self, kill_event, max_result_batch_size=1): - """ Listens on the pending_result_queue and sends out results via 0mq + """Listens on the pending_result_queue and sends out results via 0mq Parameters: ----------- @@ -564,10 +655,12 @@ def push_results(self, kill_event, max_result_batch_size=1): Event to let the thread know when it is time to die. """ - logger.debug("[RESULT_PUSH_THREAD] Starting thread") + log.debug("[RESULT_PUSH_THREAD] Starting thread") - push_poll_period = max(10, self.poll_period) / 1000 # push_poll_period must be atleast 10 ms - logger.debug("[RESULT_PUSH_THREAD] push poll period: {}".format(push_poll_period)) + push_poll_period = ( + max(10, self.poll_period) / 1000 + ) # push_poll_period must be atleast 10 ms + log.debug(f"[RESULT_PUSH_THREAD] push poll period: {push_poll_period}") last_beat = time.time() items = [] @@ -575,8 +668,10 @@ def push_results(self, kill_event, max_result_batch_size=1): while not kill_event.is_set(): try: r = self.pending_result_queue.get(block=True, timeout=push_poll_period) - # This avoids the interchange searching and attempting to unpack every message in case it's a - # status report. (Would be better to use Task Messages eventually to make this more uniform) + # This avoids the interchange searching and attempting to unpack every + # message in case it's a status report. + # (It would be better to use Task Messages eventually to make this more + # uniform) # TODO: use task messages, and don't have to prepend if isinstance(r, ManagerStatusReport): items.insert(0, r.pack()) @@ -585,31 +680,38 @@ def push_results(self, kill_event, max_result_batch_size=1): except queue.Empty: pass except Exception as e: - logger.exception("[RESULT_PUSH_THREAD] Got an exception: {}".format(e)) - - # If we have reached poll_period duration or timer has expired, we send results - if len(items) >= self.max_queue_size or time.time() > last_beat + push_poll_period: + log.exception(f"[RESULT_PUSH_THREAD] Got an exception: {e}") + + # If we have reached poll_period duration or timer has expired, we send + # results + if ( + len(items) >= self.max_queue_size + or time.time() > last_beat + push_poll_period + ): last_beat = time.time() if items: self.result_outgoing.send_multipart(items) items = [] - logger.critical("[RESULT_PUSH_THREAD] Exiting") + log.critical("[RESULT_PUSH_THREAD] Exiting") def remove_worker_init(self, worker_type): """ - Kill/Remove a worker of a given worker_type. + Kill/Remove a worker of a given worker_type. - Add a kill message to the task_type queue. + Add a kill message to the task_type queue. - Assumption : All workers of the same type are uniform, and therefore don't discriminate when killing. + Assumption : All workers of the same type are uniform, and therefore don't + discriminate when killing. """ - logger.debug("[WORKER_REMOVE] Appending KILL message to worker queue {}".format(worker_type)) + log.debug( + "[WORKER_REMOVE] Appending KILL message to worker queue {}".format( + worker_type + ) + ) self.worker_map.start_remove_worker(worker_type) - task = Task(task_id='KILL', - container_id='RAW', - task_buffer='KILL') + task = Task(task_id="KILL", container_id="RAW", task_buffer="KILL") self.task_queues[worker_type].put(task) def start(self): @@ -620,122 +722,158 @@ def start(self): Forward results """ - if self.worker_type and self.scheduler_mode == 'hard': - logger.debug("[MANAGER] Start an initial worker with worker type {}".format(self.worker_type)) - self.worker_procs.update(self.worker_map.add_worker(worker_id=str(self.worker_map.worker_id_counter), - worker_type=self.worker_type, - container_cmd_options=self.container_cmd_options, - address=self.address, - debug=self.debug, - uid=self.uid, - logdir=self.logdir, - worker_port=self.worker_port)) - - logger.debug("Initial workers launched") + if self.worker_type and self.scheduler_mode == "hard": + log.debug( + "[MANAGER] Start an initial worker with worker type {}".format( + self.worker_type + ) + ) + self.worker_procs.update( + self.worker_map.add_worker( + worker_id=str(self.worker_map.worker_id_counter), + worker_type=self.worker_type, + container_cmd_options=self.container_cmd_options, + address=self.address, + debug=self.debug, + uid=self.uid, + logdir=self.logdir, + worker_port=self.worker_port, + ) + ) + + log.debug("Initial workers launched") self._result_pusher_thread.start() self._status_report_thread.start() self.pull_tasks(self._kill_event) - logger.info("Waiting") + log.info("Waiting") def cli_run(): parser = argparse.ArgumentParser() - parser.add_argument("-d", "--debug", action='store_true', - help="Count of apps to launch") - parser.add_argument("-l", "--logdir", default="process_worker_pool_logs", - help="Process worker pool log directory") - parser.add_argument("-u", "--uid", default=str(uuid.uuid4()).split('-')[-1], - help="Unique identifier string for Manager") - parser.add_argument("-b", "--block_id", default=None, - help="Block identifier string for Manager") - parser.add_argument("-c", "--cores_per_worker", default="1.0", - help="Number of cores assigned to each worker process. Default=1.0") - parser.add_argument("-t", "--task_url", required=True, - help="REQUIRED: ZMQ url for receiving tasks") - parser.add_argument("--max_workers", default=float('inf'), - help="Caps the maximum workers that can be launched, default:infinity") - parser.add_argument("--hb_period", default=30, - help="Heartbeat period in seconds. Uses manager default unless set") - parser.add_argument("--hb_threshold", default=120, - help="Heartbeat threshold in seconds. Uses manager default unless set") - parser.add_argument("--poll", default=10, - help="Poll period used in milliseconds") - parser.add_argument("--worker_type", default=None, - help="Fixed worker type of manager") - parser.add_argument("--worker_mode", default="singularity_reuse", - help=("Choose the mode of operation from " - "(no_container, singularity_reuse, singularity_single_use")) - parser.add_argument("--container_cmd_options", default="", - help=("Container cmd options to add to container startup cmd")) - parser.add_argument("--scheduler_mode", default="soft", - help=("Choose the mode of scheduler " - "(hard, soft")) - parser.add_argument("-r", "--result_url", required=True, - help="REQUIRED: ZMQ url for posting results") - parser.add_argument("--log_max_bytes", default=256 * 1024 * 1024, - help="The maximum bytes per logger file in bytes") - parser.add_argument("--log_backup_count", default=1, - help="The number of backup (must be non-zero) per logger file") + parser.add_argument( + "-d", "--debug", action="store_true", help="Count of apps to launch" + ) + parser.add_argument( + "-l", + "--logdir", + default="process_worker_pool_logs", + help="Process worker pool log directory", + ) + parser.add_argument( + "-u", + "--uid", + default=str(uuid.uuid4()).split("-")[-1], + help="Unique identifier string for Manager", + ) + parser.add_argument( + "-b", "--block_id", default=None, help="Block identifier string for Manager" + ) + parser.add_argument( + "-c", + "--cores_per_worker", + default="1.0", + help="Number of cores assigned to each worker process. Default=1.0", + ) + parser.add_argument( + "-t", "--task_url", required=True, help="REQUIRED: ZMQ url for receiving tasks" + ) + parser.add_argument( + "--max_workers", + default=float("inf"), + help="Caps the maximum workers that can be launched, default:infinity", + ) + parser.add_argument( + "--hb_period", + default=30, + help="Heartbeat period in seconds. Uses manager default unless set", + ) + parser.add_argument( + "--hb_threshold", + default=120, + help="Heartbeat threshold in seconds. Uses manager default unless set", + ) + parser.add_argument("--poll", default=10, help="Poll period used in milliseconds") + parser.add_argument( + "--worker_type", default=None, help="Fixed worker type of manager" + ) + parser.add_argument( + "--worker_mode", + default="singularity_reuse", + help=( + "Choose the mode of operation from " + "(no_container, singularity_reuse, singularity_single_use" + ), + ) + parser.add_argument( + "--container_cmd_options", + default="", + help=("Container cmd options to add to container startup cmd"), + ) + parser.add_argument( + "--scheduler_mode", + default="soft", + help=("Choose the mode of scheduler " "(hard, soft"), + ) + parser.add_argument( + "-r", + "--result_url", + required=True, + help="REQUIRED: ZMQ url for posting results", + ) args = parser.parse_args() - try: - os.makedirs(os.path.join(args.logdir, args.uid)) - except FileExistsError: - pass + os.makedirs(os.path.join(args.logdir, args.uid), exist_ok=True) + setup_logging( + logfile=os.path.join(args.logdir, args.uid, "manager.log"), debug=args.debug + ) try: - global logger - # TODO The config options for the rotatingfilehandler need to be implemented and checked so that it is user configurable - logger = set_file_logger(os.path.join(args.logdir, args.uid, 'manager.log'), - name='funcx_manager', - level=logging.DEBUG if args.debug is True else logging.INFO, - max_bytes=float(args.log_max_bytes), # TODO: Test if this still works on forwarder_rearch_1 - backup_count=int(args.log_backup_count)) # TODO: Test if this still works on forwarder_rearch_1 - - logger.info("Python version: {}".format(sys.version)) - logger.info("Debug logging: {}".format(args.debug)) - logger.info("Log dir: {}".format(args.logdir)) - logger.info("Manager ID: {}".format(args.uid)) - logger.info("Block ID: {}".format(args.block_id)) - logger.info("cores_per_worker: {}".format(args.cores_per_worker)) - logger.info("task_url: {}".format(args.task_url)) - logger.info("result_url: {}".format(args.result_url)) - logger.info("hb_period: {}".format(args.hb_period)) - logger.info("hb_threshold: {}".format(args.hb_threshold)) - logger.info("max_workers: {}".format(args.max_workers)) - logger.info("poll_period: {}".format(args.poll)) - logger.info("worker_mode: {}".format(args.worker_mode)) - logger.info("container_cmd_options: {}".format(args.container_cmd_options)) - logger.info("scheduler_mode: {}".format(args.scheduler_mode)) - logger.info("worker_type: {}".format(args.worker_type)) - logger.info("log_max_bytes: {}".format(args.log_max_bytes)) - logger.info("log_backup_count: {}".format(args.log_backup_count)) - - manager = Manager(task_q_url=args.task_url, - result_q_url=args.result_url, - uid=args.uid, - block_id=args.block_id, - cores_per_worker=float(args.cores_per_worker), - max_workers=args.max_workers if args.max_workers == float('inf') else int(args.max_workers), - heartbeat_threshold=int(args.hb_threshold), - heartbeat_period=int(args.hb_period), - logdir=args.logdir, - debug=args.debug, - worker_mode=args.worker_mode, - container_cmd_options=args.container_cmd_options, - scheduler_mode=args.scheduler_mode, - worker_type=args.worker_type, - poll_period=int(args.poll)) + log.info(f"Python version: {sys.version}") + log.info(f"Debug logging: {args.debug}") + log.info(f"Log dir: {args.logdir}") + log.info(f"Manager ID: {args.uid}") + log.info(f"Block ID: {args.block_id}") + log.info(f"cores_per_worker: {args.cores_per_worker}") + log.info(f"task_url: {args.task_url}") + log.info(f"result_url: {args.result_url}") + log.info(f"hb_period: {args.hb_period}") + log.info(f"hb_threshold: {args.hb_threshold}") + log.info(f"max_workers: {args.max_workers}") + log.info(f"poll_period: {args.poll}") + log.info(f"worker_mode: {args.worker_mode}") + log.info(f"container_cmd_options: {args.container_cmd_options}") + log.info(f"scheduler_mode: {args.scheduler_mode}") + log.info(f"worker_type: {args.worker_type}") + + manager = Manager( + task_q_url=args.task_url, + result_q_url=args.result_url, + uid=args.uid, + block_id=args.block_id, + cores_per_worker=float(args.cores_per_worker), + max_workers=args.max_workers + if args.max_workers == float("inf") + else int(args.max_workers), + heartbeat_threshold=int(args.hb_threshold), + heartbeat_period=int(args.hb_period), + logdir=args.logdir, + debug=args.debug, + worker_mode=args.worker_mode, + container_cmd_options=args.container_cmd_options, + scheduler_mode=args.scheduler_mode, + worker_type=args.worker_type, + poll_period=int(args.poll), + ) manager.start() except Exception as e: - logger.critical("process_worker_pool exiting from an exception") - logger.exception("Caught error: {}".format(e)) + log.critical("process_worker_pool exiting from an exception") + log.exception(f"Caught error: {e}") raise else: - logger.info("process_worker_pool main event loop exiting normally") + log.info("process_worker_pool main event loop exiting normally") print("PROCESS_WORKER_POOL main event loop exiting normally") diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_worker.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_worker.py index 6e69ee823..fd2d00e24 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_worker.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_worker.py @@ -1,33 +1,43 @@ #!/usr/bin/env python3 -import logging import argparse -import zmq -import sys +import logging +import os import pickle import signal -import os +import sys +import zmq from parsl.app.errors import RemoteExceptionWrapper -from funcx import set_file_logger from funcx.serialize import FuncXSerializer -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task +from funcx_endpoint.executors.high_throughput.messages import Message +from funcx_endpoint.logging_config import setup_logging + +log = logging.getLogger(__name__) + +DEFAULT_RESULT_SIZE_LIMIT_MB = 10 +DEFAULT_RESULT_SIZE_LIMIT_B = DEFAULT_RESULT_SIZE_LIMIT_MB * 1024 * 1024 class MaxResultSizeExceeded(Exception): - """ Result produced by the function exceeds the maximum supported result size threshold of 512000B """ + Result produced by the function exceeds the maximum supported result size + threshold""" + def __init__(self, result_size, result_size_limit): self.result_size = result_size self.result_size_limit = result_size_limit def __str__(self): - return f"Task result of {self.result_size}B exceeded current limit of {self.result_size_limit}B" + return ( + f"Task result of {self.result_size}B exceeded current " + f"limit of {self.result_size_limit}B" + ) -class FuncXWorker(object): - """ The FuncX worker +class FuncXWorker: + """The FuncX worker Parameters ---------- @@ -40,16 +50,9 @@ class FuncXWorker(object): port : int Port at which the manager can be reached - logdir : str - Logging directory - - debug : Bool - Enables debug logging - result_size_limit : int Maximum result size allowed in Bytes - Default = 10 MB == 10 * (2**20) Bytes - + Default = 10 MB Funcx worker will use the REP sockets to: task = recv () @@ -57,29 +60,26 @@ class FuncXWorker(object): send(result) """ - def __init__(self, worker_id, address, port, logdir, debug=False, worker_type='RAW', result_size_limit=512000): + def __init__( + self, + worker_id, + address, + port, + worker_type="RAW", + result_size_limit=DEFAULT_RESULT_SIZE_LIMIT_B, + ): self.worker_id = worker_id self.address = address self.port = port - self.logdir = logdir - self.debug = debug self.worker_type = worker_type self.serializer = FuncXSerializer() self.serialize = self.serializer.serialize self.deserialize = self.serializer.deserialize self.result_size_limit = result_size_limit - global logger - logger = set_file_logger(os.path.join(logdir, f'funcx_worker_{worker_id}.log'), - name="worker_log", - level=logging.DEBUG if debug else logging.INFO) - - logger.info('Initializing worker {}'.format(worker_id)) - logger.info('Worker is of type: {}'.format(worker_type)) - - if debug: - logger.debug('Debug logging enabled') + log.info(f"Initializing worker {worker_id}") + log.info(f"Worker is of type: {worker_type}") self.context = zmq.Context() self.poller = zmq.Poller() @@ -88,85 +88,93 @@ def __init__(self, worker_id, address, port, logdir, debug=False, worker_type='R self.task_socket = self.context.socket(zmq.DEALER) self.task_socket.setsockopt(zmq.IDENTITY, self.identity) - logger.info('Trying to connect to : tcp://{}:{}'.format(self.address, self.port)) - self.task_socket.connect('tcp://{}:{}'.format(self.address, self.port)) + log.info(f"Trying to connect to : tcp://{self.address}:{self.port}") + self.task_socket.connect(f"tcp://{self.address}:{self.port}") self.poller.register(self.task_socket, zmq.POLLIN) signal.signal(signal.SIGTERM, self.handler) def handler(self, signum, frame): - logger.error("Signal handler called with signal", signum) + log.error("Signal handler called with signal", signum) sys.exit(1) def registration_message(self): - return {'worker_id': self.worker_id, - 'worker_type': self.worker_type} + return {"worker_id": self.worker_id, "worker_type": self.worker_type} def start(self): - logger.info("Starting worker") + log.info("Starting worker") result = self.registration_message() - task_type = b'REGISTER' - logger.debug("Sending registration") - self.task_socket.send_multipart([task_type, # Byte encoded - pickle.dumps(result)]) + task_type = b"REGISTER" + log.debug("Sending registration") + self.task_socket.send_multipart( + [task_type, pickle.dumps(result)] # Byte encoded + ) while True: - logger.debug("Waiting for task") + log.debug("Waiting for task") p_task_id, p_container_id, msg = self.task_socket.recv_multipart() task_id = pickle.loads(p_task_id) container_id = pickle.loads(p_container_id) - logger.debug("Received task_id:{} with task:{}".format(task_id, msg)) + log.debug(f"Received task_id:{task_id} with task:{msg}") result = None task_type = None if task_id == "KILL": task = Message.unpack(msg) - if task.task_buffer.decode('utf-8') == "KILL": - logger.info("[KILL] -- Worker KILL message received! ") - task_type = b'WRKR_DIE' + if task.task_buffer.decode("utf-8") == "KILL": + log.info("[KILL] -- Worker KILL message received! ") + task_type = b"WRKR_DIE" else: - logger.exception("Caught an exception of non-KILL message for KILL task") + log.exception( + "Caught an exception of non-KILL message for KILL task" + ) continue else: - logger.debug("Executing task...") + log.debug("Executing task...") try: result = self.execute_task(msg) serialized_result = self.serialize(result) - if sys.getsizeof(serialized_result) > self.result_size_limit: - raise MaxResultSizeExceeded(sys.getsizeof(serialized_result), - self.result_size_limit) - + if len(serialized_result) > self.result_size_limit: + raise MaxResultSizeExceeded( + len(serialized_result), self.result_size_limit + ) except Exception as e: - logger.exception(f"Caught an exception {e}") - result_package = {'task_id': task_id, - 'container_id': container_id, - 'exception': self.serialize( - RemoteExceptionWrapper(*sys.exc_info()))} + log.exception(f"Caught an exception {e}") + result_package = { + "task_id": task_id, + "container_id": container_id, + "exception": self.serialize( + RemoteExceptionWrapper(*sys.exc_info()) + ), + } else: - logger.debug("Execution completed without exception") - result_package = {'task_id': task_id, - 'container_id': container_id, - 'result': serialized_result} + log.debug("Execution completed without exception") + result_package = { + "task_id": task_id, + "container_id": container_id, + "result": serialized_result, + } result = result_package - task_type = b'TASK_RET' + task_type = b"TASK_RET" - logger.debug("Sending result") + log.debug("Sending result") - self.task_socket.send_multipart([task_type, # Byte encoded - pickle.dumps(result)]) + self.task_socket.send_multipart( + [task_type, pickle.dumps(result)] # Byte encoded + ) - if task_type == b'WRKR_DIE': - logger.info("*** WORKER {} ABOUT TO DIE ***".format(self.worker_id)) + if task_type == b"WRKR_DIE": + log.info(f"*** WORKER {self.worker_id} ABOUT TO DIE ***") # Kill the worker after accepting death in message to manager. sys.exit() # We need to return here to allow for sys.exit mocking in tests return - logger.warning("Broke out of the loop... dying") + log.warning("Broke out of the loop... dying") def execute_task(self, message): """Deserialize the buffer and execute the task. @@ -174,34 +182,53 @@ def execute_task(self, message): Returns the result or throws exception. """ task = Message.unpack(message) - f, args, kwargs = self.serializer.unpack_and_deserialize(task.task_buffer.decode('utf-8')) + f, args, kwargs = self.serializer.unpack_and_deserialize( + task.task_buffer.decode("utf-8") + ) return f(*args, **kwargs) def cli_run(): parser = argparse.ArgumentParser() - parser.add_argument("-w", "--worker_id", required=True, - help="ID of worker from process_worker_pool") - parser.add_argument("-t", "--type", required=False, - help="Container type of worker", default="RAW") - parser.add_argument("-a", "--address", required=True, - help="Address for the manager, eg X,Y,") - parser.add_argument("-p", "--port", required=True, - help="Internal port at which the worker connects to the manager") - parser.add_argument("--logdir", required=True, - help="Directory path where worker log files written") - parser.add_argument("-d", "--debug", action='store_true', - help="Directory path where worker log files written") - + parser.add_argument( + "-w", "--worker_id", required=True, help="ID of worker from process_worker_pool" + ) + parser.add_argument( + "-t", "--type", required=False, help="Container type of worker", default="RAW" + ) + parser.add_argument( + "-a", "--address", required=True, help="Address for the manager, eg X,Y," + ) + parser.add_argument( + "-p", + "--port", + required=True, + help="Internal port at which the worker connects to the manager", + ) + parser.add_argument( + "--logdir", required=True, help="Directory path where worker log files written" + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + help="Directory path where worker log files written", + ) args = parser.parse_args() - worker = FuncXWorker(args.worker_id, - args.address, - int(args.port), - args.logdir, - worker_type=args.type, - debug=args.debug, ) + + os.makedirs(args.logdir, exist_ok=True) + setup_logging( + logfile=os.path.join(args.logdir, f"funcx_worker_{args.worker_id}.log"), + debug=args.debug, + ) + + worker = FuncXWorker( + args.worker_id, + args.address, + int(args.port), + worker_type=args.type, + ) worker.start() - return if __name__ == "__main__": diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/global_config.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/global_config.py index 3ea43e94d..10d43462d 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/global_config.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/global_config.py @@ -1,10 +1,11 @@ import getpass + from parsl.addresses import address_by_hostname global_options = { - 'username': getpass.getuser(), - 'email': 'USER@USERDOMAIN.COM', - 'broker_address': '127.0.0.1', - 'broker_port': 8088, - 'endpoint_address': address_by_hostname(), + "username": getpass.getuser(), + "email": "USER@USERDOMAIN.COM", + "broker_address": "127.0.0.1", + "broker_port": 8088, + "endpoint_address": address_by_hostname(), } diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange.py index d2825ef5b..b8eae4dd9 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange.py @@ -1,34 +1,39 @@ #!/usr/bin/env python import argparse -from typing import Tuple, Dict - -import zmq +import collections +import json +import logging import os -import sys -import platform -import random -import time import pickle -import logging +import platform import queue +import sys import threading -import json -import daemon -import collections - -from logging.handlers import RotatingFileHandler +import time +from typing import Dict, Tuple +import daemon +import zmq +from parsl.app.errors import RemoteExceptionWrapper from parsl.executors.errors import ScalingFailed from parsl.version import VERSION as PARSL_VERSION -from parsl.app.errors import RemoteExceptionWrapper -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task -from funcx_endpoint.executors.high_throughput.messages import EPStatusReport, Heartbeat, TaskStatusCode -from funcx_endpoint.executors.high_throughput.messages import TaskCancel, BadCommand from funcx.sdk.client import FuncXClient -from funcx import set_file_logger -from funcx_endpoint.executors.high_throughput.interchange_task_dispatch import naive_interchange_task_dispatch from funcx.serialize import FuncXSerializer +from funcx_endpoint.executors.high_throughput.interchange_task_dispatch import ( + naive_interchange_task_dispatch, +) +from funcx_endpoint.executors.high_throughput.messages import ( + BadCommand, + EPStatusReport, + Heartbeat, + Message, + MessageType, + TaskStatusCode, +) +from funcx_endpoint.logging_config import setup_logging + +log = logging.getLogger(__name__) LOOP_SLOWDOWN = 0.0 # in seconds HEARTBEAT_CODE = (2 ** 32) - 1 @@ -36,21 +41,20 @@ class ShutdownRequest(Exception): - """ Exception raised when any async component receives a ShutdownRequest - """ + """Exception raised when any async component receives a ShutdownRequest""" def __init__(self): self.tstamp = time.time() def __repr__(self): - return "Shutdown request received at {}".format(self.tstamp) + return f"Shutdown request received at {self.tstamp}" def __str__(self): return self.__repr__() class ManagerLost(Exception): - """ Task lost due to worker loss. Worker is considered lost when multiple heartbeats + """Task lost due to worker loss. Worker is considered lost when multiple heartbeats have been missed. """ @@ -59,15 +63,14 @@ def __init__(self, worker_id): self.tstamp = time.time() def __repr__(self): - return "Task failure due to loss of manager {}".format(self.worker_id) + return f"Task failure due to loss of manager {self.worker_id}" def __str__(self): return self.__repr__() class BadRegistration(Exception): - ''' A new Manager tried to join the executor with a BadRegistration message - ''' + """A new Manager tried to join the executor with a BadRegistration message""" def __init__(self, worker_id, critical=False): self.worker_id = worker_id @@ -75,15 +78,17 @@ def __init__(self, worker_id, critical=False): self.handled = "critical" if critical else "suppressed" def __repr__(self): - return "Manager:{} caused a {} failure".format(self.worker_id, - self.handled) + return ( + f"Manager {self.worker_id} attempted to register with a bad " + f"registration message. Caused a {self.handled} failure" + ) def __str__(self): return self.__repr__() -class Interchange(object): - """ Interchange is a task orchestrator for distributed systems. +class Interchange: + """Interchange is a task orchestrator for distributed systems. 1. Asynchronously queue large volume of tasks (>100K) 2. Allow for workers to join and leave the union @@ -95,42 +100,36 @@ class Interchange(object): TODO: We most likely need a PUB channel to send out global commands, like shutdown """ - def __init__(self, - # - strategy=None, - poll_period=None, - heartbeat_period=None, - heartbeat_threshold=None, - working_dir=None, - provider=None, - max_workers_per_node=None, - mem_per_worker=None, - prefetch_capacity=None, - - scheduler_mode=None, - container_type=None, - container_cmd_options='', - worker_mode=None, - cold_routing_interval=10.0, - - funcx_service_address=None, - scaling_enabled=True, - # - client_address="127.0.0.1", - interchange_address="127.0.0.1", - client_ports: Tuple[int, int, int] = (50055, 50056, 50057), - worker_ports=None, - worker_port_range=(54000, 55000), - cores_per_worker=1.0, - worker_debug=False, - launch_cmd=None, - logdir=".", - logging_level=logging.INFO, - endpoint_id=None, - suppress_failure=False, - log_max_bytes=256 * 1024 * 1024, - log_backup_count=1, - ): + def __init__( + self, + strategy=None, + poll_period=None, + heartbeat_period=None, + heartbeat_threshold=None, + working_dir=None, + provider=None, + max_workers_per_node=None, + mem_per_worker=None, + prefetch_capacity=None, + scheduler_mode=None, + container_type=None, + container_cmd_options="", + worker_mode=None, + cold_routing_interval=10.0, + funcx_service_address=None, + scaling_enabled=True, + client_address="127.0.0.1", + interchange_address="127.0.0.1", + client_ports: Tuple[int, int, int] = (50055, 50056, 50057), + worker_ports=None, + worker_port_range=None, + cores_per_worker=1.0, + worker_debug=False, + launch_cmd=None, + logdir=".", + endpoint_id=None, + suppress_failure=False, + ): """ Parameters ---------- @@ -138,10 +137,12 @@ def __init__(self, Funcx config object that describes how compute should be provisioned client_address : str - The ip address at which the parsl client can be reached. Default: "127.0.0.1" + The ip address at which the parsl client can be reached. + Default: "127.0.0.1" interchange_address : str - The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" + The ip address at which the workers will be able to reach the Interchange. + Default: "127.0.0.1" client_ports : Tuple[int, int, int] The ports at which the client can be reached @@ -150,11 +151,13 @@ def __init__(self, TODO : update worker_ports : tuple(int, int) - The specific two ports at which workers will connect to the Interchange. Default: None + The specific two ports at which workers will connect to the Interchange. + Default: None worker_port_range : tuple(int, int) - The interchange picks ports at random from the range which will be used by workers. - This is overridden when the worker_ports option is set. Defauls: (54000, 55000) + The interchange picks ports at random from the range which will be used by + workers. This is overridden when the worker_ports option is set. + Default: (54000, 55000) cores_per_worker : float cores to be assigned to each worker. Oversubscription is possible @@ -165,7 +168,8 @@ def __init__(self, For example, singularity exec {container_cmd_options} cold_routing_interval: float - The time interval between warm and cold function routing in SOFT scheduler_mode. + The time interval between warm and cold function routing in SOFT + scheduler_mode. It is ONLY used when using soft scheduler_mode. We need this to avoid container workers being idle for too long. But we dont't want this cold routing to occur too often, @@ -178,36 +182,22 @@ def __init__(self, logdir : str Parsl log directory paths. Logs and temp files go here. Default: '.' - logging_level : int - Logging level as defined in the logging module. Default: logging.INFO (20) - endpoint_id : str Identity string that identifies the endpoint to the broker suppress_failure : Bool - When set to True, the interchange will attempt to suppress failures. Default: False + When set to True, the interchange will attempt to suppress failures. + Default: False funcx_service_address: str - Override funcx_service_address used by the FuncXClient. If no address is specified, - the FuncXClient's default funcx_service_address is used. + Override funcx_service_address used by the FuncXClient. If no address is + specified, the FuncXClient's default funcx_service_address is used. Default: None """ self.logdir = logdir os.makedirs(self.logdir, exist_ok=True) - - global logger - logger = set_file_logger(os.path.join(self.logdir, 'interchange.log'), - name="interchange", - level=logging_level, - max_bytes=log_max_bytes, - backup_count=log_backup_count) - - logger.info("logger location {}, logger filesize: {}, logger backup count: {}".format(logger.handlers, - log_max_bytes, - log_backup_count)) - - logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id)) + log.info(f"Initializing Interchange process with Endpoint ID: {endpoint_id}") # self.max_workers_per_node = max_workers_per_node @@ -221,13 +211,10 @@ def __init__(self, self.worker_mode = worker_mode self.cold_routing_interval = cold_routing_interval - self.log_max_bytes = log_max_bytes - self.log_backup_count = log_backup_count self.working_dir = working_dir self.provider = provider self.worker_debug = worker_debug self.scaling_enabled = scaling_enabled - # self.strategy = strategy self.client_address = client_address @@ -241,26 +228,29 @@ def __init__(self, self.last_heartbeat = time.time() self.serializer = FuncXSerializer() - logger.info("Attempting connection to forwarder at {} on ports: {},{},{}".format( - client_address, client_ports[0], client_ports[1], client_ports[2])) + log.info( + "Attempting connection to forwarder at {} on ports: {},{},{}".format( + client_address, client_ports[0], client_ports[1], client_ports[2] + ) + ) self.context = zmq.Context() self.task_incoming = self.context.socket(zmq.DEALER) self.task_incoming.set_hwm(0) self.task_incoming.RCVTIMEO = 10 # in milliseconds - logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0])) - self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0])) + log.info(f"Task incoming on tcp://{client_address}:{client_ports[0]}") + self.task_incoming.connect(f"tcp://{client_address}:{client_ports[0]}") self.results_outgoing = self.context.socket(zmq.DEALER) self.results_outgoing.set_hwm(0) - logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1])) - self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1])) + log.info(f"Results outgoing on tcp://{client_address}:{client_ports[1]}") + self.results_outgoing.connect(f"tcp://{client_address}:{client_ports[1]}") self.command_channel = self.context.socket(zmq.DEALER) self.command_channel.RCVTIMEO = 1000 # in milliseconds # self.command_channel.set_hwm(0) - logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2])) - self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) - logger.info("Connected to forwarder") + log.info(f"Command channel on tcp://{client_address}:{client_ports[2]}") + self.command_channel.connect(f"tcp://{client_address}:{client_ports[2]}") + log.info("Connected to forwarder") self.pending_task_queue = {} self.containers = {} @@ -270,9 +260,11 @@ def __init__(self, else: self.fxs = FuncXClient() - logger.info("Interchange address is {}".format(self.interchange_address)) + log.info(f"Interchange address is {self.interchange_address}") self.worker_ports = worker_ports - self.worker_port_range = worker_port_range + self.worker_port_range = ( + worker_port_range if worker_port_range is not None else (54000, 55000) + ) self.task_outgoing = self.context.socket(zmq.ROUTER) self.task_outgoing.set_hwm(0) @@ -284,19 +276,28 @@ def __init__(self, self.worker_task_port = self.worker_ports[0] self.worker_result_port = self.worker_ports[1] - self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port)) - self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port)) + self.task_outgoing.bind(f"tcp://*:{self.worker_task_port}") + self.results_incoming.bind(f"tcp://*:{self.worker_result_port}") else: - self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*', - min_port=worker_port_range[0], - max_port=worker_port_range[1], max_tries=100) - self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*', - min_port=worker_port_range[0], - max_port=worker_port_range[1], max_tries=100) + self.worker_task_port = self.task_outgoing.bind_to_random_port( + "tcp://*", + min_port=worker_port_range[0], + max_port=worker_port_range[1], + max_tries=100, + ) + self.worker_result_port = self.results_incoming.bind_to_random_port( + "tcp://*", + min_port=worker_port_range[0], + max_port=worker_port_range[1], + max_tries=100, + ) - logger.info("Bound to ports {},{} for incoming worker connections".format( - self.worker_task_port, self.worker_result_port)) + log.info( + "Bound to ports {},{} for incoming worker connections".format( + self.worker_task_port, self.worker_result_port + ) + ) self._ready_manager_queue = {} @@ -305,36 +306,38 @@ def __init__(self, self.launch_cmd = launch_cmd self.last_core_hr_counter = 0 if not launch_cmd: - self.launch_cmd = ("funcx-manager {debug} {max_workers} " - "-c {cores_per_worker} " - "--poll {poll_period} " - "--task_url={task_url} " - "--result_url={result_url} " - "--logdir={logdir} " - "--block_id={{block_id}} " - "--hb_period={heartbeat_period} " - "--hb_threshold={heartbeat_threshold} " - "--worker_mode={worker_mode} " - "--container_cmd_options='{container_cmd_options}' " - "--scheduler_mode={scheduler_mode} " - "--log_max_bytes={log_max_bytes} " - "--log_backup_count={log_backup_count} " - "--worker_type={{worker_type}} ") - - self.current_platform = {'parsl_v': PARSL_VERSION, - 'python_v': "{}.{}.{}".format(sys.version_info.major, - sys.version_info.minor, - sys.version_info.micro), - 'os': platform.system(), - 'hname': platform.node(), - 'dir': os.getcwd()} - - logger.info("Platform info: {}".format(self.current_platform)) + self.launch_cmd = ( + "funcx-manager {debug} {max_workers} " + "-c {cores_per_worker} " + "--poll {poll_period} " + "--task_url={task_url} " + "--result_url={result_url} " + "--logdir={logdir} " + "--block_id={{block_id}} " + "--hb_period={heartbeat_period} " + "--hb_threshold={heartbeat_threshold} " + "--worker_mode={worker_mode} " + "--container_cmd_options='{container_cmd_options}' " + "--scheduler_mode={scheduler_mode} " + "--worker_type={{worker_type}} " + ) + + self.current_platform = { + "parsl_v": PARSL_VERSION, + "python_v": "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro + ), + "os": platform.system(), + "hname": platform.node(), + "dir": os.getcwd(), + } + + log.info(f"Platform info: {self.current_platform}") self._block_counter = 0 try: self.load_config() except Exception: - logger.exception("Caught exception") + log.exception("Caught exception") raise self.tasks = set() @@ -344,54 +347,62 @@ def __init__(self, self.container_switch_count = {} def load_config(self): - """ Load the config - """ - logger.info("Loading endpoint local config") + """Load the config""" + log.info("Loading endpoint local config") working_dir = self.working_dir if self.working_dir is None: working_dir = os.path.join(self.logdir, "worker_logs") - logger.info("Setting working_dir: {}".format(working_dir)) + log.info(f"Setting working_dir: {working_dir}") self.provider.script_dir = working_dir - if hasattr(self.provider, 'channel'): - self.provider.channel.script_dir = os.path.join(working_dir, 'submit_scripts') - self.provider.channel.makedirs(self.provider.channel.script_dir, exist_ok=True) + if hasattr(self.provider, "channel"): + self.provider.channel.script_dir = os.path.join( + working_dir, "submit_scripts" + ) + self.provider.channel.makedirs( + self.provider.channel.script_dir, exist_ok=True + ) os.makedirs(self.provider.script_dir, exist_ok=True) debug_opts = "--debug" if self.worker_debug else "" - max_workers = "" if self.max_workers_per_node == float('inf') \ - else "--max_workers={}".format(self.max_workers_per_node) + max_workers = ( + "" + if self.max_workers_per_node == float("inf") + else f"--max_workers={self.max_workers_per_node}" + ) worker_task_url = f"tcp://{self.interchange_address}:{self.worker_task_port}" - worker_result_url = f"tcp://{self.interchange_address}:{self.worker_result_port}" - - l_cmd = self.launch_cmd.format(debug=debug_opts, - max_workers=max_workers, - cores_per_worker=self.cores_per_worker, - # mem_per_worker=self.mem_per_worker, - prefetch_capacity=self.prefetch_capacity, - task_url=worker_task_url, - result_url=worker_result_url, - nodes_per_block=self.provider.nodes_per_block, - heartbeat_period=self.heartbeat_period, - heartbeat_threshold=self.heartbeat_threshold, - poll_period=self.poll_period, - worker_mode=self.worker_mode, - container_cmd_options=self.container_cmd_options, - scheduler_mode=self.scheduler_mode, - logdir=working_dir, - log_max_bytes=self.log_max_bytes, - log_backup_count=self.log_backup_count) + worker_result_url = ( + f"tcp://{self.interchange_address}:{self.worker_result_port}" + ) + + l_cmd = self.launch_cmd.format( + debug=debug_opts, + max_workers=max_workers, + cores_per_worker=self.cores_per_worker, + # mem_per_worker=self.mem_per_worker, + prefetch_capacity=self.prefetch_capacity, + task_url=worker_task_url, + result_url=worker_result_url, + nodes_per_block=self.provider.nodes_per_block, + heartbeat_period=self.heartbeat_period, + heartbeat_threshold=self.heartbeat_threshold, + poll_period=self.poll_period, + worker_mode=self.worker_mode, + container_cmd_options=self.container_cmd_options, + scheduler_mode=self.scheduler_mode, + logdir=working_dir, + ) self.launch_cmd = l_cmd - logger.info("Launch command: {}".format(self.launch_cmd)) + log.info(f"Launch command: {self.launch_cmd}") if self.scaling_enabled: - logger.info("Scaling ...") + log.info("Scaling ...") self.scale_out(self.provider.init_blocks) def get_tasks(self, count): - """ Obtains a batch of tasks from the internal pending_task_queue + """Obtains a batch of tasks from the internal pending_task_queue Parameters ---------- @@ -423,7 +434,7 @@ def migrate_tasks_to_internal(self, kill_event, status_request): kill_event : threading.Event Event to let the thread know when it is time to die. """ - logger.info("[TASK_PULL_THREAD] Starting") + log.info("[TASK_PULL_THREAD] Starting") task_counter = 0 poller = zmq.Poller() poller.register(self.task_incoming, zmq.POLLIN) @@ -435,79 +446,104 @@ def migrate_tasks_to_internal(self, kill_event, status_request): self.last_heartbeat = time.time() except zmq.Again: # We just timed out while attempting to receive - logger.debug("[TASK_PULL_THREAD] {} tasks in internal queue".format(self.total_pending_task_count)) + log.debug( + "[TASK_PULL_THREAD] {} tasks in internal queue".format( + self.total_pending_task_count + ) + ) continue try: msg = Message.unpack(raw_msg) - logger.debug("[TASK_PULL_THREAD] received Message/Heartbeat? on task queue") + log.debug( + "[TASK_PULL_THREAD] received Message/Heartbeat? on task queue" + ) except Exception: - logger.exception("Failed to unpack message") + log.exception("Failed to unpack message") pass - if msg == 'STOP': + if msg == "STOP": # TODO: Yadu. This should be replaced by a proper MessageType kill_event.set() break elif isinstance(msg, Heartbeat): - logger.debug("Got heartbeat") + log.debug("Got heartbeat") else: - logger.info("[TASK_PULL_THREAD] Received Task:{}".format(msg.task_id)) + log.info(f"[TASK_PULL_THREAD] Received task:{msg}") local_container = self.get_container(msg.container_id) msg.set_local_container(local_container) if local_container not in self.pending_task_queue: - self.pending_task_queue[local_container] = queue.Queue(maxsize=10 ** 6) + self.pending_task_queue[local_container] = queue.Queue( + maxsize=10 ** 6 + ) # We pass the raw message along - self.pending_task_queue[local_container].put({'task_id': msg.task_id, - 'container_id': msg.container_id, - 'local_container': local_container, - 'raw_buffer': raw_msg}) + self.pending_task_queue[local_container].put( + { + "task_id": msg.task_id, + "container_id": msg.container_id, + "local_container": local_container, + "raw_buffer": raw_msg, + } + ) self.total_pending_task_count += 1 self.task_status_deltas[msg.task_id] = TaskStatusCode.WAITING_FOR_NODES - logger.debug(f"[TASK_PULL_THREAD] Task:{msg.task_id} is now WAITING_FOR_NODES") - logger.debug("[TASK_PULL_THREAD] pending task count: {}".format(self.total_pending_task_count)) + log.debug( + f"[TASK_PULL_THREAD] task {msg.task_id} is now WAITING_FOR_NODES" + ) + log.debug( + "[TASK_PULL_THREAD] pending task count: {}".format( + self.total_pending_task_count + ) + ) task_counter += 1 - logger.debug("[TASK_PULL_THREAD] Fetched Task:{}".format(task_counter)) + log.debug(f"[TASK_PULL_THREAD] Fetched task:{task_counter}") def get_container(self, container_uuid): - """ Get the container image location if it is not known to the interchange""" + """Get the container image location if it is not known to the interchange""" if container_uuid not in self.containers: - if container_uuid == 'RAW' or not container_uuid: - self.containers[container_uuid] = 'RAW' + if container_uuid == "RAW" or not container_uuid: + self.containers[container_uuid] = "RAW" else: try: - container = self.fxs.get_container(container_uuid, self.container_type) + container = self.fxs.get_container( + container_uuid, self.container_type + ) except Exception: - logger.exception("[FETCH_CONTAINER] Unable to resolve container location") - self.containers[container_uuid] = 'RAW' + log.exception( + "[FETCH_CONTAINER] Unable to resolve container location" + ) + self.containers[container_uuid] = "RAW" else: - logger.info("[FETCH_CONTAINER] Got container info: {}".format(container)) - self.containers[container_uuid] = container.get('location', 'RAW') + log.info(f"[FETCH_CONTAINER] Got container info: {container}") + self.containers[container_uuid] = container.get("location", "RAW") return self.containers[container_uuid] def get_total_tasks_outstanding(self): - """ Get the outstanding tasks in total - """ + """Get the outstanding tasks in total""" outstanding = {} for task_type in self.pending_task_queue: - outstanding[task_type] = outstanding.get(task_type, 0) + self.pending_task_queue[task_type].qsize() + outstanding[task_type] = ( + outstanding.get(task_type, 0) + + self.pending_task_queue[task_type].qsize() + ) for manager in self._ready_manager_queue: - for task_type in self._ready_manager_queue[manager]['tasks']: - outstanding[task_type] = outstanding.get(task_type, 0) + len(self._ready_manager_queue[manager]['tasks'][task_type]) + for task_type in self._ready_manager_queue[manager]["tasks"]: + outstanding[task_type] = outstanding.get(task_type, 0) + len( + self._ready_manager_queue[manager]["tasks"][task_type] + ) return outstanding def get_total_live_workers(self): - """ Get the total active workers - """ + """Get the total active workers""" active = 0 for manager in self._ready_manager_queue: - if self._ready_manager_queue[manager]['active']: - active += self._ready_manager_queue[manager]['max_worker_count'] + if self._ready_manager_queue[manager]["active"]: + active += self._ready_manager_queue[manager]["max_worker_count"] return active def get_outstanding_breakdown(self): - """ Get outstanding breakdown per manager and in the interchange queues + """Get outstanding breakdown per manager and in the interchange queues Returns ------- @@ -517,16 +553,23 @@ def get_outstanding_breakdown(self): pending_on_interchange = self.total_pending_task_count # Reporting pending on interchange is a deviation from Parsl - reply = [('interchange', pending_on_interchange, True)] + reply = [("interchange", pending_on_interchange, True)] for manager in self._ready_manager_queue: - resp = (manager.decode('utf-8'), - sum([len(tids) for tids in self._ready_manager_queue[manager]['tasks'].values()]), - self._ready_manager_queue[manager]['active']) + resp = ( + manager.decode("utf-8"), + sum( + [ + len(tids) + for tids in self._ready_manager_queue[manager]["tasks"].values() + ] + ), + self._ready_manager_queue[manager]["active"], + ) reply.append(resp) return reply def _hold_block(self, block_id): - """ Sends hold command to all managers which are in a specific block + """Sends hold command to all managers which are in a specific block Parameters ---------- @@ -534,13 +577,15 @@ def _hold_block(self, block_id): Block identifier of the block to be put on hold """ for manager in self._ready_manager_queue: - if self._ready_manager_queue[manager]['active'] and \ - self._ready_manager_queue[manager]['block_id'] == block_id: - logger.debug("[HOLD_BLOCK]: Sending hold to manager: {}".format(manager)) + if ( + self._ready_manager_queue[manager]["active"] + and self._ready_manager_queue[manager]["block_id"] == block_id + ): + log.debug(f"[HOLD_BLOCK]: Sending hold to manager: {manager}") self.hold_manager(manager) def hold_manager(self, manager): - """ Put manager on hold + """Put manager on hold Parameters ---------- @@ -548,66 +593,72 @@ def hold_manager(self, manager): Manager id to be put on hold while being killed """ if manager in self._ready_manager_queue: - self._ready_manager_queue[manager]['active'] = False + self._ready_manager_queue[manager]["active"] = False def _status_report_loop(self, kill_event, status_report_queue: queue.Queue): - logger.debug("[STATUS] Status reporting loop starting") + log.debug("[STATUS] Status reporting loop starting") while not kill_event.is_set(): - logger.debug(f"Endpoint id : {self.endpoint_id}, {type(self.endpoint_id)}") + log.debug(f"Endpoint id : {self.endpoint_id}, {type(self.endpoint_id)}") msg = EPStatusReport( - self.endpoint_id, - self.get_status_report(), - self.task_status_deltas + self.endpoint_id, self.get_status_report(), self.task_status_deltas + ) + log.debug( + "[STATUS] Sending status report to executor, and clearing task deltas." ) - logger.debug("[STATUS] Sending status report to executor, and clearing task deltas.") status_report_queue.put(msg.pack()) self.task_status_deltas.clear() time.sleep(self.heartbeat_period) def _command_server(self, kill_event): - """ Command server to run async command to the interchange + """Command server to run async command to the interchange - We want to be able to receive the following not yet implemented/updated commands: + We want to be able to receive the following not yet implemented/updated + commands: - OutstandingCount - ListManagers (get outstanding broken down by manager) - HoldWorker - Shutdown """ - logger.debug("[COMMAND] Command Server Starting") + log.debug("[COMMAND] Command Server Starting") while not kill_event.is_set(): try: buffer = self.command_channel.recv() - logger.debug(f"[COMMAND] Received command request {buffer}") + log.debug(f"[COMMAND] Received command request {buffer}") command = Message.unpack(buffer) if command.type is MessageType.TASK_CANCEL: - logger.info(f"[COMMAND] Received TASK_CANCEL for Task:{command.task_id}") + log.info( + f"[COMMAND] Received TASK_CANCEL for Task:{command.task_id}" + ) self.enqueue_task_cancel(command.task_id) reply = command elif command.type is MessageType.HEARTBEAT_REQ: - logger.info("[COMMAND] Received synchonous HEARTBEAT_REQ from hub") - logger.info(f"[COMMAND] Replying with Heartbeat({self.endpoint_id})") + log.info("[COMMAND] Received synchonous HEARTBEAT_REQ from hub") + log.info(f"[COMMAND] Replying with Heartbeat({self.endpoint_id})") reply = Heartbeat(self.endpoint_id) else: - logger.error(f"Received unsupported message type:{command.type} on command channel") + log.error( + f"Received unsupported message type:{command.type} on " + "command channel" + ) reply = BadCommand(f"Unknown command type: {command.type}") - logger.debug("[COMMAND] Reply: {}".format(reply)) + log.debug(f"[COMMAND] Reply: {reply}") self.command_channel.send(reply.pack()) except zmq.Again: - logger.debug("[COMMAND] is alive") + log.debug("[COMMAND] is alive") continue def enqueue_task_cancel(self, task_id): - """ Cancel a task on the interchange + """Cancel a task on the interchange Here are the task states and responses we issue here - 1. Task is pending in queues -> we add task to a trap to capture while in dispatch - and delegate cancel to the manager the task is assigned to + 1. Task is pending in queues -> we add task to a trap to capture while in + dispatch and delegate cancel to the manager the task is assigned to 2. Task is in a transitionary state between pending in queue and dispatched -> task is added pre-emptively to trap 3. Task is pending on a manager -> we delegate cancellation to manager @@ -618,14 +669,17 @@ def enqueue_task_cancel(self, task_id): race-condition. Since the task can't be dispatched before scheduling is complete, either must work. """ - logger.debug(f"Received task_cancel request for Task:{task_id}") + log.debug(f"Received task_cancel request for Task:{task_id}") self.task_cancel_pending_trap[task_id] = task_id for manager in self._ready_manager_queue: - for task_type in self._ready_manager_queue[manager]['tasks']: - for tid in self._ready_manager_queue[manager]['tasks'][task_type]: + for task_type in self._ready_manager_queue[manager]["tasks"]: + for tid in self._ready_manager_queue[manager]["tasks"][task_type]: if tid == task_id: - logger.debug(f"Task:{task_id} is running, moving task_cancel message onto queue") + log.debug( + f"Task:{task_id} is running, " + "moving task_cancel message onto queue" + ) self.task_cancel_running_queue.put((manager, task_id)) self.task_cancel_pending_trap.pop(task_id, None) break @@ -639,14 +693,14 @@ def stop(self): self._command_thread.join() def start(self, poll_period=None): - """ Start the Interchange + """Start the Interchange Parameters: ---------- poll_period : int poll_period in milliseconds """ - logger.info("Incoming ports bound") + log.info("Incoming ports bound") if poll_period is None: poll_period = self.poll_period @@ -656,25 +710,33 @@ def start(self, poll_period=None): self._kill_event = threading.Event() self._status_request = threading.Event() - self._task_puller_thread = threading.Thread(target=self.migrate_tasks_to_internal, - args=(self._kill_event, self._status_request, )) + self._task_puller_thread = threading.Thread( + target=self.migrate_tasks_to_internal, + args=( + self._kill_event, + self._status_request, + ), + ) self._task_puller_thread.start() - self._command_thread = threading.Thread(target=self._command_server, - args=(self._kill_event, )) + self._command_thread = threading.Thread( + target=self._command_server, args=(self._kill_event,) + ) self._command_thread.start() status_report_queue = queue.Queue() - self._status_report_thread = threading.Thread(target=self._status_report_loop, - args=(self._kill_event, status_report_queue)) + self._status_report_thread = threading.Thread( + target=self._status_report_loop, + args=(self._kill_event, status_report_queue), + ) self._status_report_thread.start() try: - logger.info("Starting strategy.") + log.info("Starting strategy.") self.strategy.start(self) except RuntimeError: # This is raised when re-registering an endpoint as strategy already exists - logger.exception("Failed to start strategy.") + log.exception("Failed to start strategy.") poller = zmq.Poller() # poller.register(self.task_incoming, zmq.POLLIN) @@ -688,17 +750,22 @@ def start(self, poll_period=None): interesting_managers = set() # This value records when the last cold routing in soft mode happens - # When the cold routing in soft mode happens, it may cause worker containers to switch - # Cold routing is to reduce the number idle workers of specific task types on the managers - # when there are not enough tasks of those types in the task queues on interchange + # When the cold routing in soft mode happens, it may cause worker containers to + # switch + # Cold routing is to reduce the number idle workers of specific task types on + # the managers when there are not enough tasks of those types in the task queues + # on interchange last_cold_routing_time = time.time() while not self._kill_event.is_set(): self.socks = dict(poller.poll(timeout=poll_period)) # Listen for requests for work - if self.task_outgoing in self.socks and self.socks[self.task_outgoing] == zmq.POLLIN: - logger.debug("[MAIN] starting task_outgoing section") + if ( + self.task_outgoing in self.socks + and self.socks[self.task_outgoing] == zmq.POLLIN + ): + log.debug("[MAIN] starting task_outgoing section") message = self.task_outgoing.recv_multipart() manager = message[0] @@ -706,88 +773,111 @@ def start(self, poll_period=None): reg_flag = False try: - msg = json.loads(message[1].decode('utf-8')) + msg = json.loads(message[1].decode("utf-8")) reg_flag = True except Exception: - logger.warning("[MAIN] Got a non-json registration message from manager:{}".format( - manager)) - logger.debug("[MAIN] Message :\n{}\n".format(message)) + log.warning( + "[MAIN] Got a non-json registration message from " + "manager:%s", + manager, + ) + log.debug(f"[MAIN] Message :\n{message}\n") # By default we set up to ignore bad nodes/registration messages. - self._ready_manager_queue[manager] = {'last': time.time(), - 'reg_time': time.time(), - 'free_capacity': {'total_workers': 0}, - 'max_worker_count': 0, - 'active': True, - 'tasks': collections.defaultdict(set), - 'total_tasks': 0} + self._ready_manager_queue[manager] = { + "last": time.time(), + "reg_time": time.time(), + "free_capacity": {"total_workers": 0}, + "max_worker_count": 0, + "active": True, + "tasks": collections.defaultdict(set), + "total_tasks": 0, + } if reg_flag is True: interesting_managers.add(manager) - logger.info("[MAIN] Adding manager: {} to ready queue".format(manager)) + log.info(f"[MAIN] Adding manager: {manager} to ready queue") self._ready_manager_queue[manager].update(msg) - logger.info("[MAIN] Registration info for manager {}: {}".format(manager, msg)) - - if (msg['python_v'].rsplit(".", 1)[0] != self.current_platform['python_v'].rsplit(".", 1)[0] or - msg['parsl_v'] != self.current_platform['parsl_v']): - logger.warn("[MAIN] Manager {} has incompatible version info with the interchange".format(manager)) - - if self.suppress_failure is False: - logger.debug("Setting kill event") - self._kill_event.set() - e = ManagerLost(manager) - result_package = {'task_id': -1, - 'exception': self.serializer.serialize(e)} - pkl_package = pickle.dumps(result_package) - self.results_outgoing.send(pickle.dumps([pkl_package])) - logger.warning("[MAIN] Sent failure reports, unregistering manager") - else: - logger.debug("[MAIN] Suppressing shutdown due to version incompatibility") - + log.info( + "[MAIN] Registration info for manager {}: {}".format( + manager, msg + ) + ) + + if ( + msg["python_v"].rsplit(".", 1)[0] + != self.current_platform["python_v"].rsplit(".", 1)[0] + or msg["parsl_v"] != self.current_platform["parsl_v"] + ): + log.info( + f"[MAIN] Manager:{manager} version:{msg['python_v']} " + "does not match the interchange" + ) else: # Registration has failed. if self.suppress_failure is False: - logger.debug("Setting kill event for bad manager") + log.debug("Setting kill event for bad manager") self._kill_event.set() e = BadRegistration(manager, critical=True) - result_package = {'task_id': -1, - 'exception': self.serializer.serialize(e)} + result_package = { + "task_id": -1, + "exception": self.serializer.serialize(e), + } pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pickle.dumps([pkl_package])) else: - logger.debug("[MAIN] Suppressing bad registration from manager:{}".format( - manager)) + log.debug( + "[MAIN] Suppressing bad registration from manager: %s", + manager, + ) else: - self._ready_manager_queue[manager]['last'] = time.time() - if message[1] == b'HEARTBEAT': - logger.debug("[MAIN] Manager {} sends heartbeat".format(manager)) - self.task_outgoing.send_multipart([manager, b'', PKL_HEARTBEAT_CODE]) + self._ready_manager_queue[manager]["last"] = time.time() + if message[1] == b"HEARTBEAT": + log.debug(f"[MAIN] Manager {manager} sends heartbeat") + self.task_outgoing.send_multipart( + [manager, b"", PKL_HEARTBEAT_CODE] + ) else: manager_adv = pickle.loads(message[1]) - logger.debug("[MAIN] Manager {} requested {}".format(manager, manager_adv)) - self._ready_manager_queue[manager]['free_capacity'].update(manager_adv) - self._ready_manager_queue[manager]['free_capacity']['total_workers'] = sum(manager_adv['free'].values()) + log.debug( + "[MAIN] Manager {} requested {}".format( + manager, manager_adv + ) + ) + self._ready_manager_queue[manager]["free_capacity"].update( + manager_adv + ) + self._ready_manager_queue[manager]["free_capacity"][ + "total_workers" + ] = sum(manager_adv["free"].values()) interesting_managers.add(manager) - # If we had received any requests, check if there are tasks that could be passed + # If we had received any requests, check if there are tasks that could be + # passed - logger.debug("[MAIN] Managers count (total/interesting): {}/{}".format( - len(self._ready_manager_queue), - len(interesting_managers))) + log.debug( + "[MAIN] Managers count (total/interesting): {}/{}".format( + len(self._ready_manager_queue), len(interesting_managers) + ) + ) if time.time() - last_cold_routing_time > self.cold_routing_interval: - task_dispatch, dispatched_task = naive_interchange_task_dispatch(interesting_managers, - self.pending_task_queue, - self._ready_manager_queue, - scheduler_mode=self.scheduler_mode, - cold_routing=True) + task_dispatch, dispatched_task = naive_interchange_task_dispatch( + interesting_managers, + self.pending_task_queue, + self._ready_manager_queue, + scheduler_mode=self.scheduler_mode, + cold_routing=True, + ) last_cold_routing_time = time.time() else: - task_dispatch, dispatched_task = naive_interchange_task_dispatch(interesting_managers, - self.pending_task_queue, - self._ready_manager_queue, - scheduler_mode=self.scheduler_mode, - cold_routing=False) + task_dispatch, dispatched_task = naive_interchange_task_dispatch( + interesting_managers, + self.pending_task_queue, + self._ready_manager_queue, + scheduler_mode=self.scheduler_mode, + cold_routing=False, + ) self.total_pending_task_count -= dispatched_task @@ -796,9 +886,12 @@ def start(self, poll_period=None): while True: try: manager, task_id = self.task_cancel_running_queue.get(block=False) - logger.debug(f"[MAIN] Task:{task_id} on manager:{manager} is now CANCELLED while running") - cancel_message = pickle.dumps(('TASK_CANCEL', task_id)) - self.task_outgoing.send_multipart([manager, b'', cancel_message]) + log.debug( + f"[MAIN] Task:{task_id} on manager:{manager} is " + "now CANCELLED while running" + ) + cancel_message = pickle.dumps(("TASK_CANCEL", task_id)) + self.task_outgoing.send_multipart([manager, b"", cancel_message]) except queue.Empty: break @@ -806,109 +899,175 @@ def start(self, poll_period=None): for manager in task_dispatch: tasks = task_dispatch[manager] if tasks: - logger.info("[MAIN] Sending task message {} to manager {}".format(tasks, manager)) - serialised_raw_tasks_buffer = pickle.dumps(tasks) - self.task_outgoing.send_multipart([manager, b'', serialised_raw_tasks_buffer]) + log.info( + "[MAIN] Sending task message {} to manager {}".format( + tasks, manager + ) + ) + serializd_raw_tasks_buffer = pickle.dumps(tasks) + self.task_outgoing.send_multipart( + [manager, b"", serializd_raw_tasks_buffer] + ) for task in tasks: task_id = task["task_id"] - if self.task_cancel_pending_trap and task_id in self.task_cancel_pending_trap: - logger.info(f"[MAIN] Task:{task_id} CANCELLED before launch") - cancel_message = pickle.dumps(('TASK_CANCEL', task_id)) - self.task_outgoing.send_multipart([manager, b'', cancel_message]) + if ( + self.task_cancel_pending_trap + and task_id in self.task_cancel_pending_trap + ): + log.info(f"[MAIN] Task:{task_id} CANCELLED before launch") + cancel_message = pickle.dumps(("TASK_CANCEL", task_id)) + self.task_outgoing.send_multipart( + [manager, b"", cancel_message] + ) self.task_cancel_pending_trap.pop(task_id) else: - logger.debug(f"[MAIN] Task:{task_id} is now WAITING_FOR_LAUNCH") - self.task_status_deltas[task_id] = TaskStatusCode.WAITING_FOR_LAUNCH + log.debug( + f"[MAIN] Task:{task_id} is now WAITING_FOR_LAUNCH" + ) + self.task_status_deltas[ + task_id + ] = TaskStatusCode.WAITING_FOR_LAUNCH # Receive any results and forward to client - if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN: - logger.debug("[MAIN] entering results_incoming section") + if ( + self.results_incoming in self.socks + and self.socks[self.results_incoming] == zmq.POLLIN + ): + log.debug("[MAIN] entering results_incoming section") manager, *b_messages = self.results_incoming.recv_multipart() if manager not in self._ready_manager_queue: - logger.warning("[MAIN] Received a result from a un-registered manager: {}".format(manager)) + log.warning( + "[MAIN] Received a result from a un-registered manager: %s", + manager, + ) else: - # We expect the batch of messages to be (optionally) a task status update message - # followed by 0 or more task results + # We expect the batch of messages to be (optionally) a task status + # update message followed by 0 or more task results try: - logger.debug("[MAIN] Trying to unpack ") + log.debug("[MAIN] Trying to unpack ") manager_report = Message.unpack(b_messages[0]) if manager_report.task_statuses: - logger.info(f"[MAIN] Got manager status report: {manager_report.task_statuses}") + log.info( + "[MAIN] Got manager status report: %s", + manager_report.task_statuses, + ) self.task_status_deltas.update(manager_report.task_statuses) - self.task_outgoing.send_multipart([manager, b'', PKL_HEARTBEAT_CODE]) + self.task_outgoing.send_multipart( + [manager, b"", PKL_HEARTBEAT_CODE] + ) b_messages = b_messages[1:] - self._ready_manager_queue[manager]['last'] = time.time() - self.container_switch_count[manager] = manager_report.container_switch_count - logger.info(f"[MAIN] Got container switch count: {self.container_switch_count}") + self._ready_manager_queue[manager]["last"] = time.time() + self.container_switch_count[ + manager + ] = manager_report.container_switch_count + log.info( + "[MAIN] Got container switch count: %s", + self.container_switch_count, + ) except Exception: pass if len(b_messages): - logger.info("[MAIN] Got {} result items in batch".format(len(b_messages))) + log.info( + "[MAIN] Got {} result items in batch".format( + len(b_messages) + ) + ) for b_message in b_messages: r = pickle.loads(b_message) - logger.debug("[MAIN] Received result for Task:{} from {}".format(r['task_id'], - manager)) - task_type = self.containers[r['container_id']] - logger.debug(f"[MAIN] Removing for manager:{manager} from {self._ready_manager_queue}") - if r['task_id'] in self.task_status_deltas: - del self.task_status_deltas[r['task_id']] - self._ready_manager_queue[manager]['tasks'][task_type].remove(r['task_id']) - self._ready_manager_queue[manager]['total_tasks'] -= len(b_messages) + log.debug( + "[MAIN] Received result for task {} from {}".format( + r["task_id"], manager + ) + ) + task_type = self.containers[r["container_id"]] + log.debug( + "[MAIN] Removing for manager: %s from %s", + manager, + self._ready_manager_queue, + ) + if r["task_id"] in self.task_status_deltas: + del self.task_status_deltas[r["task_id"]] + self._ready_manager_queue[manager]["tasks"][task_type].remove( + r["task_id"] + ) + self._ready_manager_queue[manager]["total_tasks"] -= len(b_messages) # TODO: handle this with a Task message or something? - # previously used this; switched to mono-message, self.results_outgoing.send_multipart(b_messages) + # previously used this; switched to mono-message, + # self.results_outgoing.send_multipart(b_messages) self.results_outgoing.send(pickle.dumps(b_messages)) - logger.debug("[MAIN] Current tasks: {}".format(self._ready_manager_queue[manager]['tasks'])) - logger.debug("[MAIN] leaving results_incoming section") - - # Send status reports from this main thread to avoid thread-safety on zmq sockets + log.debug( + "[MAIN] Current tasks: {}".format( + self._ready_manager_queue[manager]["tasks"] + ) + ) + log.debug("[MAIN] leaving results_incoming section") + + # Send status reports from this main thread to avoid thread-safety on zmq + # sockets try: packed_status_report = status_report_queue.get(block=False) - logger.debug(f"[MAIN] forwarding status report: {packed_status_report}") + log.debug(f"[MAIN] forwarding status report: {packed_status_report}") self.results_outgoing.send(packed_status_report) except queue.Empty: pass - # logger.debug("[MAIN] entering bad_managers section") - bad_managers = [manager for manager in self._ready_manager_queue if - time.time() - self._ready_manager_queue[manager]['last'] > self.heartbeat_threshold] + # log.debug("[MAIN] entering bad_managers section") + bad_managers = [ + manager + for manager in self._ready_manager_queue + if time.time() - self._ready_manager_queue[manager]["last"] + > self.heartbeat_threshold + ] bad_manager_msgs = [] for manager in bad_managers: - logger.debug("[MAIN] Last: {} Current: {}".format(self._ready_manager_queue[manager]['last'], time.time())) - logger.warning("[MAIN] Too many heartbeats missed for manager {}".format(manager)) + log.debug( + "[MAIN] Last: {} Current: {}".format( + self._ready_manager_queue[manager]["last"], time.time() + ) + ) + log.warning(f"[MAIN] Too many heartbeats missed for manager {manager}") e = ManagerLost(manager) - for task_type in self._ready_manager_queue[manager]['tasks']: - for tid in self._ready_manager_queue[manager]['tasks'][task_type]: + for task_type in self._ready_manager_queue[manager]["tasks"]: + for tid in self._ready_manager_queue[manager]["tasks"][task_type]: try: raise ManagerLost(manager) except Exception: - result_package = {'task_id': tid, 'exception': self.serializer.serialize(RemoteExceptionWrapper(*sys.exc_info()))} + result_package = { + "task_id": tid, + "exception": self.serializer.serialize( + RemoteExceptionWrapper(*sys.exc_info()) + ), + } pkl_package = pickle.dumps(result_package) bad_manager_msgs.append(pkl_package) - logger.warning("[MAIN] Sent failure reports, unregistering manager {}".format(manager)) - self._ready_manager_queue.pop(manager, 'None') + log.warning( + "[MAIN] Sent failure reports, unregistering manager {}".format( + manager + ) + ) + self._ready_manager_queue.pop(manager, "None") if manager in interesting_managers: interesting_managers.remove(manager) if bad_manager_msgs: self.results_outgoing.send(pickle.dumps(bad_manager_msgs)) - logger.debug("[MAIN] ending one main loop iteration") + log.debug("[MAIN] ending one main loop iteration") if self._status_request.is_set(): - logger.info("status request response") + log.info("status request response") result_package = self.get_status_report() pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) - logger.info("[MAIN] Sent info response") + log.info("[MAIN] Sent info response") self._status_request.clear() delta = time.time() - start - logger.info("Processed {} tasks in {} seconds".format(count, delta)) - logger.warning("Exiting") + log.info(f"Processed {count} tasks in {delta} seconds") + log.warning("Exiting") def get_status_report(self): - """ Get utilization numbers - """ + """Get utilization numbers""" total_cores = 0 total_mem = 0 core_hrs = 0 @@ -920,36 +1079,43 @@ def get_status_report(self): live_workers = self.get_total_live_workers() for manager in self._ready_manager_queue: - total_cores += self._ready_manager_queue[manager]['cores'] - total_mem += self._ready_manager_queue[manager]['mem'] - active_dur = abs(time.time() - self._ready_manager_queue[manager]['reg_time']) + total_cores += self._ready_manager_queue[manager]["cores"] + total_mem += self._ready_manager_queue[manager]["mem"] + active_dur = abs( + time.time() - self._ready_manager_queue[manager]["reg_time"] + ) core_hrs += (active_dur * total_cores) / 3600 - if self._ready_manager_queue[manager]['active']: + if self._ready_manager_queue[manager]["active"]: active_managers += 1 - free_capacity += self._ready_manager_queue[manager]['free_capacity']['total_workers'] - - result_package = {'task_id': -2, - 'info': {'total_cores': total_cores, - 'total_mem': total_mem, - 'new_core_hrs': core_hrs - self.last_core_hr_counter, - 'total_core_hrs': round(core_hrs, 2), - 'managers': num_managers, - 'active_managers': active_managers, - 'total_workers': live_workers, - 'idle_workers': free_capacity, - 'pending_tasks': pending_tasks, - 'outstanding_tasks': outstanding_tasks, - 'worker_mode': self.worker_mode, - 'scheduler_mode': self.scheduler_mode, - 'scaling_enabled': self.scaling_enabled, - 'mem_per_worker': self.mem_per_worker, - 'cores_per_worker': self.cores_per_worker, - 'prefetch_capacity': self.prefetch_capacity, - 'max_blocks': self.provider.max_blocks, - 'min_blocks': self.provider.min_blocks, - 'max_workers_per_node': self.max_workers_per_node, - 'nodes_per_block': self.provider.nodes_per_block - }} + free_capacity += self._ready_manager_queue[manager]["free_capacity"][ + "total_workers" + ] + + result_package = { + "task_id": -2, + "info": { + "total_cores": total_cores, + "total_mem": total_mem, + "new_core_hrs": core_hrs - self.last_core_hr_counter, + "total_core_hrs": round(core_hrs, 2), + "managers": num_managers, + "active_managers": active_managers, + "total_workers": live_workers, + "idle_workers": free_capacity, + "pending_tasks": pending_tasks, + "outstanding_tasks": outstanding_tasks, + "worker_mode": self.worker_mode, + "scheduler_mode": self.scheduler_mode, + "scaling_enabled": self.scaling_enabled, + "mem_per_worker": self.mem_per_worker, + "cores_per_worker": self.cores_per_worker, + "prefetch_capacity": self.prefetch_capacity, + "max_blocks": self.provider.max_blocks, + "min_blocks": self.provider.min_blocks, + "max_workers_per_node": self.max_workers_per_node, + "nodes_per_block": self.provider.nodes_per_block, + }, + } self.last_core_hr_counter = core_hrs return result_package @@ -965,22 +1131,30 @@ def scale_out(self, blocks=1, task_type=None): if self.provider: self._block_counter += 1 external_block_id = str(self._block_counter) - if not task_type and self.scheduler_mode == 'hard': - launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type='RAW') + if not task_type and self.scheduler_mode == "hard": + launch_cmd = self.launch_cmd.format( + block_id=external_block_id, worker_type="RAW" + ) else: - launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type=task_type) + launch_cmd = self.launch_cmd.format( + block_id=external_block_id, worker_type=task_type + ) if not task_type: internal_block = self.provider.submit(launch_cmd, 1) else: internal_block = self.provider.submit(launch_cmd, 1, task_type) - logger.debug("Launched block {}->{}".format(external_block_id, internal_block)) + log.debug(f"Launched block {external_block_id}->{internal_block}") if not internal_block: - raise(ScalingFailed(self.config.provider.label, - "Attempts to provision nodes via provider has failed")) + raise ( + ScalingFailed( + self.config.provider.label, + "Attempts to provision nodes via provider has failed", + ) + ) self.blocks[external_block_id] = internal_block self.block_id_map[internal_block] = external_block_id else: - logger.error("No execution provider available") + log.error("No execution provider available") r = None return r @@ -998,14 +1172,22 @@ def scale_in(self, blocks=None, block_ids=None, task_type=None): if block_ids is None: block_ids = [] if task_type: - logger.info("Scaling in blocks of specific task type {}. Let the provider decide which to kill".format(task_type)) + log.info( + "Scaling in blocks of specific task type %s. Let the provider decide " + "which to kill", + task_type, + ) if self.scaling_enabled and self.provider: to_kill, r = self.provider.cancel(blocks, task_type) - logger.info("Get the killed blocks: {}, and status: {}".format(to_kill, r)) + log.info(f"Get the killed blocks: {to_kill}, and status: {r}") for job in to_kill: - logger.info("[scale_in] Getting the block_id map {} for job {}".format(self.block_id_map, job)) + log.info( + "[scale_in] Getting the block_id map {} for job {}".format( + self.block_id_map, job + ) + ) block_id = self.block_id_map[job] - logger.info("[scale_in] Holding block {}".format(block_id)) + log.info(f"[scale_in] Holding block {block_id}") self._hold_block(block_id) self.blocks.pop(block_id) return r @@ -1029,13 +1211,16 @@ def scale_in(self, blocks=None, block_ids=None, task_type=None): return r def provider_status(self): - """ Get status of all blocks from the provider - """ + """Get status of all blocks from the provider""" status = [] if self.provider: - logger.debug("[MAIN] Getting the status of {} blocks.".format(list(self.blocks.values()))) + log.debug( + "[MAIN] Getting the status of {} blocks.".format( + list(self.blocks.values()) + ) + ) status = self.provider.status(list(self.blocks.values())) - logger.debug("[MAIN] The status is {}".format(status)) + log.debug(f"[MAIN] The status is {status}") return status @@ -1043,59 +1228,85 @@ def provider_status(self): def starter(comm_q, *args, **kwargs): """Start the interchange process - The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__ + The executor is expected to call this function. The args, kwargs match that of the + Interchange.__init__ """ - # logger = multiprocessing.get_logger() ic = Interchange(*args, **kwargs) - comm_q.put((ic.worker_task_port, - ic.worker_result_port)) + comm_q.put((ic.worker_task_port, ic.worker_result_port)) ic.start() def cli_run(): parser = argparse.ArgumentParser() - parser.add_argument("-c", "--client_address", required=True, - help="Client address") - parser.add_argument("--client_ports", required=True, - help="client ports as a triple of outgoing,incoming,command") - parser.add_argument("--worker_port_range", - help="Worker port range as a tuple") - parser.add_argument("-l", "--logdir", default="./parsl_worker_logs", - help="Parsl worker log directory") - parser.add_argument("-p", "--poll_period", - help="REQUIRED: poll period used for main thread") - parser.add_argument("--worker_ports", default=None, - help="OPTIONAL, pair of workers ports to listen on, eg --worker_ports=50001,50005") - parser.add_argument("--suppress_failure", action='store_true', - help="Enables suppression of failures") - parser.add_argument("--endpoint_id", default=None, - help="Endpoint ID, used to identify the endpoint to the remote broker") - parser.add_argument("--hb_threshold", - help="Heartbeat threshold in seconds") - parser.add_argument("--config", default=None, - help="Configuration object that describes provisioning") - parser.add_argument("-d", "--debug", action='store_true', - help="Enables debug logging") + parser.add_argument("-c", "--client_address", required=True, help="Client address") + parser.add_argument( + "--client_ports", + required=True, + help="client ports as a triple of outgoing,incoming,command", + ) + parser.add_argument("--worker_port_range", help="Worker port range as a tuple") + parser.add_argument( + "-l", + "--logdir", + default="./parsl_worker_logs", + help="Parsl worker log directory", + ) + parser.add_argument( + "-p", "--poll_period", help="REQUIRED: poll period used for main thread" + ) + parser.add_argument( + "--worker_ports", + default=None, + help="OPTIONAL, pair of workers ports to listen on, " + "eg --worker_ports=50001,50005", + ) + parser.add_argument( + "--suppress_failure", + action="store_true", + help="Enables suppression of failures", + ) + parser.add_argument( + "--endpoint_id", + default=None, + help="Endpoint ID, used to identify the endpoint to the remote broker", + ) + parser.add_argument("--hb_threshold", help="Heartbeat threshold in seconds") + parser.add_argument( + "--config", + default=None, + help="Configuration object that describes provisioning", + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enables debug logging" + ) print("Starting HTEX Intechange") - args = parser.parse_args() - optionals = {} - optionals['suppress_failure'] = args.suppress_failure - optionals['logdir'] = os.path.abspath(args.logdir) - optionals['client_address'] = args.client_address - optionals['client_ports'] = [int(i) for i in args.client_ports.split(',')] - optionals['endpoint_id'] = args.endpoint_id - optionals['config'] = args.config + args = parser.parse_args() - if args.debug: - optionals['logging_level'] = logging.DEBUG + args.logdir = os.path.abspath(args.logdir) if args.worker_ports: - optionals['worker_ports'] = [int(i) for i in args.worker_ports.split(',')] + args.worker_ports = [int(i) for i in args.worker_ports.split(",")] if args.worker_port_range: - optionals['worker_port_range'] = [int(i) for i in args.worker_port_range.split(',')] + args.worker_port_range = [int(i) for i in args.worker_port_range.split(",")] + + os.makedirs(args.logdir, exist_ok=True) + setup_logging( + logfile=os.path.join(args.logdir, "interchange.log"), + debug=args.debug, + console_enabled=False, + ) with daemon.DaemonContext(): - ic = Interchange(**optionals) + ic = Interchange( + logdir=args.logdir, + suppress_failure=args.suppress_failure, + client_address=args.client_address, + client_ports=[int(i) for i in args.client_ports.split(",")], + endpoint_id=args.endpoint_id, + config=args.config, + worker_ports=args.worker_ports, + worker_port_range=args.worker_port_range, + ) ic.start() diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange_task_dispatch.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange_task_dispatch.py index b6017f802..fae1b1a4f 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange_task_dispatch.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange_task_dispatch.py @@ -1,50 +1,57 @@ -import math -import random -import queue -import logging import collections +import logging +import queue +import random -logger = logging.getLogger("interchange.task_dispatch") -logger.info("Interchange task dispatch started") +log = logging.getLogger(__name__) +log.info("Interchange task dispatch started") -def naive_interchange_task_dispatch(interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode='hard', - cold_routing=False): +def naive_interchange_task_dispatch( + interesting_managers, + pending_task_queue, + ready_manager_queue, + scheduler_mode="hard", + cold_routing=False, +): """ This is an initial task dispatching algorithm for interchange. - It returns a dictionary, whose key is manager, and the value is the list of tasks to be sent to manager, - and the total number of dispatched tasks. + It returns a dictionary, whose key is manager, and the value is the list of tasks + to be sent to manager, and the total number of dispatched tasks. """ - if scheduler_mode == 'hard': - return dispatch(interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode='hard') + if scheduler_mode == "hard": + return dispatch( + interesting_managers, + pending_task_queue, + ready_manager_queue, + scheduler_mode="hard", + ) - elif scheduler_mode == 'soft': + elif scheduler_mode == "soft": task_dispatch, dispatched_tasks = {}, 0 - loops = ['warm'] if not cold_routing else ['warm', 'cold'] + loops = ["warm"] if not cold_routing else ["warm", "cold"] for loop in loops: - task_dispatch, dispatched_tasks = dispatch(interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode='soft', - loop=loop, - task_dispatch=task_dispatch, - dispatched_tasks=dispatched_tasks) + task_dispatch, dispatched_tasks = dispatch( + interesting_managers, + pending_task_queue, + ready_manager_queue, + scheduler_mode="soft", + loop=loop, + task_dispatch=task_dispatch, + dispatched_tasks=dispatched_tasks, + ) return task_dispatch, dispatched_tasks -def dispatch(interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode='hard', - loop='warm', - task_dispatch=None, - dispatched_tasks=0): +def dispatch( + interesting_managers, + pending_task_queue, + ready_manager_queue, + scheduler_mode="hard", + loop="warm", + task_dispatch=None, + dispatched_tasks=0, +): """ This is the core task dispatching algorithm for interchange. The algorithm depends on the scheduler mode and which loop. @@ -55,99 +62,129 @@ def dispatch(interesting_managers, shuffled_managers = list(interesting_managers) random.shuffle(shuffled_managers) for manager in shuffled_managers: - tasks_inflight = ready_manager_queue[manager]['total_tasks'] - real_capacity = min(ready_manager_queue[manager]['free_capacity']['total_workers'], - ready_manager_queue[manager]['max_worker_count'] - tasks_inflight) - if (real_capacity and ready_manager_queue[manager]['active']): - if scheduler_mode == 'hard': - tasks, tids = get_tasks_hard(pending_task_queue, - ready_manager_queue[manager], - real_capacity) + tasks_inflight = ready_manager_queue[manager]["total_tasks"] + real_capacity = min( + ready_manager_queue[manager]["free_capacity"]["total_workers"], + ready_manager_queue[manager]["max_worker_count"] - tasks_inflight, + ) + if real_capacity and ready_manager_queue[manager]["active"]: + if scheduler_mode == "hard": + tasks, tids = get_tasks_hard( + pending_task_queue, ready_manager_queue[manager], real_capacity + ) else: - tasks, tids = get_tasks_soft(pending_task_queue, - ready_manager_queue[manager], - real_capacity, - loop=loop) - logger.debug("[MAIN] Get tasks {} from queue".format(tasks)) + tasks, tids = get_tasks_soft( + pending_task_queue, + ready_manager_queue[manager], + real_capacity, + loop=loop, + ) + log.debug(f"[MAIN] Get tasks {tasks} from queue") if tasks: for task_type in tids: # This line is a set update, not dict update - ready_manager_queue[manager]['tasks'][task_type].update(tids[task_type]) - logger.debug("[MAIN] The tasks on manager {} is {}".format(manager, ready_manager_queue[manager]['tasks'])) - ready_manager_queue[manager]['total_tasks'] += len(tasks) + ready_manager_queue[manager]["tasks"][task_type].update( + tids[task_type] + ) + log.debug( + "[MAIN] The tasks on manager {} is {}".format( + manager, ready_manager_queue[manager]["tasks"] + ) + ) + ready_manager_queue[manager]["total_tasks"] += len(tasks) if manager not in task_dispatch: task_dispatch[manager] = [] task_dispatch[manager] += tasks dispatched_tasks += len(tasks) - logger.debug("[MAIN] Assigned tasks {} to manager {}".format(tids, manager)) - if ready_manager_queue[manager]['free_capacity']['total_workers'] > 0: - logger.debug("[MAIN] Manager {} still has free_capacity {}".format(manager, ready_manager_queue[manager]['free_capacity']['total_workers'])) + log.debug(f"[MAIN] Assigned tasks {tids} to manager {manager}") + if ready_manager_queue[manager]["free_capacity"]["total_workers"] > 0: + log.debug( + "[MAIN] Manager {} still has free_capacity {}".format( + manager, + ready_manager_queue[manager]["free_capacity"][ + "total_workers" + ], + ) + ) else: - logger.debug("[MAIN] Manager {} is now saturated".format(manager)) + log.debug(f"[MAIN] Manager {manager} is now saturated") interesting_managers.remove(manager) else: interesting_managers.remove(manager) - logger.debug("The task dispatch of {} loop is {}, in total {} tasks".format(loop, task_dispatch, dispatched_tasks)) + log.debug( + "The task dispatch of {} loop is {}, in total {} tasks".format( + loop, task_dispatch, dispatched_tasks + ) + ) return task_dispatch, dispatched_tasks def get_tasks_hard(pending_task_queue, manager_ads, real_capacity): tasks = [] tids = collections.defaultdict(set) - task_type = manager_ads['worker_type'] + task_type = manager_ads["worker_type"] if not task_type: - logger.warning("Using hard scheduler mode but with manager worker type unset. Use soft scheduler mode. Set this in the config.") + log.warning( + "Using hard scheduler mode but with manager worker type unset. " + "Use soft scheduler mode. Set this in the config." + ) return tasks, tids if task_type not in pending_task_queue: - logger.debug("No task of type {}. Exiting task fetching.".format(task_type)) + log.debug(f"No task of type {task_type}. Exiting task fetching.") return tasks, tids # dispatch tasks of available types on manager - if task_type in manager_ads['free_capacity']['free']: - while manager_ads['free_capacity']['free'][task_type] > 0 and real_capacity > 0: + if task_type in manager_ads["free_capacity"]["free"]: + while manager_ads["free_capacity"]["free"][task_type] > 0 and real_capacity > 0: try: x = pending_task_queue[task_type].get(block=False) except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + log.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['free'][task_type] -= 1 - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["free"][task_type] -= 1 + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 # dispatch tasks to unused slots based on the manager type - logger.debug("Second round of task fetching in hard mode") - while manager_ads['free_capacity']['free']["unused"] > 0 and real_capacity > 0: + log.debug("Second round of task fetching in hard mode") + while manager_ads["free_capacity"]["free"]["unused"] > 0 and real_capacity > 0: try: x = pending_task_queue[task_type].get(block=False) except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + log.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['free']['unused'] -= 1 - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["free"]["unused"] -= 1 + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 return tasks, tids -def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop='warm'): +def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop="warm"): tasks = [] tids = collections.defaultdict(set) # Warm routing to dispatch tasks - if loop == 'warm': - for task_type in manager_ads['free_capacity']['free']: + if loop == "warm": + for task_type in manager_ads["free_capacity"]["free"]: # Dispatch tasks that are of the available container types on the manager - if task_type != 'unused': - type_inflight = len(manager_ads['tasks'].get(task_type, set())) - type_capacity = min(manager_ads['free_capacity']['free'][task_type], - manager_ads['free_capacity']['total'][task_type] - type_inflight) - while manager_ads['free_capacity']['free'][task_type] > 0 and real_capacity > 0 and type_capacity > 0: + if task_type != "unused": + type_inflight = len(manager_ads["tasks"].get(task_type, set())) + type_capacity = min( + manager_ads["free_capacity"]["free"][task_type], + manager_ads["free_capacity"]["total"][task_type] - type_inflight, + ) + while ( + manager_ads["free_capacity"]["free"][task_type] > 0 + and real_capacity > 0 + and type_capacity > 0 + ): try: if task_type not in pending_task_queue: break @@ -155,11 +192,11 @@ def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop='warm'): except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + log.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['free'][task_type] -= 1 - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["free"][task_type] -= 1 + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 type_capacity -= 1 # Dispatch tasks to unused container slots on the manager @@ -167,18 +204,21 @@ def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop='warm'): task_types = list(pending_task_queue.keys()) random.shuffle(task_types) for task_type in task_types: - while (manager_ads['free_capacity']['free']['unused'] > 0 and - manager_ads['free_capacity']['total_workers'] > 0 and real_capacity > 0): + while ( + manager_ads["free_capacity"]["free"]["unused"] > 0 + and manager_ads["free_capacity"]["total_workers"] > 0 + and real_capacity > 0 + ): try: x = pending_task_queue[task_type].get(block=False) except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + log.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['free']['unused'] -= 1 - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["free"]["unused"] -= 1 + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 return tasks, tids @@ -188,19 +228,19 @@ def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop='warm'): # This is needed to avoid workers being idle for too long # Potential issues may be that it could kill containers of short tasks frequently # Tune cold_routing_interval in the config to balance such a tradeoff - logger.debug("Cold function routing!") + log.debug("Cold function routing!") task_types = list(pending_task_queue.keys()) random.shuffle(task_types) for task_type in task_types: - while manager_ads['free_capacity']['total_workers'] > 0 and real_capacity > 0: + while manager_ads["free_capacity"]["total_workers"] > 0 and real_capacity > 0: try: x = pending_task_queue[task_type].get(block=False) except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + log.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 return tasks, tids diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/mac_safe_queue.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/mac_safe_queue.py index 1ce44dd48..357f8e32c 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/mac_safe_queue.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/mac_safe_queue.py @@ -1,5 +1,8 @@ import platform -if platform.system() == 'Darwin': + +if platform.system() == "Darwin": from parsl.executors.high_throughput.mac_safe_queue import MacSafeQueue as mpQueue else: from multiprocessing import Queue as mpQueue + +__all__ = ("mpQueue",) diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/messages.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/messages.py index 785ed0325..1e7087148 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/messages.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/messages.py @@ -3,9 +3,8 @@ from abc import ABC, abstractmethod from enum import Enum, auto from struct import Struct -from typing import Tuple -MESSAGE_TYPE_FORMATTER = Struct('b') +MESSAGE_TYPE_FORMATTER = Struct("b") class MessageType(Enum): @@ -23,8 +22,8 @@ def pack(self): @classmethod def unpack(cls, buffer): - mtype, = MESSAGE_TYPE_FORMATTER.unpack_from(buffer, offset=0) - return MessageType(mtype), buffer[MESSAGE_TYPE_FORMATTER.size:] + (mtype,) = MESSAGE_TYPE_FORMATTER.unpack_from(buffer, offset=0) + return MessageType(mtype), buffer[MESSAGE_TYPE_FORMATTER.size :] class TaskStatusCode(int, Enum): @@ -36,10 +35,7 @@ class TaskStatusCode(int, Enum): CANCELLED = auto() -COMMAND_TYPES = { - MessageType.HEARTBEAT_REQ, - MessageType.TASK_CANCEL -} +COMMAND_TYPES = {MessageType.HEARTBEAT_REQ, MessageType.TASK_CANCEL} class Message(ABC): @@ -94,9 +90,12 @@ class Task(Message): """ Task message from the forwarder->interchange """ + type = MessageType.TASK - def __init__(self, task_id: str, container_id: str, task_buffer: str, raw_buffer=None): + def __init__( + self, task_id: str, container_id: str, task_buffer: str, raw_buffer=None + ): super().__init__() self.task_id = task_id self.container_id = container_id @@ -106,15 +105,17 @@ def __init__(self, task_id: str, container_id: str, task_buffer: str, raw_buffer def pack(self) -> bytes: if self.raw_buffer is None: - add_ons = f'TID={self.task_id};CID={self.container_id};{self.task_buffer}' - self.raw_buffer = add_ons.encode('utf-8') + add_ons = f"TID={self.task_id};CID={self.container_id};{self.task_buffer}" + self.raw_buffer = add_ons.encode("utf-8") return self.type.pack() + self.raw_buffer @classmethod def unpack(cls, raw_buffer: bytes): - b_tid, b_cid, task_buf = raw_buffer.decode('utf-8').split(';', 2) - return cls(b_tid[4:], b_cid[4:], task_buf.encode('utf-8'), raw_buffer=raw_buffer) + b_tid, b_cid, task_buf = raw_buffer.decode("utf-8").split(";", 2) + return cls( + b_tid[4:], b_cid[4:], task_buf.encode("utf-8"), raw_buffer=raw_buffer + ) def set_local_container(self, container_id): self.local_container = container_id @@ -122,9 +123,12 @@ def set_local_container(self, container_id): class HeartbeatReq(Message): """ - Synchronous request for a Heartbeat. This is sent from the Forwarder to the endpoint on start to get - an initial connection and ensure liveness. + Synchronous request for a Heartbeat. + + This is sent from the Forwarder to the endpoint on start to get an initial + connection and ensure liveness. """ + type = MessageType.HEARTBEAT_REQ @property @@ -145,8 +149,10 @@ def pack(self): class Heartbeat(Message): """ - Generic Heartbeat message, sent in both directions between Forwarder and Interchange. + Generic Heartbeat message, sent in both directions between Forwarder and + Interchange. """ + type = MessageType.HEARTBEAT def __init__(self, endpoint_id): @@ -163,9 +169,11 @@ def pack(self): class EPStatusReport(Message): """ - Status report for an endpoint, sent from Interchange to Forwarder. Includes EP-wide info such as utilization, - as well as per-task status information. + Status report for an endpoint, sent from Interchange to Forwarder. + + Includes EP-wide info such as utilization, as well as per-task status information. """ + type = MessageType.EP_STATUS_REPORT def __init__(self, endpoint_id, ep_status_report, task_statuses): @@ -191,9 +199,10 @@ def pack(self): class ManagerStatusReport(Message): """ - Status report sent from the Manager to the Interchange, which mostly just amounts to saying which tasks are now - RUNNING. + Status report sent from the Manager to the Interchange, which mostly just amounts + to saying which tasks are now RUNNING. """ + type = MessageType.MANAGER_STATUS_REPORT def __init__(self, task_statuses, container_switch_count): @@ -203,7 +212,7 @@ def __init__(self, task_statuses, container_switch_count): @classmethod def unpack(cls, msg): - container_switch_count = int.from_bytes(msg[:10], 'little') + container_switch_count = int.from_bytes(msg[:10], "little") msg = msg[10:] jsonified = msg.decode("ascii") task_statuses = json.loads(jsonified) @@ -212,7 +221,11 @@ def unpack(cls, msg): def pack(self): # TODO: do better than JSON? jsonified = json.dumps(self.task_statuses) - return self.type.pack() + self.container_switch_count.to_bytes(10, 'little') + jsonified.encode("ascii") + return ( + self.type.pack() + + self.container_switch_count.to_bytes(10, "little") + + jsonified.encode("ascii") + ) class ResultsAck(Message): @@ -220,6 +233,7 @@ class ResultsAck(Message): Results acknowledgement to acknowledge a task result was received by the forwarder. Sent from forwarder->interchange """ + type = MessageType.RESULTS_ACK def __init__(self, task_id): @@ -236,8 +250,11 @@ def pack(self): class TaskCancel(Message): """ - Synchronous request for to cancel a Task. This is sent from the Executor to the Interchange + Synchronous request for to cancel a Task. + + This is sent from the Executor to the Interchange """ + type = MessageType.TASK_CANCEL def __init__(self, task_id): @@ -254,8 +271,10 @@ def pack(self): class BadCommand(Message): """ - Error message send to indicate that a command is either unknown, malformed or unsupported. + Error message send to indicate that a command is either + unknown, malformed or unsupported. """ + type = MessageType.BAD_COMMAND def __init__(self, reason: str): diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/worker_map.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/worker_map.py index 9f4be156c..399e772d8 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/worker_map.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/worker_map.py @@ -1,24 +1,27 @@ -from queue import Queue import logging +import os import random import subprocess import time -import os +from queue import Queue -logger = logging.getLogger("funcx_manager.worker_map") +log = logging.getLogger(__name__) -class WorkerMap(object): - """ WorkerMap keeps track of workers - """ +class WorkerMap: + """WorkerMap keeps track of workers""" def __init__(self, max_worker_count): self.max_worker_count = max_worker_count - self.total_worker_type_counts = {'unused': self.max_worker_count} - self.ready_worker_type_counts = {'unused': self.max_worker_count} + self.total_worker_type_counts = {"unused": self.max_worker_count} + self.ready_worker_type_counts = {"unused": self.max_worker_count} self.pending_worker_type_counts = {} - self.worker_queues = {} # a dict to keep track of all the worker_queues with the key of work_type - self.worker_types = {} # a dict to keep track of all the worker_types with the key of worker_id + self.worker_queues = ( + {} + ) # a dict to keep track of all the worker_queues with the key of work_type + self.worker_types = ( + {} + ) # a dict to keep track of all the worker_types with the key of worker_id self.worker_id_counter = 0 # used to create worker_ids # Only spin up containers if active_workers + pending_workers < max_workers. @@ -32,17 +35,22 @@ def __init__(self, max_worker_count): self.worker_idle_since = {} def register_worker(self, worker_id, worker_type): - """ Add a new worker - """ - logger.debug("In register worker worker_id: {} type:{}".format(worker_id, worker_type)) + """Add a new worker""" + log.debug(f"In register worker worker_id: {worker_id} type:{worker_type}") self.worker_types[worker_id] = worker_type if worker_type not in self.worker_queues: self.worker_queues[worker_type] = Queue() - self.total_worker_type_counts[worker_type] = self.total_worker_type_counts.get(worker_type, 0) + 1 - self.ready_worker_type_counts[worker_type] = self.ready_worker_type_counts.get(worker_type, 0) + 1 - self.pending_worker_type_counts[worker_type] = self.pending_worker_type_counts.get(worker_type, 0) - 1 + self.total_worker_type_counts[worker_type] = ( + self.total_worker_type_counts.get(worker_type, 0) + 1 + ) + self.ready_worker_type_counts[worker_type] = ( + self.ready_worker_type_counts.get(worker_type, 0) + 1 + ) + self.pending_worker_type_counts[worker_type] = ( + self.pending_worker_type_counts.get(worker_type, 0) - 1 + ) self.pending_workers -= 1 self.active_workers += 1 self.worker_queues[worker_type].put(worker_id) @@ -52,31 +60,33 @@ def register_worker(self, worker_id, worker_type): self.to_die_count[worker_type] = 0 def start_remove_worker(self, worker_type): - """ Increase the to_die_count in prep for a worker getting removed""" + """Increase the to_die_count in prep for a worker getting removed""" self.to_die_count[worker_type] += 1 def remove_worker(self, worker_id): - """ Remove the worker from the WorkerMap + """Remove the worker from the WorkerMap - Should already be KILLed by this point. + Should already be KILLed by this point. """ worker_type = self.worker_types[worker_id] self.active_workers -= 1 self.total_worker_type_counts[worker_type] -= 1 self.to_die_count[worker_type] -= 1 - self.total_worker_type_counts['unused'] += 1 - self.ready_worker_type_counts['unused'] += 1 - - def spin_up_workers(self, - next_worker_q, - mode='no_container', - container_cmd_options='', - address=None, - debug=None, - uid=None, - logdir=None, - worker_port=None): - """ Helper function to call 'remove' for appropriate workers in 'new_worker_map'. + self.total_worker_type_counts["unused"] += 1 + self.ready_worker_type_counts["unused"] += 1 + + def spin_up_workers( + self, + next_worker_q, + mode="no_container", + container_cmd_options="", + address=None, + debug=None, + uid=None, + logdir=None, + worker_port=None, + ): + """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. Parameters ---------- @@ -101,37 +111,58 @@ def spin_up_workers(self, """ spin_ups = {} - logger.debug("[SPIN UP] Next Worker Qsize: {}".format(len(next_worker_q))) - logger.debug("[SPIN UP] Active Workers: {}".format(self.active_workers)) - logger.debug("[SPIN UP] Pending Workers: {}".format(self.pending_workers)) - logger.debug("[SPIN UP] Max Worker Count: {}".format(self.max_worker_count)) - - if len(next_worker_q) > 0 and self.active_workers + self.pending_workers < self.max_worker_count: - logger.debug("[SPIN UP] Spinning up new workers!") - logger.debug(f"[SPIN up] Empty slots: {self.max_worker_count - self.active_workers - self.pending_workers}") - logger.debug(f"[SPIN up] New workers: {len(next_worker_q)}") - logger.debug(f"[SPIN up] Unused slots: {self.total_worker_type_counts['unused']}") - num_slots = min(self.max_worker_count - self.active_workers - self.pending_workers, len(next_worker_q), self.total_worker_type_counts['unused']) + log.debug(f"[SPIN UP] Next Worker Qsize: {len(next_worker_q)}") + log.debug(f"[SPIN UP] Active Workers: {self.active_workers}") + log.debug(f"[SPIN UP] Pending Workers: {self.pending_workers}") + log.debug(f"[SPIN UP] Max Worker Count: {self.max_worker_count}") + + if ( + len(next_worker_q) > 0 + and self.active_workers + self.pending_workers < self.max_worker_count + ): + log.debug("[SPIN UP] Spinning up new workers!") + log.debug( + "[SPIN up] Empty slots: %s", + self.max_worker_count - self.active_workers - self.pending_workers, + ) + log.debug(f"[SPIN up] New workers: {len(next_worker_q)}") + log.debug( + f"[SPIN up] Unused slots: {self.total_worker_type_counts['unused']}" + ) + num_slots = min( + self.max_worker_count - self.active_workers - self.pending_workers, + len(next_worker_q), + self.total_worker_type_counts["unused"], + ) for _ in range(num_slots): try: - proc = self.add_worker(worker_id=str(self.worker_id_counter), - worker_type=next_worker_q.pop(0), - container_cmd_options=container_cmd_options, - mode=mode, - address=address, debug=debug, - uid=uid, - logdir=logdir, - worker_port=worker_port) + proc = self.add_worker( + worker_id=str(self.worker_id_counter), + worker_type=next_worker_q.pop(0), + container_cmd_options=container_cmd_options, + mode=mode, + address=address, + debug=debug, + uid=uid, + logdir=logdir, + worker_port=worker_port, + ) except Exception: - logger.exception("Error spinning up worker! Skipping...") + log.exception("Error spinning up worker! Skipping...") continue else: spin_ups.update(proc) return spin_ups - def spin_down_workers(self, new_worker_map, worker_max_idletime=60, need_more=False, scheduler_mode='hard'): - """ Helper function to call 'remove' for appropriate workers in 'new_worker_map'. + def spin_down_workers( + self, + new_worker_map, + worker_max_idletime=60, + need_more=False, + scheduler_mode="hard", + ): + """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. Parameters ---------- @@ -144,12 +175,28 @@ def spin_down_workers(self, new_worker_map, worker_max_idletime=60, need_more=Fa List of removed worker types. """ if need_more: - return self._spin_down(new_worker_map, worker_max_idletime=worker_max_idletime, scheduler_mode=scheduler_mode, check_idle=False) + return self._spin_down( + new_worker_map, + worker_max_idletime=worker_max_idletime, + scheduler_mode=scheduler_mode, + check_idle=False, + ) else: - return self._spin_down(new_worker_map, worker_max_idletime=worker_max_idletime, scheduler_mode=scheduler_mode, check_idle=True) - - def _spin_down(self, new_worker_map, worker_max_idletime=60, scheduler_mode='hard', check_idle=True): - """ Helper function to call 'remove' for appropriate workers in 'new_worker_map'. + return self._spin_down( + new_worker_map, + worker_max_idletime=worker_max_idletime, + scheduler_mode=scheduler_mode, + check_idle=True, + ) + + def _spin_down( + self, + new_worker_map, + worker_max_idletime=60, + scheduler_mode="hard", + check_idle=True, + ): + """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. Parameters ---------- @@ -157,10 +204,12 @@ def _spin_down(self, new_worker_map, worker_max_idletime=60, scheduler_mode='har {worker_type: total_number_of_containers,...}. check_idle : boolean A boolean to indicate whether to check the idle time of containers or not - If checked, that means the workloads are not so busy, - and we can leave the container workers alive until the worker_max_idletime is reached. - Otherwise, that means the workloads are busy and we need to turn of some containers to acommodate - the workers, regardless of if it reaches the worker_max_idletime. + + If checked, that means the workloads are not so busy, and we can leave the + container workers alive until the worker_max_idletime is reached. Otherwise, + that means the workloads are busy and we need to turn of some containers to + acommodate the workers, regardless of if it reaches the worker_max_idletime. + Returns --------- List of removed worker types. @@ -168,21 +217,41 @@ def _spin_down(self, new_worker_map, worker_max_idletime=60, scheduler_mode='har spin_downs = [] container_switch_count = 0 for worker_type in self.total_worker_type_counts: - if worker_type == 'unused': + if worker_type == "unused": continue - if check_idle and time.time() - self.worker_idle_since[worker_type] < worker_max_idletime: - logger.debug(f"[SPIN DOWN] Current time: {time.time()}") - logger.debug(f"[SPIN DOWN] Idle since: {self.worker_idle_since[worker_type]}") - logger.debug(f"[SPIN DOWN] Worker type {worker_type} has not exceeded maximum idle time {worker_max_idletime}, continuing") + if ( + check_idle + and time.time() - self.worker_idle_since[worker_type] + < worker_max_idletime + ): + log.debug(f"[SPIN DOWN] Current time: {time.time()}") + log.debug( + f"[SPIN DOWN] Idle since: {self.worker_idle_since[worker_type]}" + ) + log.debug( + "[SPIN DOWN] Worker type %s has not exceeded maximum idle " + "time %s, continuing", + worker_type, + worker_max_idletime, + ) continue - num_remove = max(0, self.total_worker_type_counts[worker_type] - self.to_die_count.get(worker_type, 0) - new_worker_map.get(worker_type, 0)) - if scheduler_mode == 'hard': + num_remove = max( + 0, + self.total_worker_type_counts[worker_type] + - self.to_die_count.get(worker_type, 0) + - new_worker_map.get(worker_type, 0), + ) + if scheduler_mode == "hard": # Leave at least one worker alive in hard mode max_remove = max(0, self.total_worker_type_counts[worker_type] - 1) num_remove = min(num_remove, max_remove) if num_remove > 0: - logger.debug("[SPIN DOWN] Removing {} workers of type {}".format(num_remove, worker_type)) + log.debug( + "[SPIN DOWN] Removing {} workers of type {}".format( + num_remove, worker_type + ) + ) for _i in range(num_remove): spin_downs.append(worker_type) # A container switching is defined as a warm container must be @@ -196,17 +265,17 @@ def _spin_down(self, new_worker_map, worker_max_idletime=60, scheduler_mode='har def add_worker( self, worker_id=None, - mode='no_container', - worker_type='RAW', + mode="no_container", + worker_type="RAW", container_cmd_options="", walltime=1, address=None, debug=None, worker_port=None, logdir=None, - uid=None + uid=None, ): - """ Launch the appropriate worker + """Launch the appropriate worker Parameters ---------- @@ -221,64 +290,76 @@ def add_worker( if worker_id is None: str(random.random()) - debug = ' --debug' if debug else '' + debug = " --debug" if debug else "" - worker_id = ' --worker_id {}'.format(worker_id) + worker_id = f" --worker_id {worker_id}" self.worker_id_counter += 1 - cmd = (f'funcx-worker {debug}{worker_id} ' - f'-a {address} ' - f'-p {worker_port} ' - f'-t {worker_type} ' - f'--logdir={os.path.join(logdir, uid)} ') + cmd = ( + f"funcx-worker {debug}{worker_id} " + f"-a {address} " + f"-p {worker_port} " + f"-t {worker_type} " + f"--logdir={os.path.join(logdir, uid)} " + ) container_uri = None - if worker_type != 'RAW': + if worker_type != "RAW": container_uri = worker_type - logger.info("Command string :\n {}".format(cmd)) - logger.info("Mode: {}".format(mode)) - logger.info("Container uri: {}".format(container_uri)) - logger.info("Container cmd options: {}".format(container_cmd_options)) - logger.info("Worker type: {}".format(worker_type)) + log.info(f"Command string :\n {cmd}") + log.info(f"Mode: {mode}") + log.info(f"Container uri: {container_uri}") + log.info(f"Container cmd options: {container_cmd_options}") + log.info(f"Worker type: {worker_type}") - if mode == 'no_container': + if mode == "no_container": modded_cmd = cmd - elif mode == 'singularity_reuse': + elif mode == "singularity_reuse": if container_uri is None: - logger.warning("No container is specified for singularity mode. " - "Spawning a worker in a raw process instead.") + log.warning( + "No container is specified for singularity mode. " + "Spawning a worker in a raw process instead." + ) modded_cmd = cmd elif not os.path.exists(container_uri): - logger.warning(f"Container uri {container_uri} is not found. " - "Spawning a worker in a raw process instead.") + log.warning( + f"Container uri {container_uri} is not found. " + "Spawning a worker in a raw process instead." + ) modded_cmd = cmd else: - modded_cmd = f'singularity exec {container_cmd_options} {container_uri} {cmd}' - logger.info("Command string with singularity:\n {}".format(modded_cmd)) + modded_cmd = ( + f"singularity exec {container_cmd_options} {container_uri} {cmd}" + ) + log.info(f"Command string with singularity:\n {modded_cmd}") else: raise NameError("Invalid container launch mode.") try: - proc = subprocess.Popen(modded_cmd.split(), - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - shell=False) + proc = subprocess.Popen( + modded_cmd.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + shell=False, + ) except Exception: - logger.exception("Got an error in worker launch") + log.exception("Got an error in worker launch") raise - self.total_worker_type_counts['unused'] -= 1 - self.ready_worker_type_counts['unused'] -= 1 - self.pending_worker_type_counts[worker_type] = self.pending_worker_type_counts.get(worker_type, 0) + 1 + self.total_worker_type_counts["unused"] -= 1 + self.ready_worker_type_counts["unused"] -= 1 + self.pending_worker_type_counts[worker_type] = ( + self.pending_worker_type_counts.get(worker_type, 0) + 1 + ) self.pending_workers += 1 return {str(self.worker_id_counter - 1): proc} def get_next_worker_q(self, new_worker_map): - """ Helper function to generate a queue of next workers to spin up . + """Helper function to generate a queue of next workers to spin up . From a mapping generated by the scheduler Parameters @@ -293,20 +374,30 @@ def get_next_worker_q(self, new_worker_map): # next_worker_q = [] new_worker_list = [] - logger.debug(f"[GET_NEXT_WORKER] total_worker_type_counts: {self.total_worker_type_counts}") - logger.debug(f"[GET_NEXT_WORKER] pending_worker_type_counts: {self.pending_worker_type_counts}") + log.debug( + "[GET_NEXT_WORKER] total_worker_type_counts: %s", + self.total_worker_type_counts, + ) + log.debug( + "[GET_NEXT_WORKER] pending_worker_type_counts: %s", + self.pending_worker_type_counts, + ) for worker_type in new_worker_map: - cur_workers = self.total_worker_type_counts.get(worker_type, 0) + self.pending_worker_type_counts.get(worker_type, 0) + cur_workers = self.total_worker_type_counts.get( + worker_type, 0 + ) + self.pending_worker_type_counts.get(worker_type, 0) if new_worker_map[worker_type] > cur_workers: for _i in range(new_worker_map[worker_type] - cur_workers): # Add worker new_worker_list.append(worker_type) - # need_more is to reflect if a manager needs more workers than the current unused slots - # If yes, that means the manager needs to turn off some warm workers to serve the requests + # need_more is to reflect if a manager needs more workers than the current + # unused slots + # If yes, that means the manager needs to turn off some warm workers to serve + # the requests need_more = False - if len(new_worker_list) > self.total_worker_type_counts['unused']: + if len(new_worker_list) > self.total_worker_type_counts["unused"]: need_more = True # Randomly assign order of newly needed containers... add to spin-up queue. if len(new_worker_list) > 0: @@ -315,14 +406,12 @@ def get_next_worker_q(self, new_worker_map): return new_worker_list, need_more def update_worker_idle(self, worker_type): - """ Update the workers' last idle time by worker type - """ - logger.debug(f"[UPDATE_WORKER_IDLE] Worker idle since: {self.worker_idle_since}") + """Update the workers' last idle time by worker type""" + log.debug(f"[UPDATE_WORKER_IDLE] Worker idle since: {self.worker_idle_since}") self.worker_idle_since[worker_type] = time.time() def put_worker(self, worker): - """ Adds worker to the list of waiting workers - """ + """Adds worker to the list of waiting workers""" worker_type = self.worker_types[worker] if worker_type not in self.worker_queues: @@ -332,7 +421,7 @@ def put_worker(self, worker): self.worker_queues[worker_type].put(worker) def get_worker(self, worker_type): - """ Get a task and reduce the # of worker for that type by 1. + """Get a task and reduce the # of worker for that type by 1. Raises queue.Empty if empty """ worker = self.worker_queues[worker_type].get_nowait() @@ -340,28 +429,36 @@ def get_worker(self, worker_type): return worker def get_worker_counts(self): - """ Returns just the dict of worker_type and counts - """ + """Returns just the dict of worker_type and counts""" return self.total_worker_type_counts def ready_worker_count(self): return sum(self.ready_worker_type_counts.values()) def advertisement(self): - """ Manager capacity advertisement to interchange - The advertisement includes two parts. One is the read_worker_type_counts, - which reflects the capacity of different types of containers on the manager. - The other is the total number of workers of each type. - This include all the pending workers and to_die workers when advertising. - We need this "total" advertisement because we use killer tasks mechanisms to kill a worker. - When a manager is advertising, there may be some killer tssks in queue, - we want to ensure that the manager does not over-advertise its actualy capacity, - and let interchange decide if it is sending too many tasks to the manager. """ - ads = {'total': {}, 'free': {}} + Manager capacity advertisement to interchange. + + The advertisement includes two parts: + + One is the read_worker_type_counts, which reflects the capacity of different + types of containers on the manager. + + The other is the total number of workers of each type. This includes all the + pending workers and to_die workers when advertising. We need this "total" + advertisement because we use killer task mechanisms to kill a worker. When a + manager is advertising, there may be some killer tasks in queue, and we want to + ensure that the manager does not over-advertise its actual capacity. Instead, + let the interchange decide if it is sending too many tasks to the manager. + """ + ads = {"total": {}, "free": {}} total = dict(self.total_worker_type_counts) for worker_type in self.pending_worker_type_counts: - total[worker_type] = total.get(worker_type, 0) + self.pending_worker_type_counts[worker_type] - self.to_die_count.get(worker_type, 0) - ads['total'].update(total) - ads['free'].update(self.ready_worker_type_counts) + total[worker_type] = ( + total.get(worker_type, 0) + + self.pending_worker_type_counts[worker_type] + - self.to_die_count.get(worker_type, 0) + ) + ads["total"].update(total) + ads["free"].update(self.ready_worker_type_counts) return ads diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/zmq_pipes.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/zmq_pipes.py index 4409ae694..ebcb32712 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/zmq_pipes.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/zmq_pipes.py @@ -1,19 +1,18 @@ #!/usr/bin/env python3 -import zmq -import time -import pickle import logging +import pickle +import time + +import zmq -from funcx import set_file_logger from funcx_endpoint.executors.high_throughput.messages import Message -logger = logging.getLogger(__name__) +log = logging.getLogger(__name__) -class CommandClient(object): - """ CommandClient - """ +class CommandClient: + """CommandClient""" def __init__(self, ip_address, port_range): """ @@ -29,12 +28,14 @@ def __init__(self, ip_address, port_range): self.context = zmq.Context() self.zmq_socket = self.context.socket(zmq.DEALER) self.zmq_socket.set_hwm(0) - self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address), - min_port=port_range[0], - max_port=port_range[1]) + self.port = self.zmq_socket.bind_to_random_port( + f"tcp://{ip_address}", + min_port=port_range[0], + max_port=port_range[1], + ) def run(self, message): - """ This function needs to be fast at the same time aware of the possibility of + """This function needs to be fast at the same time aware of the possibility of ZMQ pipes overflowing. The timeout increases slowly if contention is detected on ZMQ pipes. @@ -51,9 +52,8 @@ def close(self): self.context.term() -class TasksOutgoing(object): - """ Outgoing task queue from the executor to the Interchange - """ +class TasksOutgoing: + """Outgoing task queue from the executor to the Interchange""" def __init__(self, ip_address, port_range): """ @@ -69,14 +69,16 @@ def __init__(self, ip_address, port_range): self.context = zmq.Context() self.zmq_socket = self.context.socket(zmq.DEALER) self.zmq_socket.set_hwm(0) - self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address), - min_port=port_range[0], - max_port=port_range[1]) + self.port = self.zmq_socket.bind_to_random_port( + f"tcp://{ip_address}", + min_port=port_range[0], + max_port=port_range[1], + ) self.poller = zmq.Poller() self.poller.register(self.zmq_socket, zmq.POLLOUT) def put(self, message, max_timeout=1000): - """ This function needs to be fast at the same time aware of the possibility of + """This function needs to be fast at the same time aware of the possibility of ZMQ pipes overflowing. The timeout increases slowly if contention is detected on ZMQ pipes. @@ -90,7 +92,8 @@ def put(self, message, max_timeout=1000): message : py object Python object to send max_timeout : int - Max timeout in milliseconds that we will wait for before raising an exception + Max timeout in milliseconds that we will wait for before raising an + exception Raises ------ @@ -108,11 +111,15 @@ def put(self, message, max_timeout=1000): return else: timeout_ms += 1 - logger.debug("Not sending due to full zmq pipe, timeout: {} ms".format(timeout_ms)) + log.debug( + "Not sending due to full zmq pipe, timeout: {} ms".format( + timeout_ms + ) + ) current_wait += timeout_ms # Send has failed. - logger.debug("Remote side has been unresponsive for {}".format(current_wait)) + log.debug(f"Remote side has been unresponsive for {current_wait}") raise zmq.error.Again def close(self): @@ -120,9 +127,8 @@ def close(self): self.context.term() -class ResultsIncoming(object): - """ Incoming results queue from the Interchange to the executor - """ +class ResultsIncoming: + """Incoming results queue from the Interchange to the executor""" def __init__(self, ip_address, port_range): """ @@ -138,9 +144,11 @@ def __init__(self, ip_address, port_range): self.context = zmq.Context() self.results_receiver = self.context.socket(zmq.DEALER) self.results_receiver.set_hwm(0) - self.port = self.results_receiver.bind_to_random_port("tcp://{}".format(ip_address), - min_port=port_range[0], - max_port=port_range[1]) + self.port = self.results_receiver.bind_to_random_port( + f"tcp://{ip_address}", + min_port=port_range[0], + max_port=port_range[1], + ) def get(self, block=True, timeout=None): block_messages = self.results_receiver.recv() @@ -150,7 +158,10 @@ def get(self, block=True, timeout=None): try: res = Message.unpack(block_messages) except Exception: - logger.exception(f"Message in results queue is not pickle/Message formatted:{block_messages}") + log.exception( + "Message in results queue is not pickle/Message formatted: %s", + block_messages, + ) return res def request_close(self): diff --git a/funcx_endpoint/funcx_endpoint/logging_config.py b/funcx_endpoint/funcx_endpoint/logging_config.py new file mode 100644 index 000000000..d4e660689 --- /dev/null +++ b/funcx_endpoint/funcx_endpoint/logging_config.py @@ -0,0 +1,71 @@ +""" +This module contains logging configuration for the funcx-endpoint application. +""" + +import logging +import logging.config +import logging.handlers +import pathlib +import typing as t + +log = logging.getLogger(__name__) + +_DEFAULT_LOGFILE = str(pathlib.Path.home() / ".funcx" / "endpoint.log") + + +def setup_logging( + *, + logfile: t.Optional[str] = None, + console_enabled: bool = True, + debug: bool = False +) -> None: + if logfile is None: + logfile = _DEFAULT_LOGFILE + + default_config = { + "version": 1, + "formatters": { + "streamfmt": { + "format": "%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + "filefmt": { + "format": ( + "%(asctime)s.%(msecs)03d " + "%(name)s:%(lineno)d [%(levelname)s] %(message)s" + ), + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "streamfmt", + }, + "logfile": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "filename": logfile, + "formatter": "filefmt", + "maxBytes": 100 * 1024 * 1024, + "backupCount": 1, + }, + }, + "loggers": { + "funcx_endpoint": { + "level": "DEBUG" if debug else "INFO", + "handlers": ["console", "logfile"] if console_enabled else ["logfile"], + }, + # configure for the funcx SDK as well + "funcx": { + "level": "DEBUG" if debug else "WARNING", + "handlers": ["logfile", "console"] if console_enabled else ["logfile"], + }, + }, + } + + logging.config.dictConfig(default_config) + + if debug: + log.debug("debug logging enabled") diff --git a/funcx_endpoint/funcx_endpoint/mock_broker/README.rst b/funcx_endpoint/funcx_endpoint/mock_broker/README.rst deleted file mode 100644 index 9bc236d39..000000000 --- a/funcx_endpoint/funcx_endpoint/mock_broker/README.rst +++ /dev/null @@ -1,53 +0,0 @@ -Notes -===== - - -We want the mock_broker to be hosting a REST service. This service will have the following routes: - -/register ---------- - -This route expects a POST with a json payload that identifies the endpoint info and responds with a -json response. - -For eg: - -POST payload:: - - { - 'python_v': '3.6', - 'os': 'Linux', - 'hname': 'borgmachine2', - 'username': 'yadu', - 'funcx_v': '0.0.1' - } - - -Response payload:: - - { - 'endpoint_id': endpoint_id, - 'task_url': 'tcp://55.77.66.22:50001', - 'result_url': 'tcp://55.77.66.22:50002', - 'command_port': 'tcp://55.77.66.22:50003' - } - - - - -Architecture and Notes ----------------------- - -The endpoint registers and receives the information - -``` - TaskQ ResultQ - | | -REST /register--> Forwarder----->Executor Client - ^ | ^ - | | | - | v | - | +-------------> Interchange -User ----> Endpoint ----| - +--> Provider -``` diff --git a/funcx_endpoint/funcx_endpoint/mock_broker/__init__.py b/funcx_endpoint/funcx_endpoint/mock_broker/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funcx_endpoint/funcx_endpoint/mock_broker/forwarder.py b/funcx_endpoint/funcx_endpoint/mock_broker/forwarder.py deleted file mode 100644 index 3e0bbd430..000000000 --- a/funcx_endpoint/funcx_endpoint/mock_broker/forwarder.py +++ /dev/null @@ -1,191 +0,0 @@ -import logging -from functools import partial -import uuid -import os -import queue -from multiprocessing import Queue - -from multiprocessing import Process -from funcx import set_file_logger - - -def double(x): - return x * 2 - - -def failer(x): - return x / 0 - - -class Forwarder(Process): - """ Forwards tasks/results between the executor and the queues - - Tasks_Q Results_Q - | ^ - | | - V | - Executors - - Todo : We need to clarify what constitutes a task that comes down - the task pipe. Does it already have the code fragment? Or does that need to be sorted - out from some DB ? - """ - - def __init__(self, task_q, result_q, executor, endpoint_id, - logdir="forwarder", logging_level=logging.INFO): - """ - Params: - task_q : A queue object - Any queue object that has get primitives. This must be a thread-safe queue. - - result_q : A queue object - Any queue object that has put primitives. This must be a thread-safe queue. - - executor: Executor object - Executor to which tasks are to be forwarded - - endpoint_id: str - Usually a uuid4 as string that identifies the executor - - logdir: str - Path to logdir - - logging_level : int - Logging level as defined in the logging module. Default: logging.INFO (20) - - """ - super().__init__() - self.logdir = logdir - os.makedirs(self.logdir, exist_ok=True) - - global logger - logger = set_file_logger(os.path.join(self.logdir, "forwarder.{}.log".format(endpoint_id)), - level=logging_level) - - logger.info("Initializing forwarder for endpoint:{}".format(endpoint_id)) - self.task_q = task_q - self.result_q = result_q - self.executor = executor - self.endpoint_id = endpoint_id - self.internal_q = Queue() - self.client_ports = None - - def handle_app_update(self, task_id, future): - """ Triggered when the executor sees a task complete. - - This can be further optimized at the executor level, where we trigger this - or a similar function when we see a results item inbound from the interchange. - """ - logger.debug("[RESULTS] Updating result") - try: - res = future.result() - self.result_q.put(task_id, res) - except Exception: - logger.debug("Task:{} failed".format(task_id)) - # Todo : Since we caught an exception, we should wrap it here, and send it - # back onto the results queue. - else: - logger.debug("Task:{} succeeded".format(task_id)) - - def run(self): - """ Process entry point. - """ - logger.info("[TASKS] Loop starting") - logger.info("[TASKS] Executor: {}".format(self.executor)) - - try: - self.task_q.connect() - self.result_q.connect() - except Exception: - logger.exception("Connecting to the queues have failed") - - self.executor.start() - conn_info = self.executor.connection_info - self.internal_q.put(conn_info) - logger.info("[TASKS] Endpoint connection info: {}".format(conn_info)) - - while True: - try: - task = self.task_q.get(timeout=10) - logger.debug("[TASKS] Not doing {}".format(task)) - except queue.Empty: - # This exception catching isn't very general, - # Essentially any timeout exception should be caught and ignored - logger.debug("[TASKS] Waiting for tasks") - pass - else: - # TODO: We are piping down a mock task. This needs to be fixed. - task_id = str(uuid.uuid4()) - args = [5] - kwargs = {} - fu = self.executor.submit(double, *args, **kwargs) - fu.add_done_callback(partial(self.handle_app_update, task_id)) - - logger.info("[TASKS] Terminating self due to user requested kill") - return - - @property - def connection_info(self): - """Get the client ports to which the interchange must connect to - """ - - if not self.client_ports: - self.client_ports = self.internal_q.get() - - return self.client_ports - - -def spawn_forwarder(address, - executor=None, - task_q=None, - result_q=None, - endpoint_id=uuid.uuid4(), - logging_level=logging.INFO): - """ Spawns a forwarder and returns the forwarder process for tracking. - - Parameters - ---------- - - address : str - IP Address to which the endpoint must connect - - executor : Executor object. Optional - Executor object to be instantiated. - - task_q : Queue object - Queue object matching funcx.queues.base.FuncxQueue interface - - logging_level : int - Logging level as defined in the logging module. Default: logging.INFO (20) - - endpoint_id : uuid string - Endpoint id for which the forwarder is being spawned. - - Returns: - A Forwarder object - """ - from funcx_endpoint.queues import RedisQueue - from funcx_endpoint.executors import HighThroughputExecutor as HTEX - from parsl.providers import LocalProvider - from parsl.channels import LocalChannel - - task_q = RedisQueue('task', '127.0.0.1') - result_q = RedisQueue('result', '127.0.0.1') - - if not executor: - executor = HTEX(label='htex', - provider=LocalProvider( - channel=LocalChannel), - address=address) - - fw = Forwarder(task_q, result_q, executor, - "Endpoint_{}".format(endpoint_id), - logging_level=logging_level) - fw.start() - return fw - - -if __name__ == "__main__": - - pass - # test() diff --git a/funcx_endpoint/funcx_endpoint/mock_broker/mock_broker.py b/funcx_endpoint/funcx_endpoint/mock_broker/mock_broker.py deleted file mode 100644 index 4a1a1ef8f..000000000 --- a/funcx_endpoint/funcx_endpoint/mock_broker/mock_broker.py +++ /dev/null @@ -1,75 +0,0 @@ -""" The broker service - -This REST service fields incoming registration requests from endpoints, -creates an appropriate forwarder to which the endpoint can connect up. -""" - - -import bottle -from bottle import post, run, request, route -import argparse -import json -import uuid -import sys - -from funcx_endpoint.mock_broker.forwarder import Forwarder, spawn_forwarder - - -@post('/register') -def register(): - """ Register an endpoint request - - 1. Start an executor client object corresponding to the endpoint - 2. Pass connection info back as a json response. - """ - - print("Request: ", request) - print("foo: ", request.app.ep_mapping) - print(json.load(request.body)) - endpoint_details = json.load(request.body) - print(endpoint_details) - - # Here we want to start an executor client. - # Make sure to not put anything into the client, until after an interchange has - # connected to avoid clogging up the pipe. Submits will block if the client has - # no endpoint connected. - endpoint_id = str(uuid.uuid4()) - fw = spawn_forwarder(request.app.address, endpoint_id=endpoint_id) - connection_info = fw.connection_info - ret_package = {'endpoint_id': endpoint_id} - ret_package.update(connection_info) - print("Ret_package : ", ret_package) - - print("Ep_id: ", endpoint_id) - request.app.ep_mapping[endpoint_id] = ret_package - return ret_package - - -@route('/list_mappings') -def list_mappings(): - return request.app.ep_mapping - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--port", default=8088, - help="Port at which the service will listen on") - parser.add_argument("-a", "--address", default='127.0.0.1', - help="Address at which the service is running") - parser.add_argument("-c", "--config", default=None, - help="Config file") - parser.add_argument("-d", "--debug", action='store_true', - help="Enables debug logging") - - args = parser.parse_args() - - app = bottle.default_app() - app.address = args.address - app.ep_mapping = {} - - try: - run(host='localhost', app=app, port=int(args.port), debug=True) - except Exception as e: - # This doesn't do anything - print("Caught exception : {}".format(e)) - exit(-1) diff --git a/funcx_endpoint/funcx_endpoint/mock_broker/mock_tester.py b/funcx_endpoint/funcx_endpoint/mock_broker/mock_tester.py deleted file mode 100644 index 07c59f39c..000000000 --- a/funcx_endpoint/funcx_endpoint/mock_broker/mock_tester.py +++ /dev/null @@ -1,32 +0,0 @@ -import argparse -import requests -import funcx -import sys -import platform -import getpass - - -def test(address): - r = requests.post(address + '/register', - json={'python_v': "{}.{}".format(sys.version_info.major, - sys.version_info.minor), - 'os': platform.system(), - 'hname': platform.node(), - 'username': getpass.getuser(), - 'funcx_v': str(funcx.__version__) - } - ) - print("Status code :", r.status_code) - print("Json : ", r.json()) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--port", default=8088, - help="Port at which the service will listen on") - parser.add_argument("-d", "--debug", action='store_true', - help="Enables debug logging") - - args = parser.parse_args() - - test("http://0.0.0.0:{}".format(args.port)) diff --git a/funcx_endpoint/funcx_endpoint/mock_broker/test.py b/funcx_endpoint/funcx_endpoint/mock_broker/test.py deleted file mode 100644 index e73b6206d..000000000 --- a/funcx_endpoint/funcx_endpoint/mock_broker/test.py +++ /dev/null @@ -1,69 +0,0 @@ -from funcx_endpoint.executors import HighThroughputExecutor as HTEX -from parsl.providers import LocalProvider -from parsl.channels import LocalChannel -import parsl -import time -parsl.set_stream_logger() - - -def double(x): - return x * 2 - - -def fail(x): - return x / 0 - - -def test_1(): - - x = HTEX(label='htex', - provider=LocalProvider( - channel=LocalChannel), - address="127.0.0.1", - ) - task_p, result_p, command_p = x.start() - print(task_p, result_p, command_p) - print("Executor initialized : ", x) - - args = [2] - kwargs = {} - f1 = x.submit(double, *args, **kwargs) - print("Sent task with :", f1) - args = [2] - kwargs = {} - f2 = x.submit(fail, *args, **kwargs) - - print("hi") - while True: - stop = input("Stop ? (y/n)") - if stop == "y": - break - - print("F1: {}, f2: {}".format(f1.done(), f2.done())) - x.shutdown() - - -def test_2(): - - from funcx_endpoint.executors.high_throughput.executor import executor_starter - - htex = HTEX(label='htex', - provider=LocalProvider( - channel=LocalChannel), - address="127.0.0.1") - print("Foo") - executor_starter(htex, "forwarder", "ep_01") - print("Here") - - -def test_3(): - from funcx_endpoint.mock_broker.forwarder import Forwarder, spawn_forwarder - fw = spawn_forwarder("127.0.0.1", endpoint_id="0001") - print("Spawned forwarder") - time.sleep(120) - print("Terminating") - fw.terminate() - - -if __name__ == '__main__': - test_3() diff --git a/funcx_endpoint/funcx_endpoint/providers/__init__.py b/funcx_endpoint/funcx_endpoint/providers/__init__.py index cc5002e00..72019166b 100644 --- a/funcx_endpoint/funcx_endpoint/providers/__init__.py +++ b/funcx_endpoint/funcx_endpoint/providers/__init__.py @@ -1,3 +1,3 @@ from funcx_endpoint.providers.kubernetes.kube import KubernetesProvider -__all__ = ['KubernetesProvider'] +__all__ = ["KubernetesProvider"] diff --git a/funcx_endpoint/funcx_endpoint/providers/kubernetes/kube.py b/funcx_endpoint/funcx_endpoint/providers/kubernetes/kube.py index a282c8ea3..7802cfa5f 100644 --- a/funcx_endpoint/funcx_endpoint/providers/kubernetes/kube.py +++ b/funcx_endpoint/funcx_endpoint/providers/kubernetes/kube.py @@ -1,18 +1,15 @@ import logging import queue import time - -from funcx_endpoint.providers.kubernetes.template import template_string - -logger = logging.getLogger("interchange.kube_provider") - -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import typeguard from parsl.errors import OptionalModuleMissing from parsl.providers.provider_base import ExecutionProvider from parsl.utils import RepresentationMixin +from funcx_endpoint.providers.kubernetes.template import template_string + try: from kubernetes import client, config @@ -20,6 +17,8 @@ except (ImportError, NameError, FileNotFoundError): _kubernetes_enabled = False +log = logging.getLogger(__name__) + class KubernetesProvider(ExecutionProvider, RepresentationMixin): """Kubernetes execution provider @@ -54,9 +53,10 @@ class KubernetesProvider(ExecutionProvider, RepresentationMixin): This is the memory "requests" option for resource specification on kubernetes. Check kubernetes docs for more details. Default is 250Mi. parallelism : float - Ratio of provisioned task slots to active tasks. A parallelism value of 1 represents aggressive - scaling where as many resources as possible are used; parallelism close to 0 represents - the opposite situation in which as few resources as possible (i.e., min_blocks) are used. + Ratio of provisioned task slots to active tasks. A parallelism value of 1 + represents aggressive scaling where as many resources as possible are used; + parallelism close to 0 represents the opposite situation in which as few + resources as possible (i.e., min_blocks) are used. worker_init : str Command to be run first for the workers, such as `python start.py`. secret : str @@ -156,23 +156,23 @@ def submit(self, cmd_string, tasks_per_node, task_type, job_name="funcx"): """ cur_timestamp = str(time.time() * 1000).split(".")[0] - job_name = "{0}-{1}".format(job_name, cur_timestamp) + job_name = f"{job_name}-{cur_timestamp}" # Use default image image = self.image if task_type == "RAW" else task_type # Set the pod name if not self.pod_name: - pod_name = "{}".format(job_name) + pod_name = f"{job_name}" else: - pod_name = "{}-{}".format(self.pod_name, cur_timestamp) + pod_name = f"{self.pod_name}-{cur_timestamp}" - logger.debug("cmd_string is {}".format(cmd_string)) + log.debug(f"cmd_string is {cmd_string}") formatted_cmd = template_string.format( command=cmd_string, worker_init=self.worker_init ) - logger.info("[KUBERNETES] Scaling out a pod with name :{}".format(pod_name)) + log.info(f"[KUBERNETES] Scaling out a pod with name :{pod_name}") self._create_pod( image=image, pod_name=pod_name, @@ -203,7 +203,7 @@ def status(self, job_ids): - ExecutionProviderExceptions or its subclasses """ # This is a hack - logger.debug("Getting Kubernetes provider status") + log.debug("Getting Kubernetes provider status") status = {} for jid in job_ids: if jid in self.resources_by_pod_name: @@ -224,7 +224,7 @@ def cancel(self, num_pods, task_type=None): break else: num_pods -= 1 - logger.info("[KUBERNETES] The to_kill pods are {}".format(to_kill)) + log.info(f"[KUBERNETES] The to_kill pods are {to_kill}") rets = self._cancel(to_kill) return to_kill, rets @@ -236,13 +236,13 @@ def _cancel(self, job_ids): [True/False...] : If the cancel operation fails the entire list will be False. """ for job in job_ids: - logger.debug("Terminating job/proc_id: {0}".format(job)) + log.debug(f"Terminating job/proc_id: {job}") # Here we are assuming that for local, the job_ids are the process id's self._delete_pod(job) self.resources_by_pod_name[job]["status"] = "CANCELLED" del self.resources_by_pod_name[job] - logger.debug( + log.debug( "[KUBERNETES] The resources in kube provider is {}".format( self.resources_by_pod_name ) @@ -291,7 +291,7 @@ def _create_pod( # Create the enviornment variables and command to initiate IPP environment_vars = client.V1EnvVar(name="TEST", value="SOME DATA") - launch_args = ["-c", "{0}".format(cmd_string)] + launch_args = ["-c", f"{cmd_string}"] volume_mounts = [] # Create mount paths for the volumes @@ -342,7 +342,7 @@ def _create_pod( api_response = self.kube_client.create_namespaced_pod( namespace=self.namespace, body=pod ) - logger.debug("Pod created. status='{0}'".format(str(api_response.status))) + log.debug(f"Pod created. status='{str(api_response.status)}'") def _delete_pod(self, pod_name): """Delete a pod""" @@ -350,7 +350,7 @@ def _delete_pod(self, pod_name): api_response = self.kube_client.delete_namespaced_pod( name=pod_name, namespace=self.namespace, body=client.V1DeleteOptions() ) - logger.debug("Pod deleted. status='{0}'".format(str(api_response.status))) + log.debug(f"Pod deleted. status='{str(api_response.status)}'") @property def label(self): diff --git a/funcx_endpoint/funcx_endpoint/queues/__init__.py b/funcx_endpoint/funcx_endpoint/queues/__init__.py deleted file mode 100644 index 102a72285..000000000 --- a/funcx_endpoint/funcx_endpoint/queues/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from funcx_endpoint.queues.redis.redis_q import RedisQueue - -__all__ = ['RedisQueue'] diff --git a/funcx_endpoint/funcx_endpoint/queues/base.py b/funcx_endpoint/funcx_endpoint/queues/base.py deleted file mode 100644 index 875933845..000000000 --- a/funcx_endpoint/funcx_endpoint/queues/base.py +++ /dev/null @@ -1,50 +0,0 @@ -from abc import ABCMeta, abstractmethod, abstractproperty -from funcx.utils.errors import FuncxError - - -class NotConnected(FuncxError): - """ Queue is not connected/active - """ - - def __init__(self, queue): - self.queue = queue - - def __repr__(self): - return "Queue {} is not connected. Cannot execute queue operations".format(self.queue) - - -class FuncxQueue(metaclass=ABCMeta): - """ Queue interface required by the Forwarder - - This is a metaclass that only enforces concrete implementations of - functionality by the child classes. - """ - - @abstractmethod - def connect(self, *args, **kwargs): - """ Connects and creates the queue. - The queue is not active until this is called - """ - pass - - @abstractmethod - def get(self, *args, **kwargs): - """ Get an item from the Queue - """ - pass - - @abstractmethod - def put(self, *args, **kwargs): - """ Put an item into the Queue - """ - pass - - @abstractproperty - def is_connected(self): - """ Returns the connected status of the queue. - - Returns - ------- - Bool - """ - pass diff --git a/funcx_endpoint/funcx_endpoint/queues/redis/__init__.py b/funcx_endpoint/funcx_endpoint/queues/redis/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funcx_endpoint/funcx_endpoint/queues/redis/redis_q.py b/funcx_endpoint/funcx_endpoint/queues/redis/redis_q.py deleted file mode 100644 index 3bf267b3c..000000000 --- a/funcx_endpoint/funcx_endpoint/queues/redis/redis_q.py +++ /dev/null @@ -1,98 +0,0 @@ -import redis -import json - -from funcx_endpoint.queues.base import NotConnected, FuncxQueue - - -class RedisQueue(FuncxQueue): - """ A basic redis queue - - The queue only connects when the `connect` method is called to avoid - issues with passing an object across processes. - - Parameters - ---------- - - hostname : str - Hostname of the redis server - - port : int - Port at which the redis server can be reached. Default: 6379 - - """ - - def __init__(self, prefix, hostname, port=6379): - """ Initialize - """ - self.hostname = hostname - self.port = port - self.redis_client = None - self.prefix = prefix - - def connect(self): - """ Connects to the Redis server - """ - try: - if not self.redis_client: - self.redis_client = redis.StrictRedis(host=self.hostname, port=self.port, decode_responses=True) - except redis.exceptions.ConnectionError: - print("ConnectionError while trying to connect to Redis@{}:{}".format(self.hostname, - self.port)) - - raise - - def get(self, timeout=1): - """ Get an item from the redis queue - - Parameters - ---------- - timeout : int - Timeout for the blocking get in seconds - """ - try: - task_list, task_id = self.redis_client.blpop(f'{self.prefix}_list', timeout=timeout) - jtask_info = self.redis_client.get(f'{self.prefix}:{task_id}') - task_info = json.loads(jtask_info) - except AttributeError: - raise NotConnected(self) - except redis.exceptions.ConnectionError: - print(f"ConnectionError while trying to connect to Redis@{self.hostname}:{self.port}") - raise - - return task_id, task_info - - def put(self, key, payload): - """ Put's the key:payload into a dict and pushes the key onto a queue - Parameters - ---------- - key : str - The task_id to be pushed - - payload : dict - Dict of task information to be stored - """ - try: - self.redis_client.set(f'{self.prefix}:{key}', json.dumps(payload)) - self.redis_client.rpush(f'{self.prefix}_list', key) - except AttributeError: - raise NotConnected(self) - except redis.exceptions.ConnectionError: - print("ConnectionError while trying to connect to Redis@{}:{}".format(self.hostname, - self.port)) - raise - - @property - def is_connected(self): - return self.redis_client is not None - - -def test(): - rq = RedisQueue('task', '127.0.0.1') - rq.connect() - rq.put("01", {'a': 1, 'b': 2}) - res = rq.get(timeout=1) - print("Result : ", res) - - -if __name__ == '__main__': - test() diff --git a/funcx_endpoint/funcx_endpoint/strategies/__init__.py b/funcx_endpoint/funcx_endpoint/strategies/__init__.py index a4c22e78d..5e6eb0ef7 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/__init__.py +++ b/funcx_endpoint/funcx_endpoint/strategies/__init__.py @@ -1,8 +1,5 @@ from funcx_endpoint.strategies.base import BaseStrategy -from funcx_endpoint.strategies.simple import SimpleStrategy from funcx_endpoint.strategies.kube_simple import KubeSimpleStrategy +from funcx_endpoint.strategies.simple import SimpleStrategy - -__all__ = ['BaseStrategy', - 'SimpleStrategy', - 'KubeSimpleStrategy'] +__all__ = ["BaseStrategy", "SimpleStrategy", "KubeSimpleStrategy"] diff --git a/funcx_endpoint/funcx_endpoint/strategies/base.py b/funcx_endpoint/funcx_endpoint/strategies/base.py index 719fd8b0f..6a4e166c6 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/base.py +++ b/funcx_endpoint/funcx_endpoint/strategies/base.py @@ -1,12 +1,11 @@ -import sys -import threading import logging +import threading import time -logger = logging.getLogger("interchange.strategy.base") +log = logging.getLogger(__name__) -class BaseStrategy(object): +class BaseStrategy: """Implements threshold-interval based flow control. The overall goal is to trap the flow of apps from the @@ -61,7 +60,9 @@ def __init__(self, *args, threshold=20, interval=5): self._event_buffer = [] self._wake_up_time = time.time() + 1 self._kill_event = threading.Event() - self._thread = threading.Thread(target=self._wake_up_timer, args=(self._kill_event,)) + self._thread = threading.Thread( + target=self._wake_up_timer, args=(self._kill_event,) + ) self._thread.daemon = True def start(self, interchange): @@ -72,21 +73,25 @@ def start(self, interchange): Interchange to bind the strategy to """ self.interchange = interchange - if hasattr(interchange, 'provider'): - logger.debug("Strategy bounds-> init:{}, min:{}, max:{}".format( - interchange.provider.init_blocks, - interchange.provider.min_blocks, - interchange.provider.max_blocks)) + if hasattr(interchange, "provider"): + log.debug( + "Strategy bounds-> init:{}, min:{}, max:{}".format( + interchange.provider.init_blocks, + interchange.provider.min_blocks, + interchange.provider.max_blocks, + ) + ) self._thread.start() def strategize(self, *args, **kwargs): - """ Strategize is called everytime the threshold or the interval is hit - """ - logger.debug("Strategize called with {} {}".format(args, kwargs)) + """Strategize is called everytime the threshold or the interval is hit""" + log.debug(f"Strategize called with {args} {kwargs}") def _wake_up_timer(self, kill_event): - """Internal. This is the function that the thread will execute. - waits on an event so that the thread can make a quick exit when close() is called + """ + Internal. This is the function that the thread will execute. + waits on an event so that the thread can make a quick exit when close() is + called Args: - kill_event (threading.Event) : Event to wait on @@ -103,7 +108,7 @@ def _wake_up_timer(self, kill_event): return if prev == self._wake_up_time: - self.make_callback(kind='timer') + self.make_callback(kind="timer") else: print("Sleeping a bit more") @@ -115,7 +120,7 @@ def notify(self, event_id): self._event_buffer.extend([event_id]) self._event_count += 1 if self._event_count >= self.threshold: - logger.debug("Eventcount >= threshold") + log.debug("Eventcount >= threshold") self.make_callback(kind="event") def make_callback(self, kind=None): @@ -135,7 +140,7 @@ def close(self): self._thread.join() -class Timer(object): +class Timer: """This timer is a simplified version of the FlowControl timer. This timer does not employ notify events. @@ -173,13 +178,16 @@ def __init__(self, callback, *args, interval=5): self._wake_up_time = time.time() + 1 self._kill_event = threading.Event() - self._thread = threading.Thread(target=self._wake_up_timer, args=(self._kill_event,)) + self._thread = threading.Thread( + target=self._wake_up_timer, args=(self._kill_event,) + ) self._thread.daemon = True self._thread.start() def _wake_up_timer(self, kill_event): """Internal. This is the function that the thread will execute. - waits on an event so that the thread can make a quick exit when close() is called + waits on an event so that the thread can make a quick exit when close() is + called Args: - kill_event (threading.Event) : Event to wait on @@ -197,18 +205,16 @@ def _wake_up_timer(self, kill_event): return if prev == self._wake_up_time: - self.make_callback(kind='timer') + self.make_callback(kind="timer") else: print("Sleeping a bit more") def make_callback(self, kind=None): - """Makes the callback and resets the timer. - """ + """Makes the callback and resets the timer.""" self._wake_up_time = time.time() + self.interval self.callback(*self.cb_args) def close(self): - """Merge the threads and terminate. - """ + """Merge the threads and terminate.""" self._kill_event.set() self._thread.join() diff --git a/funcx_endpoint/funcx_endpoint/strategies/kube_simple.py b/funcx_endpoint/funcx_endpoint/strategies/kube_simple.py index 88eb8f418..68d8ba60c 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/kube_simple.py +++ b/funcx_endpoint/funcx_endpoint/strategies/kube_simple.py @@ -1,19 +1,16 @@ -from funcx_endpoint.strategies.base import BaseStrategy -import math import logging +import math import time -logger = logging.getLogger("interchange.strategy.KubeSimple") +from funcx_endpoint.strategies.base import BaseStrategy + +log = logging.getLogger(__name__) class KubeSimpleStrategy(BaseStrategy): - """ Implements the simple strategy for Kubernetes - """ + """Implements the simple strategy for Kubernetes""" - def __init__(self, *args, - threshold=20, - interval=1, - max_idletime=60): + def __init__(self, *args, threshold=20, interval=1, max_idletime=60): """Initialize the flowcontrol object. We start the timer thread here @@ -27,11 +24,12 @@ def __init__(self, *args, seconds after which timer expires max_idletime: (int) - maximum idle time(seconds) allowed for resources after which strategy will try to kill them. + maximum idle time(seconds) allowed for resources after which strategy will + try to kill them. default: 60s """ - logger.info("KubeSimpleStrategy Initialized") + log.info("KubeSimpleStrategy Initialized") super().__init__(*args, threshold=threshold, interval=interval) self.max_idletime = max_idletime self.executors_idle_since = {} @@ -40,8 +38,7 @@ def strategize(self, *args, **kwargs): try: self._strategize(*args, **kwargs) except Exception as e: - logger.exception("Caught error in strategize : {}".format(e)) - pass + log.exception(f"Caught error in strategize : {e}") def _strategize(self, *args, **kwargs): max_pods = self.interchange.provider.max_blocks @@ -51,26 +48,31 @@ def _strategize(self, *args, **kwargs): managers_per_pod = 1 workers_per_pod = self.interchange.max_workers_per_node - if workers_per_pod == float('inf'): + if workers_per_pod == float("inf"): workers_per_pod = 1 parallelism = self.interchange.provider.parallelism active_tasks = self.interchange.get_total_tasks_outstanding() - logger.debug(f"Pending tasks : {active_tasks}") + log.debug(f"Pending tasks : {active_tasks}") status = self.interchange.provider_status() - logger.debug(f"Provider status : {status}") + log.debug(f"Provider status : {status}") for task_type in active_tasks.keys(): active_pods = status.get(task_type, 0) active_slots = active_pods * workers_per_pod * managers_per_pod active_tasks_per_type = active_tasks[task_type] - logger.debug( - 'Endpoint has {} active tasks of {}, {} active blocks, {} connected workers for {}'.format( - active_tasks_per_type, task_type, active_pods, - self.interchange.get_total_live_workers(), task_type)) + log.debug( + "Endpoint has %s active tasks of %s, %s active blocks, " + "%s connected workers for %s", + active_tasks_per_type, + task_type, + active_pods, + self.interchange.get_total_live_workers(), + task_type, + ) # Reset the idle time if we are currently running tasks if active_tasks_per_type > 0: @@ -79,29 +81,47 @@ def _strategize(self, *args, **kwargs): # Scale down only if there are no active tasks to avoid having to find which # workers are unoccupied if active_tasks_per_type == 0 and active_pods > min_pods: - # We want to make sure that max_idletime is reached before killing off resources + # We want to make sure that max_idletime is reached before killing off + # resources if not self.executors_idle_since[task_type]: - logger.debug( - "Endpoint has 0 active tasks of task type {}; starting kill timer (if idle time exceeds {}s, resources will be removed)". - format(task_type, self.max_idletime)) + log.debug( + "Endpoint has 0 active tasks of task type %s; " + "starting kill timer (if idle time exceeds %s seconds, " + "resources will be removed)", + task_type, + self.max_idletime, + ) self.executors_idle_since[task_type] = time.time() - # If we have resources idle for the max duration we have to scale_in now. - if (time.time() - self.executors_idle_since[task_type]) > self.max_idletime: - logger.info( - "Idle time has reached {}s; removing resources of task type {}".format( - self.max_idletime, task_type) + # If we have resources idle for the max duration we have to scale_in now + if ( + time.time() - self.executors_idle_since[task_type] + ) > self.max_idletime: + log.info( + "Idle time has reached %s seconds; " + "removing resources of task type %s", + self.max_idletime, + task_type, + ) + self.interchange.scale_in( + active_pods - min_pods, task_type=task_type ) - self.interchange.scale_in(active_pods - min_pods, task_type=task_type) # More tasks than the available slots. - elif active_tasks_per_type > 0 and (float(active_slots) / active_tasks_per_type) < parallelism: + elif ( + active_tasks_per_type > 0 + and (float(active_slots) / active_tasks_per_type) < parallelism + ): if active_pods < max_pods: - excess = math.ceil((active_tasks_per_type * parallelism) - active_slots) - excess_blocks = math.ceil(float(excess) / (workers_per_pod * managers_per_pod)) + excess = math.ceil( + (active_tasks_per_type * parallelism) - active_slots + ) + excess_blocks = math.ceil( + float(excess) / (workers_per_pod * managers_per_pod) + ) excess_blocks = min(excess_blocks, max_pods - active_pods) - logger.info("Requesting {} more blocks".format(excess_blocks)) + log.info(f"Requesting {excess_blocks} more blocks") self.interchange.scale_out(excess_blocks, task_type=task_type) # Immediatly scale if we are stuck with zero pods and work to do elif active_slots == 0 and active_tasks_per_type > 0: - logger.info("Requesting single pod") + log.info("Requesting single pod") self.interchange.scale_out(1, task_type=task_type) diff --git a/funcx_endpoint/funcx_endpoint/strategies/simple.py b/funcx_endpoint/funcx_endpoint/strategies/simple.py index 0e6fb3bb0..c3bffac9d 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/simple.py +++ b/funcx_endpoint/funcx_endpoint/strategies/simple.py @@ -1,21 +1,18 @@ -import math import logging +import math import time + from parsl.providers.provider_base import JobState from funcx_endpoint.strategies.base import BaseStrategy -logger = logging.getLogger("interchange.strategy.simple") +log = logging.getLogger(__name__) class SimpleStrategy(BaseStrategy): - """ Implements the simple strategy - """ + """Implements the simple strategy""" - def __init__(self, *args, - threshold=20, - interval=1, - max_idletime=60): + def __init__(self, *args, threshold=20, interval=1, max_idletime=60): """Initialize the flowcontrol object. We start the timer thread here @@ -29,25 +26,26 @@ def __init__(self, *args, seconds after which timer expires max_idletime: (int) - maximum idle time(seconds) allowed for resources after which strategy will try to kill them. + maximum idle time(seconds) allowed for resources after which strategy will + try to kill them. default: 60s """ - logger.info("SimpleStrategy Initialized") + log.info("SimpleStrategy Initialized") super().__init__(*args, threshold=threshold, interval=interval) self.max_idletime = max_idletime - self.executors = {'idle_since': None} + self.executors = {"idle_since": None} def strategize(self, *args, **kwargs): try: self._strategize(*args, **kwargs) except Exception as e: - logger.exception("Caught error in strategize : {}".format(e)) + log.exception(f"Caught error in strategize : {e}") pass def _strategize(self, *args, **kwargs): task_breakdown = self.interchange.get_outstanding_breakdown() - logger.info(f"Task breakdown {task_breakdown}") + log.debug(f"Task breakdown {task_breakdown}") min_blocks = self.interchange.provider.min_blocks max_blocks = self.interchange.provider.max_blocks @@ -55,7 +53,7 @@ def _strategize(self, *args, **kwargs): # Here we assume that each node has atleast 4 workers tasks_per_node = self.interchange.max_workers_per_node - if self.interchange.max_workers_per_node == float('inf'): + if self.interchange.max_workers_per_node == float("inf"): tasks_per_node = 1 nodes_per_block = self.interchange.provider.nodes_per_block @@ -63,19 +61,25 @@ def _strategize(self, *args, **kwargs): active_tasks = sum(self.interchange.get_total_tasks_outstanding().values()) status = self.interchange.provider_status() - logger.debug(f"Provider status : {status}") + log.debug(f"Provider status : {status}") running = sum([1 for x in status if x.state == JobState.RUNNING]) pending = sum([1 for x in status if x.state == JobState.PENDING]) active_blocks = running + pending active_slots = active_blocks * tasks_per_node * nodes_per_block - logger.debug('Endpoint has {} active tasks, {}/{} running/pending blocks, and {} connected workers'.format( - active_tasks, running, pending, self.interchange.get_total_live_workers())) + log.debug( + "Endpoint has %s active tasks, %s/%s running/pending blocks, " + "and %s connected workers", + active_tasks, + running, + pending, + self.interchange.get_total_live_workers(), + ) # reset kill timer if executor has active tasks - if active_tasks > 0 and self.executors['idle_since']: - self.executors['idle_since'] = None + if active_tasks > 0 and self.executors["idle_since"]: + self.executors["idle_since"] = None # Case 1 # No tasks. @@ -84,7 +88,7 @@ def _strategize(self, *args, **kwargs): # Fewer blocks that min_blocks if active_blocks <= min_blocks: # Ignore - # logger.debug("Strategy: Case.1a") + # log.debug("Strategy: Case.1a") pass # Case 1b @@ -92,24 +96,30 @@ def _strategize(self, *args, **kwargs): else: # We want to make sure that max_idletime is reached # before killing off resources - if not self.executors['idle_since']: - logger.debug("Endpoint has 0 active tasks; starting kill timer (if idle time exceeds {}s, resources will be removed)".format( - self.max_idletime) + if not self.executors["idle_since"]: + log.debug( + "Endpoint has 0 active tasks; starting kill timer " + "(if idle time exceeds %s seconds, resources will be removed)", + self.max_idletime, ) - self.executors['idle_since'] = time.time() + self.executors["idle_since"] = time.time() - idle_since = self.executors['idle_since'] + idle_since = self.executors["idle_since"] if (time.time() - idle_since) > self.max_idletime: # We have resources idle for the max duration, # we have to scale_in now. - logger.debug("Idle time has reached {}s; removing resources".format( - self.max_idletime) + log.debug( + "Idle time has reached {}s; removing resources".format( + self.max_idletime + ) ) self.interchange.scale_in(active_blocks - min_blocks) else: pass - # logger.debug("Strategy: Case.1b. Waiting for timer : {0}".format(idle_since)) + # log.debug( + # "Strategy: Case.1b. Waiting for timer : %s", idle_since + # ) # Case 2 # More tasks than the available slots. @@ -118,26 +128,28 @@ def _strategize(self, *args, **kwargs): # We have the max blocks possible if active_blocks >= max_blocks: # Ignore since we already have the max nodes - # logger.debug("Strategy: Case.2a") + # log.debug("Strategy: Case.2a") pass # Case 2b else: - # logger.debug("Strategy: Case.2b") + # log.debug("Strategy: Case.2b") excess = math.ceil((active_tasks * parallelism) - active_slots) - excess_blocks = math.ceil(float(excess) / (tasks_per_node * nodes_per_block)) + excess_blocks = math.ceil( + float(excess) / (tasks_per_node * nodes_per_block) + ) excess_blocks = min(excess_blocks, max_blocks - active_blocks) - logger.debug("Requesting {} more blocks".format(excess_blocks)) + log.debug(f"Requesting {excess_blocks} more blocks") self.interchange.scale_out(excess_blocks) elif active_slots == 0 and active_tasks > 0: # Case 4 # Check if slots are being lost quickly ? - logger.debug("Requesting single slot") + log.debug("Requesting single slot") if active_blocks < max_blocks: self.interchange.scale_out(1) # Case 3 # tasks ~ slots else: - # logger.debug("Strategy: Case 3") + # log.debug("Strategy: Case 3") pass diff --git a/funcx_endpoint/funcx_endpoint/strategies/test.py b/funcx_endpoint/funcx_endpoint/strategies/test.py index e3745231f..b9543776d 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/test.py +++ b/funcx_endpoint/funcx_endpoint/strategies/test.py @@ -4,8 +4,7 @@ from funcx_endpoint.strategies import SimpleStrategy -class MockInterchange(object): - +class MockInterchange: def __init__(self, max_blocks=1, tasks=10): self.tasks_pending = tasks self.max_blocks = max_blocks @@ -22,7 +21,7 @@ def get_outstanding_breakdown(self): this_round = self.tasks_pending self.tasks_pending = 0 - current = [('interchange', this_round, this_round)] + current = [("interchange", this_round, this_round)] for i in range(self.managers): current.extend((f"manager_{i}", 1, 1)) self.status.put(current) @@ -35,11 +34,11 @@ def scale_out(self): def create_data(self): q = queue.Queue() items = [ - [('interchange', 0, 0)], - [('interchange', 0, 0)], - [('interchange', 0, 0)], - [('interchange', self.tasks_pending, self.tasks_pending)], - [('interchange', self.tasks_pending, self.tasks_pending)] + [("interchange", 0, 0)], + [("interchange", 0, 0)], + [("interchange", 0, 0)], + [("interchange", self.tasks_pending, self.tasks_pending)], + [("interchange", self.tasks_pending, self.tasks_pending)], ] [q.put(i) for i in items] diff --git a/funcx_endpoint/funcx_endpoint/tests/strategies/test_kube_simple.py b/funcx_endpoint/funcx_endpoint/tests/strategies/test_kube_simple.py index d3863b01a..c5ae7a44d 100644 --- a/funcx_endpoint/funcx_endpoint/tests/strategies/test_kube_simple.py +++ b/funcx_endpoint/funcx_endpoint/tests/strategies/test_kube_simple.py @@ -1,5 +1,6 @@ import threading import time + from pytest import fixture from funcx_endpoint.executors.high_throughput.interchange import Interchange @@ -17,7 +18,6 @@ def no_op_worker(): class TestKubeSimple: - @fixture def mock_interchange(self, mocker): mock_interchange = mocker.MagicMock(Interchange) @@ -29,8 +29,10 @@ def mock_interchange(self, mocker): mock_interchange.config.provider.max_blocks = 4 mock_interchange.config.provider.nodes_per_block = 1 mock_interchange.config.provider.parallelism = 1.0 - mock_interchange.get_total_tasks_outstanding = mocker.Mock(return_value={'RAW': 0}) - mock_interchange.provider_status = mocker.Mock(return_value={'RAW': 16}) + mock_interchange.get_total_tasks_outstanding = mocker.Mock( + return_value={"RAW": 0} + ) + mock_interchange.provider_status = mocker.Mock(return_value={"RAW": 16}) mock_interchange.get_total_live_workers = mocker.Mock(return_value=0) mock_interchange.scale_in = mocker.Mock() mock_interchange.scale_out = mocker.Mock() @@ -48,7 +50,8 @@ def kube_strategy(self): def test_no_tasks_no_pods(self, mock_interchange, kube_strategy): mock_interchange.get_outstanding_breakdown.return_value = [ - ('interchange', 0, True)] + ("interchange", 0, True) + ] mock_interchange.get_total_tasks_outstanding.return_value = [] kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") @@ -57,8 +60,8 @@ def test_no_tasks_no_pods(self, mock_interchange, kube_strategy): def test_scale_in_with_no_tasks(self, mock_interchange, kube_strategy): # First there is work to do and pods are scaled up - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 16} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 16} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -66,7 +69,7 @@ def test_scale_in_with_no_tasks(self, mock_interchange, kube_strategy): # Now tasks are all done, but pods are still running. Idle time has not yet # been reached, so the pods will still be running. - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 0} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 0} kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() mock_interchange.scale_out.assert_not_called() @@ -81,8 +84,8 @@ def test_scale_in_with_no_tasks(self, mock_interchange, kube_strategy): def test_task_arrives_during_idle_time(self, mock_interchange, kube_strategy): # First there is work to do and pods are scaled up - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 16} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 16} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -90,20 +93,20 @@ def test_task_arrives_during_idle_time(self, mock_interchange, kube_strategy): # Now tasks are all done, but pods are still running. Idle time has not yet # been reached, so the pods will still be running. - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 0} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 0} kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() mock_interchange.scale_out.assert_not_called() # Now add a new task - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 1} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 1} kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() mock_interchange.scale_out.assert_not_called() # Verify that idle time is reset time.sleep(5) - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 0} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 0} kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() mock_interchange.scale_out.assert_not_called() @@ -112,8 +115,8 @@ def test_task_backlog_within_parallelism(self, mock_interchange, kube_strategy): # Aggressive scaling so new tasks will create new pods mock_interchange.config.provider.parallelism = 1.0 mock_interchange.config.provider.max_blocks = 16 - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 1} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 1} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -123,8 +126,8 @@ def test_task_backlog_gated_by_parallelism(self, mock_interchange, kube_strategy # Lazy scaling, so just a single new task won't spawn a new pod mock_interchange.config.provider.parallelism = 0.5 mock_interchange.config.provider.max_blocks = 16 - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 1} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 1} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -133,8 +136,8 @@ def test_task_backlog_gated_by_parallelism(self, mock_interchange, kube_strategy def test_task_backlog_gated_by_max_blocks(self, mock_interchange, kube_strategy): mock_interchange.config.provider.parallelism = 1.0 mock_interchange.config.provider.max_blocks = 8 - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 1} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 1} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -143,8 +146,8 @@ def test_task_backlog_gated_by_max_blocks(self, mock_interchange, kube_strategy) def test_task_backlog_already_max_blocks(self, mock_interchange, kube_strategy): mock_interchange.config.provider.parallelism = 1.0 mock_interchange.config.provider.max_blocks = 8 - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 16} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 16} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -152,8 +155,8 @@ def test_task_backlog_already_max_blocks(self, mock_interchange, kube_strategy): def test_scale_when_no_pods(self, mock_interchange, kube_strategy): mock_interchange.config.provider.parallelism = 0.01 # Very lazy scaling - mock_interchange.provider_status.return_value = {'RAW': 0} - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 1} + mock_interchange.provider_status.return_value = {"RAW": 0} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 1} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() diff --git a/funcx_endpoint/funcx_endpoint/tests/test_cancel.py b/funcx_endpoint/funcx_endpoint/tests/test_cancel.py index ac19c3dcc..20bc8faa9 100644 --- a/funcx_endpoint/funcx_endpoint/tests/test_cancel.py +++ b/funcx_endpoint/funcx_endpoint/tests/test_cancel.py @@ -1,3 +1,4 @@ +import logging import os import random import time @@ -7,13 +8,12 @@ import pytest from parsl.providers import LocalProvider -import funcx from funcx_endpoint.executors import HighThroughputExecutor -import logging + logger = logging.getLogger(__name__) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def htex(): try: os.remove("interchange.log") diff --git a/funcx_endpoint/funcx_endpoint/version.py b/funcx_endpoint/funcx_endpoint/version.py index 395c3fe21..3dbef89d5 100644 --- a/funcx_endpoint/funcx_endpoint/version.py +++ b/funcx_endpoint/funcx_endpoint/version.py @@ -4,4 +4,4 @@ VERSION = __version__ # app name to send as part of requests -app_name = "funcX Endpoint v{}".format(__version__) +app_name = f"funcX Endpoint v{__version__}" diff --git a/funcx_endpoint/requirements.txt b/funcx_endpoint/requirements.txt deleted file mode 100644 index 52f7b0bf4..000000000 --- a/funcx_endpoint/requirements.txt +++ /dev/null @@ -1,45 +0,0 @@ -requests>=2.20.0,<3 -globus_sdk<3 -funcx>=0.3.3,<0.4.0 - -# table printing used in list-endpoints -texttable>=1.6.4,<2 - -# although psutil does not declare itself to use semver, it appears to offer -# strong backwards-compatibility promises based on its changelog, usage, and -# history -# -# TODO: re-evaluate bound after we have an answer of some kind from psutil -# see: -# https://github.com/giampaolo/psutil/issues/2002 -psutil<6 - -# provides easy daemonization of the endpoint -python-daemon>=2,<3 - -# TODO: replace use of `typer` with `click` because -# 1. `typer` is a thin wrapper over `click` offering very minimal additional -# functionality -# 2. `click` follows semver and releases new major versions when known -# backwards-incompatible changes are introduced, making our application -# safer to distribute -typer==0.4.0 - - -# disallow use of 22.3.0; the whl package on some platforms causes ZMQ issues -# -# NOTE: 22.3.0 introduced a patched version of libzmq.so to the wheel packaging -# which may be the source of the problems , the problem can be fixed by -# building from source, which may mean there's an issue in the packaged library -# further investigation may be needed if the issue persists in the next pyzmq -# release -pyzmq>=22.0.0,!=22.3.0 - -# TODO: evaluate removal of the 'retry' library after the update to -# globus-sdk v3, which provides automatic retries on all API calls -retry==0.9.2 - -# 'parsl' is a core requirement of the funcx-endpoint, essential to a range -# of different features and functions -# pin exact versions because it does not use semver -parsl==1.1.0 diff --git a/funcx_endpoint/setup.py b/funcx_endpoint/setup.py index fb0c266d0..a6581328d 100644 --- a/funcx_endpoint/setup.py +++ b/funcx_endpoint/setup.py @@ -1,20 +1,70 @@ import os -from setuptools import setup, find_packages + +from setuptools import find_packages, setup + +REQUIRES = [ + "requests>=2.20.0,<3", + "globus_sdk<3", + "funcx>=0.3.3,<0.4.0", + # table printing used in list-endpoints + "texttable>=1.6.4,<2", + # although psutil does not declare itself to use semver, it appears to offer + # strong backwards-compatibility promises based on its changelog, usage, and + # history + # + # TODO: re-evaluate bound after we have an answer of some kind from psutil + # see: + # https://github.com/giampaolo/psutil/issues/2002 + "psutil<6", + # provides easy daemonization of the endpoint + "python-daemon>=2,<3", + # TODO: replace use of `typer` with `click` because + # 1. `typer` is a thin wrapper over `click` offering very minimal additional + # functionality + # 2. `click` follows semver and releases new major versions when known + # backwards-incompatible changes are introduced, making our application + # safer to distribute + "typer==0.4.0", + # disallow use of 22.3.0; the whl package on some platforms causes ZMQ issues + # + # NOTE: 22.3.0 introduced a patched version of libzmq.so to the wheel packaging + # which may be the source of the problems , the problem can be fixed by + # building from source, which may mean there's an issue in the packaged library + # further investigation may be needed if the issue persists in the next pyzmq + # release + "pyzmq>=22.0.0,!=22.3.0", + # TODO: evaluate removal of the 'retry' library after the update to + # globus-sdk v3, which provides automatic retries on all API calls + "retry==0.9.2", + # 'parsl' is a core requirement of the funcx-endpoint, essential to a range + # of different features and functions + # pin exact versions because it does not use semver + "parsl==1.1.0", +] + +TEST_REQUIRES = [ + "pytest>=5.2", + "coverage>=5.2", + "codecov==2.1.8", + "pytest-mock==3.2.0", + "flake8>=3.8", +] + version_ns = {} with open(os.path.join("funcx_endpoint", "version.py")) as f: exec(f.read(), version_ns) -version = version_ns['VERSION'] - -with open('requirements.txt') as f: - install_requires = f.readlines() +version = version_ns["VERSION"] setup( - name='funcx-endpoint', + name="funcx-endpoint", version=version, packages=find_packages(), - description='funcX: High Performance Function Serving for Science', - install_requires=install_requires, + description="funcX: High Performance Function Serving for Science", + install_requires=REQUIRES, + extras_require={ + "test": TEST_REQUIRES, + }, python_requires=">=3.6.0", classifiers=[ "Development Status :: 3 - Alpha", @@ -23,23 +73,23 @@ "Natural Language :: English", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering" - ], - keywords=[ - "funcX", - "FaaS", - "Function Serving" + "Topic :: Scientific/Engineering", ], - entry_points={'console_scripts': - ['funcx-endpoint=funcx_endpoint.endpoint.endpoint:cli_run', - 'funcx-interchange=funcx_endpoint.executors.high_throughput.interchange:cli_run', - 'funcx-manager=funcx_endpoint.executors.high_throughput.funcx_manager:cli_run', - 'funcx-worker=funcx_endpoint.executors.high_throughput.funcx_worker:cli_run', - ] + keywords=["funcX", "FaaS", "Function Serving"], + entry_points={ + "console_scripts": [ + "funcx-endpoint=funcx_endpoint.endpoint.endpoint:cli_run", + "funcx-interchange" + "=funcx_endpoint.executors.high_throughput.interchange:cli_run", + "funcx-manager" + "=funcx_endpoint.executors.high_throughput.funcx_manager:cli_run", + "funcx-worker" + "=funcx_endpoint.executors.high_throughput.funcx_worker:cli_run", + ] }, include_package_data=True, - author='funcX team', - author_email='labs@globus.org', + author="funcX team", + author_email="labs@globus.org", license="Apache License, Version 2.0", - url="https://github.com/funcx-faas/funcx" + url="https://github.com/funcx-faas/funcx", ) diff --git a/funcx_endpoint/test-requirements.txt b/funcx_endpoint/test-requirements.txt deleted file mode 100644 index fcb0abfc0..000000000 --- a/funcx_endpoint/test-requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -pytest>=5.2 -coverage>=5.2 -codecov==2.1.8 -pytest-mock==3.2.0 -flake8>=3.8 diff --git a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint.py b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint.py index 79e089719..7e3430d87 100644 --- a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint.py +++ b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint.py @@ -1,12 +1,14 @@ import os + import pytest -from funcx_endpoint.endpoint.endpoint import app from typer.testing import CliRunner +from funcx_endpoint.endpoint.endpoint import app + runner = CliRunner() -config_string = ''' +config_string = """ from funcx_endpoint.endpoint.utils.config import Config from parsl.providers import LocalProvider @@ -18,27 +20,32 @@ max_blocks=1, ), funcx_service_address='https://api.funcx.org/v1' -)''' - - -class TestEndpoint: - - @pytest.fixture(autouse=True) - def test_setup_teardown(self, mocker): - mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - yield - - def test_non_configured_endpoint(self, mocker): - result = runner.invoke(app, ["start", "newendpoint"]) - assert 'newendpoint' in result.stdout - assert 'not configured' in result.stdout - - def test_using_outofdate_config(self, mocker): - mock_loader = mocker.patch('funcx_endpoint.endpoint.endpoint.os.path.join') - mock_loader.return_value = './config.py' - config_file = open("./config.py", "w") - config_file.write(config_string) - config_file.close() - result = runner.invoke(app, ["start", "newendpoint"]) - os.remove("./config.py") - assert isinstance(result.exception, TypeError) +)""" + + +@pytest.fixture(autouse=True) +def patch_funcx_client(mocker): + mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") + + +@pytest.fixture(autouse=True) +def adjust_default_logfile(tmp_path, monkeypatch): + logfile = str(tmp_path / "endpoint.log") + monkeypatch.setattr("funcx_endpoint.logging_config._DEFAULT_LOGFILE", logfile) + + +def test_non_configured_endpoint(mocker): + result = runner.invoke(app, ["start", "newendpoint"]) + assert "newendpoint" in result.stdout + assert "not configured" in result.stdout + + +def test_using_outofdate_config(mocker): + mock_loader = mocker.patch("funcx_endpoint.endpoint.endpoint.os.path.join") + mock_loader.return_value = "./config.py" + config_file = open("./config.py", "w") + config_file.write(config_string) + config_file.close() + result = runner.invoke(app, ["start", "newendpoint"]) + os.remove("./config.py") + assert isinstance(result.exception, TypeError) diff --git a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint_manager.py b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint_manager.py index f63cb2dd4..d89c27ac5 100644 --- a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint_manager.py +++ b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint_manager.py @@ -1,26 +1,25 @@ -from funcx_endpoint.endpoint.endpoint_manager import EndpointManager -from importlib.machinery import SourceFileLoader -import os +import json import logging -import sys +import os import shutil -import pytest -import json -from pytest import fixture +from importlib.machinery import SourceFileLoader from unittest.mock import ANY -from globus_sdk import GlobusHTTPResponse, GlobusAPIError + +import pytest +from globus_sdk import GlobusAPIError, GlobusHTTPResponse from requests import Response -logger = logging.getLogger('mock_funcx') +from funcx_endpoint.endpoint.endpoint_manager import EndpointManager + +logger = logging.getLogger("mock_funcx") class TestStart: - @pytest.fixture(autouse=True) def test_setup_teardown(self): # Code that will run before your test, for example: - funcx_dir = f'{os.getcwd()}' + funcx_dir = f"{os.getcwd()}" config_dir = os.path.join(funcx_dir, "mock_endpoint") assert not os.path.exists(config_dir) # A test function will be run at this point @@ -40,20 +39,27 @@ def test_double_configure(self): manager.configure_endpoint("mock_endpoint", None) assert os.path.exists(config_dir) - with pytest.raises(Exception, match='ConfigExists'): + with pytest.raises(Exception, match="ConfigExists"): manager.configure_endpoint("mock_endpoint", None) def test_start(self, mocker): - mock_client = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - reg_info = {'endpoint_id': 'abcde12345', - 'address': 'localhost', - 'client_ports': '8080'} + mock_client = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.FuncXClient" + ) + reg_info = { + "endpoint_id": "abcde12345", + "address": "localhost", + "client_ports": "8080", + } mock_client.return_value.register_endpoint.return_value = reg_info - mock_zmq_create = mocker.patch("zmq.auth.create_certificates", - return_value=("public/key/file", None)) - mock_zmq_load = mocker.patch("zmq.auth.load_certificate", - return_value=("12345abcde".encode(), "12345abcde".encode())) + mock_zmq_create = mocker.patch( + "zmq.auth.create_certificates", return_value=("public/key/file", None) + ) + mock_zmq_load = mocker.patch( + "zmq.auth.load_certificate", + return_value=(b"12345abcde", b"12345abcde"), + ) mock_context = mocker.patch("daemon.DaemonContext") @@ -61,28 +67,36 @@ def test_start(self, mocker): mock_context.return_value.__enter__.return_value = None mock_context.return_value.__exit__.return_value = None - mock_context.return_value.pidfile.path = '' + mock_context.return_value.pidfile.path = "" - mock_daemon = mocker.patch.object(EndpointManager, 'daemon_launch', - return_value=None) + mock_daemon = mocker.patch.object( + EndpointManager, "daemon_launch", return_value=None + ) - mock_uuid = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4') + mock_uuid = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4") mock_uuid.return_value = 123456 - mock_pidfile = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile') + mock_pidfile = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile" + ) mock_pidfile.return_value = None - mock_results_ack_handler = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler') + mock_results_ack_handler = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler" + ) manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() manager.start_endpoint("mock_endpoint", None, endpoint_config) - mock_zmq_create.assert_called_with(os.path.join(config_dir, "certificates"), "endpoint") + mock_zmq_create.assert_called_with( + os.path.join(config_dir, "certificates"), "endpoint" + ) mock_zmq_load.assert_called_with("public/key/file") funcx_client_options = { @@ -90,20 +104,24 @@ def test_start(self, mocker): "check_endpoint_version": True, } - mock_daemon.assert_called_with('123456', - config_dir, - os.path.join(config_dir, "certificates"), - endpoint_config, - reg_info, - funcx_client_options, - mock_results_ack_handler.return_value) - - mock_context.assert_called_with(working_directory=config_dir, - umask=0o002, - pidfile=None, - stdout=ANY, # open(os.path.join(config_dir, './interchange.stdout'), 'w+'), - stderr=ANY, # open(os.path.join(config_dir, './interchange.stderr'), 'w+'), - detach_process=True) + mock_daemon.assert_called_with( + "123456", + config_dir, + os.path.join(config_dir, "certificates"), + endpoint_config, + reg_info, + funcx_client_options, + mock_results_ack_handler.return_value, + ) + + mock_context.assert_called_with( + working_directory=config_dir, + umask=0o002, + pidfile=None, + stdout=ANY, # open(os.path.join(config_dir, './interchange.stdout'), 'w+'), + stderr=ANY, # open(os.path.join(config_dir, './interchange.stderr'), 'w+'), + detach_process=True, + ) def test_start_registration_error(self, mocker): """This tests what happens if a 400 error response comes back from the @@ -115,70 +133,84 @@ def test_start_registration_error(self, mocker): mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") base_r = Response() - base_r.headers = { - "Content-Type": "json" - } + base_r.headers = {"Content-Type": "json"} base_r.status_code = 400 r = GlobusHTTPResponse(base_r) r.status_code = base_r.status_code r.headers = base_r.headers - mock_register_endpoint = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.register_endpoint') + mock_register_endpoint = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.register_endpoint" + ) mock_register_endpoint.side_effect = GlobusAPIError(r) - mock_zmq_create = mocker.patch("zmq.auth.create_certificates", - return_value=("public/key/file", None)) - mock_zmq_load = mocker.patch("zmq.auth.load_certificate", - return_value=("12345abcde".encode(), "12345abcde".encode())) + mock_zmq_create = mocker.patch( + "zmq.auth.create_certificates", return_value=("public/key/file", None) + ) + mock_zmq_load = mocker.patch( + "zmq.auth.load_certificate", + return_value=(b"12345abcde", b"12345abcde"), + ) - mock_uuid = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4') + mock_uuid = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4") mock_uuid.return_value = 123456 - mock_pidfile = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile') + mock_pidfile = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile" + ) mock_pidfile.return_value = None - mocker.patch('funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler') + mocker.patch("funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler") manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() with pytest.raises(GlobusAPIError): manager.start_endpoint("mock_endpoint", None, endpoint_config) - mock_zmq_create.assert_called_with(os.path.join(config_dir, "certificates"), "endpoint") + mock_zmq_create.assert_called_with( + os.path.join(config_dir, "certificates"), "endpoint" + ) mock_zmq_load.assert_called_with("public/key/file") def test_start_registration_5xx_error(self, mocker): - """This tests what happens if a 500 error response comes back from the - initial endpoint registration. It is expected that this exception should - NOT be raised and that the interchange should be started without any registration - info being passed in. The registration should then be retried in the interchange - daemon, because a 5xx error suggests that there is a temporary service issue - that will resolve on its own. mock_zmq_create and mock_zmq_load are being - asserted against because this zmq setup happens before registration occurs. + """ + This tests what happens if a 500 error response comes back from the initial + endpoint registration. + + It is expected that this exception should NOT be raised and that the interchange + should be started without any registration info being passed in. The + registration should then be retried in the interchange daemon, because a 5xx + error suggests that there is a temporary service issue that will resolve on its + own. mock_zmq_create and mock_zmq_load are being asserted against because this + zmq setup happens before registration occurs. """ mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") base_r = Response() - base_r.headers = { - "Content-Type": "json" - } + base_r.headers = {"Content-Type": "json"} base_r.status_code = 500 r = GlobusHTTPResponse(base_r) r.status_code = base_r.status_code r.headers = base_r.headers - mock_register_endpoint = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.register_endpoint') + mock_register_endpoint = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.register_endpoint" + ) mock_register_endpoint.side_effect = GlobusAPIError(r) - mock_zmq_create = mocker.patch("zmq.auth.create_certificates", - return_value=("public/key/file", None)) - mock_zmq_load = mocker.patch("zmq.auth.load_certificate", - return_value=("12345abcde".encode(), "12345abcde".encode())) + mock_zmq_create = mocker.patch( + "zmq.auth.create_certificates", return_value=("public/key/file", None) + ) + mock_zmq_load = mocker.patch( + "zmq.auth.load_certificate", + return_value=(b"12345abcde", b"12345abcde"), + ) mock_context = mocker.patch("daemon.DaemonContext") @@ -186,29 +218,37 @@ def test_start_registration_5xx_error(self, mocker): mock_context.return_value.__enter__.return_value = None mock_context.return_value.__exit__.return_value = None - mock_context.return_value.pidfile.path = '' + mock_context.return_value.pidfile.path = "" - mock_daemon = mocker.patch.object(EndpointManager, 'daemon_launch', - return_value=None) + mock_daemon = mocker.patch.object( + EndpointManager, "daemon_launch", return_value=None + ) - mock_uuid = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4') + mock_uuid = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4") mock_uuid.return_value = 123456 - mock_pidfile = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile') + mock_pidfile = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile" + ) mock_pidfile.return_value = None - mock_results_ack_handler = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler') + mock_results_ack_handler = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler" + ) manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() manager.start_endpoint("mock_endpoint", None, endpoint_config) - mock_zmq_create.assert_called_with(os.path.join(config_dir, "certificates"), "endpoint") + mock_zmq_create.assert_called_with( + os.path.join(config_dir, "certificates"), "endpoint" + ) mock_zmq_load.assert_called_with("public/key/file") funcx_client_options = { @@ -216,29 +256,37 @@ def test_start_registration_5xx_error(self, mocker): "check_endpoint_version": True, } - # We should expect reg_info in this test to be None when passed into daemon_launch - # because a 5xx GlobusAPIError was raised during registration + # We should expect reg_info in this test to be None when passed into + # daemon_launch because a 5xx GlobusAPIError was raised during registration reg_info = None - mock_daemon.assert_called_with('123456', - config_dir, - os.path.join(config_dir, "certificates"), - endpoint_config, - reg_info, - funcx_client_options, - mock_results_ack_handler.return_value) - - mock_context.assert_called_with(working_directory=config_dir, - umask=0o002, - pidfile=None, - stdout=ANY, # open(os.path.join(config_dir, './interchange.stdout'), 'w+'), - stderr=ANY, # open(os.path.join(config_dir, './interchange.stderr'), 'w+'), - detach_process=True) + mock_daemon.assert_called_with( + "123456", + config_dir, + os.path.join(config_dir, "certificates"), + endpoint_config, + reg_info, + funcx_client_options, + mock_results_ack_handler.return_value, + ) + + mock_context.assert_called_with( + working_directory=config_dir, + umask=0o002, + pidfile=None, + stdout=ANY, # open(os.path.join(config_dir, './interchange.stdout'), 'w+'), + stderr=ANY, # open(os.path.join(config_dir, './interchange.stderr'), 'w+'), + detach_process=True, + ) def test_start_without_executors(self, mocker): - mock_client = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - mock_client.return_value.register_endpoint.return_value = {'endpoint_id': 'abcde12345', - 'address': 'localhost', - 'client_ports': '8080'} + mock_client = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.FuncXClient" + ) + mock_client.return_value.register_endpoint.return_value = { + "endpoint_id": "abcde12345", + "address": "localhost", + "client_ports": "8080", + } mock_context = mocker.patch("daemon.DaemonContext") @@ -246,101 +294,136 @@ def test_start_without_executors(self, mocker): mock_context.return_value.__enter__.return_value = None mock_context.return_value.__exit__.return_value = None - mock_context.return_value.pidfile.path = '' + mock_context.return_value.pidfile.path = "" - class mock_load(): - class mock_executors(): + class mock_load: + class mock_executors: executors = None + config = mock_executors() manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - with pytest.raises(Exception, match=f'Endpoint config file at {config_dir} is missing executor definitions'): + with pytest.raises( + Exception, + match=f"Endpoint config file at {config_dir} is " + "missing executor definitions", + ): manager.start_endpoint("mock_endpoint", None, mock_load()) def test_daemon_launch(self, mocker): - mock_interchange = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.EndpointInterchange') + mock_interchange = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.EndpointInterchange" + ) mock_interchange.return_value.start.return_value = None mock_interchange.return_value.stop.return_value = None manager = EndpointManager(funcx_dir=os.getcwd()) - manager.name = 'test' + manager.name = "test" config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") mock_optionals = {} - mock_optionals['logdir'] = config_dir + mock_optionals["logdir"] = config_dir manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() funcx_client_options = {} - manager.daemon_launch('mock_endpoint_uuid', config_dir, 'mock_keys_dir', endpoint_config, None, funcx_client_options, None) - - mock_interchange.assert_called_with(endpoint_config.config, - endpoint_id='mock_endpoint_uuid', - keys_dir='mock_keys_dir', - endpoint_dir=config_dir, - endpoint_name=manager.name, - reg_info=None, - funcx_client_options=funcx_client_options, - results_ack_handler=None, - **mock_optionals) + manager.daemon_launch( + "mock_endpoint_uuid", + config_dir, + "mock_keys_dir", + endpoint_config, + None, + funcx_client_options, + None, + ) + + mock_interchange.assert_called_with( + endpoint_config.config, + endpoint_id="mock_endpoint_uuid", + keys_dir="mock_keys_dir", + endpoint_dir=config_dir, + endpoint_name=manager.name, + reg_info=None, + funcx_client_options=funcx_client_options, + results_ack_handler=None, + **mock_optionals, + ) def test_with_funcx_config(self, mocker): - mock_interchange = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.EndpointInterchange') + mock_interchange = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.EndpointInterchange" + ) mock_interchange.return_value.start.return_value = None mock_interchange.return_value.stop.return_value = None mock_optionals = {} - mock_optionals['interchange_address'] = '127.0.0.1' + mock_optionals["interchange_address"] = "127.0.0.1" mock_funcx_config = {} - mock_funcx_config['endpoint_address'] = '127.0.0.1' + mock_funcx_config["endpoint_address"] = "127.0.0.1" manager = EndpointManager(funcx_dir=os.getcwd()) - manager.name = 'test' + manager.name = "test" config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") - mock_optionals['logdir'] = config_dir + mock_optionals["logdir"] = config_dir manager.funcx_config = mock_funcx_config manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() funcx_client_options = {} - manager.daemon_launch('mock_endpoint_uuid', config_dir, 'mock_keys_dir', endpoint_config, None, funcx_client_options, None) - - mock_interchange.assert_called_with(endpoint_config.config, - endpoint_id='mock_endpoint_uuid', - keys_dir='mock_keys_dir', - endpoint_dir=config_dir, - endpoint_name=manager.name, - reg_info=None, - funcx_client_options=funcx_client_options, - results_ack_handler=None, - **mock_optionals) + manager.daemon_launch( + "mock_endpoint_uuid", + config_dir, + "mock_keys_dir", + endpoint_config, + None, + funcx_client_options, + None, + ) + + mock_interchange.assert_called_with( + endpoint_config.config, + endpoint_id="mock_endpoint_uuid", + keys_dir="mock_keys_dir", + endpoint_dir=config_dir, + endpoint_name=manager.name, + reg_info=None, + funcx_client_options=funcx_client_options, + results_ack_handler=None, + **mock_optionals, + ) def test_check_endpoint_json_no_json_no_uuid(self, mocker): - mock_uuid = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4') + mock_uuid = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4") mock_uuid.return_value = 123456 manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - assert '123456' == manager.check_endpoint_json(os.path.join(config_dir, 'endpoint.json'), None) + assert "123456" == manager.check_endpoint_json( + os.path.join(config_dir, "endpoint.json"), None + ) def test_check_endpoint_json_no_json_given_uuid(self, mocker): manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - assert '234567' == manager.check_endpoint_json(os.path.join(config_dir, 'endpoint.json'), '234567') + assert "234567" == manager.check_endpoint_json( + os.path.join(config_dir, "endpoint.json"), "234567" + ) def test_check_endpoint_json_given_json(self, mocker): manager = EndpointManager(funcx_dir=os.getcwd()) @@ -348,8 +431,10 @@ def test_check_endpoint_json_given_json(self, mocker): manager.configure_endpoint("mock_endpoint", None) - mock_dict = {'endpoint_id': 'abcde12345'} - with open(os.path.join(config_dir, 'endpoint.json'), "w") as fd: + mock_dict = {"endpoint_id": "abcde12345"} + with open(os.path.join(config_dir, "endpoint.json"), "w") as fd: json.dump(mock_dict, fd) - assert 'abcde12345' == manager.check_endpoint_json(os.path.join(config_dir, 'endpoint.json'), '234567') + assert "abcde12345" == manager.check_endpoint_json( + os.path.join(config_dir, "endpoint.json"), "234567" + ) diff --git a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_interchange.py b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_interchange.py index 0f78c8ab2..2d8f41311 100644 --- a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_interchange.py +++ b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_interchange.py @@ -1,26 +1,22 @@ -from funcx_endpoint.endpoint.endpoint_manager import EndpointManager -from funcx_endpoint.endpoint.interchange import EndpointInterchange -from funcx_endpoint.endpoint.register_endpoint import register_endpoint -from importlib.machinery import SourceFileLoader -import os import logging -import sys +import os import shutil +from importlib.machinery import SourceFileLoader + import pytest -import json -from pytest import fixture -from unittest.mock import ANY -logger = logging.getLogger('mock_funcx') +from funcx_endpoint.endpoint.endpoint_manager import EndpointManager +from funcx_endpoint.endpoint.interchange import EndpointInterchange + +logger = logging.getLogger("mock_funcx") class TestStart: - @pytest.fixture(autouse=True) def test_setup_teardown(self): # Code that will run before your test, for example: - funcx_dir = f'{os.getcwd()}' + funcx_dir = f"{os.getcwd()}" config_dir = os.path.join(funcx_dir, "mock_endpoint") assert not os.path.exists(config_dir) # A test function will be run at this point @@ -34,62 +30,72 @@ def test_endpoint_id(self, mocker): manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") - keys_dir = os.path.join(config_dir, 'certificates') + keys_dir = os.path.join(config_dir, "certificates") optionals = {} - optionals['client_address'] = '127.0.0.1' - optionals['client_ports'] = (8080, 8081, 8082) - optionals['logdir'] = './mock_endpoint' + optionals["client_address"] = "127.0.0.1" + optionals["client_ports"] = (8080, 8081, 8082) + optionals["logdir"] = "./mock_endpoint" manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() for executor in endpoint_config.config.executors: executor.passthrough = False - ic = EndpointInterchange(endpoint_config.config, - endpoint_id='mock_endpoint_id', - keys_dir=keys_dir, - **optionals) + ic = EndpointInterchange( + endpoint_config.config, + endpoint_id="mock_endpoint_id", + keys_dir=keys_dir, + **optionals, + ) for executor in ic.executors.values(): - assert executor.endpoint_id == 'mock_endpoint_id' + assert executor.endpoint_id == "mock_endpoint_id" def test_register_endpoint(self, mocker): mock_client = mocker.patch("funcx_endpoint.endpoint.interchange.FuncXClient") mock_client.return_value = None - mock_register_endpoint = mocker.patch('funcx_endpoint.endpoint.interchange.register_endpoint') - mock_register_endpoint.return_value = {'endpoint_id': 'abcde12345', - 'public_ip': '127.0.0.1', - 'tasks_port': 8080, - 'results_port': 8081, - 'commands_port': 8082, } + mock_register_endpoint = mocker.patch( + "funcx_endpoint.endpoint.interchange.register_endpoint" + ) + mock_register_endpoint.return_value = { + "endpoint_id": "abcde12345", + "public_ip": "127.0.0.1", + "tasks_port": 8080, + "results_port": 8081, + "commands_port": 8082, + } manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") - keys_dir = os.path.join(config_dir, 'certificates') + keys_dir = os.path.join(config_dir, "certificates") optionals = {} - optionals['client_address'] = '127.0.0.1' - optionals['client_ports'] = (8080, 8081, 8082) - optionals['logdir'] = './mock_endpoint' + optionals["client_address"] = "127.0.0.1" + optionals["client_ports"] = (8080, 8081, 8082) + optionals["logdir"] = "./mock_endpoint" manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() for executor in endpoint_config.config.executors: executor.passthrough = False - ic = EndpointInterchange(endpoint_config.config, - endpoint_id='mock_endpoint_id', - keys_dir=keys_dir, - **optionals) + ic = EndpointInterchange( + endpoint_config.config, + endpoint_id="mock_endpoint_id", + keys_dir=keys_dir, + **optionals, + ) ic.register_endpoint() - assert ic.client_address == '127.0.0.1' + assert ic.client_address == "127.0.0.1" assert ic.client_ports == (8080, 8081, 8082) def test_start_no_reg_info(self, mocker): @@ -100,38 +106,47 @@ def test_start_no_reg_info(self, mocker): mock_client = mocker.patch("funcx_endpoint.endpoint.interchange.FuncXClient") mock_client.return_value = None - mock_register_endpoint = mocker.patch('funcx_endpoint.endpoint.interchange.register_endpoint') - mock_register_endpoint.return_value = {'endpoint_id': 'abcde12345', - 'public_ip': '127.0.0.1', - 'tasks_port': 8080, - 'results_port': 8081, - 'commands_port': 8082, } + mock_register_endpoint = mocker.patch( + "funcx_endpoint.endpoint.interchange.register_endpoint" + ) + mock_register_endpoint.return_value = { + "endpoint_id": "abcde12345", + "public_ip": "127.0.0.1", + "tasks_port": 8080, + "results_port": 8081, + "commands_port": 8082, + } manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") - keys_dir = os.path.join(config_dir, 'certificates') + keys_dir = os.path.join(config_dir, "certificates") optionals = {} - optionals['client_address'] = '127.0.0.1' - optionals['client_ports'] = (8080, 8081, 8082) - optionals['logdir'] = './mock_endpoint' + optionals["client_address"] = "127.0.0.1" + optionals["client_ports"] = (8080, 8081, 8082) + optionals["logdir"] = "./mock_endpoint" manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() for executor in endpoint_config.config.executors: executor.passthrough = False - mock_quiesce = mocker.patch.object(EndpointInterchange, 'quiesce', - return_value=None) - mock_main_loop = mocker.patch.object(EndpointInterchange, '_main_loop', - return_value=None) - - ic = EndpointInterchange(endpoint_config.config, - endpoint_id='mock_endpoint_id', - keys_dir=keys_dir, - **optionals) + mock_quiesce = mocker.patch.object( + EndpointInterchange, "quiesce", return_value=None + ) + mock_main_loop = mocker.patch.object( + EndpointInterchange, "_main_loop", return_value=None + ) + + ic = EndpointInterchange( + endpoint_config.config, + endpoint_id="mock_endpoint_id", + keys_dir=keys_dir, + **optionals, + ) ic.results_outgoing = mocker.Mock() diff --git a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_register_endpoint.py b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_register_endpoint.py index 39cedb217..84386a2eb 100644 --- a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_register_endpoint.py +++ b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_register_endpoint.py @@ -1,23 +1,20 @@ -from funcx_endpoint.endpoint.register_endpoint import register_endpoint -import os import logging -import sys +import os import shutil + import pytest -import json -from pytest import fixture -from unittest.mock import ANY -logger = logging.getLogger('mock_funcx') +from funcx_endpoint.endpoint.register_endpoint import register_endpoint +logger = logging.getLogger("mock_funcx") -class TestRegisterEndpoint: +class TestRegisterEndpoint: @pytest.fixture(autouse=True) def test_setup_teardown(self): # Code that will run before your test, for example: - funcx_dir = f'{os.getcwd()}' + funcx_dir = f"{os.getcwd()}" config_dir = os.path.join(funcx_dir, "mock_endpoint") assert not os.path.exists(config_dir) # A test function will be run at this point @@ -27,22 +24,34 @@ def test_setup_teardown(self): shutil.rmtree(config_dir) def test_register_endpoint_no_endpoint_id(self, mocker): - mock_client = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - mock_client.return_value.register_endpoint.return_value = {'status': 'okay'} + mock_client = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.FuncXClient" + ) + mock_client.return_value.register_endpoint.return_value = {"status": "okay"} funcx_dir = os.getcwd() config_dir = os.path.join(funcx_dir, "mock_endpoint") - with pytest.raises(Exception, match='Endpoint ID was not included in the service\'s registration response.'): - register_endpoint(mock_client(), 'mock_endpoint_uuid', config_dir, 'test') + with pytest.raises( + Exception, + match="Endpoint ID was not included in the service's " + "registration response.", + ): + register_endpoint(mock_client(), "mock_endpoint_uuid", config_dir, "test") def test_register_endpoint_int_endpoint_id(self, mocker): - mock_client = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - mock_client.return_value.register_endpoint.return_value = {'status': 'okay', - 'endpoint_id': 123456} + mock_client = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.FuncXClient" + ) + mock_client.return_value.register_endpoint.return_value = { + "status": "okay", + "endpoint_id": 123456, + } funcx_dir = os.getcwd() config_dir = os.path.join(funcx_dir, "mock_endpoint") - with pytest.raises(Exception, match='Endpoint ID sent by the service was not a string.'): - register_endpoint(mock_client(), 'mock_endpoint_uuid', config_dir, 'test') + with pytest.raises( + Exception, match="Endpoint ID sent by the service was not a string." + ): + register_endpoint(mock_client(), "mock_endpoint_uuid", config_dir, "test") diff --git a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_manager.py b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_manager.py index 7b59a77cd..7e42a2062 100644 --- a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_manager.py +++ b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_manager.py @@ -1,27 +1,28 @@ -from funcx_endpoint.executors.high_throughput.funcx_manager import Manager -from funcx_endpoint.executors.high_throughput.messages import Task -import queue -import logging -import pickle -import zmq import os +import pickle +import queue import shutil + import pytest +from funcx_endpoint.executors.high_throughput.funcx_manager import Manager +from funcx_endpoint.executors.high_throughput.messages import Task -class TestManager: +class TestManager: @pytest.fixture(autouse=True) def test_setup_teardown(self): - os.makedirs(os.path.join(os.getcwd(), 'mock_uid')) + os.makedirs(os.path.join(os.getcwd(), "mock_uid")) yield - shutil.rmtree(os.path.join(os.getcwd(), 'mock_uid')) + shutil.rmtree(os.path.join(os.getcwd(), "mock_uid")) def test_remove_worker_init(self, mocker): # zmq is being mocked here because it was making tests hang - mocker.patch('funcx_endpoint.executors.high_throughput.funcx_manager.zmq.Context') + mocker.patch( + "funcx_endpoint.executors.high_throughput.funcx_manager.zmq.Context" + ) - manager = Manager(logdir='./', uid="mock_uid") + manager = Manager(logdir="./", uid="mock_uid") manager.worker_map.to_die_count["RAW"] = 0 manager.task_queues["RAW"] = queue.Queue() @@ -33,21 +34,33 @@ def test_remove_worker_init(self, mocker): def test_poll_funcx_task_socket(self, mocker): # zmq is being mocked here because it was making tests hang - mocker.patch('funcx_endpoint.executors.high_throughput.funcx_manager.zmq.Context') + mocker.patch( + "funcx_endpoint.executors.high_throughput.funcx_manager.zmq.Context" + ) - mock_worker_map = mocker.patch('funcx_endpoint.executors.high_throughput.funcx_manager.WorkerMap') + mock_worker_map = mocker.patch( + "funcx_endpoint.executors.high_throughput.funcx_manager.WorkerMap" + ) - manager = Manager(logdir='./', uid="mock_uid") + manager = Manager(logdir="./", uid="mock_uid") manager.task_queues["RAW"] = queue.Queue() manager.logdir = "./" - manager.worker_type = 'RAW' - manager.worker_procs['0'] = 'proc' + manager.worker_type = "RAW" + manager.worker_procs["0"] = "proc" - manager.funcx_task_socket.recv_multipart.return_value = b'0', b'REGISTER', pickle.dumps({'worker_type': 'RAW'}) + manager.funcx_task_socket.recv_multipart.return_value = ( + b"0", + b"REGISTER", + pickle.dumps({"worker_type": "RAW"}), + ) manager.poll_funcx_task_socket(test=True) - mock_worker_map.return_value.register_worker.assert_called_with(b'0', 'RAW') + mock_worker_map.return_value.register_worker.assert_called_with(b"0", "RAW") - manager.funcx_task_socket.recv_multipart.return_value = b'0', b'WRKR_DIE', pickle.dumps(None) + manager.funcx_task_socket.recv_multipart.return_value = ( + b"0", + b"WRKR_DIE", + pickle.dumps(None), + ) manager.poll_funcx_task_socket(test=True) - mock_worker_map.return_value.remove_worker.assert_called_with(b'0') + mock_worker_map.return_value.remove_worker.assert_called_with(b"0") assert len(manager.worker_procs) == 0 diff --git a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_worker.py b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_worker.py index 87e6f17e0..2f562b872 100644 --- a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_worker.py +++ b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_worker.py @@ -1,32 +1,37 @@ -from funcx_endpoint.executors.high_throughput.funcx_worker import FuncXWorker -from funcx_endpoint.executors.high_throughput.messages import Task import os import pickle +from funcx_endpoint.executors.high_throughput.funcx_worker import FuncXWorker +from funcx_endpoint.executors.high_throughput.messages import Task + class TestWorker: def test_register_and_kill(self, mocker): # we need to mock sys.exit here so that the worker while loop # can exit without the test being killed - mocker.patch('funcx_endpoint.executors.high_throughput.funcx_worker.sys.exit') + mocker.patch("funcx_endpoint.executors.high_throughput.funcx_worker.sys.exit") - mock_context = mocker.patch('funcx_endpoint.executors.high_throughput.funcx_worker.zmq.Context') + mock_context = mocker.patch( + "funcx_endpoint.executors.high_throughput.funcx_worker.zmq.Context" + ) # the worker will receive tasks and send messages on this mock socket mock_socket = mocker.Mock() mock_context.return_value.socket.return_value = mock_socket # send a kill message on the mock socket - task = Task(task_id='KILL', - container_id='RAW', - task_buffer='KILL') - mock_socket.recv_multipart.return_value = (pickle.dumps("KILL"), pickle.dumps("abc"), task.pack()) + task = Task(task_id="KILL", container_id="RAW", task_buffer="KILL") + mock_socket.recv_multipart.return_value = ( + pickle.dumps("KILL"), + pickle.dumps("abc"), + task.pack(), + ) # calling worker.start begins a while loop, where first a REGISTER # message is sent out, then the worker receives the KILL task, which # triggers a WRKR_DIE message to be sent before the while loop exits - worker = FuncXWorker('0', '127.0.0.1', 50001, os.getcwd()) + worker = FuncXWorker("0", "127.0.0.1", 50001, os.getcwd()) worker.start() # these 2 calls to send_multipart happen in a sequence - call1 = mocker.call([b'REGISTER', pickle.dumps(worker.registration_message())]) - call2 = mocker.call([b'WRKR_DIE', pickle.dumps(None)]) + call1 = mocker.call([b"REGISTER", pickle.dumps(worker.registration_message())]) + call2 = mocker.call([b"WRKR_DIE", pickle.dumps(None)]) mock_socket.send_multipart.assert_has_calls([call1, call2]) diff --git a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_worker_map.py b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_worker_map.py index 8712d3f49..53e9c2d2a 100644 --- a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_worker_map.py +++ b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_worker_map.py @@ -1,21 +1,26 @@ -from funcx_endpoint.executors.high_throughput.worker_map import WorkerMap import logging import os +from funcx_endpoint.executors.high_throughput.worker_map import WorkerMap + class TestWorkerMap: def test_add_worker(self, mocker): - mock_popen = mocker.patch('funcx_endpoint.executors.high_throughput.worker_map.subprocess.Popen') - mock_popen.return_value = 'proc' + mock_popen = mocker.patch( + "funcx_endpoint.executors.high_throughput.worker_map.subprocess.Popen" + ) + mock_popen.return_value = "proc" worker_map = WorkerMap(1) - worker = worker_map.add_worker(worker_id='0', - address='127.0.0.1', - debug=logging.DEBUG, - uid='test1', - logdir=os.getcwd(), - worker_port=50001) + worker = worker_map.add_worker( + worker_id="0", + address="127.0.0.1", + debug=logging.DEBUG, + uid="test1", + logdir=os.getcwd(), + worker_port=50001, + ) - assert list(worker.keys()) == ['0'] - assert worker['0'] == 'proc' + assert list(worker.keys()) == ["0"] + assert worker["0"] == "proc" assert worker_map.worker_id_counter == 1 diff --git a/funcx_endpoint/tests/integration/test_batch_submit.py b/funcx_endpoint/tests/integration/test_batch_submit.py index 6098a4986..c49002991 100644 --- a/funcx_endpoint/tests/integration/test_batch_submit.py +++ b/funcx_endpoint/tests/integration/test_batch_submit.py @@ -1,10 +1,10 @@ -import json -import sys import argparse import time + import funcx from funcx.sdk.client import FuncXClient from funcx.serialize import FuncXSerializer + fxs = FuncXSerializer() # funcx.set_stream_logger() @@ -16,35 +16,38 @@ def double(x): def test(fxc, ep_id, task_count=10): - fn_uuid = fxc.register_function(double, - description="Yadu double") + fn_uuid = fxc.register_function(double, description="Yadu double") print("FN_UUID : ", fn_uuid) start = time.time() - task_ids = fxc.map_run(list(range(task_count)), endpoint_id=ep_id, function_id=fn_uuid) + task_ids = fxc.map_run( + list(range(task_count)), endpoint_id=ep_id, function_id=fn_uuid + ) delta = time.time() - start - print("Time to launch {} tasks: {:8.3f} s".format(task_count, delta)) - print("Got {} tasks_ids ".format(len(task_ids))) + print(f"Time to launch {task_count} tasks: {delta:8.3f} s") + print(f"Got {len(task_ids)} tasks_ids ") for _i in range(3): x = fxc.get_batch_result(task_ids) - complete_count = sum([1 for t in task_ids if t in x and x[t].get('pending', False)]) - print("Batch status : {}/{} complete".format(complete_count, len(task_ids))) + complete_count = sum( + [1 for t in task_ids if t in x and x[t].get("pending", False)] + ) + print(f"Batch status : {complete_count}/{len(task_ids)} complete") if complete_count == len(task_ids): break time.sleep(2) delta = time.time() - start - print("Time to complete {} tasks: {:8.3f} s".format(task_count, delta)) - print("Throughput : {:8.3f} Tasks/s".format(task_count / delta)) + print(f"Time to complete {task_count} tasks: {delta:8.3f} s") + print(f"Throughput : {task_count / delta:8.3f} Tasks/s") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) parser.add_argument("-c", "--count", default="10") args = parser.parse_args() print("FuncX version : ", funcx.__version__) - fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1') + fxc = FuncXClient(funcx_service_address="https://dev.funcx.org/api/v1") test(fxc, args.endpoint, task_count=int(args.count)) diff --git a/funcx_endpoint/tests/integration/test_config.py b/funcx_endpoint/tests/integration/test_config.py index 80f3c6a42..47ed5c71f 100644 --- a/funcx_endpoint/tests/integration/test_config.py +++ b/funcx_endpoint/tests/integration/test_config.py @@ -1,12 +1,13 @@ -from funcx_endpoint.endpoint.utils.config import Config +import logging import os + import funcx -import logging +from funcx_endpoint.endpoint.utils.config import Config config = Config() -if __name__ == '__main__': +if __name__ == "__main__": funcx.set_stream_logger() logger = logging.getLogger(__file__) @@ -18,43 +19,50 @@ print("Loading : ", config) # Set script dir config.provider.script_dir = working_dir - config.provider.channel.script_dir = os.path.join(working_dir, 'submit_scripts') + config.provider.channel.script_dir = os.path.join(working_dir, "submit_scripts") config.provider.channel.makedirs(config.provider.channel.script_dir, exist_ok=True) os.makedirs(config.provider.script_dir, exist_ok=True) debug_opts = "--debug" if config.worker_debug else "" - max_workers = "" if config.max_workers_per_node == float('inf') \ - else "--max_workers={}".format(config.max_workers_per_node) + max_workers = ( + "" + if config.max_workers_per_node == float("inf") + else f"--max_workers={config.max_workers_per_node}" + ) worker_task_url = "tcp://127.0.0.1:54400" worker_result_url = "tcp://127.0.0.1:54401" - launch_cmd = ("funcx-worker {debug} {max_workers} " - "-c {cores_per_worker} " - "--poll {poll_period} " - "--task_url={task_url} " - "--result_url={result_url} " - "--logdir={logdir} " - "--hb_period={heartbeat_period} " - "--hb_threshold={heartbeat_threshold} " - "--mode={worker_mode} " - "--container_image={container_image} ") - - l_cmd = launch_cmd.format(debug=debug_opts, - max_workers=max_workers, - cores_per_worker=config.cores_per_worker, - prefetch_capacity=config.prefetch_capacity, - task_url=worker_task_url, - result_url=worker_result_url, - nodes_per_block=config.provider.nodes_per_block, - heartbeat_period=config.heartbeat_period, - heartbeat_threshold=config.heartbeat_threshold, - poll_period=config.poll_period, - worker_mode=config.worker_mode, - container_image=None, - logdir=working_dir) + launch_cmd = ( + "funcx-worker {debug} {max_workers} " + "-c {cores_per_worker} " + "--poll {poll_period} " + "--task_url={task_url} " + "--result_url={result_url} " + "--logdir={logdir} " + "--hb_period={heartbeat_period} " + "--hb_threshold={heartbeat_threshold} " + "--mode={worker_mode} " + "--container_image={container_image} " + ) + + l_cmd = launch_cmd.format( + debug=debug_opts, + max_workers=max_workers, + cores_per_worker=config.cores_per_worker, + prefetch_capacity=config.prefetch_capacity, + task_url=worker_task_url, + result_url=worker_result_url, + nodes_per_block=config.provider.nodes_per_block, + heartbeat_period=config.heartbeat_period, + heartbeat_threshold=config.heartbeat_threshold, + poll_period=config.poll_period, + worker_mode=config.worker_mode, + container_image=None, + logdir=working_dir, + ) config.launch_cmd = l_cmd - print("Launch command: {}".format(config.launch_cmd)) + print(f"Launch command: {config.launch_cmd}") if config.scaling_enabled: print("About to scale things") diff --git a/funcx_endpoint/tests/integration/test_containers.py b/funcx_endpoint/tests/integration/test_containers.py index a1bf16da9..0ee5cf4be 100644 --- a/funcx_endpoint/tests/integration/test_containers.py +++ b/funcx_endpoint/tests/integration/test_containers.py @@ -1,7 +1,5 @@ -import json -import sys import argparse -import time + import funcx from funcx.sdk.client import FuncXClient @@ -12,9 +10,11 @@ def container_sum(event): def test(fxc, ep_id): - fn_uuid = fxc.register_function(container_sum, - container_uuid='3861862b-152e-49a4-b15e-9a5da4205cad', - description="New sum function defined without string spec") + fn_uuid = fxc.register_function( + container_sum, + container_uuid="3861862b-152e-49a4-b15e-9a5da4205cad", + description="New sum function defined without string spec", + ) print("FN_UUID : ", fn_uuid) task_id = fxc.run([1, 2, 3, 9001], endpoint_id=ep_id, function_id=fn_uuid) @@ -22,7 +22,7 @@ def test(fxc, ep_id): print("Got from status :", r) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) args = parser.parse_args() diff --git a/funcx_endpoint/tests/integration/test_deserialization.py b/funcx_endpoint/tests/integration/test_deserialization.py index 36edb91f0..3e9c9b382 100644 --- a/funcx_endpoint/tests/integration/test_deserialization.py +++ b/funcx_endpoint/tests/integration/test_deserialization.py @@ -1,6 +1,7 @@ -from funcx.serialize import FuncXSerializer import numpy as np +from funcx.serialize import FuncXSerializer + def double(x, y=3): return x * y diff --git a/funcx_endpoint/tests/integration/test_executor.py b/funcx_endpoint/tests/integration/test_executor.py index 59686d2d0..0fe63d01e 100644 --- a/funcx_endpoint/tests/integration/test_executor.py +++ b/funcx_endpoint/tests/integration/test_executor.py @@ -1,6 +1,4 @@ from funcx_endpoint.executors.high_throughput.executor import HighThroughputExecutor -import logging -from funcx import set_file_logger def double(x): diff --git a/funcx_endpoint/tests/integration/test_executor_passthrough.py b/funcx_endpoint/tests/integration/test_executor_passthrough.py index 842f5e51f..b7f0a1bc2 100644 --- a/funcx_endpoint/tests/integration/test_executor_passthrough.py +++ b/funcx_endpoint/tests/integration/test_executor_passthrough.py @@ -1,13 +1,12 @@ -from funcx_endpoint.executors.high_throughput.executor import HighThroughputExecutor -import logging -from funcx import set_file_logger -import uuid -from funcx.serialize import FuncXSerializer -from funcx_endpoint.executors.high_throughput.messages import Message, Task -import time import pickle +import time +import uuid from multiprocessing import Queue +from funcx.serialize import FuncXSerializer +from funcx_endpoint.executors.high_throughput.executor import HighThroughputExecutor +from funcx_endpoint.executors.high_throughput.messages import Task + def double(x): return x * 2 @@ -17,8 +16,7 @@ def double(x): results_queue = Queue() # set_file_logger('executor.log', name='funcx_endpoint', level=logging.DEBUG) - htex = HighThroughputExecutor(interchange_local=True, - passthrough=True) + htex = HighThroughputExecutor(interchange_local=True, passthrough=True) htex.start(results_passthrough=results_queue) htex._start_remote_interchange_process() @@ -31,12 +29,11 @@ def double(x): fn_code = fx_serializer.serialize(double) ser_code = fx_serializer.pack_buffers([fn_code]) - ser_params = fx_serializer.pack_buffers([fx_serializer.serialize(args), - fx_serializer.serialize(kwargs)]) + ser_params = fx_serializer.pack_buffers( + [fx_serializer.serialize(args), fx_serializer.serialize(kwargs)] + ) - payload = Task(task_id, - 'RAW', - ser_code + ser_params) + payload = Task(task_id, "RAW", ser_code + ser_params) f = htex.submit_raw(payload.pack()) time.sleep(0.5) @@ -44,7 +41,7 @@ def double(x): result_package = results_queue.get() # print("Result package : ", result_package) r = pickle.loads(result_package) - result = fx_serializer.deserialize(r['result']) + result = fx_serializer.deserialize(r["result"]) print(f"Result:{i}: {result}") print("All done") diff --git a/funcx_endpoint/tests/integration/test_interchange.py b/funcx_endpoint/tests/integration/test_interchange.py index 6a697026d..967510b6e 100644 --- a/funcx_endpoint/tests/integration/test_interchange.py +++ b/funcx_endpoint/tests/integration/test_interchange.py @@ -1,22 +1,21 @@ import argparse -from funcx_endpoint.endpoint.utils.config import Config -from funcx_endpoint.executors.high_throughput.interchange import Interchange import funcx +from funcx_endpoint.endpoint.utils.config import Config +from funcx_endpoint.executors.high_throughput.interchange import Interchange funcx.set_stream_logger() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-a", "--address", required=True, - help="Address") - parser.add_argument("-c", "--client_ports", required=True, - help="ports") + parser.add_argument("-a", "--address", required=True, help="Address") + parser.add_argument("-c", "--client_ports", required=True, help="ports") args = parser.parse_args() config = Config() - ic = Interchange(client_address=args.address, - client_ports=[int(i) for i in args.client_ports.split(',')], - ) + ic = Interchange( + client_address=args.address, + client_ports=[int(i) for i in args.client_ports.split(",")], + ) ic.start() print("Interchange started") diff --git a/funcx_endpoint/tests/integration/test_per_func_batch.py b/funcx_endpoint/tests/integration/test_per_func_batch.py index 3103ba1e6..774696a8b 100644 --- a/funcx_endpoint/tests/integration/test_per_func_batch.py +++ b/funcx_endpoint/tests/integration/test_per_func_batch.py @@ -35,13 +35,13 @@ def test_batch3(a, b, c=2, d=2): task_ids = fx.batch_run(batch) delta = time.time() - start -print("Time to launch {} tasks: {:8.3f} s".format(task_count * len(func_ids), delta)) -print("Got {} tasks_ids ".format(len(task_ids))) +print(f"Time to launch {task_count * len(func_ids)} tasks: {delta:8.3f} s") +print(f"Got {len(task_ids)} tasks_ids ") for _i in range(10): x = fx.get_batch_result(task_ids) complete_count = sum([1 for t in task_ids if t in x and x[t].get("pending", False)]) - print("Batch status : {}/{} complete".format(complete_count, len(task_ids))) + print(f"Batch status : {complete_count}/{len(task_ids)} complete") if complete_count == len(task_ids): print(x) break diff --git a/funcx_endpoint/tests/integration/test_redis.py b/funcx_endpoint/tests/integration/test_redis.py deleted file mode 100644 index b17645259..000000000 --- a/funcx_endpoint/tests/integration/test_redis.py +++ /dev/null @@ -1,70 +0,0 @@ -import argparse -from funcx.serialize import FuncXSerializer -from funcx_endpoint.queues import RedisQueue -import time - - -def slow_double(i, duration=0): - import time - time.sleep(duration) - return i * 4 - - -def test(endpoint_id=None, tasks=10, duration=1, hostname=None, port=None): - tasks_rq = RedisQueue(f'task_{endpoint_id}', hostname) - results_rq = RedisQueue('results', hostname) - fxs = FuncXSerializer() - - ser_code = fxs.serialize(slow_double) - fn_code = fxs.pack_buffers([ser_code]) - - tasks_rq.connect() - results_rq.connect() - - while True: - try: - _ = results_rq.get(timeout=1) - except Exception: - print("No more results left") - break - - start = time.time() - for i in range(tasks): - ser_args = fxs.serialize([i]) - ser_kwargs = fxs.serialize({'duration': duration}) - input_data = fxs.pack_buffers([ser_args, ser_kwargs]) - payload = fn_code + input_data - container_id = "odd" if i % 2 else "even" - tasks_rq.put(f"0{i};{container_id}", payload) - - d1 = time.time() - start - print("Time to launch {} tasks: {:8.3f} s".format(tasks, d1)) - - print(f"Launched {tasks} tasks") - for _i in range(tasks): - _ = results_rq.get(timeout=300) - # print("Result : ", res) - - delta = time.time() - start - print("Time to complete {} tasks: {:8.3f} s".format(tasks, delta)) - print("Throughput : {:8.3f} Tasks/s".format(tasks / delta)) - return delta - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("-r", "--redis_hostname", required=True, - help="Hostname of the Redis server") - parser.add_argument("-e", "--endpoint_id", required=True, - help="Endpoint_id") - parser.add_argument("-d", "--duration", required=True, - help="Duration of the tasks") - parser.add_argument("-c", "--count", required=True, - help="Number of tasks") - - args = parser.parse_args() - - test(endpoint_id=args.endpoint_id, - hostname=args.redis_hostname, - duration=int(args.duration), - tasks=int(args.count)) diff --git a/funcx_endpoint/tests/integration/test_registration.py b/funcx_endpoint/tests/integration/test_registration.py index 1d69c4d72..e12f750e6 100644 --- a/funcx_endpoint/tests/integration/test_registration.py +++ b/funcx_endpoint/tests/integration/test_registration.py @@ -1,9 +1,8 @@ from funcx.sdk.client import FuncXClient - if __name__ == "__main__": fxc = FuncXClient() print(fxc) - fxc.register_endpoint('foobar', None) + fxc.register_endpoint("foobar", None) diff --git a/funcx_endpoint/tests/integration/test_serialization.py b/funcx_endpoint/tests/integration/test_serialization.py index b93dc30c9..bf2101161 100644 --- a/funcx_endpoint/tests/integration/test_serialization.py +++ b/funcx_endpoint/tests/integration/test_serialization.py @@ -8,7 +8,7 @@ def foo(x, y=3): def test_1(): jb = concretes.json_base64() - d = jb.serialize(([2], {'y': 10})) + d = jb.serialize(([2], {"y": 10})) args, kwargs = jb.deserialize(d) result = foo(*args, **kwargs) print(result) @@ -22,7 +22,7 @@ def test_2(): fn = jb.deserialize(f) print(fn) - assert fn(2) == 6, "Expected 6 got {}".format(fn(2)) + assert fn(2) == 6, f"Expected 6 got {fn(2)}" def test_code_1(): @@ -79,6 +79,7 @@ def bar(x, y=5): def test_overall(): from funcx.serialize.facade import FuncXSerializer + fxs = FuncXSerializer() print(fxs._list_methods()) @@ -87,7 +88,7 @@ def test_overall(): print(fxs.deserialize(x)) -if __name__ == '__main__': +if __name__ == "__main__": # test_1() # test_2() diff --git a/funcx_endpoint/tests/integration/test_status.py b/funcx_endpoint/tests/integration/test_status.py index 441665426..d406009d5 100644 --- a/funcx_endpoint/tests/integration/test_status.py +++ b/funcx_endpoint/tests/integration/test_status.py @@ -23,8 +23,9 @@ def sum_yadu_new01(event): def test(fxc, ep_id): - fn_uuid = fxc.register_function(sum_yadu_new01, - description="New sum function defined without string spec") + fn_uuid = fxc.register_function( + sum_yadu_new01, description="New sum function defined without string spec" + ) print("FN_UUID : ", fn_uuid) task_id = fxc.run([1, 2, 3, 9001], endpoint_id=ep_id, function_id=fn_uuid) @@ -34,6 +35,7 @@ def test(fxc, ep_id): def platinfo(): import platform + return platform.uname() @@ -42,23 +44,21 @@ def div_by_zero(x): def test2(fxc, ep_id): - fn_uuid = fxc.register_function(platinfo, - description="Get platform info") + fn_uuid = fxc.register_function(platinfo, description="Get platform info") print("FN_UUID : ", fn_uuid) task_id = fxc.run(endpoint_id=ep_id, function_id=fn_uuid) time.sleep(2) r = fxc.get_task_status(task_id) - if 'details' in r: - s_buf = r['details']['result'] + if "details" in r: + s_buf = r["details"]["result"] print("Result : ", fxs.deserialize(s_buf)) else: print("Got from status :", r) def test3(fxc, ep_id): - fn_uuid = fxc.register_function(div_by_zero, - description="Div by zero") + fn_uuid = fxc.register_function(div_by_zero, description="Div by zero") print("FN_UUID : ", fn_uuid) task_id = fxc.run(1099, endpoint_id=ep_id, function_id=fn_uuid) @@ -70,8 +70,7 @@ def test3(fxc, ep_id): def test4(fxc, ep_id): - fn_uuid = fxc.register_function(platinfo, - description="Get platform info") + fn_uuid = fxc.register_function(platinfo, description="Get platform info") print("FN_UUID : ", fn_uuid) task_id = fxc.run(endpoint_id=ep_id, function_id=fn_uuid) @@ -80,7 +79,7 @@ def test4(fxc, ep_id): print("Got result : ", r) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) args = parser.parse_args() diff --git a/funcx_endpoint/tests/integration/test_submits.py b/funcx_endpoint/tests/integration/test_submits.py index 9b7eb45d3..db1f67f5a 100644 --- a/funcx_endpoint/tests/integration/test_submits.py +++ b/funcx_endpoint/tests/integration/test_submits.py @@ -1,5 +1,3 @@ -import json -import sys import argparse from funcx.sdk.client import FuncXClient @@ -29,16 +27,18 @@ def sum_yadu_new01(event): def test(fxc, ep_id): - fn_uuid = fxc.register_function(sum_yadu_new01, - ep_id, # TODO: We do not need ep id here - description="New sum function defined without string spec") + fn_uuid = fxc.register_function( + sum_yadu_new01, + ep_id, # TODO: We do not need ep id here + description="New sum function defined without string spec", + ) print("FN_UUID : ", fn_uuid) res = fxc.run([1, 2, 3, 99], endpoint_id=ep_id, function_id=fn_uuid) print(res) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) args = parser.parse_args() diff --git a/funcx_endpoint/tests/integration/test_throttling.py b/funcx_endpoint/tests/integration/test_throttling.py index 4729a637c..96dc5d36a 100644 --- a/funcx_endpoint/tests/integration/test_throttling.py +++ b/funcx_endpoint/tests/integration/test_throttling.py @@ -6,44 +6,47 @@ pytest test_throttling.py """ -import pytest -import globus_sdk from unittest.mock import Mock -from funcx.sdk.utils.throttling import (ThrottledBaseClient, - MaxRequestSizeExceeded, - MaxRequestsExceeded) +import globus_sdk +import pytest + +from funcx.sdk.utils.throttling import ( + MaxRequestsExceeded, + MaxRequestSizeExceeded, + ThrottledBaseClient, +) @pytest.fixture def mock_globus_sdk(monkeypatch): - monkeypatch.setattr(globus_sdk.base.BaseClient, '__init__', Mock()) + monkeypatch.setattr(globus_sdk.base.BaseClient, "__init__", Mock()) def test_size_throttling_on_small_requests(mock_globus_sdk): cli = ThrottledBaseClient() # Should not raise - jb = {'not': 'big enough'} - cli.throttle_request_size('POST', '/my_rest_endpoint', json_body=jb) + jb = {"not": "big enough"} + cli.throttle_request_size("POST", "/my_rest_endpoint", json_body=jb) # Should not raise for these methods - cli.throttle_request_size('GET', '/my_rest_endpoint') - cli.throttle_request_size('PUT', '/my_rest_endpoint') - cli.throttle_request_size('DELETE', '/my_rest_endpoint') + cli.throttle_request_size("GET", "/my_rest_endpoint") + cli.throttle_request_size("PUT", "/my_rest_endpoint") + cli.throttle_request_size("DELETE", "/my_rest_endpoint") def test_size_throttle_on_large_request(mock_globus_sdk): cli = ThrottledBaseClient() # Test with ~2mb sized POST - jb = {'is': 'l' + 'o' * 2 * 2 ** 20 + 'ng'} + jb = {"is": "l" + "o" * 2 * 2 ** 20 + "ng"} with pytest.raises(MaxRequestSizeExceeded): - cli.throttle_request_size('POST', '/my_rest_endpoint', json_body=jb) + cli.throttle_request_size("POST", "/my_rest_endpoint", json_body=jb) # Test on text request - data = 'B' + 'i' * 2 * 2 ** 20 + 'gly' + data = "B" + "i" * 2 * 2 ** 20 + "gly" with pytest.raises(MaxRequestSizeExceeded): - cli.throttle_request_size('POST', '/my_rest_endpoint', text_body=data) + cli.throttle_request_size("POST", "/my_rest_endpoint", text_body=data) def test_low_threshold_requests_does_not_raise(mock_globus_sdk): diff --git a/funcx_endpoint/tests/smoke_tests/conftest.py b/funcx_endpoint/tests/smoke_tests/conftest.py index 981fa78bd..af87f1145 100644 --- a/funcx_endpoint/tests/smoke_tests/conftest.py +++ b/funcx_endpoint/tests/smoke_tests/conftest.py @@ -1,10 +1,12 @@ +import collections import json import os +import time import pytest -from globus_sdk import ConfidentialAppAuthClient, AccessTokenAuthorizer +from globus_sdk import AccessTokenAuthorizer, ConfidentialAppAuthClient + from funcx import FuncXClient -from funcx.sdk.executor import FuncXExecutor # the non-tutorial endpoint will be required, with the following priority order for # finding the ID: @@ -18,13 +20,25 @@ _LOCAL_ENDPOINT_ID = os.getenv("FUNCX_LOCAL_ENDPOINT_ID") _CONFIGS = { + "dev": { + "client_args": { + "funcx_service_address": "https://api.dev.funcx.org/v2", + "results_ws_uri": "wss://api.dev.funcx.org/ws/v2/", + }, + # assert versions are as expected on dev + "forwarder_min_version": "0.3.5", + "api_min_version": "0.3.5", + # This fn is public and searchable + "public_hello_fn_uuid": "f84351f9-6f82-45d8-8eca-80d8f73645be", + "endpoint_uuid": _LOCAL_ENDPOINT_ID, + }, "prod": { # By default tests are against production, which means we do not need to pass # any arguments to the client object (default will point at prod stack) "client_args": {}, # assert versions are as expected on prod - "forwarder_version": "0.3.3", - "api_version": "0.3.3", + "forwarder_min_version": "0.3.5", + "api_min_version": "0.3.5", # This fn is public and searchable "public_hello_fn_uuid": "b0a5d1a0-2b22-4381-b899-ba73321e41e0", # For production tests, the target endpoint should be the tutorial_endpoint @@ -80,13 +94,13 @@ def pytest_addoption(parser): "--api-client-id", metavar="api-client-id", default=None, - help="The API Client ID. Used for github actions testing" + help="The API Client ID. Used for github actions testing", ) parser.addoption( "--api-client-secret", metavar="api-client-secret", default=None, - help="The API Client Secret. Used for github actions testing" + help="The API Client Secret. Used for github actions testing", ) @@ -124,14 +138,22 @@ def funcx_test_config(pytestconfig, funcx_test_config_name): if api_client_id and api_client_secret: client = ConfidentialAppAuthClient(api_client_id, api_client_secret) - scopes = ["https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all", - "urn:globus:auth:scope:search.api.globus.org:all", - "openid"] - - token_response = client.oauth2_client_credentials_tokens(requested_scopes=scopes) - fx_token = token_response.by_resource_server['funcx_service']['access_token'] - search_token = token_response.by_resource_server['search.api.globus.org']['access_token'] - openid_token = token_response.by_resource_server['auth.globus.org']['access_token'] + scopes = [ + "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all", + "urn:globus:auth:scope:search.api.globus.org:all", + "openid", + ] + + token_response = client.oauth2_client_credentials_tokens( + requested_scopes=scopes + ) + fx_token = token_response.by_resource_server["funcx_service"]["access_token"] + search_token = token_response.by_resource_server["search.api.globus.org"][ + "access_token" + ] + openid_token = token_response.by_resource_server["auth.globus.org"][ + "access_token" + ] fx_auth = AccessTokenAuthorizer(fx_token) search_auth = AccessTokenAuthorizer(search_token) @@ -148,22 +170,10 @@ def funcx_test_config(pytestconfig, funcx_test_config_name): def fxc(funcx_test_config): client_args = funcx_test_config["client_args"] fxc = FuncXClient(**client_args) + fxc.throttling_enabled = False return fxc -@pytest.fixture(scope="session") -def async_fxc(funcx_test_config): - client_args = funcx_test_config["client_args"] - fxc = FuncXClient(**client_args, asynchronous=True) - return fxc - - -@pytest.fixture(scope="session") -def fx(fxc): - fx = FuncXExecutor(fxc) - return fx - - @pytest.fixture def endpoint(funcx_test_config): return funcx_test_config["endpoint_uuid"] @@ -175,3 +185,44 @@ def tutorial_funcion_id(funcx_test_config): if not funcid: pytest.skip("test requires a pre-defined public hello function") return funcid + + +FuncResult = collections.namedtuple( + "FuncResult", ["func_id", "task_id", "result", "response"] +) + + +@pytest.fixture +def submit_function_and_get_result(fxc): + def submit_fn( + endpoint_id, func=None, func_args=None, func_kwargs=None, initial_sleep=0 + ): + if callable(func): + func_id = fxc.register_function(func) + else: + func_id = func + + if func_args is None: + func_args = () + if func_kwargs is None: + func_kwargs = {} + + task_id = fxc.run( + *func_args, endpoint_id=endpoint_id, function_id=func_id, **func_kwargs + ) + + if initial_sleep: + time.sleep(initial_sleep) + + result = None + response = None + for attempt in range(10): + response = fxc.get_task(task_id) + if response.get("pending") is False: + result = response.get("result") + else: + time.sleep(attempt) + + return FuncResult(func_id, task_id, result, response) + + return submit_fn diff --git a/funcx_endpoint/tests/smoke_tests/test_async.py b/funcx_endpoint/tests/smoke_tests/test_async.py deleted file mode 100644 index 4a9eb5afe..000000000 --- a/funcx_endpoint/tests/smoke_tests/test_async.py +++ /dev/null @@ -1,19 +0,0 @@ -import asyncio -import random - - -def squared(x): - return x ** 2 - - -async def simple_task(fxc, endpoint): - squared_function = fxc.register_function(squared) - x = random.randint(0, 100) - task = fxc.run(x, endpoint_id=endpoint, function_id=squared_function) - result = await asyncio.wait_for(task, 60) - assert result == squared(x), "Got wrong answer" - - -def test_simple(async_fxc, endpoint): - """Testing basic async functionality""" - async_fxc.loop.run_until_complete(simple_task(async_fxc, endpoint)) diff --git a/funcx_endpoint/tests/smoke_tests/test_executor.py b/funcx_endpoint/tests/smoke_tests/test_executor.py deleted file mode 100644 index 9795ebe2b..000000000 --- a/funcx_endpoint/tests/smoke_tests/test_executor.py +++ /dev/null @@ -1,14 +0,0 @@ -import random - - -def double(x): - return x * 2 - - -def test_executor_basic(fx, endpoint): - """Test executor interface""" - - x = random.randint(0, 100) - fut = fx.submit(double, x, endpoint_id=endpoint) - - assert fut.result(timeout=60) == x * 2, "Got wrong answer" diff --git a/funcx_endpoint/tests/smoke_tests/test_running_functions.py b/funcx_endpoint/tests/smoke_tests/test_running_functions.py index 6e78bd3f8..e875d3ead 100644 --- a/funcx_endpoint/tests/smoke_tests/test_running_functions.py +++ b/funcx_endpoint/tests/smoke_tests/test_running_functions.py @@ -1,14 +1,12 @@ import time -def test_run_pre_registered_function(fxc, endpoint, tutorial_funcion_id): +def test_run_pre_registered_function( + endpoint, tutorial_funcion_id, submit_function_and_get_result +): """This test confirms that we are connected to the default production DB""" - fn_id = fxc.run(endpoint_id=endpoint, function_id=tutorial_funcion_id) - - time.sleep(30) - - result = fxc.get_result(fn_id) - assert result == "Hello World!", f"Expected result: Hello World!, got {result}" + r = submit_function_and_get_result(endpoint, func=tutorial_funcion_id) + assert r.result == "Hello World!" def double(x): diff --git a/funcx_endpoint/tests/smoke_tests/test_s3_indirect.py b/funcx_endpoint/tests/smoke_tests/test_s3_indirect.py new file mode 100644 index 000000000..8b73401d0 --- /dev/null +++ b/funcx_endpoint/tests/smoke_tests/test_s3_indirect.py @@ -0,0 +1,59 @@ +import pytest + +from funcx_endpoint.executors.high_throughput.funcx_worker import MaxResultSizeExceeded + + +def large_result_producer(size: int) -> str: + return bytearray(size) + + +def large_arg_consumer(data: str) -> int: + return len(data) + + +@pytest.mark.parametrize("size", [200, 2000, 20000, 200000]) +def test_allowed_result_sizes(submit_function_and_get_result, endpoint, size): + """funcX should allow all listed result sizes which are under 512KB limit""" + + r = submit_function_and_get_result( + endpoint, func=large_result_producer, func_args=(size,) + ) + assert len(r.result) == size + + +def test_result_size_too_large(submit_function_and_get_result, endpoint): + """ + funcX should raise a MaxResultSizeExceeded exception when results exceeds 10MB + limit + """ + r = submit_function_and_get_result( + endpoint, func=large_result_producer, func_args=(11 * 1024 * 1024,) + ) + assert r.result is None + assert "exception" in r.response + # the exception that comes back is a wrapper, so we must "reraise()" to get the + # true error out + with pytest.raises(MaxResultSizeExceeded): + r.response["exception"].reraise() + + +@pytest.mark.parametrize("size", [200, 2000, 20000, 200000]) +def test_allowed_arg_sizes(submit_function_and_get_result, endpoint, size): + """funcX should allow all listed result sizes which are under 512KB limit""" + r = submit_function_and_get_result( + endpoint, func=large_arg_consumer, func_args=(bytearray(size),) + ) + assert r.result == size + + +@pytest.mark.skip(reason="As of 0.3.4, an arg size limit is not being enforced") +def test_arg_size_too_large(submit_function_and_get_result, endpoint, size=55000000): + """funcX should raise an exception for objects larger than some limit, + which we are yet to define. This does not work right now. + """ + + r = submit_function_and_get_result( + endpoint, func=large_result_producer, func_args=(bytearray(550000),) + ) + assert r.result is None + assert "exception" in r.response diff --git a/funcx_endpoint/tests/smoke_tests/test_version.py b/funcx_endpoint/tests/smoke_tests/test_version.py index d9a18c6df..121f1a5cb 100644 --- a/funcx_endpoint/tests/smoke_tests/test_version.py +++ b/funcx_endpoint/tests/smoke_tests/test_version.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion + import requests @@ -13,11 +15,13 @@ def test_web_service(fxc, endpoint, funcx_test_config): ) service_version = response.json() - api_version = funcx_test_config.get("api_version") - if api_version is not None: + api_min_version = funcx_test_config.get("api_min_version") + if api_min_version is not None: + parsed_min = LooseVersion(api_min_version) + parsed_service = LooseVersion(service_version) assert ( - service_version == api_version - ), f"Expected API version:{api_version}, got {service_version}" + parsed_service >= parsed_min + ), f"Expected API version >={api_min_version}, got {service_version}" def test_forwarder(fxc, endpoint, funcx_test_config): @@ -32,11 +36,13 @@ def test_forwarder(fxc, endpoint, funcx_test_config): ) forwarder_version = response.json()["forwarder"] - expected_version = funcx_test_config.get("forwarder_version") - if expected_version: + min_version = funcx_test_config.get("forwarder_min_version") + if min_version: + parsed_min = LooseVersion(min_version) + parsed_forwarder = LooseVersion(forwarder_version) assert ( - forwarder_version == expected_version - ), f"Expected Forwarder version:{expected_version}, got {forwarder_version}" + parsed_forwarder >= parsed_min + ), f"Expected Forwarder version >= {min_version}, got {forwarder_version}" def say_hello(): diff --git a/funcx_endpoint/tests/tutorial_ep/test_tutotial_ep.py b/funcx_endpoint/tests/tutorial_ep/test_tutotial_ep.py index b92b9224c..b9ed0123d 100644 --- a/funcx_endpoint/tests/tutorial_ep/test_tutotial_ep.py +++ b/funcx_endpoint/tests/tutorial_ep/test_tutotial_ep.py @@ -1,9 +1,11 @@ -import time -import logging import argparse -import sys import copy -from globus_sdk import ConfidentialAppAuthClient, AccessTokenAuthorizer +import logging +import sys +import time + +from globus_sdk import AccessTokenAuthorizer, ConfidentialAppAuthClient + from funcx.sdk.client import FuncXClient @@ -11,11 +13,20 @@ def identity(x): return x -class TestTutorial(): - - def __init__(self, fx_auth, search_auth, openid_auth, - endpoint_id, func, expected, - args=None, timeout=15, concurrency=1, tol=1e-5): +class TestTutorial: + def __init__( + self, + fx_auth, + search_auth, + openid_auth, + endpoint_id, + func, + expected, + args=None, + timeout=15, + concurrency=1, + tol=1e-5, + ): self.endpoint_id = endpoint_id self.func = func self.expected = expected @@ -23,16 +34,20 @@ def __init__(self, fx_auth, search_auth, openid_auth, self.timeout = timeout self.concurrency = concurrency self.tol = tol - self.fxc = FuncXClient(fx_authorizer=fx_auth, - search_authorizer=search_auth, - openid_authorizer=openid_auth) + self.fxc = FuncXClient( + fx_authorizer=fx_auth, + search_authorizer=search_auth, + openid_authorizer=openid_auth, + ) self.func_uuid = self.fxc.register_function(self.func) self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.DEBUG) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) - formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s") + formatter = logging.Formatter( + "%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s" + ) handler.setFormatter(formatter) self.logger.addHandler(handler) @@ -40,14 +55,18 @@ def run(self): try: submissions = [] for _ in range(self.concurrency): - task = self.fxc.run(self.args, endpoint_id=self.endpoint_id, function_id=self.func_uuid) + task = self.fxc.run( + self.args, endpoint_id=self.endpoint_id, function_id=self.func_uuid + ) submissions.append(task) time.sleep(self.timeout) unfinished = copy.deepcopy(submissions) while True: - unfinished[:] = [task for task in unfinished if self.fxc.get_task(task)['pending']] + unfinished[:] = [ + task for task in unfinished if self.fxc.get_task(task)["pending"] + ] if not unfinished: break time.sleep(self.timeout) @@ -56,45 +75,53 @@ def run(self): for task in submissions: result = self.fxc.get_result(task) if abs(result - self.expected) > self.tol: - self.logger.exception(f'Difference for task {task}. ' - f'Returned: {result}, Expected: {self.expected}') + self.logger.exception( + f"Difference for task {task}. " + f"Returned: {result}, Expected: {self.expected}" + ) else: success += 1 - self.logger.info(f'{success}/{self.concurrency} tasks completed successfully') + self.logger.info( + f"{success}/{self.concurrency} tasks completed successfully" + ) except KeyboardInterrupt: - self.logger.info('Cancelled by keyboard interruption') + self.logger.info("Cancelled by keyboard interruption") except Exception as e: - self.logger.exception(f'Encountered exception: {e}') + self.logger.exception(f"Encountered exception: {e}") raise if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--tutorial", required=True, - help="Tutorial Endpoint ID") - parser.add_argument("-i", "--id", required=True, - help="API_CLIENT_ID for Globus") - parser.add_argument("-s", "--secret", required=True, - help="API_CLIENT_SECRET for Globus") + parser.add_argument("-t", "--tutorial", required=True, help="Tutorial Endpoint ID") + parser.add_argument("-i", "--id", required=True, help="API_CLIENT_ID for Globus") + parser.add_argument( + "-s", "--secret", required=True, help="API_CLIENT_SECRET for Globus" + ) args = parser.parse_args() client = ConfidentialAppAuthClient(args.id, args.secret) - scopes = ["https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all", - "urn:globus:auth:scope:search.api.globus.org:all", - "openid"] + scopes = [ + "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all", + "urn:globus:auth:scope:search.api.globus.org:all", + "openid", + ] token_response = client.oauth2_client_credentials_tokens(requested_scopes=scopes) - fx_token = token_response.by_resource_server['funcx_service']['access_token'] - search_token = token_response.by_resource_server['search.api.globus.org']['access_token'] - openid_token = token_response.by_resource_server['auth.globus.org']['access_token'] + fx_token = token_response.by_resource_server["funcx_service"]["access_token"] + search_token = token_response.by_resource_server["search.api.globus.org"][ + "access_token" + ] + openid_token = token_response.by_resource_server["auth.globus.org"]["access_token"] fx_auth = AccessTokenAuthorizer(fx_token) search_auth = AccessTokenAuthorizer(search_token) openid_auth = AccessTokenAuthorizer(openid_token) val = 1 - tt = TestTutorial(fx_auth, search_auth, openid_auth, - args.tutorial, identity, val, args=val) + tt = TestTutorial( + fx_auth, search_auth, openid_auth, args.tutorial, identity, val, args=val + ) tt.run() diff --git a/funcx_endpoint/tests/version_mismatch_tests/README.rst b/funcx_endpoint/tests/version_mismatch_tests/README.rst new file mode 100644 index 000000000..24b74400e --- /dev/null +++ b/funcx_endpoint/tests/version_mismatch_tests/README.rst @@ -0,0 +1,45 @@ +Setting up the envs +------------------- + +Step.1: To run these tests first create 4 conda envs with the appropriate python3 versions + +```bash +conda create --name=funcx_version_mismatch_py3.6 'python=3.6' +conda create --name=funcx_version_mismatch_py3.7 'python=3.7' +conda create --name=funcx_version_mismatch_py3.8 'python=3.8' +conda create --name=funcx_version_mismatch_py3.9 'python=3.9' +``` + +Step.2: Next checkout the branch `relax_version_match_constraints` and run the `update_all.sh` script +to install the locally checked out code. Run the `update_all.sh` script like this: + +```bash + +./update_all.sh +``` + +Step.3: Update the `config.py` file with the path to your `conda.sh` script. + +Create an endpoint +------------------ + +Step.4: You need an endpoint running locally named `mismatched` + +``` +funcx-endpoint configure mismatched +``` + +You do not need to start, or configure this EP. The tests below will copy over configs. + +Running the tests +----------------- + +Step.5: Run the tests like this: + +``` +bash -i $PWD/run_test_matrix.sh $PWD +``` + + + + diff --git a/funcx_endpoint/tests/version_mismatch_tests/config.py b/funcx_endpoint/tests/version_mismatch_tests/config.py new file mode 100644 index 000000000..bd075d6b4 --- /dev/null +++ b/funcx_endpoint/tests/version_mismatch_tests/config.py @@ -0,0 +1,40 @@ +import os + +from parsl.providers import LocalProvider + +from funcx_endpoint.endpoint.utils.config import Config +from funcx_endpoint.executors import HighThroughputExecutor + +CONDA_ENV = os.environ["WORKER_CONDA_ENV"] +print(f"Using conda env:{CONDA_ENV} for worker_init") + +config = Config( + executors=[ + HighThroughputExecutor( + provider=LocalProvider( + init_blocks=1, + min_blocks=0, + max_blocks=1, + # FIX ME: Update conda.sh file to match your paths + worker_init=( + "source ~/anaconda3/etc/profile.d/conda.sh; " + f"conda activate {CONDA_ENV}; " + "python3 --version" + ), + ), + ) + ], + funcx_service_address="https://api2.funcx.org/v2", +) + +# For now, visible_to must be a list of URNs for globus auth users or groups, e.g.: +# urn:globus:auth:identity:{user_uuid} +# urn:globus:groups:id:{group_uuid} +meta = { + "name": "0.3.3", + "description": "", + "organization": "", + "department": "", + "public": False, + "visible_to": [], +} diff --git a/funcx_endpoint/tests/version_mismatch_tests/run_test_matrix.sh b/funcx_endpoint/tests/version_mismatch_tests/run_test_matrix.sh new file mode 100644 index 000000000..0bf598e34 --- /dev/null +++ b/funcx_endpoint/tests/version_mismatch_tests/run_test_matrix.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +fn + +pwd=$1 +cd $pwd + +VERSIONS=("3.6" "3.7" "3.8" "3.9") +conda init bash + +# This will fail if the endpoint configuration already exists +# funcx-endpoint configure mismatched + +# Copy over configs +cp config.py ~/.funcx/mismatched + + +start_endpoint() { + worker_v=$1 + ep_v=$2 + export WORKER_CONDA_ENV=$w_env_name + export WORKER_CONDA_ENV=$w_env_name + echo "Running ep with $WORKER_CONDA_ENV with python $worker_v" + funcx-endpoint start mismatched + endpoint_id=$(funcx-endpoint list | grep mismatched | awk '{print $6}') + sleep 2 + python3 $pwd/test_mismatched.py -e $endpoint_id -w $worker_v -v $ep_v + if [ $? -eq 0 ]; then + echo "TEST PASSED, EP_PY_V:$ep_v WORKER_PY_V:$worker_v" + else + echo "TEST FAILED, EP_PY_V:$ep_v WORKER_PY_V:$worker_v" + return 1 + fi + echo "Stopping endpoint in 2s" + sleep 2 + funcx-endpoint stop mismatched + sleep 2 +} + +run_test () { + funcx_endpoint list mismatched + python3 test_mismatched.py $worker_v +} + + +for ep_v in ${VERSIONS[*]} +# for ep_v in "3.7" +do + for worker_v in ${VERSIONS[*]} + # for worker_v in "3.9" + do + ep_env_name="funcx_version_mismatch_py$ep_v" + w_env_name="funcx_version_mismatch_py$worker_v" + echo "Testing EP:python=$ep_v against Worker:python=$worker_v" + conda activate $ep_env_name + which funcx-endpoint + start_endpoint $worker_v $ep_v + if [ $? -ne 0 ] ; then + echo Aborting tests due to failure. + exit 1 + fi + done +done + diff --git a/funcx_endpoint/tests/version_mismatch_tests/test_mismatched.py b/funcx_endpoint/tests/version_mismatch_tests/test_mismatched.py new file mode 100644 index 000000000..a189e2e24 --- /dev/null +++ b/funcx_endpoint/tests/version_mismatch_tests/test_mismatched.py @@ -0,0 +1,112 @@ +import argparse + +from funcx import FuncXClient +from funcx.sdk.executor import FuncXExecutor +from funcx_endpoint.executors.high_throughput.interchange import ManagerLost + + +def get_version(): + import sys + + return f"{sys.version_info.major}.{sys.version_info.minor}" + + +def raise_error(): + import sys + + raise ValueError(f"{sys.version_info.major}.{sys.version_info.minor}") + + +def kill_managers(): + import os + + os.system("killall funcx-manager") + + +def test_worker_version(fx, ep_id, ep_version, version): + import sys + + print(f"Running a version check against endpoint:{ep_id}") + future = fx.submit(get_version, endpoint_id=ep_id) + print(f"Future launched with future:{future}") + try: + print(f"Expected worker_version : {version}, actual: {future.result()}") + assert ( + future.result(timeout=10) == version + ), f"Expected worker version:{version} Got:{future.result()}" + + except Exception: + print(f"Expected worker_version : {version}, actual: {future.result()}") + sys.exit(1) + else: + print(f"Worker returned the expected version:{future.result()}") + + +def test_app_exception(fx, ep_id, ep_version, version): + import sys + + print(f"Checking exceptions from app on endpoint:{ep_id}") + future = fx.submit(raise_error, endpoint_id=ep_id) + print(f"Future launched with future:{future}") + try: + future.result(timeout=120) + except ValueError: + print("Worker returned the correct exception") + except Exception as e: + print("Wrong exception type...") + print(f"Wrong exception type, {type(e)}") + print(f"Expected ValueError, actual: {repr(e)}") + print("Exiting because of wrong exception") + sys.exit(1) + else: + print("No exception, expected ValueError") + sys.exit(1) + + +def test_kill_manager(fx, ep_id, ep_version, version): + import sys + + print("Testing manager kill to hopefully provoke a ManagerLost") + + future = fx.submit(kill_managers, endpoint_id=ep_id) + print(f"Future launched with future:{future}") + try: + future.result(timeout=600) # leave a longish time for this timeout... + except ManagerLost as me: + print(f"Worker returned the correct exception: {repr(me)}") + except Exception as e: + print(f"Expected ValueError, actual: {repr(e)}") + print("Exiting...") + sys.exit(1) + else: + print("No exception, expected ValueError") + sys.exit(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-e", + "--endpoint_id", + required=True, + help="Target endpoint to send functions to", + ) + parser.add_argument( + "-v", + "--ep_version", + required=True, + help="EP VERSION", + ) + parser.add_argument( + "-w", + "--worker_version", + required=True, + help="Target endpoint to send functions to", + ) + + args = parser.parse_args() + + fx = FuncXExecutor(FuncXClient()) + test_worker_version(fx, args.endpoint_id, args.ep_version, args.worker_version) + test_app_exception(fx, args.endpoint_id, args.ep_version, args.worker_version) + test_kill_manager(fx, args.endpoint_id, args.ep_version, args.worker_version) diff --git a/funcx_endpoint/tests/version_mismatch_tests/update_all.sh b/funcx_endpoint/tests/version_mismatch_tests/update_all.sh new file mode 100755 index 000000000..5d1b6a952 --- /dev/null +++ b/funcx_endpoint/tests/version_mismatch_tests/update_all.sh @@ -0,0 +1,20 @@ +#!/bin/bash -ex + +VERSIONS=("3.6" "3.7" "3.8" "3.9") +SRC_PATH=$1 + +conda init bash +for py_v in ${VERSIONS[*]} +do + env_name="funcx_version_mismatch_py$py_v" + echo "Updating $env_name" + conda deactivate + conda activate $env_name + echo $CONDA_PREFIX + pushd . + cd $SRC_PATH + pip install ./funcx_sdk ./funcx_endpoint + pip install parsl + popd +done + diff --git a/funcx_sdk/funcx/__init__.py b/funcx_sdk/funcx/__init__.py index d17d8b1f6..500ed6b20 100644 --- a/funcx_sdk/funcx/__init__.py +++ b/funcx_sdk/funcx/__init__.py @@ -1,12 +1,11 @@ """ funcX : Fast function serving for clouds, clusters and supercomputers. """ -import logging - from funcx.sdk.version import VERSION __author__ = "The funcX team" __version__ = VERSION from funcx.sdk.client import FuncXClient -from funcx.utils.loggers import set_file_logger, set_stream_logger + +__all__ = ("FuncXClient",) diff --git a/funcx_sdk/funcx/sdk/__init__.py b/funcx_sdk/funcx/sdk/__init__.py index 7775be9a4..8435c7f08 100644 --- a/funcx_sdk/funcx/sdk/__init__.py +++ b/funcx_sdk/funcx/sdk/__init__.py @@ -1 +1,3 @@ from funcx.sdk.version import VERSION + +__all__ = ("VERSION",) diff --git a/funcx_sdk/funcx/sdk/asynchronous/ws_polling_task.py b/funcx_sdk/funcx/sdk/asynchronous/ws_polling_task.py index a60a34348..02938097d 100644 --- a/funcx_sdk/funcx/sdk/asynchronous/ws_polling_task.py +++ b/funcx_sdk/funcx/sdk/asynchronous/ws_polling_task.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from asyncio import AbstractEventLoop, QueueEmpty +from asyncio import AbstractEventLoop import dill import websockets @@ -13,7 +13,7 @@ from funcx.sdk.asynchronous.funcx_task import FuncXTask -logger = logging.getLogger("asyncio") +log = logging.getLogger(__name__) class WebSocketPollingTask: @@ -72,8 +72,9 @@ def __init__( # the WebSocket server immediately self.running_task_group_ids.add(self.init_task_group_id) - # Set event loop explicitly since event loop can only be fetched automatically in main thread - # when batch submission is enabled, the task submission is in a new thread + # Set event loop explicitly since event loop can only be fetched automatically + # in main thread when batch submission is enabled, the task submission is in a + # new thread asyncio.set_event_loop(self.loop) self.task_group_ids_queue = asyncio.Queue() self.pending_tasks = {} @@ -91,12 +92,13 @@ async def init_ws(self, start_message_handlers=True): self.ws = await websockets.connect( self.results_ws_uri, extra_headers=headers ) - # initial Globus authentication happens during the HTTP portion of the handshake, - # so an invalid handshake means that the user was not authenticated + # initial Globus authentication happens during the HTTP portion of the + # handshake, so an invalid handshake means that the user was not authenticated except InvalidStatusCode as e: if e.status_code == 404: raise Exception( - "WebSocket service responsed with a 404. Please ensure you set the correct results_ws_uri" + "WebSocket service responsed with a 404. " + "Please ensure you set the correct results_ws_uri" ) else: raise e @@ -117,7 +119,19 @@ async def send_outgoing(self, queue: asyncio.Queue): task_group_id = await queue.get() await self.ws.send(task_group_id) - async def handle_incoming(self, pending_futures, auto_close=False): + async def handle_incoming(self, pending_futures, auto_close=False) -> bool: + """ + + Parameters + ---------- + pending_futures + auto_close + + Returns + ------- + True -- If connection is closing from internal shutdown process + False -- External disconnect - to be handled by reconnect logic + """ while True: try: raw_data = await asyncio.wait_for(self.ws.recv(), timeout=1.0) @@ -125,23 +139,26 @@ async def handle_incoming(self, pending_futures, auto_close=False): pass except ConnectionClosedOK: if self.closed_by_main_thread: - logger.debug("WebSocket connection closed by main thread") + log.info("WebSocket connection closed by main thread") + return True else: - logger.error("WebSocket connection closed unexpectedly") - return + log.info("WebSocket connection closed by remote-side") + return False else: data = json.loads(raw_data) task_id = data["task_id"] if task_id in pending_futures: if await self.set_result(task_id, data, pending_futures): - return + return True else: - # This scenario occurs rarely using non-batching mode, - # but quite often in batching mode. - # When submitting tasks in batch with batch_run, - # some task results may be received by websocket before the response of batch_run, + # This scenario occurs rarely using non-batching mode, but quite + # often in batching mode. + # + # When submitting tasks in batch with batch_run, some task results + # may be received by websocket before the response of batch_run, # and pending_futures do not have the futures for the tasks yet. - # We store these in unknown_results and process when their futures are ready. + # We store these in unknown_results and process when their futures + # are ready. self.unknown_results[task_id] = data # Handle the results received but not processed before @@ -149,7 +166,7 @@ async def handle_incoming(self, pending_futures, auto_close=False): for task_id in unprocessed_task_ids: data = self.unknown_results.pop(task_id) if await self.set_result(task_id, data, pending_futures): - return + return True async def set_result(self, task_id, data, pending_futures): """Sets the result of a future with given task_id in the pending_futures map, @@ -186,10 +203,10 @@ async def set_result(self, task_id, data, pending_futures): else: future.set_exception(Exception(data["reason"])) except Exception: - logger.exception("Caught unexpected exception while setting results") + log.exception("Caught unexpected exception while setting results") - # When the counter hits 0 we always exit. This guarantees that that - # if the counter increments to 1 on the executor, this handler needs to be restarted. + # When the counter hits 0 we always exit. This guarantees that that if the + # counter increments to 1 on the executor, this handler needs to be restarted. if self.atomic_controller is not None: count = self.atomic_controller.decrement() # Only close when count == 0 and unknown_results are empty @@ -199,6 +216,11 @@ async def set_result(self, task_id, data, pending_futures): return True return False + async def close(self): + """Close underlying web-sockets, does not stop listeners directly""" + await self.ws.close() + self.ws = None + def put_task_group_id(self, task_group_id): # prevent the task_group_id from being sent to the WebSocket server # multiple times @@ -216,14 +238,19 @@ def add_task(self, task: FuncXTask): def get_auth_header(self): """ - Gets an Authorization header to be sent during the WebSocket handshake. Based on - header setting in the Globus SDK: https://github.com/globus/globus-sdk-python/blob/main/globus_sdk/base.py + Gets an Authorization header to be sent during the WebSocket handshake. Returns ------- Key-value tuple of the Authorization header (key, value) """ + # TODO: under SDK v3 this will be + # + # return ( + # "Authorization", + # self.funcx_client.authorizer.get_authorization_header()` + # ) headers = dict() self.funcx_client.authorizer.set_authorization_header(headers) header_name = "Authorization" diff --git a/funcx_sdk/funcx/sdk/client.py b/funcx_sdk/funcx/sdk/client.py index 94d7108f9..238ad75c3 100644 --- a/funcx_sdk/funcx/sdk/client.py +++ b/funcx_sdk/funcx/sdk/client.py @@ -28,6 +28,8 @@ logger = logging.getLogger(__name__) +_FUNCX_HOME = os.path.join("~", ".funcx") + class FuncXClient(FuncXErrorHandlingClient): """Main class for interacting with the funcX service @@ -48,7 +50,7 @@ class FuncXClient(FuncXErrorHandlingClient): def __init__( self, http_timeout=None, - funcx_home=os.path.join("~", ".funcx"), + funcx_home=_FUNCX_HOME, force_login=False, fx_authorizer=None, search_authorizer=None, diff --git a/funcx_sdk/funcx/sdk/error_handling_client.py b/funcx_sdk/funcx/sdk/error_handling_client.py index b1336ad6e..6faaad319 100644 --- a/funcx_sdk/funcx/sdk/error_handling_client.py +++ b/funcx_sdk/funcx/sdk/error_handling_client.py @@ -6,7 +6,9 @@ class FuncXErrorHandlingClient(ThrottledBaseClient): - """Class which handles errors from GET, POST, and DELETE requests before proceeding""" + """ + Class which handles errors from GET, POST, and DELETE requests before proceeding + """ def get(self, path, **kwargs): try: diff --git a/funcx_sdk/funcx/sdk/executor.py b/funcx_sdk/funcx/sdk/executor.py index 61b599688..382421618 100644 --- a/funcx_sdk/funcx/sdk/executor.py +++ b/funcx_sdk/funcx/sdk/executor.py @@ -10,7 +10,7 @@ from funcx.sdk.asynchronous.ws_polling_task import WebSocketPollingTask -logger = logging.getLogger("asyncio") +log = logging.getLogger(__name__) class AtomicController: @@ -104,7 +104,7 @@ def __init__( atexit.register(self.shutdown) if self.batch_enabled: - logger.info("Batch submission enabled.") + log.info("Batch submission enabled.") self.start_batching_thread() def start_batching_thread(self): @@ -117,7 +117,7 @@ def start_batching_thread(self): ) self.task_submit_thread.daemon = True self.task_submit_thread.start() - logger.info("Started task submit thread") + log.info("Started task submit thread") def submit(self, function, *args, endpoint_id=None, container_uuid=None, **kwargs): """Initiate an invocation @@ -145,7 +145,7 @@ def submit(self, function, *args, endpoint_id=None, container_uuid=None, **kwarg if function not in self._function_registry: # Please note that this is a partial implementation, not all function # registration options are fleshed out here. - logger.debug(f"Function:{function} is not registered. Registering") + log.debug(f"Function:{function} is not registered. Registering") try: function_id = self.funcx_client.register_function( function, @@ -153,11 +153,11 @@ def submit(self, function, *args, endpoint_id=None, container_uuid=None, **kwarg container_uuid=container_uuid, ) except Exception: - logger.error(f"Error in registering {function.__name__}") + log.error(f"Error in registering {function.__name__}") raise else: self._function_registry[function] = function_id - logger.debug(f"Function registered with id:{function_id}") + log.debug(f"Function registered with id:{function_id}") task_id = self._task_counter self._task_counter += 1 @@ -190,13 +190,13 @@ def task_submit_thread(self, kill_event): while not kill_event.is_set(): messages = self._get_tasks_in_batch() if messages: - logger.info( + log.info( "[TASK_SUBMIT_THREAD] Submitting {} tasks to funcX".format( len(messages) ) ) self._submit_tasks(messages) - logger.info("[TASK_SUBMIT_THREAD] Exiting") + log.info("[TASK_SUBMIT_THREAD] Exiting") def _submit_tasks(self, messages): """Submit a batch of tasks""" @@ -212,12 +212,12 @@ def _submit_tasks(self, messages): batch.add( *args, **kwargs, endpoint_id=endpoint_id, function_id=function_id ) - logger.debug(f"[TASK_SUBMIT_THREAD] Adding msg {msg} to funcX batch") + log.debug(f"[TASK_SUBMIT_THREAD] Adding msg {msg} to funcX batch") try: batch_tasks = self.funcx_client.batch_run(batch) - logger.debug(f"Batch submitted to task_group: {self.task_group_id}") + log.debug(f"Batch submitted to task_group: {self.task_group_id}") except Exception: - logger.error( + log.error( "[TASK_SUBMIT_THREAD] Error submitting {} tasks to funcX".format( len(messages) ) @@ -231,7 +231,10 @@ def _submit_tasks(self, messages): self.poller_thread.atomic_controller.increment() def _get_tasks_in_batch(self): - """Get tasks from task_outgoing queue in batch, either by interval or by batch size""" + """ + Get tasks from task_outgoing queue in batch, + either by interval or by batch size + """ messages = [] start = time.time() while True: @@ -252,7 +255,7 @@ def shutdown(self): self.poller_thread.shutdown() if self.batch_enabled: self._kill_event.set() - logger.debug(f"Executor:{self.label} shutting down") + log.debug(f"Executor:{self.label} shutting down") def noop(): @@ -305,19 +308,29 @@ def start(self): self.thread = threading.Thread(target=self.event_loop_thread, args=(eventloop,)) self.thread.daemon = True self.thread.start() - logger.debug("Started web_socket_poller thread") + log.debug("Started web_socket_poller thread") def event_loop_thread(self, eventloop): asyncio.set_event_loop(eventloop) eventloop.run_until_complete(self.web_socket_poller()) async def web_socket_poller(self): - # TODO: if WebSocket connection fails, we should either retry connecting and back off - # or we should set an exception to all of the outstanding futures - await self.ws_handler.init_ws(start_message_handlers=False) - await self.ws_handler.handle_incoming( - self._function_future_map, auto_close=True - ) + """Start ws and listen for tasks. + If a remote disconnect breaks the ws, close the ws and reconnect""" + while True: + await self.ws_handler.init_ws(start_message_handlers=False) + status = await self.ws_handler.handle_incoming( + self._function_future_map, auto_close=True + ) + if status is False: + # handle_incoming broke from a remote side disconnect + # we should close and re-connect + log.info("Attempting ws close") + await self.ws_handler.close() + log.info("Attempting ws re-connect") + else: + # clean exit + break def shutdown(self): if self.ws_handler is None: diff --git a/funcx_sdk/funcx/sdk/search.py b/funcx_sdk/funcx/sdk/search.py index ed644fa56..016c190e3 100644 --- a/funcx_sdk/funcx/sdk/search.py +++ b/funcx_sdk/funcx/sdk/search.py @@ -81,7 +81,8 @@ def search_function(self, q, offset=0, limit=DEFAULT_SEARCH_LIMIT, advanced=Fals # print(res) # Restructure results to look like the data dict in FuncXClient - # see the JSON structure of res.data: https://docs.globus.org/api/search/search/#gsearchresult + # see the JSON structure of res.data: + # https://docs.globus.org/api/search/search/#gsearchresult gmeta = response.data["gmeta"] results = [] for item in gmeta: @@ -130,7 +131,7 @@ def search_endpoint(self, q, scope="all", owner_id=None): } elif scope == "shared-with-me": # TODO: filter for public=False AND owner != self._owner_uuid - # but...need to build advanced query for that, because GFilters cannot do NOT + # need to build advanced query for that, because GFilters cannot do NOT # raise Exception('This scope has not been implemented') scope_filter = { "type": "match_all", diff --git a/funcx_sdk/funcx/sdk/utils/futures.py b/funcx_sdk/funcx/sdk/utils/futures.py index 117bd6ca8..8e570dba2 100644 --- a/funcx_sdk/funcx/sdk/utils/futures.py +++ b/funcx_sdk/funcx/sdk/utils/futures.py @@ -2,7 +2,6 @@ Credit: Logan Ward """ -import json from concurrent.futures import Future from threading import Thread from time import sleep diff --git a/funcx_sdk/funcx/tests/README.rst b/funcx_sdk/funcx/tests/README.rst index 2e0f34801..9579d3586 100644 --- a/funcx_sdk/funcx/tests/README.rst +++ b/funcx_sdk/funcx/tests/README.rst @@ -24,9 +24,8 @@ From Pypi conda create -y --name funcx_testing_py3.8 python=3.8 conda activate funcx_testing_py3.8 - pip install funcx==0.0.6a5 + pip install 'funcx[test]==0.0.6a5' pip install funcx-endpoint==0.0.6a5 - pip install -r ./funcx_sdk/test-requirements.txt From Source ^^^^^^^^^^^ @@ -37,9 +36,8 @@ Here's a sequence of steps that should be copy-pastable: conda create -y --name funcx_testing_py3.8 python=3.8 conda activate funcx_testing_py3.8 - pip install ./funcx_sdk/ + pip install './funcx_sdk[test] pip install ./funcx_endpoint/ - pip install -r ./funcx_sdk/test-requirements.txt Setup an endpoint ----------------- diff --git a/funcx_sdk/funcx/tests/conftest.py b/funcx_sdk/funcx/tests/conftest.py index 6381cba1f..9a52ad832 100644 --- a/funcx_sdk/funcx/tests/conftest.py +++ b/funcx_sdk/funcx/tests/conftest.py @@ -4,7 +4,7 @@ from funcx.sdk.executor import FuncXExecutor config = { - "funcx_service_address": "https://api2.funcx.org/v2", # For testing against local k8s + "funcx_service_address": "https://api2.funcx.org/v2", "endpoint_uuid": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "results_ws_uri": "wss://api2.funcx.org/ws/v2/", } diff --git a/funcx_sdk/funcx/tests/test_errors/test_invalid_endpoint.py b/funcx_sdk/funcx/tests/test_errors/test_invalid_endpoint.py index 98f9d6ceb..6ac59713f 100644 --- a/funcx_sdk/funcx/tests/test_errors/test_invalid_endpoint.py +++ b/funcx_sdk/funcx/tests/test_errors/test_invalid_endpoint.py @@ -1,6 +1,5 @@ import pytest -from funcx.sdk.client import FuncXClient from funcx.utils.response_errors import EndpointNotFound diff --git a/funcx_sdk/funcx/tests/test_errors/test_invalid_function.py b/funcx_sdk/funcx/tests/test_errors/test_invalid_function.py index a4c1d47c4..70ffdbb1a 100644 --- a/funcx_sdk/funcx/tests/test_errors/test_invalid_function.py +++ b/funcx_sdk/funcx/tests/test_errors/test_invalid_function.py @@ -1,8 +1,5 @@ -import time - import pytest -from funcx.sdk.client import FuncXClient from funcx.utils.response_errors import FunctionNotFound diff --git a/funcx_sdk/funcx/tests/test_executor.py b/funcx_sdk/funcx/tests/test_executor.py index 2251f45bf..2206021ab 100644 --- a/funcx_sdk/funcx/tests/test_executor.py +++ b/funcx_sdk/funcx/tests/test_executor.py @@ -177,7 +177,8 @@ def test_batch_delays(batch_fx, endpoint): # test locally: python3 test_executor.py -e -# test on dev: python3 test_executor.py -s https://api2.dev.funcx.org/v2 -w wss://api2.dev.funcx.org/ws/v2/ -e +# test on dev: +# python3 test_executor.py -s https://api2.dev.funcx.org/v2 -w wss://api2.dev.funcx.org/ws/v2/ -e # noqa:E501 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/funcx_sdk/funcx/tests/test_longduration.py b/funcx_sdk/funcx/tests/test_longduration.py new file mode 100644 index 000000000..c52d47bfc --- /dev/null +++ b/funcx_sdk/funcx/tests/test_longduration.py @@ -0,0 +1,37 @@ +import random +import time + + +def double(x): + return x * 2 + + +def failing_task(): + raise IndexError() + + +def delay_n(n): + import time + + time.sleep(n) + return n + + +def noop(): + return + + +def test_random_delay(fx, endpoint, base_delay=600, n=10): + """Tests tasks that run 10mins which is the websocket disconnect period""" + + futures = {} + for _i in range(n): + delay = base_delay + random.randint(10, 30) + fut = fx.submit(delay_n, delay, endpoint_id=endpoint) + futures[fut] = delay + + time.sleep(3) + + for fut in futures: + assert fut.result(timeout=700) == futures[fut] + print(f"I slept for {fut.result()} seconds") diff --git a/funcx_sdk/funcx/tests/test_performance/test_performance.py b/funcx_sdk/funcx/tests/test_performance/test_performance.py index 8cf6e3795..f25193622 100644 --- a/funcx_sdk/funcx/tests/test_performance/test_performance.py +++ b/funcx_sdk/funcx/tests/test_performance/test_performance.py @@ -2,8 +2,6 @@ import pytest -from funcx.sdk.client import FuncXClient - def double(x): return x * 2 diff --git a/funcx_sdk/funcx/tests/test_result_size.py b/funcx_sdk/funcx/tests/test_result_size.py index ec18b8a11..40a9fbbb7 100644 --- a/funcx_sdk/funcx/tests/test_result_size.py +++ b/funcx_sdk/funcx/tests/test_result_size.py @@ -1,6 +1,5 @@ import pytest -from funcx.sdk.client import FuncXClient from funcx.utils.errors import TaskPending from funcx_endpoint.executors.high_throughput.funcx_worker import MaxResultSizeExceeded @@ -46,8 +45,11 @@ def test_allowed_result_sizes(fxc, endpoint, size): assert len(x) == size, "Result size does not match excepted size" -def test_result_size_too_large(fxc, endpoint, size=550000): - """funcX should raise a MaxResultSizeExceeded exception when results exceeds 512KB limit""" +def test_result_size_too_large(fxc, endpoint, size=11 * 1024 * 1024): + """ + funcX should raise a MaxResultSizeExceeded exception when results exceeds 10MB + limit + """ fn_uuid = fxc.register_function( large_result_producer, endpoint, description="LargeResultProducer" ) diff --git a/funcx_sdk/funcx/utils/loggers.py b/funcx_sdk/funcx/utils/loggers.py deleted file mode 100644 index 6eb7b0f51..000000000 --- a/funcx_sdk/funcx/utils/loggers.py +++ /dev/null @@ -1,78 +0,0 @@ -import logging -from logging.handlers import RotatingFileHandler - -file_format_string = ( - "%(asctime)s.%(msecs)03d %(name)s:%(lineno)d [%(levelname)s] %(message)s" -) - - -stream_format_string = "%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s" - - -def set_file_logger( - filename, - name="funcx", - level=logging.DEBUG, - format_string=None, - max_bytes=100 * 1024 * 1024, - backup_count=1, -): - """Add a stream log handler. - - Args: - - filename (string): Name of the file to write logs to - - name (string): Logger name - - level (logging.LEVEL): Set the logging level. - - format_string (string): Set the format string - - maxBytes: The maximum bytes per logger file, default: 100MB - - backupCount: The number of backup (must be non-zero) per logger file, default: 1 - - Returns: - - None - """ - if format_string is None: - format_string = file_format_string - - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - handler = RotatingFileHandler( - filename, maxBytes=max_bytes, backupCount=backup_count - ) - handler.setLevel(level) - formatter = logging.Formatter(format_string, datefmt="%Y-%m-%d %H:%M:%S") - handler.setFormatter(formatter) - logger.addHandler(handler) - - ws_logger = logging.getLogger("asyncio") - ws_logger.addHandler(handler) - return logger - - -def set_stream_logger(name="funcx", level=logging.DEBUG, format_string=None): - """Add a stream log handler. - - Args: - - name (string) : Set the logger name. - - level (logging.LEVEL) : Set to logging.DEBUG by default. - - format_string (string) : Set to None by default. - - Returns: - - None - """ - if format_string is None: - format_string = stream_format_string - - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - handler = logging.StreamHandler() - handler.setLevel(level) - formatter = logging.Formatter(format_string, datefmt="%Y-%m-%d %H:%M:%S") - handler.setFormatter(formatter) - logger.addHandler(handler) - - ws_logger = logging.getLogger("asyncio") - ws_logger.addHandler(handler) - return logger - - -logging.getLogger("funcx").addHandler(logging.NullHandler()) diff --git a/funcx_sdk/funcx/utils/response_errors.py b/funcx_sdk/funcx/utils/response_errors.py index 085e14c9d..6b9a9c33c 100644 --- a/funcx_sdk/funcx/utils/response_errors.py +++ b/funcx_sdk/funcx/utils/response_errors.py @@ -2,8 +2,10 @@ from enum import Enum -# IMPORTANT: new error codes can be added, but existing error codes must not be changed once published. -# changing existing error codes will cause problems with users that have older SDK versions +# IMPORTANT: new error codes can be added, but existing error codes must not be changed +# once published. +# changing existing error codes will cause problems with users that have older SDK +# versions class ResponseErrorCode(int, Enum): UNKNOWN_ERROR = 0 USER_UNAUTHENTICATED = 1 @@ -86,8 +88,9 @@ def unpack(cls, res_data): ): try: # if the response error code is not recognized here because the - # user is not using the latest SDK version, an exception will occur here - # which we will pass in order to give the user a generic exception below + # user is not using the latest SDK version, an exception will occur + # here, which we will pass in order to give the user a generic + # exception below res_error_code = ResponseErrorCode(res_data["code"]) error_class = None if res_error_code is ResponseErrorCode.USER_UNAUTHENTICATED: @@ -179,9 +182,10 @@ def __init__(self): class UserNotFound(FuncxResponseError): - """User not found exception. This error should only be used when the server must - look up a user in order to fulfill the user's request body. If the request only - fails because the user is unauthenticated, UserUnauthenticated should be used instead. + """ + User not found exception. This error should only be used when the server must look + up a user in order to fulfill the user's request body. If the request only fails + because the user is unauthenticated, UserUnauthenticated should be used instead. """ code = ResponseErrorCode.USER_NOT_FOUND @@ -399,7 +403,10 @@ class EndpointOutdated(FuncxResponseError): def __init__(self, min_ep_version): self.error_args = [min_ep_version] - self.reason = f"Endpoint is out of date. Minimum supported endpoint version is {min_ep_version}" + self.reason = ( + "Endpoint is out of date. " + f"Minimum supported endpoint version is {min_ep_version}" + ) class TaskGroupNotFound(FuncxResponseError): diff --git a/funcx_sdk/requirements.txt b/funcx_sdk/requirements.txt deleted file mode 100644 index 810968bcf..000000000 --- a/funcx_sdk/requirements.txt +++ /dev/null @@ -1,18 +0,0 @@ -# request sending and authorization tools -requests>=2.20.0 -globus-sdk<3 - -# 'websockets' is used for the client-side websocket listener -websockets==9.1 - -# table printing used in search result rendering -texttable>=1.6.4,<2 - -# versions >=0.2.3 requires globus-sdk v3 -# TODO: update pin to latest when globus-sdk is updated -fair_research_login==0.2.2 - -# dill is an extension of `pickle` to a wider array of native python types -# pin to the latest version, as 'dill' is not at 1.0 and does not have a clear -# versioning and compatibility policy -dill==0.3.4 diff --git a/funcx_sdk/setup.py b/funcx_sdk/setup.py index 5688332d7..f12aff36a 100644 --- a/funcx_sdk/setup.py +++ b/funcx_sdk/setup.py @@ -2,20 +2,47 @@ from setuptools import find_namespace_packages, setup +REQUIRES = [ + # request sending and authorization tools + "requests>=2.20.0", + "globus-sdk<3", + # 'websockets' is used for the client-side websocket listener + "websockets==9.1", + # table printing used in search result rendering + "texttable>=1.6.4,<2", + # versions >=0.2.3 requires globus-sdk v3 + # TODO: update pin to latest when globus-sdk is updated + "fair_research_login==0.2.2", + # dill is an extension of `pickle` to a wider array of native python types + # pin to the latest version, as 'dill' is not at 1.0 and does not have a clear + # versioning and compatibility policy + "dill==0.3.4", +] + +TEST_REQUIRES = [ + "flake8==3.8.0", + "numpy", + "pytest", +] +DEV_REQUIRES = TEST_REQUIRES + [ + "pre-commit", +] + version_ns = {} with open(os.path.join("funcx", "sdk", "version.py")) as f: exec(f.read(), version_ns) version = version_ns["VERSION"] -with open("requirements.txt") as f: - install_requires = f.readlines() - setup( name="funcx", version=version, packages=find_namespace_packages(include=["funcx", "funcx.*"]), description="funcX: High Performance Function Serving for Science", - install_requires=install_requires, + install_requires=REQUIRES, + extras_require={ + "dev": DEV_REQUIRES, + "test": TEST_REQUIRES, + }, python_requires=">=3.6.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/funcx_sdk/test-requirements.txt b/funcx_sdk/test-requirements.txt deleted file mode 100644 index c6574756d..000000000 --- a/funcx_sdk/test-requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -flake8==3.8.0 -numpy -pytest - -# TODO: when we refactor dev requirements to be kept in extras, move this into -# `funcx_sdk[dev]` and the other requirements into `funcx_sdk[test]` -# for an example of this kind of config, see attrs: -# https://github.com/python-attrs/attrs/blob/fb154878ebd2758d2a3b4dc518d21fd4f73e12d2/setup.py#L64-L67 -pre-commit diff --git a/helm/README.md b/helm/README.md index 96550a97d..6a308e1c7 100644 --- a/helm/README.md +++ b/helm/README.md @@ -7,20 +7,18 @@ permissions to create the worker pod. ## How to Use First you need to install valid funcX credentials into the cluster's -namespace. Launch a local version of the endpoint to get it to populate your -`~/.funcx/credentials/funcx_sdk_tokens.json` with -``` -funcx-endpoint start -``` - +namespace. If you've used the funcX client, these will already be available +in your home directory's `.funcx/credentials` folder. If not, they can easily +be created with: +```shell +pip install funcx +python -c "from funcx.sdk.client import FuncXClient; FuncXClient()" +```` It will prompt you with an authentication URL to visit and ask you to paste the -resulting token. After it completes you can stop your endpoint with -``` -funcx-endpoint stop -``` +resulting token. -cd to your `~/.funcx/credentials` directory and install the keys file as a -kubernetes secret. +Now that you have a valid funcX token, cd to your `~/.funcx/credentials` +directory and install the keys file as a kubernetes secret. ```shell script kubectl create secret generic funcx-sdk-tokens --from-file=funcx_sdk_tokens.json