Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scone: use slurm dependencies to submit heatmaps and model jobs simultaneously #104

Merged
merged 4 commits into from
Oct 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 68 additions & 84 deletions pippin/classifiers/scone.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import shutil
import subprocess
from pathlib import Path
Expand Down Expand Up @@ -65,11 +64,13 @@ def __init__(self, name, output_dir, config, dependencies, mode, options, index=
{task_setup}"""

self.config_path = str(output_path_obj / "model_config.yml")
self.logfile = str(output_path_obj / "output.log")
self.model_sbatch_job_path = str(output_path_obj / "job.slurm")

self.heatmaps_path = str(heatmaps_path_obj)
self.heatmaps_done_file = str(heatmaps_path_obj / "done.txt")
self.heatmaps_sbatch_header_path = str(heatmaps_path_obj / "sbatch_header.sh")

self.logfile = str(output_path_obj / "output.log")
self.heatmaps_log_path = str(heatmaps_path_obj / f"create_heatmaps__{Path(self.config_path).name.split('.')[0]}.log")

remake_heatmaps = self.options.get("REMAKE_HEATMAPS", False)
self.keep_heatmaps = not remake_heatmaps
Expand All @@ -88,68 +89,23 @@ def make_sbatch_header(self, option_name, header_dict, use_gpu=False):
header_dict = merge_dict(header_dict, self.batch_replace)
return self._update_header(sbatch_header, header_dict)

def make_heatmaps(self, mode):
self.logger.info("heatmaps not created, creating now")
shutil.rmtree(self.output_dir, ignore_errors=True)
mkdirs(self.heatmaps_path)

sim_dep = self.get_simulation_dependency()
sim_dirs = sim_dep.output["photometry_dirs"][self.index] # if multiple realizations, get only the current one with self.index

lcdata_paths = self._get_lcdata_paths(sim_dirs)
metadata_paths = [path.replace("PHOT", "HEAD") for path in lcdata_paths]

# TODO: if externally specified batchfile exists, have to parse desired logfile path from it
header_dict = {
"REPLACE_LOGFILE": self.heatmaps_log_path,
"REPLACE_WALLTIME": "10:00:00", #TODO: change to scale with # of heatmaps expected
"REPLACE_MEM": "16GB",
}
heatmaps_sbatch_header = self.make_sbatch_header("HEATMAPS_BATCH_FILE", header_dict)
with open(self.heatmaps_sbatch_header_path, "w+") as f:
f.write(heatmaps_sbatch_header)

self._write_config_file(metadata_paths, lcdata_paths, mode, self.config_path)

# call create_heatmaps/run.py, which sbatches X create heatmaps jobs
subprocess.run([f"python {Path(self.path_to_classifier) / 'create_heatmaps/run.py'} --config_path {self.config_path}"], shell=True)

def classify(self, mode):
heatmaps_created = self._heatmap_creation_success() and self.keep_heatmaps

if not heatmaps_created:
self.heatmaps_log_path = os.path.join(self.heatmaps_path, f"create_heatmaps__{os.path.basename(self.config_path).split('.')[0]}.log")
self.make_heatmaps(mode)
# TODO: check status in a different job? but what if the job doesn't run or runs after the other ones are already completed?
# -- otherwise if ssh connection dies the classification won't run
# -- any better solution than while loop + sleep?

start_sleep_time = self.global_config["OUTPUT"]["ping_frequency"]
max_sleep_time = self.global_config["OUTPUT"]["max_ping_frequency"]
current_sleep_time = start_sleep_time

while self.num_jobs_in_queue() > 0:
self.logger.debug(f"> 0 {self.job_base_name} jobs still in the queue, sleeping for {current_sleep_time}")
time.sleep(current_sleep_time)
current_sleep_time *= 2
if current_sleep_time > max_sleep_time:
current_sleep_time = max_sleep_time

self.logger.debug("jobs done, evaluating success")
if not self._heatmap_creation_success():
self.logger.error(f"heatmaps were not created successfully, see logs at {self.heatmaps_log_path}")
return Task.FINISHED_FAILURE

# when all done, sbatch a gpu job for actual classification
self.logger.info("heatmaps created, continuing")
def make_heatmaps_sbatch_header(self):
self.logger.info("heatmaps not created, creating now")
shutil.rmtree(self.output_dir, ignore_errors=True)
mkdirs(self.heatmaps_path)

failed = False
if os.path.exists(self.done_file):
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
if "FAILURE" in f.read().upper():
failed = True
# TODO: if externally specified batchfile exists, have to parse desired logfile path from it
header_dict = {
"REPLACE_LOGFILE": self.heatmaps_log_path,
"REPLACE_WALLTIME": "10:00:00", #TODO: change to scale with # of heatmaps expected
"REPLACE_MEM": "8GB",
}
heatmaps_sbatch_header = self.make_sbatch_header("HEATMAPS_BATCH_FILE", header_dict)

with open(self.heatmaps_sbatch_header_path, "w+") as f:
f.write(heatmaps_sbatch_header)

def make_model_sbatch_script(self):
header_dict = {
"REPLACE_NAME": self.job_base_name,
"REPLACE_LOGFILE": "output.log",
Expand All @@ -171,26 +127,51 @@ def classify(self, mode):
"sbatch_header": model_sbatch_header,
"task_setup": self.update_setup(setup_dict, self.task_setup['scone'])
}
slurm_output_file = Path(self.output_dir) / "job.slurm"
self.logger.info(f"Running SCONE, slurm job outputting to {slurm_output_file}")

self.logger.info(f"Running SCONE model, slurm job written to {self.model_sbatch_job_path}")
slurm_script = self.slurm.format(**format_dict)

new_hash = self.get_hash_from_string(slurm_script)
with open(self.model_sbatch_job_path, "w") as f:
f.write(slurm_script)

return slurm_script

def classify(self, mode):
failed = False
if Path(self.done_file).exists():
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
if "SUCCESS" not in f.read().upper():
failed = True

if self._check_regenerate(new_hash) or failed:
self.logger.debug("Regenerating")

with open(slurm_output_file, "w") as f:
f.write(slurm_script)
heatmaps_created = self._heatmap_creation_success() and self.keep_heatmaps
if not heatmaps_created:
self.make_heatmaps_sbatch_header()

sim_dep = self.get_simulation_dependency()
sim_dirs = sim_dep.output["photometry_dirs"][self.index] # if multiple realizations, get only the current one with self.index

lcdata_paths = self._get_lcdata_paths(sim_dirs)
metadata_paths = [path.replace("PHOT", "HEAD") for path in lcdata_paths]

self._write_config_file(metadata_paths, lcdata_paths, mode, self.config_path, heatmaps_created)

slurm_script = self.make_model_sbatch_script()
new_hash = self.get_hash_from_string(slurm_script)

self.save_new_hash(new_hash)
self.logger.info(f"Submitting batch job {slurm_output_file}")
self.logger.info(f"Submitting batch job {self.model_sbatch_job_path}")

# TODO: nersc needs `module load esslurm` to sbatch gpu jobs, maybe make
# this shell command to a file so diff systems can define their own
subprocess.run(f"sbatch {slurm_output_file}", cwd=self.output_dir, shell=True)
subprocess.run([f"python {Path(self.path_to_classifier) / 'run.py'} --config_path {self.config_path}"], shell=True)
else:
self.logger.info("Hash check passed, not rerunning")
self.should_be_done()

return True

def predict(self):
Expand All @@ -204,23 +185,26 @@ def _get_types(self):
t = self.get_simulation_dependency().output
return t["types"]

def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path):
def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path, heatmaps_created):
config = {}

# environment configuration
config["init_env_heatmaps"] = self.init_env_heatmaps
config["init_env"] = self.init_env

# info for heatmap creation
config["metadata_paths"] = metadata_paths
config["lcdata_paths"] = lcdata_paths
config["num_wavelength_bins"] = self.options.get("NUM_WAVELENGTH_BINS", 32)
config["num_mjd_bins"] = self.options.get("NUM_MJD_BINS", 180)
config["heatmaps_path"] = self.heatmaps_path
config["sbatch_header_path"] = self.heatmaps_sbatch_header_path
if not heatmaps_created:
config["sbatch_header_path"] = self.heatmaps_sbatch_header_path

config["heatmaps_donefile"] = self.heatmaps_done_file
config["heatmaps_logfile"] = self.heatmaps_log_path
config["sim_fraction"] = self.options.get("SIM_FRACTION", 1) # 1/sim_fraction % of simulated SNe will be used for the model
config["heatmaps_path"] = self.heatmaps_path
config["model_sbatch_job_path"] = self.model_sbatch_job_path
config["num_wavelength_bins"] = self.options.get("NUM_WAVELENGTH_BINS", 32)
config["num_mjd_bins"] = self.options.get("NUM_MJD_BINS", 180)
config["metadata_paths"] = metadata_paths
config["lcdata_paths"] = lcdata_paths

# info for classification model
config["categorical"] = self.options.get("CATEGORICAL", False)
Expand All @@ -242,29 +226,29 @@ def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path):
cfgfile.write(yaml.dump(config))

def _check_completion(self, squeue):
if os.path.exists(self.done_file):
if Path(self.done_file).exists()
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
if "FAILURE" in f.read().upper():
if "SUCCESS" not in f.read().upper():
return Task.FINISHED_FAILURE

pred_path = os.path.join(self.output_dir, "predictions.csv")
pred_path = str(Path(self.output_dir) / "predictions.csv")
predictions = pd.read_csv(pred_path)
predictions = predictions[["snid", "pred_labels"]] # make sure snid is the first col
predictions = predictions.rename(columns={"pred_labels": self.get_prob_column_name()})
predictions.to_csv(pred_path, index=False)
self.logger.info(f"Predictions file can be found at {pred_path}")
self.output.update({"model_filename": self.options.get("MODEL", os.path.join(self.output_dir, "trained_model")), "predictions_filename": pred_path})
self.output.update({"model_filename": self.options.get("MODEL", str(Path(self.output_dir) / "trained_model")), "predictions_filename": pred_path})
return Task.FINISHED_SUCCESS
return self.check_for_job(squeue, self.job_base_name)

def _heatmap_creation_success(self):
if not os.path.exists(self.heatmaps_done_file):
if not Path(self.heatmaps_done_file).exists():
return False
with open(self.heatmaps_done_file, "r") as donefile:
if "CREATE HEATMAPS FAILURE" in donefile.read():
return False
return os.path.exists(self.heatmaps_path) and os.path.exists(os.path.join(self.heatmaps_path, "done.log"))
return Path(self.heatmaps_path).exists() and str(Path(self.heatmaps_path) / "done.log").exists()

def num_jobs_in_queue(self):
print("rerun num jobs in queue")
Expand All @@ -274,7 +258,7 @@ def num_jobs_in_queue(self):

@staticmethod
def _get_lcdata_paths(sim_dir):
lcdata_paths = [f.path for f in os.scandir(sim_dir) if "PHOT" in f.path]
lcdata_paths = [f.path for f in Path(sim_dir).iterdir() if "PHOT" in f.path]
return lcdata_paths

@staticmethod
Expand Down