diff --git a/executorlib/interactive/flux.py b/executorlib/interactive/flux.py index 6537bbef..2c324828 100644 --- a/executorlib/interactive/flux.py +++ b/executorlib/interactive/flux.py @@ -7,7 +7,7 @@ from executorlib.standalone.interactive.spawner import BaseSpawner -def validate_max_workers(max_workers, cores, threads_per_core): +def validate_max_workers(max_workers: int, cores: int, threads_per_core: int): handle = flux.Flux() cores_total = flux.resource.list.resource_list(handle).get().up.ncores cores_requested = max_workers * cores * threads_per_core diff --git a/executorlib/interactive/slurm.py b/executorlib/interactive/slurm.py index b662061a..8c962529 100644 --- a/executorlib/interactive/slurm.py +++ b/executorlib/interactive/slurm.py @@ -6,8 +6,8 @@ SLURM_COMMAND = "srun" -def validate_max_workers(max_workers, cores, threads_per_core): - cores_total = os.environ["SLURM_NTASKS"] * os.environ["SLURM_CPUS_PER_TASK"] +def validate_max_workers(max_workers: int, cores: int, threads_per_core: int): + cores_total = int(os.environ["SLURM_NTASKS"]) * int(os.environ["SLURM_CPUS_PER_TASK"]) cores_requested = max_workers * cores * threads_per_core if cores_total < cores_requested: raise ValueError(