diff --git a/pippin/classifiers/scone.py b/pippin/classifiers/scone.py index e00e46b3..3fe9db2f 100644 --- a/pippin/classifiers/scone.py +++ b/pippin/classifiers/scone.py @@ -3,6 +3,7 @@ import subprocess from pathlib import Path import yaml +import pandas as pd from pippin.classifiers.classifier import Classifier from pippin.config import get_config, get_output_loc, mkdirs @@ -22,16 +23,18 @@ class SconeClassifier(Classifier): MASK_FIT: TEST # partial match on lcfit name MODE: train/predict OPTS: - FITOPT: someLabel # Exact match to fitopt in a fitopt file. USED FOR TRAINING ONLY - FEATURES: x1 c zHD # Columns out of fitres file to use as features - MODEL: someName # exact name of training classification task + GPU: True + CATEGORICAL: False + NUM_WAVELENGTH_BINS: 32 + NUM_MJD_BINS: 180 + REMAKE_HEATMAPS: False + NUM_EPOCHS: 400 + IA_FRACTION: 0.5 OUTPUTS: ======== - name : name given in the yml - output_dir: top level output directory - prob_column_name: name of the column to get probabilities out of - predictions_filename: location of csv filename with id/probs + predictions.csv: list of snids and associated predictions + training_history.csv: training history output from keras """ @@ -40,6 +43,7 @@ def __init__(self, name, output_dir, config, dependencies, mode, options, index= self.global_config = get_config() self.options = options + self.gpu = self.options.get("GPU", True) self.conda_env = self.global_config["SCONE"]["conda_env_cpu"] if not self.gpu else self.global_config["SCONE"]["conda_env_gpu"] self.path_to_classifier = self.global_config["SCONE"]["location"] @@ -53,8 +57,8 @@ def __init__(self, name, output_dir, config, dependencies, mode, options, index= self.logfile = os.path.join(self.output_dir, "output.log") - remake_heatmaps = self.options.get("REMAKE_HEATMAPS") - self.keep_heatmaps = not remake_heatmaps if remake_heatmaps is not None else True + remake_heatmaps = self.options.get("REMAKE_HEATMAPS", False) + self.keep_heatmaps = not remake_heatmaps def classify(self, mode): if self.gpu: @@ -93,6 +97,7 @@ def classify(self, mode): # check success of intermediate steps and don't redo them if successful heatmaps_created = self._heatmap_creation_success() and self.keep_heatmaps + failed = False if os.path.exists(self.done_file): self.logger.debug(f"Found done file at {self.done_file}") with open(self.done_file) as f: @@ -135,6 +140,8 @@ def train(self): def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path): config = {} config["categorical"] = self.options.get("CATEGORICAL", False) + # TODO: replace num epochs with autostop: stop training when slope plateaus? + # TODO: how to choose optimal batch size? config["num_epochs"] = self.options.get("NUM_EPOCHS") config["metadata_paths"] = metadata_paths config["lcdata_paths"] = lcdata_paths @@ -143,6 +150,7 @@ def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path): config["num_mjd_bins"] = self.options.get("NUM_MJD_BINS", 180) config["Ia_fraction"] = self.options.get("IA_FRACTION", 0.5) config["donefile"] = self.done_file + config["output_path"] = self.output_dir config["mode"] = mode config["sn_type_id_to_name"] = {42: "SNII", 52: "SNIax", @@ -161,10 +169,21 @@ def _check_completion(self, squeue): with open(self.done_file) as f: if "FAILURE" in f.read().upper(): return Task.FINISHED_FAILURE + + pred_path = os.path.join(self.output_dir, "predictions.csv") + predictions = pd.read_csv(pred_path) + predictions = predictions.rename(columns={"pred": self.get_prob_column_name()}) + predictions.to_csv(pred_path, index=False) + self.logger.info(f"Predictions file can be found at {pred_path}") return Task.FINISHED_SUCCESS return self.check_for_job(squeue, self.job_base_name) def _heatmap_creation_success(self): + if not os.path.exists(self.done_file): + return False + with open(self.done_file, "r") as donefile: + if "CREATE HEATMAPS FAILURE" in donefile.read(): + return False return os.path.exists(self.heatmaps_path) and os.path.exists(os.path.join(self.heatmaps_path, "done.log")) @staticmethod diff --git a/pippin/external/SNANA_FITS_to_pd.py b/pippin/external/SNANA_FITS_to_pd.py index 3fa24000..f002b726 100644 --- a/pippin/external/SNANA_FITS_to_pd.py +++ b/pippin/external/SNANA_FITS_to_pd.py @@ -54,10 +54,10 @@ def read_fits(fname,drop_separators=False): df_header = df_header.rename(columns={"SNID":"object_id", "SNTYPE": "true_target", "PEAKMJD": "true_peakmjd", "REDSHIFT_FINAL": "true_z", "MWEBV": "mwebv"}) df_header.replace({"true_target": {120: 42, 20: 42, 121: 42, 21: 42, 122: 42, 22: 42, 130: 62, 30: 62, 131: 62, 31: 62, 101: 90, 1: 90, 102: 52, 2: 52, 104: 64, 4: 64, 103: 95, 3: 95, 191: 67, 91: 67}}, inplace=True) - print(df_header) df_phot = df_phot.rename(columns={"SNID":"object_id", "MJD": "mjd", "FLT": "passband", "FLUXCAL": "flux", "FLUXCALERR": "flux_err"}) - df_phot.replace({"passband": - {b'u ': 0, b'g ': 1, b'r ': 2, b'i ': 3, b'z ': 4, b'Y ': 5}}, inplace=True) + passband_dict = {"passband": {b'u ': 0, b'g ': 1, b'r ': 2, b'i ': 3, b'z ': 4, b'Y ': 5}} + df_phot = df_phot[df_phot.passband.isin(passband_dict["passband"])] + df_phot.replace(passband_dict, inplace=True) return df_header, df_phot