Skip to content

Commit

Permalink
Allow optional dependencies in Classification tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
OmegaLambda1998 committed May 2, 2024
1 parent 55862c6 commit 4918476
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 3 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ CLASSIFICATION:
COMBINE_MASK: [SIM_IA, SIM_CC] # optional mask to combine multiple sim runs into one classification job (e.g. separate CC and Ia sims). NOTE: currently not compatible with SuperNNova/SNIRF
OPTS:
MODEL: file_or_label # only needed in predict mode, how to find the trained classifier
OPTIONAL_MASK: # mask for optional dependencies. Not all classifiers make use of this
OPTIONAL_MASK_SIM: # mask for optional sim dependencies. Not all classifiers make use of this
OPTIONAL_MASK_FIT: # mask for optional lcfit dependencies. Not all classifiers make use of this
WHATREVER_THE: CLASSIFIER_NEEDS
```

Expand Down Expand Up @@ -1000,6 +1003,13 @@ however you want.)
You'll also notice a very simply `_check_completion` method, and a `get_requirmenets` method. The latter returns a two-tuple of booleans, indicating
whether the classifier needs photometry and light curve fitting results respectively. For the NearestNeighbour code, it classifies based
only on SALT2 features, so I return `(False, True)`.
You can also define a `get_optional_requirements` method which, like `get_requirements`, returns a two-tuple of booleans, indicating whether the classifer needs photometry and light curve fitting resulst *for this particular run*. By default, this method returns:
- `True, True` if `OPTIONAL_MASK` set in `OPTS`
- `True, False` if `OPTIONAL_MASK_SIM` set in `OPTS`
- `False, True` if `OPTIONAL_MASK_FIT` set in `OPTS`
- `False, False` otherwise.

If you define your own method based on classifier specific requirements, then these `OPTIONAL_MASK*` keys can still be set to choose which tasks are optionally included. If there are not set, then the normal `MASK`, `MASK_SIM`, and `MASK_FIT` are used instead. Note that if *no* masks are set then *every* sim or lcfit task will be included.

Finally, you'll need to add your classifier into the ClassifierFactory in `classifiers/factory.py`, so that I can link a class name
in the YAML configuration to your actual class. Yeah yeah, I could use reflection or dynamic module scanning or similar, but I've had issues getting
Expand Down
83 changes: 80 additions & 3 deletions pippin/classifiers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@ class Classifier(Task):
==============
CLASSIFICATION:
label:
MASK: TEST # partial match on sim and classifier
MASK: TEST # partial match on sim name and lcfit name
MASK_SIM: TEST # partial match on sim name
MASK_FIT: TEST # partial match on lcfit name
COMBINE_MASK: TEST1,TEST2 # combining multiple masks (e.g. SIM_Ia,SIM_CC)
COMBINE_MASK: TEST1,TEST2 # combining multiple masks (e.g. SIM_Ia,SIM_CC) - *exact* match on sim name and lcfit name
MODE: train/predict # Some classifiers dont need training and so you can set to predict straight away
OPTS:
# Masks for optional dependencies. Since whether an optional dependency is allowed is classifier specific, this is a classifier opt
# If any are defined, they take precedence, otherwise use above masks for optional dependencies too
OPTIONAL_MASK: TEST
OPTIONAL_MASK_SIM: TEST
OPTIONAL_MASK_FIT: TEST
CHANGES_FOR_INDIVIDUAL_CLASSIFIERS
OUTPUTS:
Expand Down Expand Up @@ -68,6 +75,27 @@ def get_requirements(config):
"""
return True, True


@staticmethod
def get_optional_requirements(config):
""" Return what data *may* be used by the classier.
Default behaviour:
if OPTIONAL_MASK != "":
True, True
if OPTIONAL_MASK_SIM != "":
True, False
if OPTIONAL_MASK_FIT != "":
False, True
:param config: the input dictionary `OPTS` from the config file
:return: a two tuple - (needs simulation photomerty, needs a fitres file)
"""

opt_sim = ("OPTIONAL_MASK" in config) or ("OPTIONAL_MASK_SIM" in config)
opt_fit = ("OPTIONAL_MASK" in config) or ("OPTIONAL_MASK_FIT" in config)
return opt_sim, opt_fit

def get_fit_dependency(self, output=True):
fit_deps = []
for t in self.dependencies:
Expand Down Expand Up @@ -180,6 +208,55 @@ def get_num_ranseed(sim_tasks, lcfit_tasks):

needs_sim, needs_lc = cls.get_requirements(options)

# Load in all optional tasks
opt_sim, opt_lc = cls.get_optional_requirements(options)
opt_deps = []
if opt_sim or opt_lc:
# Get all optional masks
mask = options.get("OPTIONAL_MASK", "")
mask_sim = options.get("OPTIONAL_MASK_SIM", "")
mask_fit = options.get("OPTIONAL_MASK_FIT", "")

# If no optional masks are set, use base masks
if not any([mask, mask_sim, mask_fit]):
mask = config.get("MASK", "")
mask_sim = config.get("MASK_SIM", "")
mask_fit = config.get("MASK_FIT", "")

# Get optional sim tasks
optional_sim_tasks = []
if opt_sim:
if not any([mask, mask_sim]):
Task.logger.debug(f"No optional sim masks set, all sim tasks included as dependendencies")
optional_sim_tasks = sim_tasks
else:
for s in sim_tasks:
if mask_sim and mask_sim in s.name:
optional_sim_tasks.append(s)
elif mask and mask in s.name:
optional_sim_tasks.append(s)
if len(optional_sim_tasks) == 0:
Task.logger.warn(f"Optional SIM dependency but no matching sim tasks for MASK: {mask} or MASK_SIM: {mask_sim}")
else:
Task.logger.debug(f"Found {len(optional_sim_tasks)} optional SIM dependencies")
# Get optional lcfit tasks
optional_lcfit_tasks = []
if opt_lc:
if not any([mask, mask_fit]):
Task.logger.debug(f"No optional lcfit masks set, all lcfit tasks included as dependendencies")
optional_lcfit_tasks = lcfit_tasks
else:
for l in lcfit_tasks:
if mask_fit and mask_fit in l.name:
optional_lcfit_tasks.append(l)
elif mask and mask in l.name:
optional_lcfit_tasks.append(l)
if len(optional_lcfit_tasks) == 0:
Task.logger.warn(f"Optional LCFIT dependency but no matching lcfit tasks for MASK: {mask} or MASK_FIT: {mask_fit}")
else:
Task.logger.debug(f"Found {len(optional_lcfit_tasks)} optional LCFIT dependencies")
opt_deps = optional_sim_tasks + optional_lcfit_tasks

runs = []
if "COMBINE_MASK" in config:
combined_tasks = []
Expand Down Expand Up @@ -253,7 +330,7 @@ def get_num_ranseed(sim_tasks, lcfit_tasks):
len(folders) == 1
), f"Training requires one version of the lcfits, you have {len(folders)} for lcfit task {l}. Make sure your training sim doesn't set RANSEED_CHANGE"

deps = sim_deps + fit_deps
deps = sim_deps + fit_deps + opt_deps

sim_name = "_".join([s.name for s in sim_deps if s is not None]) if len(sim_deps) > 0 else None
fit_name = "_".join([l.name for l in fit_deps if l is not None]) if len(fit_deps) > 0 else None
Expand Down
25 changes: 25 additions & 0 deletions tests/config_files/valid_classify_sim_with_lcfit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
SIM:
EXAMPLESIM:
IA_G10_DES3YR:
BASE: surveys/sdss/sims_ia/sn_ia_g10_sdss_3yr.input
II:
BASE: surveys/sdss/sims_cc/sn_ii_templates.input
Ibc:
BASE: surveys/sdss/sims_cc/sn_ibc_templates.input
GLOBAL:
NGEN_UNIT: 1
RANSEED_REPEAT: 10 12345
SOLID_ANGLE: 10

LCFIT:
D:
BASE: surveys/des/lcfit_nml/des_5yr.nml

CLASSIFICATION:
PERFECT:
CLASSIFIER: PerfectClassifier
MODE: predict
OPTS:
OPTIONAL_MASK_FIT: "D"
PROB_IA: 1.0
PROB_CC: 0.0
16 changes: 16 additions & 0 deletions tests/test_valid_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ def test_classifier_lcfit_config_valid():
assert task.output["prob_column_name"] == "PROB_FITPROBTEST"
assert len(task.dependencies) == 2

def test_classifier_sim_with_opt_lcfit_config_valid():
manager = get_manager(yaml="tests/config_files/valid_classify_sim_with_lcfit.yml", check=True)
tasks = manager.tasks

assert len(tasks) == 3
assert isinstance(tasks[0], SNANASimulation)
assert isinstance(tasks[1], SNANALightCurveFit)
assert isinstance(tasks[2], PerfectClassifier)

task = tasks[-1]
assert task.name == "PERFECT"
assert task.output["prob_column_name"] == "PROB_PERFECT"
deps = task.dependencies
assert len(deps) == 2
assert isinstance(deps[0], SNANASimulation)
assert isinstance(deps[1], SNANALightCurveFit)

def test_agg_config_valid():

Expand Down

0 comments on commit 4918476

Please sign in to comment.