Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add state filter tests and handle readonly database #8

Merged
merged 5 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 145 additions & 138 deletions src/gridtk/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,102 @@ def get_command(self, ctx, cmd_name):
ctx.fail(f"Too many matches: {', '.join(sorted(matches))}") # noqa: RET503


def parse_job_ids(job_ids: str) -> list[int]:
"""Parse the job ids."""
if not job_ids:
return []
try:
if "," in job_ids:
final_job_ids = []
for job_id in job_ids.split(","):
final_job_ids.extend(parse_job_ids(job_id))
return final_job_ids
if "-" in job_ids:
start, end_str = job_ids.split("-")
return list(range(int(start), int(end_str) + 1))
if "+" in job_ids:
start, length = job_ids.split("+")
end = int(start) + int(length)
return list(range(int(start), end + 1))
return [int(job_ids)]
except ValueError as e:
raise click.BadParameter(f"Invalid job id {job_ids}") from e


def parse_states(states: str) -> list[str]:
"""Normalize a list of comma-separated states to their long name format."""
from .models import JOB_STATES_MAPPING

if not states:
return []
states = states.upper()
if states == "ALL":
return list(JOB_STATES_MAPPING.values())
final_states = []
for state in states.split(","):
state = JOB_STATES_MAPPING.get(state, state)
if state not in JOB_STATES_MAPPING.values():
raise click.BadParameter(
f"Invalid state: {state}\nValid values are: ALL {' '.join(list(JOB_STATES_MAPPING.keys())+list(JOB_STATES_MAPPING.values()))} or a comma (,) separated list of them."
)
final_states.append(state)
return final_states


def job_ids_callback(ctx, param, value):
"""Implement a callback for the job ids option."""
return parse_job_ids(value)


def states_callback(ctx, param, value):
"""Implement a callback for the states option."""
return parse_states(value)


def job_filters(f_py=None, default_states=None):
"""Filter jobs based on the provided function and default states."""
assert callable(f_py) or f_py is None
from .models import JOB_STATES_MAPPING

def _job_filters_decorator(function):
function = click.option(
"--name",
"names",
multiple=True,
help="Selects jobs based on their name. For multiple names, repeat this option.",
)(function)
function = click.option(
"-s",
"--state",
"states",
default=default_states,
help="Selects jobs based on their states separated by comma. Possible values are "
+ ", ".join([f"{v} ({k})" for k, v in JOB_STATES_MAPPING.items()])
+ " and ALL.",
callback=states_callback,
)(function)
function = click.option(
"-j",
"--jobs",
"job_ids",
help=(
"Selects only these job ids, separated by comma. A range can also be "
"specified in the form 'start-end' ('-j 3-5' is equivalent to "
"'-j 3,4,5') or in the form 'start+length' ('-j 4+3' is equivalent to "
"'-j 4,5,6,7')."
),
callback=job_ids_callback,
)(function)
function = click.option(
"--dependents/--no-dependents",
default=False,
help="Select dependents jobs (jobs that depend on selected jobs) as well.",
)(function)
return function # noqa: RET504

return _job_filters_decorator(f_py) if callable(f_py) else _job_filters_decorator


@click.group(
cls=CustomGroup,
context_settings={
Expand All @@ -53,14 +149,14 @@ def get_command(self, ctx, cmd_name):
"--database",
help="Path to the database file.",
default=Path("jobs.sql3"),
type=click.Path(path_type=Path, file_okay=True, dir_okay=False, writable=True),
type=click.Path(path_type=Path, file_okay=True, dir_okay=False),
)
@click.option(
"-l",
"--logs-dir",
help="Path to the logs directory.",
default=Path("logs"),
type=click.Path(path_type=Path, file_okay=False, dir_okay=True, writable=True),
type=click.Path(path_type=Path, file_okay=False, dir_okay=True),
)
@click.pass_context
def cli(ctx, database, logs_dir):
Expand Down Expand Up @@ -271,96 +367,6 @@ def submit(
session.commit()


def parse_job_ids(job_ids: str) -> list[int]:
"""Parse the job ids."""
if not job_ids:
return []
try:
if "," in job_ids:
final_job_ids = []
for job_id in job_ids.split(","):
final_job_ids.extend(parse_job_ids(job_id))
return final_job_ids
if "-" in job_ids:
start, end_str = job_ids.split("-")
return list(range(int(start), int(end_str) + 1))
if "+" in job_ids:
start, length = job_ids.split("+")
end = int(start) + int(length)
return list(range(int(start), end + 1))
return [int(job_ids)]
except ValueError as e:
raise click.BadParameter(f"Invalid job id {job_ids}") from e


def parse_states(states: str) -> list[str]:
"""Normalize a list of comma-separated states to their long name format."""
from .models import JOB_STATES_MAPPING

if not states:
return []
states = states.upper()
if states == "ALL":
return list(JOB_STATES_MAPPING.values())
states_split = states.split(",")
final_states = []
for state in states_split:
state = JOB_STATES_MAPPING.get(state, state)
if state not in JOB_STATES_MAPPING.values():
raise click.BadParameter(f"Invalid state: {state}")
final_states.append(state)
return final_states


def job_ids_callback(ctx, param, value):
"""Implement a callback for the job ids option."""
return parse_job_ids(value)


def states_callback(ctx, param, value):
"""Implement a callback for the states option."""
return parse_states(value)


def job_filters(f_py=None, default_states=None):
"""Filter jobs based on the provided function and default states."""
assert callable(f_py) or f_py is None
from .models import JOB_STATES_MAPPING

def _job_filters_decorator(function):
function = click.option(
"--name",
"names",
multiple=True,
help="Selects jobs based on their name. For multiple names, repeat this option.",
)(function)
function = click.option(
"-s",
"--state",
"states",
default=default_states,
help="Selects jobs based on their states separated by comma. Possible values are "
+ ", ".join([f"{v} ({k})" for k, v in JOB_STATES_MAPPING.items()])
+ " and ALL.",
callback=states_callback,
)(function)
function = click.option(
"-j",
"--jobs",
"job_ids",
help="Selects only these job ids, separated by comma.", # TODO: explain range notation
callback=job_ids_callback,
)(function)
function = click.option(
"--dependents/--no-dependents",
default=False,
help="Select dependents jobs (jobs that depend on selected jobs) as well.",
)(function)
return function # noqa: RET504

return _job_filters_decorator(f_py) if callable(f_py) else _job_filters_decorator


@cli.command()
@job_filters(default_states="BF,CA,F,NF,OOM,TO")
@click.pass_context
Expand All @@ -384,6 +390,48 @@ def resubmit(
session.commit()


@cli.command(name="list")
@job_filters
@click.pass_context
def list_jobs(
ctx: click.Context,
job_ids: list[int],
states: list[str],
names: list[str],
dependents: bool,
):
"""List jobs in the queue, similar to sacct and squeue."""
from tabulate import tabulate

from .manager import JobManager

job_manager: JobManager = ctx.meta["job_manager"]
with job_manager as session:
jobs = job_manager.list_jobs(
job_ids=job_ids, states=states, names=names, dependents=dependents
)
table = defaultdict(list)
for job in jobs:
table["job-id"].append(job.id)
table["slurm-id"].append(job.grid_id)
table["nodes"].append(job.nodes)
table["state"].append(f"{job.state} ({job.exit_code})")
table["job-name"].append(job.name)
output = job.output_files[0].resolve()
try:
output = output.relative_to(Path.cwd().resolve())
except ValueError:
pass

table["output"].append(output)
table["dependencies"].append(
",".join([str(dep_job) for dep_job in job.dependencies_ids])
)
table["command"].append("gridtk submit " + " ".join(job.command))
click.echo(tabulate(table, headers="keys"))
session.commit()


@cli.command()
@job_filters
@click.pass_context
Expand All @@ -407,44 +455,26 @@ def stop(
session.commit()


@cli.command(name="list")
@cli.command()
@job_filters
@click.pass_context
def list_jobs(
def delete(
ctx: click.Context,
job_ids: list[int],
states: list[str],
names: list[str],
dependents: bool,
):
"""List jobs in the queue, similar to sacct and squeue."""
from tabulate import tabulate

"""Delete a job from the queue."""
from .manager import JobManager

job_manager: JobManager = ctx.meta["job_manager"]
with job_manager as session:
jobs = job_manager.list_jobs(
jobs = job_manager.delete_jobs(
job_ids=job_ids, states=states, names=names, dependents=dependents
)
table = defaultdict(list)
for job in jobs:
table["job-id"].append(job.id)
table["slurm-id"].append(job.grid_id)
table["nodes"].append(job.nodes)
table["state"].append(f"{job.state} ({job.exit_code})")
table["job-name"].append(job.name)
table["output"].append(
job_manager.logs_dir
/ job.output_files[0]
.resolve()
.relative_to(job_manager.logs_dir.resolve())
)
table["dependencies"].append(
",".join([str(dep_job) for dep_job in job.dependencies_ids])
)
table["command"].append("gridtk submit " + " ".join(job.command))
click.echo(tabulate(table, headers="keys"))
click.echo(f"Deleted job {job.id} with slurm id {job.grid_id}")
session.commit()


Expand Down Expand Up @@ -502,28 +532,5 @@ def report(
session.commit()


@cli.command()
@job_filters
@click.pass_context
def delete(
ctx: click.Context,
job_ids: list[int],
states: list[str],
names: list[str],
dependents: bool,
):
"""Delete a job from the queue."""
from .manager import JobManager

job_manager: JobManager = ctx.meta["job_manager"]
with job_manager as session:
jobs = job_manager.delete_jobs(
job_ids=job_ids, states=states, names=names, dependents=dependents
)
for job in jobs:
click.echo(f"Deleted job {job.id} with slurm id {job.grid_id}")
session.commit()


if __name__ == "__main__":
cli()
19 changes: 16 additions & 3 deletions src/gridtk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from collections.abc import Iterable
from pathlib import Path
from typing import Any
from typing import Any, Optional

import sqlalchemy

Expand Down Expand Up @@ -98,15 +98,26 @@ def get_dependent_jobs_recursive(jobs: Iterable[Job]) -> list[Job]:
class JobManager:
"""Implements a job manager for Slurm."""

def __init__(self, database: Path, logs_dir: Path) -> None:
def __init__(
self, database: Path, logs_dir: Path, read_only: Optional[bool] = None
) -> None:
self.database = Path(database)
# check if database exists and is read-only
if (
read_only is None
and self.database.exists()
and not os.access(self.database, os.W_OK)
):
read_only = True
self.read_only = read_only
self.engine = create_engine(f"sqlite:///{self.database}", echo=False)
self.logs_dir = Path(logs_dir)
self.logs_dir.mkdir(exist_ok=True)

def __enter__(self):
# opens a new session and returns it
Base.metadata.create_all(self.engine)
if not self.read_only:
Base.metadata.create_all(self.engine)
self._session = Session(self.engine)
self._session.begin()
return self._session
Expand Down Expand Up @@ -165,6 +176,8 @@ def submit_job(self, name, command, array, dependencies):

def update_jobs(self) -> None:
"""Update the status of all jobs."""
if self.read_only:
return
jobs_by_grid_id: dict[int, Job] = dict()
query = self.session.query(Job)
for job in query.all():
Expand Down
Loading
Loading