Skip to content

Commit

Permalink
Merge pull request #104 from dessn/scone
Browse files Browse the repository at this point in the history
Scone: use slurm dependencies to submit heatmaps and model jobs simultaneously
  • Loading branch information
helenqu authored Oct 3, 2022
2 parents c9b0955 + cb2057e commit e9cb280
Showing 1 changed file with 68 additions and 84 deletions.
152 changes: 68 additions & 84 deletions pippin/classifiers/scone.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import shutil
import subprocess
from pathlib import Path
Expand Down Expand Up @@ -65,11 +64,13 @@ def __init__(self, name, output_dir, config, dependencies, mode, options, index=
{task_setup}"""

self.config_path = str(output_path_obj / "model_config.yml")
self.logfile = str(output_path_obj / "output.log")
self.model_sbatch_job_path = str(output_path_obj / "job.slurm")

self.heatmaps_path = str(heatmaps_path_obj)
self.heatmaps_done_file = str(heatmaps_path_obj / "done.txt")
self.heatmaps_sbatch_header_path = str(heatmaps_path_obj / "sbatch_header.sh")

self.logfile = str(output_path_obj / "output.log")
self.heatmaps_log_path = str(heatmaps_path_obj / f"create_heatmaps__{Path(self.config_path).name.split('.')[0]}.log")

remake_heatmaps = self.options.get("REMAKE_HEATMAPS", False)
self.keep_heatmaps = not remake_heatmaps
Expand All @@ -88,68 +89,23 @@ def make_sbatch_header(self, option_name, header_dict, use_gpu=False):
header_dict = merge_dict(header_dict, self.batch_replace)
return self._update_header(sbatch_header, header_dict)

def make_heatmaps(self, mode):
self.logger.info("heatmaps not created, creating now")
shutil.rmtree(self.output_dir, ignore_errors=True)
mkdirs(self.heatmaps_path)

sim_dep = self.get_simulation_dependency()
sim_dirs = sim_dep.output["photometry_dirs"][self.index] # if multiple realizations, get only the current one with self.index

lcdata_paths = self._get_lcdata_paths(sim_dirs)
metadata_paths = [path.replace("PHOT", "HEAD") for path in lcdata_paths]

# TODO: if externally specified batchfile exists, have to parse desired logfile path from it
header_dict = {
"REPLACE_LOGFILE": self.heatmaps_log_path,
"REPLACE_WALLTIME": "10:00:00", #TODO: change to scale with # of heatmaps expected
"REPLACE_MEM": "16GB",
}
heatmaps_sbatch_header = self.make_sbatch_header("HEATMAPS_BATCH_FILE", header_dict)
with open(self.heatmaps_sbatch_header_path, "w+") as f:
f.write(heatmaps_sbatch_header)

self._write_config_file(metadata_paths, lcdata_paths, mode, self.config_path)

# call create_heatmaps/run.py, which sbatches X create heatmaps jobs
subprocess.run([f"python {Path(self.path_to_classifier) / 'create_heatmaps/run.py'} --config_path {self.config_path}"], shell=True)

def classify(self, mode):
heatmaps_created = self._heatmap_creation_success() and self.keep_heatmaps

if not heatmaps_created:
self.heatmaps_log_path = os.path.join(self.heatmaps_path, f"create_heatmaps__{os.path.basename(self.config_path).split('.')[0]}.log")
self.make_heatmaps(mode)
# TODO: check status in a different job? but what if the job doesn't run or runs after the other ones are already completed?
# -- otherwise if ssh connection dies the classification won't run
# -- any better solution than while loop + sleep?

start_sleep_time = self.global_config["OUTPUT"]["ping_frequency"]
max_sleep_time = self.global_config["OUTPUT"]["max_ping_frequency"]
current_sleep_time = start_sleep_time

while self.num_jobs_in_queue() > 0:
self.logger.debug(f"> 0 {self.job_base_name} jobs still in the queue, sleeping for {current_sleep_time}")
time.sleep(current_sleep_time)
current_sleep_time *= 2
if current_sleep_time > max_sleep_time:
current_sleep_time = max_sleep_time

self.logger.debug("jobs done, evaluating success")
if not self._heatmap_creation_success():
self.logger.error(f"heatmaps were not created successfully, see logs at {self.heatmaps_log_path}")
return Task.FINISHED_FAILURE

# when all done, sbatch a gpu job for actual classification
self.logger.info("heatmaps created, continuing")
def make_heatmaps_sbatch_header(self):
self.logger.info("heatmaps not created, creating now")
shutil.rmtree(self.output_dir, ignore_errors=True)
mkdirs(self.heatmaps_path)

failed = False
if os.path.exists(self.done_file):
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
if "FAILURE" in f.read().upper():
failed = True
# TODO: if externally specified batchfile exists, have to parse desired logfile path from it
header_dict = {
"REPLACE_LOGFILE": self.heatmaps_log_path,
"REPLACE_WALLTIME": "10:00:00", #TODO: change to scale with # of heatmaps expected
"REPLACE_MEM": "8GB",
}
heatmaps_sbatch_header = self.make_sbatch_header("HEATMAPS_BATCH_FILE", header_dict)

with open(self.heatmaps_sbatch_header_path, "w+") as f:
f.write(heatmaps_sbatch_header)

def make_model_sbatch_script(self):
header_dict = {
"REPLACE_NAME": self.job_base_name,
"REPLACE_LOGFILE": "output.log",
Expand All @@ -171,26 +127,51 @@ def classify(self, mode):
"sbatch_header": model_sbatch_header,
"task_setup": self.update_setup(setup_dict, self.task_setup['scone'])
}
slurm_output_file = Path(self.output_dir) / "job.slurm"
self.logger.info(f"Running SCONE, slurm job outputting to {slurm_output_file}")

self.logger.info(f"Running SCONE model, slurm job written to {self.model_sbatch_job_path}")
slurm_script = self.slurm.format(**format_dict)

new_hash = self.get_hash_from_string(slurm_script)
with open(self.model_sbatch_job_path, "w") as f:
f.write(slurm_script)

return slurm_script

def classify(self, mode):
failed = False
if Path(self.done_file).exists():
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
if "SUCCESS" not in f.read().upper():
failed = True

if self._check_regenerate(new_hash) or failed:
self.logger.debug("Regenerating")

with open(slurm_output_file, "w") as f:
f.write(slurm_script)
heatmaps_created = self._heatmap_creation_success() and self.keep_heatmaps
if not heatmaps_created:
self.make_heatmaps_sbatch_header()

sim_dep = self.get_simulation_dependency()
sim_dirs = sim_dep.output["photometry_dirs"][self.index] # if multiple realizations, get only the current one with self.index

lcdata_paths = self._get_lcdata_paths(sim_dirs)
metadata_paths = [path.replace("PHOT", "HEAD") for path in lcdata_paths]

self._write_config_file(metadata_paths, lcdata_paths, mode, self.config_path, heatmaps_created)

slurm_script = self.make_model_sbatch_script()
new_hash = self.get_hash_from_string(slurm_script)

self.save_new_hash(new_hash)
self.logger.info(f"Submitting batch job {slurm_output_file}")
self.logger.info(f"Submitting batch job {self.model_sbatch_job_path}")

# TODO: nersc needs `module load esslurm` to sbatch gpu jobs, maybe make
# this shell command to a file so diff systems can define their own
subprocess.run(f"sbatch {slurm_output_file}", cwd=self.output_dir, shell=True)
subprocess.run([f"python {Path(self.path_to_classifier) / 'run.py'} --config_path {self.config_path}"], shell=True)
else:
self.logger.info("Hash check passed, not rerunning")
self.should_be_done()

return True

def predict(self):
Expand All @@ -204,23 +185,26 @@ def _get_types(self):
t = self.get_simulation_dependency().output
return t["types"]

def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path):
def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path, heatmaps_created):
config = {}

# environment configuration
config["init_env_heatmaps"] = self.init_env_heatmaps
config["init_env"] = self.init_env

# info for heatmap creation
config["metadata_paths"] = metadata_paths
config["lcdata_paths"] = lcdata_paths
config["num_wavelength_bins"] = self.options.get("NUM_WAVELENGTH_BINS", 32)
config["num_mjd_bins"] = self.options.get("NUM_MJD_BINS", 180)
config["heatmaps_path"] = self.heatmaps_path
config["sbatch_header_path"] = self.heatmaps_sbatch_header_path
if not heatmaps_created:
config["sbatch_header_path"] = self.heatmaps_sbatch_header_path

config["heatmaps_donefile"] = self.heatmaps_done_file
config["heatmaps_logfile"] = self.heatmaps_log_path
config["sim_fraction"] = self.options.get("SIM_FRACTION", 1) # 1/sim_fraction % of simulated SNe will be used for the model
config["heatmaps_path"] = self.heatmaps_path
config["model_sbatch_job_path"] = self.model_sbatch_job_path
config["num_wavelength_bins"] = self.options.get("NUM_WAVELENGTH_BINS", 32)
config["num_mjd_bins"] = self.options.get("NUM_MJD_BINS", 180)
config["metadata_paths"] = metadata_paths
config["lcdata_paths"] = lcdata_paths

# info for classification model
config["categorical"] = self.options.get("CATEGORICAL", False)
Expand All @@ -242,29 +226,29 @@ def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path):
cfgfile.write(yaml.dump(config))

def _check_completion(self, squeue):
if os.path.exists(self.done_file):
if Path(self.done_file).exists()
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
if "FAILURE" in f.read().upper():
if "SUCCESS" not in f.read().upper():
return Task.FINISHED_FAILURE

pred_path = os.path.join(self.output_dir, "predictions.csv")
pred_path = str(Path(self.output_dir) / "predictions.csv")
predictions = pd.read_csv(pred_path)
predictions = predictions[["snid", "pred_labels"]] # make sure snid is the first col
predictions = predictions.rename(columns={"pred_labels": self.get_prob_column_name()})
predictions.to_csv(pred_path, index=False)
self.logger.info(f"Predictions file can be found at {pred_path}")
self.output.update({"model_filename": self.options.get("MODEL", os.path.join(self.output_dir, "trained_model")), "predictions_filename": pred_path})
self.output.update({"model_filename": self.options.get("MODEL", str(Path(self.output_dir) / "trained_model")), "predictions_filename": pred_path})
return Task.FINISHED_SUCCESS
return self.check_for_job(squeue, self.job_base_name)

def _heatmap_creation_success(self):
if not os.path.exists(self.heatmaps_done_file):
if not Path(self.heatmaps_done_file).exists():
return False
with open(self.heatmaps_done_file, "r") as donefile:
if "CREATE HEATMAPS FAILURE" in donefile.read():
return False
return os.path.exists(self.heatmaps_path) and os.path.exists(os.path.join(self.heatmaps_path, "done.log"))
return Path(self.heatmaps_path).exists() and str(Path(self.heatmaps_path) / "done.log").exists()

def num_jobs_in_queue(self):
print("rerun num jobs in queue")
Expand All @@ -274,7 +258,7 @@ def num_jobs_in_queue(self):

@staticmethod
def _get_lcdata_paths(sim_dir):
lcdata_paths = [f.path for f in os.scandir(sim_dir) if "PHOT" in f.path]
lcdata_paths = [f.path for f in Path(sim_dir).iterdir() if "PHOT" in f.path]
return lcdata_paths

@staticmethod
Expand Down

0 comments on commit e9cb280

Please sign in to comment.