Skip to content

Commit

Permalink
Simplify CLI with hydra instantiate (#230)
Browse files Browse the repository at this point in the history
* update poetry

* more error checking in draw_blends for MR

* bug in this line when using basic measure function and MR

* catch error related to meas_band_num in metrics

* use hydra instantiate and call rather than available_*

* update documentation

* test MR on main, remove error checking now done by hydra

* Revert "update poetry"

This reverts commit aad74ea.
  • Loading branch information
ismael-mendoza authored Oct 15, 2021
1 parent 4061998 commit 223650d
Show file tree
Hide file tree
Showing 17 changed files with 109 additions and 136 deletions.
20 changes: 10 additions & 10 deletions btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from abc import ABC
from abc import abstractmethod
from collections.abc import Iterable
from itertools import chain

import galsim
Expand Down Expand Up @@ -169,14 +170,20 @@ def __init__(
if isinstance(surveys, Survey):
self.surveys = [surveys]
self.check_compatibility(surveys)
elif isinstance(surveys, list):
elif isinstance(surveys, Iterable):
for s in surveys:
if not isinstance(s, Survey):
raise TypeError("surveys must be a Survey object or a list of Survey objects.")
raise TypeError(
f"surveys must be a Survey object or an Iterable of Survey objects, but"
f"Iterable contained object of type {type(s)}"
)
self.check_compatibility(s)
self.surveys = surveys
else:
raise TypeError("surveys must be a Survey object or a list of Survey objects.")
raise TypeError(
f"surveys must be a Survey object or an Iterable of Survey objects,"
f"but surveys is type {type(surveys)}"
)
self.is_multiresolution = len(self.surveys) > 1

self.stamp_size = stamp_size
Expand Down Expand Up @@ -643,10 +650,3 @@ def render_single(self, entry, filt, psf, survey, extra_data):
nx=pix_stamp_size, ny=pix_stamp_size, scale=survey.pixel_scale, dtype=np.float64
)
return galaxy_image


available_draw_blends = {
"CatsimGenerator": CatsimGenerator,
"CosmosGenerator": CosmosGenerator,
"GalsimHubGenerator": GalsimHubGenerator,
}
29 changes: 3 additions & 26 deletions btk/main.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,21 @@
"""Implements main function to run BTK end-to-end."""
from collections.abc import Iterable

from hydra.utils import instantiate
from omegaconf import OmegaConf

from btk.catalog import available_catalogs
from btk.draw_blends import available_draw_blends
from btk.measure import available_measure_functions
from btk.measure import MeasureGenerator
from btk.metrics import MetricsGenerator
from btk.sampling_functions import available_sampling_functions
from btk.survey import _get_survey_from_cfg


def main(cfg: OmegaConf):
"""Run BTK from end-to-end using a hydra configuration object."""
# get catalog
if cfg.catalog.name not in available_catalogs:
raise ValueError(f"Catalog '{cfg.catalog.name}' is not implemented in BTK.")
catalog = available_catalogs[cfg.catalog.name].from_file(cfg.catalog.catalog_files)

# get sampling function
if cfg.sampling.name not in available_sampling_functions:
raise ValueError(f"Sampling function '{cfg.sampling.name}' is not implemented in BTK.")
sampling_function = available_sampling_functions[cfg.sampling.name](**cfg.sampling.kwargs)

# get survey(s) to be used.
if not isinstance(cfg.surveys, Iterable):
cfg.surveys = [cfg.surveys]

surveys = []
for survey_name in cfg.surveys:
survey = _get_survey_from_cfg(cfg.surveys[survey_name])
surveys.append(survey)
surveys = [_get_survey_from_cfg(cfg.surveys[survey_name]) for survey_name in cfg.surveys]

# get draw blends generator.
if cfg.draw_blends.name not in available_draw_blends:
raise ValueError("DrawBlendGenerator specified is not implemented in BTK.")
draw_blend_generator = available_draw_blends[cfg.draw_blends.name](
catalog, sampling_function, surveys, **cfg.draw_blends.kwargs
)
draw_blend_generator = instantiate(cfg.draw_blends, surveys=surveys)

# get measure_functions.
measure_functions = []
Expand Down
1 change: 0 additions & 1 deletion btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def basic_measure(
if is_multiresolution:
if surveys is None:
raise ValueError("surveys are required in order to use the MR feature.")
surveys = kwargs.get("surveys", None)
survey_name = surveys[0].name
coadd = np.mean(batch["blend_images"][survey_name][idx], axis=channel_indx)
wcs = batch["wcs"][survey_name]
Expand Down
24 changes: 18 additions & 6 deletions btk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"""
import os
from collections.abc import Iterable

import astropy.table
import galsim
Expand Down Expand Up @@ -730,20 +731,31 @@ def __next__(self):
"""Returns metric results calculated on one batch."""
blend_results, measure_results = next(self.measure_generator)
surveys = self.measure_generator.draw_blend_generator.surveys
meas_band_num = self.meas_band_num

metrics_results = {}
for meas_func in measure_results["catalog"].keys():
if self.is_multiresolution:
if not isinstance(meas_band_num, Iterable) and hasattr(meas_band_num, "__len__"):
raise ValueError(
f"meas_band_num is required to be an finite length iterable,"
f"instead type was {type(meas_band_num)}"
)
if not len(meas_band_num) == len(surveys):
raise ValueError(
f"meas_band_num should be an iterable of exactly the same len as the"
f"number of surveys: {len(surveys)}, instead len is {len(meas_band_num)}"
)
metrics_results_f = {}
for i, surv in enumerate(blend_results["isolated_images"].keys()):
additional_params = {
"psf": blend_results["psf"][surv],
"pixel_scale": surveys[i].pixel_scale,
"meas_band_num": self.meas_band_num[i],
"meas_band_num": meas_band_num[i],
"verbose": self.verbose,
}
noise_threshold = self.noise_threshold_factor * np.sqrt(
get_mean_sky_level(surveys[i], surveys[i].filters[self.meas_band_num[i]])
get_mean_sky_level(surveys[i], surveys[i].filters[meas_band_num[i]])
)
target_meas = {}
for k in self.target_meas.keys():
Expand All @@ -756,7 +768,7 @@ def __next__(self):
measure_results["segmentation"][meas_func][surv],
measure_results["deblended_images"][meas_func][surv],
noise_threshold,
self.meas_band_num[i],
meas_band_num[i],
target_meas,
channels_last=self.measure_generator.channels_last,
save_path=os.path.join(self.save_path, meas_func, surv)
Expand All @@ -771,11 +783,11 @@ def __next__(self):
additional_params = {
"psf": blend_results["psf"],
"pixel_scale": surveys[0].pixel_scale,
"meas_band_num": self.meas_band_num,
"meas_band_num": meas_band_num,
"verbose": self.verbose,
}
noise_threshold = self.noise_threshold_factor * np.sqrt(
get_mean_sky_level(surveys[0], surveys[0].filters[self.meas_band_num])
get_mean_sky_level(surveys[0], surveys[0].filters[meas_band_num])
)
target_meas = {}
for k in self.target_meas.keys():
Expand All @@ -788,7 +800,7 @@ def __next__(self):
measure_results["segmentation"][meas_func],
measure_results["deblended_images"][meas_func],
noise_threshold,
self.meas_band_num,
meas_band_num,
target_meas,
channels_last=self.measure_generator.channels_last,
save_path=os.path.join(self.save_path, meas_func, surveys[0].name)
Expand Down
8 changes: 0 additions & 8 deletions btk/sampling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,3 @@ def __call__(self, table, **kwargs):
)
assert len(no_boundary) <= self.max_number, message
return no_boundary


available_sampling_functions = {
"DefaultSampling": DefaultSampling,
"DefaultSamplingGalsimHub": DefaultSamplingGalsimHub,
"GroupSampling": GroupSampling,
"GroupSamplingNumbered": GroupSamplingNumbered,
}
2 changes: 2 additions & 0 deletions conf/catalog/catsim.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: btk.catalog.CatsimCatalog.from_file
catalog_files: ${paths.data}/sample_input_catalog.fits
4 changes: 4 additions & 0 deletions conf/catalog/cosmos.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: btk.catalog.CosmosCatalog.from_file
catalog_files:
- ${paths.data}/cosmos/real_galaxy_catalog_23.5_example.fits
- ${paths.data}/cosmos/real_galaxy_catalog_23.5_example_fits.fits
5 changes: 1 addition & 4 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
defaults:
- _self_
- catalog: catsim
- sampling: default
- draw_blends: catsim
- surveys:
- Rubin
- override hydra/help: btk_help

catalog:
name: catsim
catalog_files: ${paths.data}/sample_input_catalog.fits

### all arguments ###

# system
Expand Down
24 changes: 13 additions & 11 deletions conf/draw_blends/catsim.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
name: "CatsimGenerator"
kwargs:
batch_size: ${batch_size}
stamp_size: ${stamp_size}
cpus: ${cpus}
verbose: ${verbose}
add_noise: ${add_noise}
shifts: Null
indexes: Null
channels_last: ${channels_last}
save_path: ${save_path}
_target_: btk.draw_blends.CatsimGenerator
catalog: ${catalog}
sampling_function: ${sampling}
surveys: Null
batch_size: ${batch_size}
stamp_size: ${stamp_size}
cpus: ${cpus}
verbose: ${verbose}
add_noise: ${add_noise}
shifts: Null
indexes: Null
channels_last: ${channels_last}
save_path: ${save_path}
24 changes: 13 additions & 11 deletions conf/draw_blends/cosmos.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
name: "CosmosGenerator"
kwargs:
batch_size: ${batch_size}
stamp_size: ${stamp_size}
cpus: ${cpus}
verbose: ${verbose}
add_noise: ${add_noise}
shifts: Null
indexes: Null
channels_last: ${channels_last}
save_path: ${save_path}
_target_: btk.draw_blends.CosmosGenerator
catalog: ${catalog}
sampling_function: ${sampling}
surveys: Null
batch_size: ${batch_size}
stamp_size: ${stamp_size}
cpus: ${cpus}
verbose: ${verbose}
add_noise: ${add_noise}
shifts: Null
indexes: Null
channels_last: ${channels_last}
save_path: ${save_path}
30 changes: 15 additions & 15 deletions conf/draw_blends/galsim_hub.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
name: "GalsimHubGenerator"
kwargs:
catalog: ${catalog}
surveys: Null
batch_size: ${batch_size}
stamp_size: ${stamp_size}
cpus: ${cpus}
verbose: ${verbose}
add_noise: ${add_noise}
shifts: Null
indexes: Null
channels_last: ${channels_last}
galsim_hub_model: ${galsim_hub_model}
param_names: ${param_names}
save_path: ${save_path}
_target_: btk.draw_blends.GalsimHubGenerator
catalog: ${catalog}
sampling: ${sampling}
surveys: Null
batch_size: ${batch_size}
stamp_size: ${stamp_size}
cpus: ${cpus}
verbose: ${verbose}
add_noise: ${add_noise}
shifts: Null
indexes: Null
channels_last: ${channels_last}
galsim_hub_model: ${galsim_hub_model}
param_names: ${param_names}
save_path: ${save_path}
14 changes: 7 additions & 7 deletions conf/hydra/help/btk_help.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ template: |-
Assuming that BTK has been pip installed, you can run btk from the command line like e.g.
btk sampling=default draw_blends=catsim max_number=3 save_path=/directory/to/save/results cpus=1
verbose=False surveys=[Rubin, HST] surveys.Rubin.airmass=1.1
verbose=False surveys=[Rubin,HST] surveys.Rubin.airmass=1.1
sampling=default catalog.name=catsim use_metrics=['detection', 'segmentation'] (...)
You need to create the directory to save results yourself (preferably an empty directory) and specify its absolute path when you run the CLI via the `save_path` parameter.
Expand All @@ -39,17 +39,17 @@ template: |-
{default, galsim_hub, group_sampling, group_sampling_numbered}
* catalog: Attribute group consisting of two sub-attributes.
* catalog.name: Name of the BTK catalog class, options: {catsim, cosmos}
* catalog: Name of the BTK catalog class, options: {catsim, cosmos}
* catalog.catalog_files: Path to files containing catalog information. The 'catsim'
catalog requires one path, while the `cosmos` type requires two paths specified as
catalog requires one path, while the 'cosmos' type requires two paths specified as
a list. (see documentation for more details).
* surveys: Name of the survey(s) you want to use, options are
{Rubin, HST, HSC, DES, CFHT, Euclid} and correspond to each of the config files available
in conf/surveys. You can pass in a list of surveys for multi-resolution
studies too. For example:
btk surveys=[Rubin, HST] (...)
btk surveys=[Rubin,HST] (...)
Assuming that you want to use e.g. the Rubin survey default parameters but with a couple
of changes, you can modify individual parameters of a given survey directly from the
Expand Down Expand Up @@ -83,14 +83,14 @@ template: |-
* galsim_hub: Attribute group consisting of options:
* model: Which galsim_hub model to use (default: 'hub:Lanusse2020')
* param_names: list of the parameters with which the generation is parametrized; this
is unique to each model (default: `['flux_radious, 'mag_auto', 'zphot']`).
is unique to each model (default: `['flux_radius', 'mag_auto', 'zphot']`).
NOTE: This flag is only used when `draw_blends=galsim_hub`
NOTE: This flag has no effect unless `draw_blends=galsim_hub`
* measure_kwargs: Dictionary or list of dictionaries containing the keyword arguments to be
passed in to each measure_function.
* measure_functions: List of measure_functions to be ran, options {'basic', 'sep'}.
* measure_functions: List of measure_functions to be ran, options {basic, sep}.
* use_metrics: List of metrics to return, options are:
{'detection', 'segmentation', 'reconstruction'}
Expand Down
9 changes: 4 additions & 5 deletions conf/sampling/default.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name: "DefaultSampling"
kwargs:
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
_target_: btk.sampling_functions.DefaultSampling
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
9 changes: 4 additions & 5 deletions conf/sampling/galsim_hub.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name: "DefaultSamplingGalsimHub"
kwargs:
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
_target_: btk.sampling_functions.DefaultSamplingGalsimHub
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
13 changes: 6 additions & 7 deletions conf/sampling/group_sampling.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: "GroupSamplingFunction"
kwargs:
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
wld_catalog_name: ${paths.data}/sample_group_catalog.fits
shift: Null
_target_: btk.sampling_functions.GroupSampling
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
wld_catalog_name: ${paths.data}/sample_group_catalog.fits
shift: Null
13 changes: 6 additions & 7 deletions conf/sampling/group_sampling_numbered.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: "GroupSamplingFunctionNumbered"
kwargs:
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
wld_catalog_name: ${paths.data}/sample_group_catalog.fits
shift: Null
_target_: btk.sampling_functions.GroupSamplingNumbered
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
wld_catalog_name: ${paths.data}/sample_group_catalog.fits
shift: Null
Loading

0 comments on commit 223650d

Please sign in to comment.