diff --git a/btk/draw_blends.py b/btk/draw_blends.py index b4cb86188..3952449bf 100644 --- a/btk/draw_blends.py +++ b/btk/draw_blends.py @@ -165,10 +165,12 @@ def __init__( if isinstance(surveys, Survey): self.surveys = [surveys] + self.check_compatibility(surveys) elif isinstance(surveys, list): for s in surveys: if not isinstance(s, Survey): raise TypeError("surveys must be a Survey object or a list of Survey objects.") + self.check_compatibility(s) self.surveys = surveys else: raise TypeError("surveys must be a Survey object or a list of Survey objects.") @@ -180,6 +182,13 @@ def __init__( self.channels_last = channels_last self.save_path = save_path + def check_compatibility(self, survey): + """Checks that the compatibility between the survey, the catalog and the generator. + + This should be implemented in subclasses. + """ + pass + def __iter__(self): """Returns iterable which is the object itself.""" return self @@ -431,6 +440,24 @@ class CatsimGenerator(DrawBlendsGenerator): compatible_catalogs = ("CatsimCatalog",) + def check_compatibility(self, survey): + """Checks the compatibility between the catalog and a given survey. + + Args: + survey (btk.survey.Survey): Survey to check + """ + if type(self.catalog).__name__ not in self.compatible_catalogs: + raise ValueError( + f"The catalog provided is of the wrong type. The types of " + f"catalogs available for the {type(self).__name__} are {self.compatible_catalogs}" + ) + for f in survey.filters: + if f.name + "_ab" not in self.catalog.table.keys(): + raise ValueError( + f"The {f.name} filter of the survey {survey.name} " + f"has no associated magnitude in the given catalog." + ) + def render_single(self, entry, filt, psf, survey, extra_data): """Returns the Galsim Image of an isolated galaxy.""" if self.verbose: @@ -456,6 +483,26 @@ class CosmosGenerator(DrawBlendsGenerator): compatible_catalogs = ("CosmosCatalog",) + def check_compatibility(self, survey): + """Checks the compatibility between the catalog and a given survey. + + Args: + survey (btk.survey.Survey): Survey to check + """ + if type(self.catalog).__name__ not in self.compatible_catalogs: + raise ValueError( + f"The catalog provided is of the wrong type. The types of " + f"catalogs available for the {type(self).__name__} are {self.compatible_catalogs}" + ) + if "ref_mag" not in self.catalog.table.keys(): + for f in survey.filters: + if f"{survey.name}_{f.name}" not in self.catalog.table.keys(): + raise ValueError( + f"The {f.name} filter of the survey {survey.name} " + f"has no associated magnitude in the given catalog, " + f"and the catalog does not contain a 'ref_mag' column" + ) + def render_single(self, entry, filt, psf, survey, extra_data): """Returns the Galsim Image of an isolated galaxy.""" galsim_catalog = self.catalog.get_galsim_catalog() @@ -489,6 +536,18 @@ class GalsimHubGenerator(DrawBlendsGenerator): compatible_catalogs = ("CosmosCatalog",) + def check_compatibility(self, survey): + """Checks the compatibility between the catalog and a given survey. + + Args: + survey (btk.survey.Survey): Survey to check + """ + if type(self.catalog).__name__ not in self.compatible_catalogs: + raise ValueError( + f"The catalog provided is of the wrong type. The types of " + f"catalogs available for the {type(self).__name__} are {self.compatible_catalogs}" + ) + def __init__( self, catalog, diff --git a/tests/test_error_cases.py b/tests/test_error_cases.py index 03bc0ae23..444519186 100644 --- a/tests/test_error_cases.py +++ b/tests/test_error_cases.py @@ -2,7 +2,9 @@ from conftest import data_dir from btk.catalog import CatsimCatalog +from btk.catalog import CosmosCatalog from btk.draw_blends import CatsimGenerator +from btk.draw_blends import CosmosGenerator from btk.draw_blends import get_catsim_galaxy from btk.draw_blends import SourceNotVisible from btk.sampling_functions import DefaultSampling @@ -14,6 +16,16 @@ CATALOG_PATH = data_dir / "sample_input_catalog.fits" +COSMOS_CATALOG_PATHS = [ + str(data_dir / "cosmos/real_galaxy_catalog_23.5_example.fits"), + str(data_dir / "cosmos/real_galaxy_catalog_23.5_example_fits.fits"), +] + +COSMOS_EXT_CATALOG_PATHS = [ + str(data_dir / "cosmos/real_galaxy_catalog_26_extension_example.fits"), + str(data_dir / "cosmos/real_galaxy_catalog_26_extension_example_fits.fits"), +] + def test_sampling_no_max_number(): class TestSamplingFunction(SamplingFunction): @@ -207,3 +219,47 @@ def test_psf(): get_psf_from_file("tests/example_psf", get_surveys("Rubin")) get_psf_from_file("tests/multi_psf", get_surveys("Rubin")) # The case where the folder is empty cannot be tested as you cannot add an empty folder to git + + +def test_incompatible_catalogs(): + stamp_size = 24.0 + batch_size = 8 + cpus = 1 + add_noise = True + + catalog = CatsimCatalog.from_file(CATALOG_PATH) + sampling_function = DefaultSampling(stamp_size=stamp_size) + with pytest.raises(ValueError): + # Wrong generator + draw_generator = CosmosGenerator( # noqa: F841 + catalog, + sampling_function, + get_surveys("Rubin"), + stamp_size=stamp_size, + batch_size=batch_size, + cpus=cpus, + add_noise=add_noise, + ) + with pytest.raises(ValueError): + # Missing filter + draw_generator = CatsimGenerator( # noqa: F841 + catalog, + sampling_function, + get_surveys("HST"), + stamp_size=stamp_size, + batch_size=batch_size, + cpus=cpus, + add_noise=add_noise, + ) + + catalog = CosmosCatalog.from_file(COSMOS_CATALOG_PATHS) + with pytest.raises(ValueError): + draw_generator = CatsimGenerator( # noqa: F841 + catalog, + sampling_function, + get_surveys("Rubin"), + stamp_size=stamp_size, + batch_size=batch_size, + cpus=cpus, + add_noise=add_noise, + )