diff --git a/README.md b/README.md index 91f0bdc2..69aacb84 100644 --- a/README.md +++ b/README.md @@ -514,6 +514,9 @@ CLASSIFICATION: COMBINE_MASK: [SIM_IA, SIM_CC] # optional mask to combine multiple sim runs into one classification job (e.g. separate CC and Ia sims). NOTE: currently not compatible with SuperNNova/SNIRF OPTS: MODEL: file_or_label # only needed in predict mode, how to find the trained classifier + OPTIONAL_MASK: # mask for optional dependencies. Not all classifiers make use of this + OPTIONAL_MASK_SIM: # mask for optional sim dependencies. Not all classifiers make use of this + OPTIONAL_MASK_FIT: # mask for optional lcfit dependencies. Not all classifiers make use of this WHATREVER_THE: CLASSIFIER_NEEDS ``` @@ -1000,6 +1003,13 @@ however you want.) You'll also notice a very simply `_check_completion` method, and a `get_requirmenets` method. The latter returns a two-tuple of booleans, indicating whether the classifier needs photometry and light curve fitting results respectively. For the NearestNeighbour code, it classifies based only on SALT2 features, so I return `(False, True)`. +You can also define a `get_optional_requirements` method which, like `get_requirements`, returns a two-tuple of booleans, indicating whether the classifer needs photometry and light curve fitting resulst *for this particular run*. By default, this method returns: +- `True, True` if `OPTIONAL_MASK` set in `OPTS` +- `True, False` if `OPTIONAL_MASK_SIM` set in `OPTS` +- `False, True` if `OPTIONAL_MASK_FIT` set in `OPTS` +- `False, False` otherwise. + +If you define your own method based on classifier specific requirements, then these `OPTIONAL_MASK*` keys can still be set to choose which tasks are optionally included. If there are not set, then the normal `MASK`, `MASK_SIM`, and `MASK_FIT` are used instead. Note that if *no* masks are set then *every* sim or lcfit task will be included. Finally, you'll need to add your classifier into the ClassifierFactory in `classifiers/factory.py`, so that I can link a class name in the YAML configuration to your actual class. Yeah yeah, I could use reflection or dynamic module scanning or similar, but I've had issues getting diff --git a/pippin/classifiers/classifier.py b/pippin/classifiers/classifier.py index 9dc88a2e..f4498b40 100644 --- a/pippin/classifiers/classifier.py +++ b/pippin/classifiers/classifier.py @@ -14,12 +14,19 @@ class Classifier(Task): ============== CLASSIFICATION: label: - MASK: TEST # partial match on sim and classifier + MASK: TEST # partial match on sim name and lcfit name MASK_SIM: TEST # partial match on sim name MASK_FIT: TEST # partial match on lcfit name - COMBINE_MASK: TEST1,TEST2 # combining multiple masks (e.g. SIM_Ia,SIM_CC) + COMBINE_MASK: TEST1,TEST2 # combining multiple masks (e.g. SIM_Ia,SIM_CC) - *exact* match on sim name and lcfit name + MODE: train/predict # Some classifiers dont need training and so you can set to predict straight away OPTS: + # Masks for optional dependencies. Since whether an optional dependency is allowed is classifier specific, this is a classifier opt + # If any are defined, they take precedence, otherwise use above masks for optional dependencies too + OPTIONAL_MASK: TEST + OPTIONAL_MASK_SIM: TEST + OPTIONAL_MASK_FIT: TEST + CHANGES_FOR_INDIVIDUAL_CLASSIFIERS OUTPUTS: @@ -68,6 +75,27 @@ def get_requirements(config): """ return True, True + + @staticmethod + def get_optional_requirements(config): + """ Return what data *may* be used by the classier. + Default behaviour: + if OPTIONAL_MASK != "": + True, True + if OPTIONAL_MASK_SIM != "": + True, False + if OPTIONAL_MASK_FIT != "": + False, True + + + :param config: the input dictionary `OPTS` from the config file + :return: a two tuple - (needs simulation photomerty, needs a fitres file) + """ + + opt_sim = ("OPTIONAL_MASK" in config) or ("OPTIONAL_MASK_SIM" in config) + opt_fit = ("OPTIONAL_MASK" in config) or ("OPTIONAL_MASK_FIT" in config) + return opt_sim, opt_fit + def get_fit_dependency(self, output=True): fit_deps = [] for t in self.dependencies: @@ -180,6 +208,55 @@ def get_num_ranseed(sim_tasks, lcfit_tasks): needs_sim, needs_lc = cls.get_requirements(options) + # Load in all optional tasks + opt_sim, opt_lc = cls.get_optional_requirements(options) + opt_deps = [] + if opt_sim or opt_lc: + # Get all optional masks + mask = options.get("OPTIONAL_MASK", "") + mask_sim = options.get("OPTIONAL_MASK_SIM", "") + mask_fit = options.get("OPTIONAL_MASK_FIT", "") + + # If no optional masks are set, use base masks + if not any([mask, mask_sim, mask_fit]): + mask = config.get("MASK", "") + mask_sim = config.get("MASK_SIM", "") + mask_fit = config.get("MASK_FIT", "") + + # Get optional sim tasks + optional_sim_tasks = [] + if opt_sim: + if not any([mask, mask_sim]): + Task.logger.debug(f"No optional sim masks set, all sim tasks included as dependendencies") + optional_sim_tasks = sim_tasks + else: + for s in sim_tasks: + if mask_sim and mask_sim in s.name: + optional_sim_tasks.append(s) + elif mask and mask in s.name: + optional_sim_tasks.append(s) + if len(optional_sim_tasks) == 0: + Task.logger.warn(f"Optional SIM dependency but no matching sim tasks for MASK: {mask} or MASK_SIM: {mask_sim}") + else: + Task.logger.debug(f"Found {len(optional_sim_tasks)} optional SIM dependencies") + # Get optional lcfit tasks + optional_lcfit_tasks = [] + if opt_lc: + if not any([mask, mask_fit]): + Task.logger.debug(f"No optional lcfit masks set, all lcfit tasks included as dependendencies") + optional_lcfit_tasks = lcfit_tasks + else: + for l in lcfit_tasks: + if mask_fit and mask_fit in l.name: + optional_lcfit_tasks.append(l) + elif mask and mask in l.name: + optional_lcfit_tasks.append(l) + if len(optional_lcfit_tasks) == 0: + Task.logger.warn(f"Optional LCFIT dependency but no matching lcfit tasks for MASK: {mask} or MASK_FIT: {mask_fit}") + else: + Task.logger.debug(f"Found {len(optional_lcfit_tasks)} optional LCFIT dependencies") + opt_deps = optional_sim_tasks + optional_lcfit_tasks + runs = [] if "COMBINE_MASK" in config: combined_tasks = [] @@ -253,7 +330,7 @@ def get_num_ranseed(sim_tasks, lcfit_tasks): len(folders) == 1 ), f"Training requires one version of the lcfits, you have {len(folders)} for lcfit task {l}. Make sure your training sim doesn't set RANSEED_CHANGE" - deps = sim_deps + fit_deps + deps = sim_deps + fit_deps + opt_deps sim_name = "_".join([s.name for s in sim_deps if s is not None]) if len(sim_deps) > 0 else None fit_name = "_".join([l.name for l in fit_deps if l is not None]) if len(fit_deps) > 0 else None diff --git a/tests/config_files/valid_classify_sim_with_lcfit.yml b/tests/config_files/valid_classify_sim_with_lcfit.yml new file mode 100644 index 00000000..97d1065c --- /dev/null +++ b/tests/config_files/valid_classify_sim_with_lcfit.yml @@ -0,0 +1,25 @@ +SIM: + EXAMPLESIM: + IA_G10_DES3YR: + BASE: surveys/sdss/sims_ia/sn_ia_g10_sdss_3yr.input + II: + BASE: surveys/sdss/sims_cc/sn_ii_templates.input + Ibc: + BASE: surveys/sdss/sims_cc/sn_ibc_templates.input + GLOBAL: + NGEN_UNIT: 1 + RANSEED_REPEAT: 10 12345 + SOLID_ANGLE: 10 + +LCFIT: + D: + BASE: surveys/des/lcfit_nml/des_5yr.nml + +CLASSIFICATION: + PERFECT: + CLASSIFIER: PerfectClassifier + MODE: predict + OPTS: + OPTIONAL_MASK_FIT: "D" + PROB_IA: 1.0 + PROB_CC: 0.0 diff --git a/tests/test_valid_config.py b/tests/test_valid_config.py index cc9e2c5b..88033ab1 100644 --- a/tests/test_valid_config.py +++ b/tests/test_valid_config.py @@ -133,6 +133,22 @@ def test_classifier_lcfit_config_valid(): assert task.output["prob_column_name"] == "PROB_FITPROBTEST" assert len(task.dependencies) == 2 +def test_classifier_sim_with_opt_lcfit_config_valid(): + manager = get_manager(yaml="tests/config_files/valid_classify_sim_with_lcfit.yml", check=True) + tasks = manager.tasks + + assert len(tasks) == 3 + assert isinstance(tasks[0], SNANASimulation) + assert isinstance(tasks[1], SNANALightCurveFit) + assert isinstance(tasks[2], PerfectClassifier) + + task = tasks[-1] + assert task.name == "PERFECT" + assert task.output["prob_column_name"] == "PROB_PERFECT" + deps = task.dependencies + assert len(deps) == 2 + assert isinstance(deps[0], SNANASimulation) + assert isinstance(deps[1], SNANALightCurveFit) def test_agg_config_valid():