Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add scone classifier #42

Merged
merged 3 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cfg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ DataSkimmer:
conda_env: snn_gpu
location: $PRODUCTS/utilities/dataskim

SCONE:
conda_env: scone_cpu
location: $PRODUCTS/classifiers/scone

CosmoMC:
location: $PRODUCTS/CosmoMC/v04/CosmoMC-master
static_loc: cosmomc_static_chains
Expand Down
203 changes: 203 additions & 0 deletions pippin/classifiers/scone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import os
import shutil
import subprocess
from pathlib import Path
import yaml

from pippin.classifiers.classifier import Classifier
from pippin.config import get_config, get_output_loc, mkdirs
from pippin.task import Task
from pippin.external.SNANA_FITS_to_pd import read_fits


class SconeClassifier(Classifier):
""" Nearest Neighbor Python classifier

CONFIGURATION:
==============
CLASSIFICATION:
label:
MASK: TEST # partial match on sim and classifier
MASK_SIM: TEST # partial match on sim name
MASK_FIT: TEST # partial match on lcfit name
MODE: train/predict
OPTS:
FITOPT: someLabel # Exact match to fitopt in a fitopt file. USED FOR TRAINING ONLY
FEATURES: x1 c zHD # Columns out of fitres file to use as features
MODEL: someName # exact name of training classification task

OUTPUTS:
========
name : name given in the yml
output_dir: top level output directory
prob_column_name: name of the column to get probabilities out of
predictions_filename: location of csv filename with id/probs

"""

def __init__(self, name, output_dir, config, dependencies, mode, options, index=0, model_name=None):
super().__init__(name, output_dir, config, dependencies, mode, options, index=index, model_name=model_name)
self.global_config = get_config()
self.options = options

self.conda_env = self.global_config["SCONE"]["conda_env"]
self.path_to_classifier = self.global_config["SCONE"]["location"]

self.job_base_name = os.path.basename(Path(output_dir).parents[1]) + "__" + os.path.basename(output_dir)

self.config_path = os.path.join(self.output_dir, "model_config.yml")
self.heatmaps_path = os.path.join(self.output_dir, "heatmaps")
self.csvs_path = os.path.join(self.output_dir, "sim_csvs")
self.slurm = """{sbatch_header}
{task_setup}"""

self.logfile = os.path.join(self.output_dir, "output.log")

remake_heatmaps = self.options.get("REMAKE_HEATMAPS")
self.keep_heatmaps = not remake_heatmaps if remake_heatmaps is not None else True

def classify(self, mode):
if self.gpu:
self.sbatch_header = self.sbatch_gpu_header
else:
self.sbatch_header = self.sbatch_cpu_header

header_dict = {
"job-name": self.job_base_name,
"output": "output.log",
"time": "00:55:00", # TODO: scale based on number of heatmaps
"mem-per-cpu": "8GB",
"ntasks": "1",
"cpus-per-task": "4"
}
self.update_header(header_dict)

setup_dict = {
"conda_env": self.conda_env,
"path_to_classifier": self.path_to_classifier,
"heatmaps_path": self.heatmaps_path,
"config_path": self.config_path,
"done_file": self.done_file,
}

format_dict = {
"sbatch_header": self.sbatch_header,
"task_setup": self.update_setup(setup_dict, self.task_setup['scone'])
}
slurm_output_file = self.output_dir + "/job.slurm"
self.logger.info(f"Running SCONE, slurm job outputting to {slurm_output_file}")
slurm_script = self.slurm.format(**format_dict)

new_hash = self.get_hash_from_string(slurm_script)

# check success of intermediate steps and don't redo them if successful
heatmaps_created = self._heatmap_creation_success() and self.keep_heatmaps

if os.path.exists(self.done_file):
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
if "FAILURE" in f.read().upper():
failed = True

if self._check_regenerate(new_hash) or failed:
self.logger.debug("Regenerating")
if not heatmaps_created:
shutil.rmtree(self.output_dir, ignore_errors=True)
mkdirs(self.output_dir)
else:
for f in [f.path for f in os.scandir(self.output_dir) if f.is_file()]:
os.remove(f)

sim_dep = self.get_simulation_dependency()
sim_dirs = sim_dep.output["photometry_dirs"]
if not os.path.exists(self.csvs_path):
os.makedirs(self.csvs_path)
metadata_paths, lcdata_paths = self._fitres_to_csv(self._get_lcdata_paths(sim_dirs), self.csvs_path)

self._write_config_file(metadata_paths, lcdata_paths, mode, self.config_path) # TODO: what if they don't want to train on all sims?

with open(slurm_output_file, "w") as f:
f.write(slurm_script)
self.save_new_hash(new_hash)
self.logger.info(f"Submitting batch job {slurm_output_file}")
subprocess.run(["sbatch", slurm_output_file], cwd=self.output_dir)
else:
self.logger.info("Hash check passed, not rerunning")
self.should_be_done()
return True

def predict(self):
return self.classify("predict")

def train(self):
return self.classify("train")

def _write_config_file(self, metadata_paths, lcdata_paths, mode, config_path):
config = {}
config["categorical"] = self.options.get("CATEGORICAL", False)
config["num_epochs"] = self.options.get("NUM_EPOCHS")
config["metadata_paths"] = metadata_paths
config["lcdata_paths"] = lcdata_paths
config["heatmaps_path"] = self.heatmaps_path
config["num_wavelength_bins"] = self.options.get("NUM_WAVELENGTH_BINS", 32)
config["num_mjd_bins"] = self.options.get("NUM_MJD_BINS", 180)
config["Ia_fraction"] = self.options.get("IA_FRACTION", 0.5)
config["donefile"] = self.done_file
config["mode"] = mode
config["sn_type_id_to_name"] = {42: "SNII",
52: "SNIax",
62: "SNIbc",
67: "SNIa-91bg",
64: "KN",
90: "SNIa",
95: "SLSN-1"}

with open(config_path, "w+") as cfgfile:
cfgfile.write(yaml.dump(config))

def _check_completion(self, squeue):
if os.path.exists(self.done_file):
self.logger.debug(f"Found done file at {self.done_file}")
with open(self.done_file) as f:
if "FAILURE" in f.read().upper():
return Task.FINISHED_FAILURE
return Task.FINISHED_SUCCESS
return self.check_for_job(squeue, self.job_base_name)

def _heatmap_creation_success(self):
return os.path.exists(self.heatmaps_path) and os.path.exists(os.path.join(self.heatmaps_path, "done.log"))

@staticmethod
def _get_lcdata_paths(sim_dirs):
sim_paths = [f.path for sim_dir in sim_dirs for f in os.scandir(sim_dir)]
lcdata_paths = [path for path in sim_paths if "PHOT" in path]

return lcdata_paths

@staticmethod
def _fitres_to_csv(lcdata_paths, output_dir):
csv_metadata_paths = []
csv_lcdata_paths = []

for path in lcdata_paths:
csv_metadata_path = os.path.join(output_dir, os.path.basename(path).replace("PHOT.FITS.gz", "HEAD.csv"))
csv_lcdata_path = os.path.join(output_dir, os.path.basename(path).replace(".FITS.gz", ".csv"))

if os.path.exists(csv_metadata_path) and os.path.exists(csv_lcdata_path):
csv_metadata_paths.append(csv_metadata_path)
csv_lcdata_paths.append(csv_lcdata_path)
continue

metadata, lcdata = read_fits(path)

metadata.to_csv(csv_metadata_path)
lcdata.to_csv(csv_lcdata_path)

csv_metadata_paths.append(csv_metadata_path)
csv_lcdata_paths.append(csv_lcdata_path)

return csv_metadata_paths, csv_lcdata_paths

@staticmethod
def get_requirements(options):
return True, False
78 changes: 78 additions & 0 deletions pippin/external/SNANA_FITS_to_pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pandas as pd
from pathlib import Path
from astropy.table import Table
import os

"""
SNANA simulation/data format to pandas
"""

def read_fits(fname,drop_separators=False):
"""Load SNANA formatted data and cast it to a PANDAS dataframe

Args:
fname (str): path + name to PHOT.FITS file
drop_separators (Boolean): if -777 are to be dropped

Returns:
(pandas.DataFrame) dataframe from PHOT.FITS file (with ID)
(pandas.DataFrame) dataframe from HEAD.FITS file
"""

# load photometry
dat = Table.read(fname, format='fits')
df_phot = dat.to_pandas()
# failsafe
if df_phot.MJD.values[-1] == -777.0:
df_phot = df_phot.drop(df_phot.index[-1])
if df_phot.MJD.values[0] == -777.0:
df_phot = df_phot.drop(df_phot.index[0])

# load header
header = Table.read(fname.replace("PHOT", "HEAD"), format="fits")
df_header = header.to_pandas()
df_header["SNID"] = df_header["SNID"].astype(np.int32)

# add SNID to phot for skimming
arr_ID = np.zeros(len(df_phot), dtype=np.int32)
# New light curves are identified by MJD == -777.0
arr_idx = np.where(df_phot["MJD"].values == -777.0)[0]
arr_idx = np.hstack((np.array([0]), arr_idx, np.array([len(df_phot)])))
# Fill in arr_ID
for counter in range(1, len(arr_idx)):
start, end = arr_idx[counter - 1], arr_idx[counter]
# index starts at zero
arr_ID[start:end] = df_header.SNID.iloc[counter - 1]
df_phot["SNID"] = arr_ID

if drop_separators:
df_phot = df_phot[df_phot.MJD != -777.000]

df_header = df_header[["SNID", "SNTYPE", "PEAKMJD", "REDSHIFT_FINAL", "MWEBV"]]
df_phot = df_phot[["SNID", "MJD", "FLT", "FLUXCAL", "FLUXCALERR"]]
df_header = df_header.rename(columns={"SNID":"object_id", "SNTYPE": "true_target", "PEAKMJD": "true_peakmjd", "REDSHIFT_FINAL": "true_z", "MWEBV": "mwebv"})
df_header.replace({"true_target":
{120: 42, 20: 42, 121: 42, 21: 42, 122: 42, 22: 42, 130: 62, 30: 62, 131: 62, 31: 62, 101: 90, 1: 90, 102: 52, 2: 52, 104: 64, 4: 64, 103: 95, 3: 95, 191: 67, 91: 67}}, inplace=True)
print(df_header)
df_phot = df_phot.rename(columns={"SNID":"object_id", "MJD": "mjd", "FLT": "passband", "FLUXCAL": "flux", "FLUXCALERR": "flux_err"})
df_phot.replace({"passband":
{b'u ': 0, b'g ': 1, b'r ': 2, b'i ': 3, b'z ': 4, b'Y ': 5}}, inplace=True)

return df_header, df_phot

def save_fits(df, fname):
"""Save data frame in fits table

Arguments:
df {pandas.DataFrame} -- data to save
fname {str} -- outname, must end in .FITS
"""

keep_cols = df.keys()
df = df.reset_index()
df = df[keep_cols]

outtable = Table.from_pandas(df)
Path(fname).parent.mkdir(parents=True, exist_ok=True)
outtable.write(fname, format='fits', overwrite=True)
21 changes: 21 additions & 0 deletions pippin/tasks/scone
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
source activate {conda_env}
cd {path_to_classifier}
if ! [ -f {heatmaps_path}/done.log ]; then
echo {heatmaps_path}/done.log
echo "#################TIMING heatmap creation start: `date`"
python create_heatmaps.py --config_path {config_path}
if ! [ $? -eq 0 ]; then
echo "create heatmaps FAILED"
echo create_heatmaps FAILURE >> {done_file}
exit 1
fi
fi
echo "#################TIMING heatmap creation done now, starting classifier: `date`"
python run_model.py --config_path {config_path}
if [ $? -eq 0 ]; then
echo classify SUCCESS >> {done_file}
else
echo classify FAILURE >> {done_file}
fi
echo "#################TIMING classifier finished: `date`"