Skip to content

Commit

Permalink
Merge pull request #37 from Samreay/pa_test
Browse files Browse the repository at this point in the history
Added batch_size, num_layers, and hidden_dim to Supernnova. Removed s…
  • Loading branch information
OmegaLambda1998 authored Feb 16, 2021
2 parents ea36e6b + 80cceac commit addfc10
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
9 changes: 9 additions & 0 deletions pippin/classifiers/supernnova.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __init__(self, name, output_dir, config, dependencies, mode, options, index=
self.cyclic = options.get("CYCLIC", True)
self.seed = options.get("SEED", 0)
self.clean = config.get("CLEAN", True)
self.batch_size = options.get("BATCH_SIZE", 128)
self.num_layers = options.get("NUM_LAYERS", 2)
self.hidden_dim = options.get("HIDDEN_DIM", 32)
self.validate_model()

assert self.norm in [
Expand Down Expand Up @@ -165,6 +168,9 @@ def classify(self, training):
fit = self.get_fit_dependency()
fit_dir = f"" if fit is None else f"--fits_dir {fit['fitres_dirs'][self.index]}"
cyclic = "--cyclic" if self.variant in ["vanilla", "variational"] and self.cyclic else ""
batch_size = f"--batch_size {self.batch_size}"
num_layers = f"--num_layers {self.num_layers}"
hidden_dim = f"--hidden_dim {self.hidden_dim}"
variant = f"--model {self.variant}"
if self.variant == "bayesian":
variant += " --num_inference_samples 20"
Expand Down Expand Up @@ -213,6 +219,9 @@ def classify(self, training):
"cuda": "--use_cuda" if self.gpu else "",
"clean_command": f"rm -rf {self.dump_dir}/processed" if self.clean else "",
"seed": f"--seed {self.seed}" if self.seed else "",
"batch_size": batch_size,
"num_layers": num_layers,
"hidden_dim": hidden_dim
}

format_dict = {
Expand Down
3 changes: 0 additions & 3 deletions pippin/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,9 @@ def update_header(self, header_dict):
replace_dict = {"job-name": "REPLACE_NAME", "output": "REPLACE_LOGFILE", "time": "REPLACE_WALLTIME", "mem-per-cpu": "REPLACE_MEM"} # Allows for replacing just the REPLACE_ key"
for key, value in header_dict.items():
if key in replace_dict.keys():# REPLACE_ replacements
self.logger.debug(f"Replacing {key}: {replace_dict[key]} -> {value}")
self.sbatch_header = self.sbatch_header.replace(replace_dict[key], value)
self.logger.debug(self.sbatch_header)
else: # SBATCH replacements
lines = self.sbatch_header.split('\n')
self.logger.debug(f"SBATCH replacement: {key} -> {value}")
line = f"#SBATCH --{key}={value}"
if f'--{key}=' in self.sbatch_header: # Replace the sbatch key
idx = [i for i in range(len(lines)) if f'--{key}=' in lines[i]][0]
Expand Down
2 changes: 1 addition & 1 deletion pippin/tasks/supernnova
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if [ $? -ne 0 ]; then
echo FAILURE > {done_file2}
else
echo "#################TIMING Database done now, starting classifier: `date`"
python run.py {cuda} --sntypes '{sntypes}' --done_file {done_file} --batch_size 20 --dump_dir {dump_dir} {cyclic} {variant} {model} {phot} {redshift} {norm} {seed} {command}
python run.py {cuda} --sntypes '{sntypes}' --done_file {done_file} {batch_size} {num_layers} {hidden_dim} --dump_dir {dump_dir} {cyclic} {variant} {model} {phot} {redshift} {norm} {seed} {command}
if [ $? -eq 0 ]; then
{clean_command}
echo SUCCESS > {done_file2}
Expand Down

0 comments on commit addfc10

Please sign in to comment.