From 223650dfe8438d26346f601806cc64a2eaeaefcb Mon Sep 17 00:00:00 2001 From: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> Date: Fri, 15 Oct 2021 15:30:52 -0400 Subject: [PATCH] Simplify CLI with hydra instantiate (#230) * update poetry * more error checking in draw_blends for MR * bug in this line when using basic measure function and MR * catch error related to meas_band_num in metrics * use hydra instantiate and call rather than available_* * update documentation * test MR on main, remove error checking now done by hydra * Revert "update poetry" This reverts commit aad74eac86c4bed2811e773773038cd423935531. --- btk/draw_blends.py | 20 +++++++-------- btk/main.py | 29 +++------------------ btk/measure.py | 1 - btk/metrics.py | 24 ++++++++++++----- btk/sampling_functions.py | 8 ------ conf/catalog/catsim.yaml | 2 ++ conf/catalog/cosmos.yaml | 4 +++ conf/config.yaml | 5 +--- conf/draw_blends/catsim.yaml | 24 +++++++++-------- conf/draw_blends/cosmos.yaml | 24 +++++++++-------- conf/draw_blends/galsim_hub.yaml | 30 +++++++++++----------- conf/hydra/help/btk_help.yaml | 14 +++++----- conf/sampling/default.yaml | 9 +++---- conf/sampling/galsim_hub.yaml | 9 +++---- conf/sampling/group_sampling.yaml | 13 +++++----- conf/sampling/group_sampling_numbered.yaml | 13 +++++----- tests/test_main.py | 16 +++--------- 17 files changed, 109 insertions(+), 136 deletions(-) create mode 100644 conf/catalog/catsim.yaml create mode 100644 conf/catalog/cosmos.yaml diff --git a/btk/draw_blends.py b/btk/draw_blends.py index ee0b32e88..d90e76ef2 100644 --- a/btk/draw_blends.py +++ b/btk/draw_blends.py @@ -3,6 +3,7 @@ import os from abc import ABC from abc import abstractmethod +from collections.abc import Iterable from itertools import chain import galsim @@ -169,14 +170,20 @@ def __init__( if isinstance(surveys, Survey): self.surveys = [surveys] self.check_compatibility(surveys) - elif isinstance(surveys, list): + elif isinstance(surveys, Iterable): for s in surveys: if not isinstance(s, Survey): - raise TypeError("surveys must be a Survey object or a list of Survey objects.") + raise TypeError( + f"surveys must be a Survey object or an Iterable of Survey objects, but" + f"Iterable contained object of type {type(s)}" + ) self.check_compatibility(s) self.surveys = surveys else: - raise TypeError("surveys must be a Survey object or a list of Survey objects.") + raise TypeError( + f"surveys must be a Survey object or an Iterable of Survey objects," + f"but surveys is type {type(surveys)}" + ) self.is_multiresolution = len(self.surveys) > 1 self.stamp_size = stamp_size @@ -643,10 +650,3 @@ def render_single(self, entry, filt, psf, survey, extra_data): nx=pix_stamp_size, ny=pix_stamp_size, scale=survey.pixel_scale, dtype=np.float64 ) return galaxy_image - - -available_draw_blends = { - "CatsimGenerator": CatsimGenerator, - "CosmosGenerator": CosmosGenerator, - "GalsimHubGenerator": GalsimHubGenerator, -} diff --git a/btk/main.py b/btk/main.py index 9d60e2d14..337453196 100644 --- a/btk/main.py +++ b/btk/main.py @@ -1,44 +1,21 @@ """Implements main function to run BTK end-to-end.""" from collections.abc import Iterable +from hydra.utils import instantiate from omegaconf import OmegaConf -from btk.catalog import available_catalogs -from btk.draw_blends import available_draw_blends from btk.measure import available_measure_functions from btk.measure import MeasureGenerator from btk.metrics import MetricsGenerator -from btk.sampling_functions import available_sampling_functions from btk.survey import _get_survey_from_cfg def main(cfg: OmegaConf): """Run BTK from end-to-end using a hydra configuration object.""" - # get catalog - if cfg.catalog.name not in available_catalogs: - raise ValueError(f"Catalog '{cfg.catalog.name}' is not implemented in BTK.") - catalog = available_catalogs[cfg.catalog.name].from_file(cfg.catalog.catalog_files) - - # get sampling function - if cfg.sampling.name not in available_sampling_functions: - raise ValueError(f"Sampling function '{cfg.sampling.name}' is not implemented in BTK.") - sampling_function = available_sampling_functions[cfg.sampling.name](**cfg.sampling.kwargs) - - # get survey(s) to be used. - if not isinstance(cfg.surveys, Iterable): - cfg.surveys = [cfg.surveys] - - surveys = [] - for survey_name in cfg.surveys: - survey = _get_survey_from_cfg(cfg.surveys[survey_name]) - surveys.append(survey) + surveys = [_get_survey_from_cfg(cfg.surveys[survey_name]) for survey_name in cfg.surveys] # get draw blends generator. - if cfg.draw_blends.name not in available_draw_blends: - raise ValueError("DrawBlendGenerator specified is not implemented in BTK.") - draw_blend_generator = available_draw_blends[cfg.draw_blends.name]( - catalog, sampling_function, surveys, **cfg.draw_blends.kwargs - ) + draw_blend_generator = instantiate(cfg.draw_blends, surveys=surveys) # get measure_functions. measure_functions = [] diff --git a/btk/measure.py b/btk/measure.py index 029129113..936cb3eaf 100644 --- a/btk/measure.py +++ b/btk/measure.py @@ -111,7 +111,6 @@ def basic_measure( if is_multiresolution: if surveys is None: raise ValueError("surveys are required in order to use the MR feature.") - surveys = kwargs.get("surveys", None) survey_name = surveys[0].name coadd = np.mean(batch["blend_images"][survey_name][idx], axis=channel_indx) wcs = batch["wcs"][survey_name] diff --git a/btk/metrics.py b/btk/metrics.py index b6316a26b..8872af427 100644 --- a/btk/metrics.py +++ b/btk/metrics.py @@ -53,6 +53,7 @@ """ import os +from collections.abc import Iterable import astropy.table import galsim @@ -730,20 +731,31 @@ def __next__(self): """Returns metric results calculated on one batch.""" blend_results, measure_results = next(self.measure_generator) surveys = self.measure_generator.draw_blend_generator.surveys + meas_band_num = self.meas_band_num metrics_results = {} for meas_func in measure_results["catalog"].keys(): if self.is_multiresolution: + if not isinstance(meas_band_num, Iterable) and hasattr(meas_band_num, "__len__"): + raise ValueError( + f"meas_band_num is required to be an finite length iterable," + f"instead type was {type(meas_band_num)}" + ) + if not len(meas_band_num) == len(surveys): + raise ValueError( + f"meas_band_num should be an iterable of exactly the same len as the" + f"number of surveys: {len(surveys)}, instead len is {len(meas_band_num)}" + ) metrics_results_f = {} for i, surv in enumerate(blend_results["isolated_images"].keys()): additional_params = { "psf": blend_results["psf"][surv], "pixel_scale": surveys[i].pixel_scale, - "meas_band_num": self.meas_band_num[i], + "meas_band_num": meas_band_num[i], "verbose": self.verbose, } noise_threshold = self.noise_threshold_factor * np.sqrt( - get_mean_sky_level(surveys[i], surveys[i].filters[self.meas_band_num[i]]) + get_mean_sky_level(surveys[i], surveys[i].filters[meas_band_num[i]]) ) target_meas = {} for k in self.target_meas.keys(): @@ -756,7 +768,7 @@ def __next__(self): measure_results["segmentation"][meas_func][surv], measure_results["deblended_images"][meas_func][surv], noise_threshold, - self.meas_band_num[i], + meas_band_num[i], target_meas, channels_last=self.measure_generator.channels_last, save_path=os.path.join(self.save_path, meas_func, surv) @@ -771,11 +783,11 @@ def __next__(self): additional_params = { "psf": blend_results["psf"], "pixel_scale": surveys[0].pixel_scale, - "meas_band_num": self.meas_band_num, + "meas_band_num": meas_band_num, "verbose": self.verbose, } noise_threshold = self.noise_threshold_factor * np.sqrt( - get_mean_sky_level(surveys[0], surveys[0].filters[self.meas_band_num]) + get_mean_sky_level(surveys[0], surveys[0].filters[meas_band_num]) ) target_meas = {} for k in self.target_meas.keys(): @@ -788,7 +800,7 @@ def __next__(self): measure_results["segmentation"][meas_func], measure_results["deblended_images"][meas_func], noise_threshold, - self.meas_band_num, + meas_band_num, target_meas, channels_last=self.measure_generator.channels_last, save_path=os.path.join(self.save_path, meas_func, surveys[0].name) diff --git a/btk/sampling_functions.py b/btk/sampling_functions.py index e49252b5c..6ad75e019 100644 --- a/btk/sampling_functions.py +++ b/btk/sampling_functions.py @@ -437,11 +437,3 @@ def __call__(self, table, **kwargs): ) assert len(no_boundary) <= self.max_number, message return no_boundary - - -available_sampling_functions = { - "DefaultSampling": DefaultSampling, - "DefaultSamplingGalsimHub": DefaultSamplingGalsimHub, - "GroupSampling": GroupSampling, - "GroupSamplingNumbered": GroupSamplingNumbered, -} diff --git a/conf/catalog/catsim.yaml b/conf/catalog/catsim.yaml new file mode 100644 index 000000000..f7b25eefe --- /dev/null +++ b/conf/catalog/catsim.yaml @@ -0,0 +1,2 @@ +_target_: btk.catalog.CatsimCatalog.from_file +catalog_files: ${paths.data}/sample_input_catalog.fits diff --git a/conf/catalog/cosmos.yaml b/conf/catalog/cosmos.yaml new file mode 100644 index 000000000..3d906d47f --- /dev/null +++ b/conf/catalog/cosmos.yaml @@ -0,0 +1,4 @@ +name: btk.catalog.CosmosCatalog.from_file +catalog_files: + - ${paths.data}/cosmos/real_galaxy_catalog_23.5_example.fits + - ${paths.data}/cosmos/real_galaxy_catalog_23.5_example_fits.fits diff --git a/conf/config.yaml b/conf/config.yaml index 0c31366a4..c1c5d4f7e 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -1,15 +1,12 @@ defaults: - _self_ + - catalog: catsim - sampling: default - draw_blends: catsim - surveys: - Rubin - override hydra/help: btk_help -catalog: - name: catsim - catalog_files: ${paths.data}/sample_input_catalog.fits - ### all arguments ### # system diff --git a/conf/draw_blends/catsim.yaml b/conf/draw_blends/catsim.yaml index 3ea676887..0b8636778 100644 --- a/conf/draw_blends/catsim.yaml +++ b/conf/draw_blends/catsim.yaml @@ -1,11 +1,13 @@ -name: "CatsimGenerator" -kwargs: - batch_size: ${batch_size} - stamp_size: ${stamp_size} - cpus: ${cpus} - verbose: ${verbose} - add_noise: ${add_noise} - shifts: Null - indexes: Null - channels_last: ${channels_last} - save_path: ${save_path} +_target_: btk.draw_blends.CatsimGenerator +catalog: ${catalog} +sampling_function: ${sampling} +surveys: Null +batch_size: ${batch_size} +stamp_size: ${stamp_size} +cpus: ${cpus} +verbose: ${verbose} +add_noise: ${add_noise} +shifts: Null +indexes: Null +channels_last: ${channels_last} +save_path: ${save_path} diff --git a/conf/draw_blends/cosmos.yaml b/conf/draw_blends/cosmos.yaml index a02645574..1cf026a2d 100644 --- a/conf/draw_blends/cosmos.yaml +++ b/conf/draw_blends/cosmos.yaml @@ -1,11 +1,13 @@ -name: "CosmosGenerator" -kwargs: - batch_size: ${batch_size} - stamp_size: ${stamp_size} - cpus: ${cpus} - verbose: ${verbose} - add_noise: ${add_noise} - shifts: Null - indexes: Null - channels_last: ${channels_last} - save_path: ${save_path} +_target_: btk.draw_blends.CosmosGenerator +catalog: ${catalog} +sampling_function: ${sampling} +surveys: Null +batch_size: ${batch_size} +stamp_size: ${stamp_size} +cpus: ${cpus} +verbose: ${verbose} +add_noise: ${add_noise} +shifts: Null +indexes: Null +channels_last: ${channels_last} +save_path: ${save_path} diff --git a/conf/draw_blends/galsim_hub.yaml b/conf/draw_blends/galsim_hub.yaml index 1d663f9bd..51fa1fa10 100644 --- a/conf/draw_blends/galsim_hub.yaml +++ b/conf/draw_blends/galsim_hub.yaml @@ -1,15 +1,15 @@ -name: "GalsimHubGenerator" -kwargs: - catalog: ${catalog} - surveys: Null - batch_size: ${batch_size} - stamp_size: ${stamp_size} - cpus: ${cpus} - verbose: ${verbose} - add_noise: ${add_noise} - shifts: Null - indexes: Null - channels_last: ${channels_last} - galsim_hub_model: ${galsim_hub_model} - param_names: ${param_names} - save_path: ${save_path} +_target_: btk.draw_blends.GalsimHubGenerator +catalog: ${catalog} +sampling: ${sampling} +surveys: Null +batch_size: ${batch_size} +stamp_size: ${stamp_size} +cpus: ${cpus} +verbose: ${verbose} +add_noise: ${add_noise} +shifts: Null +indexes: Null +channels_last: ${channels_last} +galsim_hub_model: ${galsim_hub_model} +param_names: ${param_names} +save_path: ${save_path} diff --git a/conf/hydra/help/btk_help.yaml b/conf/hydra/help/btk_help.yaml index 1e7dc22e2..383eee48b 100644 --- a/conf/hydra/help/btk_help.yaml +++ b/conf/hydra/help/btk_help.yaml @@ -27,7 +27,7 @@ template: |- Assuming that BTK has been pip installed, you can run btk from the command line like e.g. btk sampling=default draw_blends=catsim max_number=3 save_path=/directory/to/save/results cpus=1 - verbose=False surveys=[Rubin, HST] surveys.Rubin.airmass=1.1 + verbose=False surveys=[Rubin,HST] surveys.Rubin.airmass=1.1 sampling=default catalog.name=catsim use_metrics=['detection', 'segmentation'] (...) You need to create the directory to save results yourself (preferably an empty directory) and specify its absolute path when you run the CLI via the `save_path` parameter. @@ -39,9 +39,9 @@ template: |- {default, galsim_hub, group_sampling, group_sampling_numbered} * catalog: Attribute group consisting of two sub-attributes. - * catalog.name: Name of the BTK catalog class, options: {catsim, cosmos} + * catalog: Name of the BTK catalog class, options: {catsim, cosmos} * catalog.catalog_files: Path to files containing catalog information. The 'catsim' - catalog requires one path, while the `cosmos` type requires two paths specified as + catalog requires one path, while the 'cosmos' type requires two paths specified as a list. (see documentation for more details). * surveys: Name of the survey(s) you want to use, options are @@ -49,7 +49,7 @@ template: |- in conf/surveys. You can pass in a list of surveys for multi-resolution studies too. For example: - btk surveys=[Rubin, HST] (...) + btk surveys=[Rubin,HST] (...) Assuming that you want to use e.g. the Rubin survey default parameters but with a couple of changes, you can modify individual parameters of a given survey directly from the @@ -83,14 +83,14 @@ template: |- * galsim_hub: Attribute group consisting of options: * model: Which galsim_hub model to use (default: 'hub:Lanusse2020') * param_names: list of the parameters with which the generation is parametrized; this - is unique to each model (default: `['flux_radious, 'mag_auto', 'zphot']`). + is unique to each model (default: `['flux_radius', 'mag_auto', 'zphot']`). - NOTE: This flag is only used when `draw_blends=galsim_hub` + NOTE: This flag has no effect unless `draw_blends=galsim_hub` * measure_kwargs: Dictionary or list of dictionaries containing the keyword arguments to be passed in to each measure_function. - * measure_functions: List of measure_functions to be ran, options {'basic', 'sep'}. + * measure_functions: List of measure_functions to be ran, options {basic, sep}. * use_metrics: List of metrics to return, options are: {'detection', 'segmentation', 'reconstruction'} diff --git a/conf/sampling/default.yaml b/conf/sampling/default.yaml index bf5413229..f300d0400 100644 --- a/conf/sampling/default.yaml +++ b/conf/sampling/default.yaml @@ -1,5 +1,4 @@ -name: "DefaultSampling" -kwargs: - max_number: ${max_number} - stamp_size: ${stamp_size} - maxshift: ${max_shift} +_target_: btk.sampling_functions.DefaultSampling +max_number: ${max_number} +stamp_size: ${stamp_size} +maxshift: ${max_shift} diff --git a/conf/sampling/galsim_hub.yaml b/conf/sampling/galsim_hub.yaml index 135dfdd6a..18ba1e05a 100644 --- a/conf/sampling/galsim_hub.yaml +++ b/conf/sampling/galsim_hub.yaml @@ -1,5 +1,4 @@ -name: "DefaultSamplingGalsimHub" -kwargs: - max_number: ${max_number} - stamp_size: ${stamp_size} - maxshift: ${max_shift} +_target_: btk.sampling_functions.DefaultSamplingGalsimHub +max_number: ${max_number} +stamp_size: ${stamp_size} +maxshift: ${max_shift} diff --git a/conf/sampling/group_sampling.yaml b/conf/sampling/group_sampling.yaml index 97781b1ba..54bec9c85 100644 --- a/conf/sampling/group_sampling.yaml +++ b/conf/sampling/group_sampling.yaml @@ -1,7 +1,6 @@ -name: "GroupSamplingFunction" -kwargs: - max_number: ${max_number} - stamp_size: ${stamp_size} - maxshift: ${max_shift} - wld_catalog_name: ${paths.data}/sample_group_catalog.fits - shift: Null +_target_: btk.sampling_functions.GroupSampling +max_number: ${max_number} +stamp_size: ${stamp_size} +maxshift: ${max_shift} +wld_catalog_name: ${paths.data}/sample_group_catalog.fits +shift: Null diff --git a/conf/sampling/group_sampling_numbered.yaml b/conf/sampling/group_sampling_numbered.yaml index dc4ff548c..882dac7ba 100644 --- a/conf/sampling/group_sampling_numbered.yaml +++ b/conf/sampling/group_sampling_numbered.yaml @@ -1,7 +1,6 @@ -name: "GroupSamplingFunctionNumbered" -kwargs: - max_number: ${max_number} - stamp_size: ${stamp_size} - maxshift: ${max_shift} - wld_catalog_name: ${paths.data}/sample_group_catalog.fits - shift: Null +_target_: btk.sampling_functions.GroupSamplingNumbered +max_number: ${max_number} +stamp_size: ${stamp_size} +maxshift: ${max_shift} +wld_catalog_name: ${paths.data}/sample_group_catalog.fits +shift: Null diff --git a/tests/test_main.py b/tests/test_main.py index f11c672d1..62d9987c2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -21,8 +21,8 @@ def test_main(): # test survey CLI cfg = get_cfg(overrides={"surveys": "Rubin"}) main(cfg) - cfg = get_cfg(overrides={"surveys": ["Rubin", "HST"]}) - # TODO: Do end to end with multiple surveys once MR measure function implemented. + cfg = get_cfg(overrides={"surveys": ["Rubin", "DES"], "meas_band_num": [0, 0]}) + main(cfg) def test_CLI(): @@ -32,17 +32,7 @@ def test_CLI(): def test_errors(): - cfg = get_cfg(overrides={"catalog.name": "MyCatalog"}) - with pytest.raises(ValueError) as excinfo: - main(cfg) - assert "not implemented" in str(excinfo.value) - - cfg = get_cfg(overrides={"sampling.name": "MySampling"}) - with pytest.raises(ValueError) as excinfo: - main(cfg) - assert "not implemented" in str(excinfo.value) - - cfg = get_cfg(overrides={"draw_blends.name": "MyDrawBlends"}) + cfg = get_cfg(overrides={"measure.measure_functions": ["NotExistantMeasureFunction"]}) with pytest.raises(ValueError) as excinfo: main(cfg) assert "not implemented" in str(excinfo.value)