Skip to content

Commit

Permalink
Merge pull request #345 from pyiron/toplevel
Browse files Browse the repository at this point in the history
Extend QueueAdapter to support dynamic configuration
  • Loading branch information
jan-janssen authored Sep 28, 2024
2 parents 4bc7b9b + b8bf0fa commit 6ccd8be
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 55 deletions.
147 changes: 93 additions & 54 deletions pysqa/queueadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from jinja2 import Template

from pysqa.base.abstract import QueueAdapterAbstractClass
from pysqa.base.config import QueueAdapterWithConfig, read_config
from pysqa.base.core import execute_command
from pysqa.base.config import QueueAdapterWithConfig, Queues, read_config
from pysqa.base.core import QueueAdapterCore, execute_command
from pysqa.base.modular import ModularQueueAdapter


Expand Down Expand Up @@ -39,7 +39,10 @@ class QueueAdapter(QueueAdapterAbstractClass):
"""

def __init__(
self, directory: str = "~/.queues", execute_command: callable = execute_command
self,
directory: Optional[str] = None,
queue_type: Optional[str] = None,
execute_command: callable = execute_command,
):
"""
Initialize the QueueAdapter.
Expand All @@ -48,35 +51,41 @@ def __init__(
directory (str): Directory containing the queue.yaml files and corresponding templates.
execute_command (callable): Function to execute commands.
"""
queue_yaml = os.path.join(directory, "queue.yaml")
clusters_yaml = os.path.join(directory, "clusters.yaml")
self._adapter = None
if os.path.exists(queue_yaml):
self._queue_dict = {
"default": set_queue_adapter(
config=read_config(file_name=queue_yaml),
directory=directory,
execute_command=execute_command,
)
}
primary_queue = "default"
elif os.path.exists(clusters_yaml):
config = read_config(file_name=clusters_yaml)
self._queue_dict = {
k: set_queue_adapter(
config=read_config(file_name=os.path.join(directory, v)),
directory=directory,
execute_command=execute_command,
if directory is not None:
queue_yaml = os.path.join(directory, "queue.yaml")
clusters_yaml = os.path.join(directory, "clusters.yaml")
self._adapter = None
if os.path.exists(queue_yaml):
self._queue_dict = {
"default": set_queue_adapter(
config=read_config(file_name=queue_yaml),
directory=directory,
execute_command=execute_command,
)
}
primary_queue = "default"
elif os.path.exists(clusters_yaml):
config = read_config(file_name=clusters_yaml)
self._queue_dict = {
k: set_queue_adapter(
config=read_config(file_name=os.path.join(directory, v)),
directory=directory,
execute_command=execute_command,
)
for k, v in config["cluster"].items()
}
primary_queue = config["cluster_primary"]
else:
raise ValueError(
"Neither a queue.yaml file nor a clusters.yaml file were found in "
+ directory
)
for k, v in config["cluster"].items()
}
primary_queue = config["cluster_primary"]
self._adapter = self._queue_dict[primary_queue]
elif queue_type is not None:
self._queue_dict = {}
self._adapter = QueueAdapterCore(queue_type=queue_type.upper())
else:
raise ValueError(
"Neither a queue.yaml file nor a clusters.yaml file were found in "
+ directory
)
self._adapter = self._queue_dict[primary_queue]
raise ValueError()

def list_clusters(self) -> List[str]:
"""
Expand All @@ -97,14 +106,17 @@ def switch_cluster(self, cluster_name: str):
self._adapter = self._queue_dict[cluster_name]

@property
def config(self) -> dict:
def config(self) -> Union[dict, None]:
"""
Get the QueueAdapter configuration.
Returns:
dict: The QueueAdapter configuration.
"""
return self._adapter.config
if isinstance(self._adapter, QueueAdapterWithConfig):
return self._adapter.config
else:
return None

@property
def ssh_delete_file_on_remote(self) -> bool:
Expand All @@ -114,7 +126,10 @@ def ssh_delete_file_on_remote(self) -> bool:
Returns:
bool: The value of ssh_delete_file_on_remote property.
"""
return self._adapter.ssh_delete_file_on_remote
if isinstance(self._adapter, QueueAdapterWithConfig):
return self._adapter.ssh_delete_file_on_remote
else:
return False

@property
def remote_flag(self) -> bool:
Expand All @@ -124,37 +139,49 @@ def remote_flag(self) -> bool:
Returns:
bool: The value of remote_flag property.
"""
return self._adapter.remote_flag
if isinstance(self._adapter, QueueAdapterWithConfig):
return self._adapter.remote_flag
else:
return False

@property
def queue_list(self) -> List[str]:
def queue_list(self) -> Union[List[str], None]:
"""
Get the list of available queues.
Returns:
List[str]: The list of available queues.
"""
return self._adapter.queue_list
if isinstance(self._adapter, QueueAdapterWithConfig):
return self._adapter.queue_list
else:
return None

@property
def queue_view(self) -> pandas.DataFrame:
def queue_view(self) -> Union[pandas.DataFrame, None]:
"""
Get the Pandas DataFrame representation of the available queues.
Returns:
pandas.DataFrame: The Pandas DataFrame representation of the available queues.
"""
return self._adapter.queue_view
if isinstance(self._adapter, QueueAdapterWithConfig):
return self._adapter.queue_view
else:
return None

@property
def queues(self) -> List[str]:
def queues(self) -> Union[Queues, None]:
"""
Get the list of available queues.
Returns:
List[str]: The list of available queues.
"""
return self._adapter.queues
if isinstance(self._adapter, QueueAdapterWithConfig):
return self._adapter.queues
else:
return None

def submit_job(
self,
Expand Down Expand Up @@ -220,7 +247,10 @@ def get_job_from_remote(self, working_directory: str):
Args:
working_directory (str): The working directory.
"""
self._adapter.get_job_from_remote(working_directory=working_directory)
if isinstance(self._adapter, QueueAdapterWithConfig):
self._adapter.get_job_from_remote(working_directory=working_directory)
else:
raise TypeError()

def transfer_file_to_remote(
self,
Expand All @@ -236,11 +266,14 @@ def transfer_file_to_remote(
transfer_back (bool): Whether to transfer the file back.
delete_file_on_remote (bool): Whether to delete the file on the remote host.
"""
self._adapter.transfer_file(
file=file,
transfer_back=transfer_back,
delete_file_on_remote=delete_file_on_remote,
)
if isinstance(self._adapter, QueueAdapterWithConfig):
self._adapter.transfer_file(
file=file,
transfer_back=transfer_back,
delete_file_on_remote=delete_file_on_remote,
)
else:
raise TypeError()

def convert_path_to_remote(self, path: str) -> str:
"""
Expand All @@ -252,7 +285,10 @@ def convert_path_to_remote(self, path: str) -> str:
Returns:
str: The remote path.
"""
return self._adapter.convert_path_to_remote(path=path)
if isinstance(self._adapter, QueueAdapterWithConfig):
return self._adapter.convert_path_to_remote(path=path)
else:
raise TypeError()

def delete_job(self, process_id: int) -> str:
"""
Expand Down Expand Up @@ -334,13 +370,16 @@ def check_queue_parameters(
Returns:
List: A list containing the checked parameters [cores, run_time_max, memory_max].
"""
return self._adapter.check_queue_parameters(
queue=queue,
cores=cores,
run_time_max=run_time_max,
memory_max=memory_max,
active_queue=active_queue,
)
if isinstance(self._adapter, QueueAdapterWithConfig):
return self._adapter.check_queue_parameters(
queue=queue,
cores=cores,
run_time_max=run_time_max,
memory_max=memory_max,
active_queue=active_queue,
)
else:
return cores, run_time_max, memory_max


def set_queue_adapter(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def test_missing_config(self):
with self.assertRaises(ValueError):
QueueAdapter(directory=os.path.join(self.path, "config/error"))

def test_no_config(self):
with self.assertRaises(ValueError):
QueueAdapter()

def test_bad_queue_template(self):
with self.assertRaises(TemplateSyntaxError):
QueueAdapter(directory=os.path.join(self.path, "config/bad_template"))
Expand Down
Loading

0 comments on commit 6ccd8be

Please sign in to comment.