Skip to content

Commit

Permalink
Add reservation support in slurm sync for scheduled maintenance
Browse files Browse the repository at this point in the history
  • Loading branch information
harshthakkar01 committed Aug 8, 2024
1 parent 2b62970 commit a9c754e
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from itertools import chain
from pathlib import Path
import yaml
from datetime import datetime
import subprocess
from typing import List, Dict

import util
from util import (
Expand Down Expand Up @@ -70,6 +73,10 @@
),
)

SLURM_CREATE_RESERVATION = "scontrol create reservation user=root starttime={} duration=180 nodes={} reservationname={}"
SLURM_DELETE_RESERVATION = "scontrol delete reservation {}"
SLURM_SHOW_RESERVATION = "scontrol show reservation"


def start_instance_op(inst):
return lkp.compute.instances().start(
Expand Down Expand Up @@ -472,6 +479,115 @@ def update_topology(lkp: util.Lookup) -> None:
log.debug("Topology configuration updated. Reconfiguring Slurm.")
util.scontrol_reconfigure(lkp)


def delete_reservations(reservation_map: Dict[str, str]) -> None:
for reservation_name in list(reservation_map.keys()):
util.run(SLURM_DELETE_RESERVATION.format(reservation_name), timeout=30)
del reservation_map[reservation_name]


def update_slurm_reservation_maintenance(reservation_name: str, node: str,
starttime: str, reservation_map: Dict[str, str]) -> None:
## If reservation exists for maintenance, update reservation if time window
## for scheduled maintenance changes.
if reservation_map is not None and reservation_name in reservation_map:
details = reservation_map[reservation_name]

nodes = None
start_time = None

values = details.split(",")
nodes = values[0].strip()
maintenance_start_time = values[1].strip()

if nodes and start_time:
if maintenance_start_time == starttime:
return

util.run(SLURM_DELETE_RESERVATION.format(reservation_name), timeout=30)
util.run(SLURM_CREATE_RESERVATION.format(starttime, nodes, reservation_name), timeout=30)

del reservation_map[reservation_name] # remove the reservation from the map.

else: # Reservation doesn't exist, make reservation during maintenance window.
util.run(SLURM_CREATE_RESERVATION.format(starttime, node, reservation_name), timeout=30)


def get_slurm_reservation_maintenance() -> Dict[str, str]:
res = util.run(SLURM_SHOW_RESERVATION, timeout=30)
all_reservations = [x.split() for x in res.stdout.split("\n\n")[:-1]]
reservation_map = {}

for reservation in all_reservations:
reservation_name = None
nodes = None
start_time = None
for item in reservation:
key, value = item.split('=', 1) # Split at the first '='

if key == 'ReservationName':
reservation_name = value
elif key == 'Nodes':
nodes = value
elif key == 'StartTime':
start_time = value

if reservation_name is None or nodes is None or start_time is None:
continue

# Check if the reservation is for scheduled maintenance.
maintenance_reservation = nodes + "_maintenance"
if reservation_name != maintenance_reservation:
continue

reservation_map[reservation_name] = nodes + ', ' + start_time

return reservation_map


def get_upcoming_maintenance(lkp: util.Lookup) -> Dict[str, str]:
res = lkp.instances()
upc_maint_map = {}

for node, properties in res.items():
for key, value in properties.items():
if key == 'upcomingMaintenance':
upc_maint_map[node] = value['startTimeWindow']['earliest']

return upc_maint_map

# Sync maintenance reservation gets upcoming maintenance notification and
# updates slurm reservation for the node during scheduled maintenance window.
def sync_maintenance_reservation(lkp: util.Lookup) -> None:
# Get upcoming maintenance details from the slurm cluster.
# [node --> earliest-maintenance-start-time]
upc_maint_map = get_upcoming_maintenance(lkp)
log.debug(f"upcoming-maintenance-vms: {upc_maint_map}")

# Get current slurm reservation for maintenance.
# [reservation_name --> nodes + ', ' + start_time]
# reservation_name should be of format nodes_maintenance.
reservation_map = get_slurm_reservation_maintenance()
log.debug(f"reservation-map: {reservation_map}")

# Iterate upcoming maintenance and update reservation.
# Also remove reservation from the queue after update.
for node, time in upc_maint_map.items():
reservation_name = node + "_maintenance"
# Update start time format to be compatible with slurm reservation format.
starttime = datetime.strptime(time, "%Y-%m-%dT%H:%M:%S%z")
formatted_starttime = starttime.strftime("%Y-%m-%dT%H:%M:%S")

update_slurm_reservation_maintenance(reservation_name, node, formatted_starttime, reservation_map)

# If reservation map is not empty, means we have reserved vms for maintenance
# and maintenance is already finished. Remove all these maintenance reservations.
if not reservation_map or len(reservation_map) == 0:
return

delete_reservations(reservation_map)


def main():
try:
reconfigure_slurm()
Expand All @@ -483,15 +599,22 @@ def main():
sync_slurm()
except Exception:
log.exception("failed to sync instances")

try:
sync_placement_groups()
except Exception:
log.exception("failed to sync placement groups")

try:
update_topology(lkp)
except Exception:
log.exception("failed to update topology")

try:
sync_maintenance_reservation(lkp)
except Exception:
log.exception("failed to sync slurm reservation for scheduled maintenance")

try:
install_custom_scripts(check_hash=True)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def install_custom_scripts(check_hash=False):
chown_slurm(dirs.custom_scripts / par)
need_update = True
if check_hash and fullpath.exists():
# TODO: MD5 reported by gcloud may differ from the one calculated here (e.g. if blob got gzipped),
# TODO: MD5 reported by gcloud may differ from the one calculated here (e.g. if blob got gzipped),
# consider using gCRC32C
need_update = hash_file(fullpath) != blob.md5_hash
if need_update:
Expand Down Expand Up @@ -509,7 +509,6 @@ def init_log_and_parse(parser: argparse.ArgumentParser) -> argparse.Namespace:
help="Enable detailed api request output",
)
args = parser.parse_args()

loglevel = args.loglevel
if cfg.enable_debug_logging:
loglevel = logging.DEBUG
Expand Down Expand Up @@ -557,7 +556,6 @@ def log_api_request(request):
"""log.trace info about a compute API request"""
if not cfg.extra_logging_flags.get("trace_api"):
return

# output the whole request object as pretty yaml
# the body is nested json, so load it as well
rep = json.loads(request.to_json())
Expand Down Expand Up @@ -1656,6 +1654,12 @@ def instances(self, project=None, slurm_cluster_name=None):
slurm_cluster_name=slurm_cluster_name,
instance_information_fields=instance_information_fields,
)

# TODO: Merge this with all fields when upcoming maintenance is
# supported in beta.
if lkp.endpoint_versions['compute'] == 'alpha':
instance_information_fields.append("upcomingMaintenance")

instance_information_fields = sorted(set(instance_information_fields))
instance_fields = ",".join(instance_information_fields)
fields = f"items.zones.instances({instance_fields}),nextPageToken"
Expand Down Expand Up @@ -1683,7 +1687,7 @@ def properties(inst):
instance_iter = (
(inst["name"], properties(inst))
for inst in chain.from_iterable(
m["instances"] for m in result.get("items", {}).values()
m["instances"] for m in result.get("items", {}).values() if "instances" in m
)
)
instances.update(
Expand Down

0 comments on commit a9c754e

Please sign in to comment.