Skip to content

Commit

Permalink
Merge pull request #44 from Samreay/scone
Browse files Browse the repository at this point in the history
fix docstring, write predictions.csv file
  • Loading branch information
helenqu authored Mar 20, 2021
2 parents 244ed41 + e943ff3 commit 9e0fb3f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
37 changes: 28 additions & 9 deletions pippin/classifiers/scone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess
from pathlib import Path
import yaml
import pandas as pd

from pippin.classifiers.classifier import Classifier
from pippin.config import get_config, get_output_loc, mkdirs
Expand All @@ -22,16 +23,18 @@ class SconeClassifier(Classifier):
MASK_FIT: TEST # partial match on lcfit name
MODE: train/predict
OPTS:
FITOPT: someLabel # Exact match to fitopt in a fitopt file. USED FOR TRAINING ONLY
FEATURES: x1 c zHD # Columns out of fitres file to use as features
MODEL: someName # exact name of training classification task
GPU: True
CATEGORICAL: False
NUM_WAVELENGTH_BINS: 32
NUM_MJD_BINS: 180
REMAKE_HEATMAPS: False
NUM_EPOCHS: 400
IA_FRACTION: 0.5
OUTPUTS:
========
name : name given in the yml
output_dir: top level output directory
prob_column_name: name of the column to get probabilities out of
predictions_filename: location of csv filename with id/probs
predictions.csv: list of snids and associated predictions
training_history.csv: training history output from keras
"""

Expand All @@ -40,6 +43,7 @@ def __init__(self, name, output_dir, config, dependencies, mode, options, index=
self.global_config = get_config()
self.options = options

self.gpu = self.options.get("GPU", True)
self.conda_env = self.global_config["SCONE"]["conda_env_cpu"] if not self.gpu else self.global_config["SCONE"]["conda_env_gpu"]
self.path_to_classifier = self.global_config["SCONE"]["location"]

Expand All @@ -53,8 +57,8 @@ def __init__(self, name, output_dir, config, dependencies, mode, options, index=

self.logfile = os.path.join(self.output_dir, "output.log")

remake_heatmaps = self.options.get("REMAKE_HEATMAPS")
self.keep_heatmaps = not remake_heatmaps if remake_heatmaps is not None else True
remake_heatmaps = self.options.get("REMAKE_HEATMAPS", False)
self.keep_heatmaps = not remake_heatmaps

def classify(self, mode):
if self.gpu:
Expand Down Expand Up @@ -93,6 +97,7 @@ def classify(self, mode):
# check success of intermediate steps and don't redo them if successful
heatmaps_created = self._heatmap_creation_success() and self.keep_heatmaps

failed = False
if os.path.exists(self.done_file):
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
Expand Down Expand Up @@ -135,6 +140,8 @@ def train(self):
def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path):
config = {}
config["categorical"] = self.options.get("CATEGORICAL", False)
# TODO: replace num epochs with autostop: stop training when slope plateaus?
# TODO: how to choose optimal batch size?
config["num_epochs"] = self.options.get("NUM_EPOCHS")
config["metadata_paths"] = metadata_paths
config["lcdata_paths"] = lcdata_paths
Expand All @@ -143,6 +150,7 @@ def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path):
config["num_mjd_bins"] = self.options.get("NUM_MJD_BINS", 180)
config["Ia_fraction"] = self.options.get("IA_FRACTION", 0.5)
config["donefile"] = self.done_file
config["output_path"] = self.output_dir
config["mode"] = mode
config["sn_type_id_to_name"] = {42: "SNII",
52: "SNIax",
Expand All @@ -161,10 +169,21 @@ def _check_completion(self, squeue):
with open(self.done_file) as f:
if "FAILURE" in f.read().upper():
return Task.FINISHED_FAILURE

pred_path = os.path.join(self.output_dir, "predictions.csv")
predictions = pd.read_csv(pred_path)
predictions = predictions.rename(columns={"pred": self.get_prob_column_name()})
predictions.to_csv(pred_path, index=False)
self.logger.info(f"Predictions file can be found at {pred_path}")
return Task.FINISHED_SUCCESS
return self.check_for_job(squeue, self.job_base_name)

def _heatmap_creation_success(self):
if not os.path.exists(self.done_file):
return False
with open(self.done_file, "r") as donefile:
if "CREATE HEATMAPS FAILURE" in donefile.read():
return False
return os.path.exists(self.heatmaps_path) and os.path.exists(os.path.join(self.heatmaps_path, "done.log"))

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions pippin/external/SNANA_FITS_to_pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def read_fits(fname,drop_separators=False):
df_header = df_header.rename(columns={"SNID":"object_id", "SNTYPE": "true_target", "PEAKMJD": "true_peakmjd", "REDSHIFT_FINAL": "true_z", "MWEBV": "mwebv"})
df_header.replace({"true_target":
{120: 42, 20: 42, 121: 42, 21: 42, 122: 42, 22: 42, 130: 62, 30: 62, 131: 62, 31: 62, 101: 90, 1: 90, 102: 52, 2: 52, 104: 64, 4: 64, 103: 95, 3: 95, 191: 67, 91: 67}}, inplace=True)
print(df_header)
df_phot = df_phot.rename(columns={"SNID":"object_id", "MJD": "mjd", "FLT": "passband", "FLUXCAL": "flux", "FLUXCALERR": "flux_err"})
df_phot.replace({"passband":
{b'u ': 0, b'g ': 1, b'r ': 2, b'i ': 3, b'z ': 4, b'Y ': 5}}, inplace=True)
passband_dict = {"passband": {b'u ': 0, b'g ': 1, b'r ': 2, b'i ': 3, b'z ': 4, b'Y ': 5}}
df_phot = df_phot[df_phot.passband.isin(passband_dict["passband"])]
df_phot.replace(passband_dict, inplace=True)

return df_header, df_phot

Expand Down

0 comments on commit 9e0fb3f

Please sign in to comment.