Skip to content

Commit

Permalink
Merge pull request #153 from dessn/tmp
Browse files Browse the repository at this point in the history
SCONE: make slurm memory a user-facing option
  • Loading branch information
helenqu authored Jan 24, 2024
2 parents 7d06f45 + d36170d commit 2a9094a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion pippin/classifiers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def get_num_ranseed(sim_tasks, lcfit_tasks):
name = config["CLASSIFIER"]
cls = ClassifierFactory.get(name)
options = config.get("OPTS", {})
if options == None:
Task.fail_config(f"Classifier {clas_name} has no OPTS specified -- either remove the OPTS keyword or specify some options under it")
if "MODE" not in config:
Task.fail_config(f"Classifier task {clas_name} needs to specify MODE as train or predict")
mode = config["MODE"].lower()
Expand All @@ -169,7 +171,7 @@ def get_num_ranseed(sim_tasks, lcfit_tasks):
mode = Classifier.PREDICT

# Prevent mode = predict and SIM_FRACTION < 1
if mode == Classifier.PREDICT and "SIM_FRACTION" in options and options["SIM_FRACTION"] < 1:
if mode == Classifier.PREDICT and options.get("SIM_FRACTION", 1) > 1:
Task.fail_config("SIM_FRACTION must be 1 (all sims included) for predict mode")

# Validate that train is not used on certain classifiers
Expand Down
6 changes: 3 additions & 3 deletions pippin/classifiers/scone.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def make_heatmaps_sbatch_header(self):
header_dict = {
"REPLACE_LOGFILE": self.heatmaps_log_path,
"REPLACE_WALLTIME": "12:00:00", #TODO: change to scale with # of heatmaps expected
"REPLACE_MEM": "64GB",
"REPLACE_MEM": self.options.get("HEATMAPS_MEM", "32GB"),
}
heatmaps_sbatch_header = self.make_sbatch_header("HEATMAPS_BATCH_FILE", header_dict)

Expand All @@ -111,7 +111,7 @@ def make_model_sbatch_script(self):
header_dict = {
"REPLACE_NAME": self.job_base_name,
"REPLACE_LOGFILE": str(Path(self.output_dir) / "output.log"),
"REPLACE_MEM": "32GB",
"REPLACE_MEM": self.options.get("MODEL_MEM", "64GB"),
"REPLACE_WALLTIME": "4:00:00" if self.gpu else "12:00:00", # 4h is max for gpu
}
model_sbatch_header = self.make_sbatch_header("MODEL_BATCH_FILE", header_dict, use_gpu=self.gpu)
Expand Down Expand Up @@ -221,7 +221,7 @@ def _make_config(self, metadata_paths, lcdata_paths, mode, heatmaps_created):
config["batch_size"] = self.options.get("BATCH_SIZE", 32) # TODO: replace with percentage of total size?
config["Ia_fraction"] = self.options.get("IA_FRACTION", 0.5)
config["output_path"] = self.output_dir
config["trained_model"] = self.options.get("MODEL", False)
config["trained_model"] = self.options.get("MODEL", None)
config["kcor_file"] = self.options.get("KCOR_FILE", None)
config["mode"] = mode
config["job_base_name"] = self.job_base_name
Expand Down

0 comments on commit 2a9094a

Please sign in to comment.