diff --git a/btk/__init__.py b/btk/__init__.py index 877444d84..059d2cad8 100644 --- a/btk/__init__.py +++ b/btk/__init__.py @@ -17,3 +17,4 @@ from . import plot_utils from . import sampling_functions from . import survey +from . import utils diff --git a/btk/draw_blends.py b/btk/draw_blends.py index 8fa63d9dd..0668d640e 100644 --- a/btk/draw_blends.py +++ b/btk/draw_blends.py @@ -1,5 +1,6 @@ """Module for generating batches of drawn blended images.""" import copy +import os from abc import ABC from abc import abstractmethod from itertools import chain @@ -148,6 +149,7 @@ def __init__( shifts=None, indexes=None, channels_last=False, + save_path=None, ): """Initializes the DrawBlendsGenerator class. @@ -170,6 +172,8 @@ def __init__( channels_last (bool): Whether to return images as numpy arrays with the channel (band) dimension as the last dimension or before the pixels dimensions (default). + save_path (str): Path to a directory where results will be saved. If left + as None, results will not be saved. """ self.blend_generator = BlendGenerator( catalog, sampling_function, batch_size, shifts, indexes, verbose @@ -191,8 +195,8 @@ def __init__( self.add_noise = add_noise self.verbose = verbose - self.channels_last = channels_last + self.save_path = save_path def __iter__(self): """Returns iterable which is the object itself.""" @@ -274,6 +278,20 @@ def __next__(self): blend_images[s.name][i] = batch_results[i][0] isolated_images[s.name][i] = batch_results[i][1] batch_blend_cat[s.name].append(batch_results[i][2]) + + if self.save_path is not None: + if not os.path.exists(os.path.join(self.save_path, s.name)): + os.mkdir(os.path.join(self.save_path, s.name)) + + np.save(os.path.join(self.save_path, s.name, "blended"), blend_images[s.name]) + np.save(os.path.join(self.save_path, s.name, "isolated"), isolated_images[s.name]) + for i in range(len(batch_results)): + batch_blend_cat[s.name][i].write( + os.path.join(self.save_path, s.name, f"blend_info_{i}"), + format="ascii", + overwrite=True, + ) + if len(self.surveys) > 1: output = { "blend_images": blend_images, @@ -291,6 +309,7 @@ def __next__(self): "psf": psfs[survey_name], "wcs": wcss[survey_name], } + return output def render_mini_batch(self, blend_list, psf, wcs, survey, extra_data=None): diff --git a/btk/measure.py b/btk/measure.py index 3dbd8f168..832c89aae 100644 --- a/btk/measure.py +++ b/btk/measure.py @@ -48,6 +48,7 @@ def measure_function(batch, idx, **kwargs): Omitted keys in the returned dictionary are automatically assigned a `None` value (except for `catalog` which is a mandatory entry). """ +import os from itertools import repeat import astropy.table @@ -154,6 +155,7 @@ def __init__( cpus=1, verbose=False, measure_kwargs: list = None, + save_path=None, ): """Initialize measurement generator. @@ -168,6 +170,8 @@ def __init__( to be passed in to each of the `measure_functions`. Each dictionnary is passed one time to each function, meaning that each function which be ran as many times as there are different dictionnaries. + save_path (str): Path to a directory where results will be saved. If left + as None, results will not be saved. """ # setup and verify measure_functions. if callable(measure_functions): @@ -188,6 +192,7 @@ def __init__( self.batch_size = self.draw_blend_generator.batch_size self.channels_last = self.draw_blend_generator.channels_last self.verbose = verbose + self.save_path = save_path # initialize measure_kwargs dictionary. self.measure_kwargs = [{}] if measure_kwargs is None else measure_kwargs @@ -281,5 +286,19 @@ def __next__(self): measure_output[j][i][key] for j in range(len(measure_output)) ] measure_results[f.__name__ + str(m)] = measure_dict + if self.save_path is not None: + dir_name = f.__name__ + str(m) + if not os.path.exists(os.path.join(self.save_path, dir_name)): + os.mkdir(os.path.join(self.save_path, dir_name)) + + for key in ["segmentation", "deblended_images"]: + if key in measure_dict.keys(): + np.save(os.path.join(self.save_path, dir_name, key), measure_dict[key]) + for j, cat in enumerate(measure_dict["catalog"]): + cat.write( + os.path.join(self.save_path, dir_name, f"detection_catalog_{j}"), + format="ascii", + overwrite=True, + ) return blend_output, measure_results diff --git a/btk/metrics.py b/btk/metrics.py index 1e9f6ddea..03ac527ff 100644 --- a/btk/metrics.py +++ b/btk/metrics.py @@ -53,6 +53,8 @@ is the standard scalar product on vectors. """ +import os + import astropy.table import galsim import matplotlib.pyplot as plt @@ -163,7 +165,6 @@ def get_detection_match(true_table, detected_table, f_distance=distance_center): raise KeyError("Detection table has no column y_peak") match_table = astropy.table.Table() - print(f_distance) # dist[i][j] = distance between true object i and detected object j. dist = np.zeros((len(true_table), len(detected_table))) for i, true_gal in enumerate(true_table): @@ -531,6 +532,7 @@ def compute_metrics( # noqa: C901 meas_band_num=0, target_meas={}, channels_last=False, + save_path=None, f_distance=distance_center, ): """Computes all requested metrics given information in a single batch from measure_generator. @@ -565,6 +567,8 @@ def compute_metrics( # noqa: C901 be returned for both isolated and deblended images to compare. channels_last (bool) : Indicates whether the images should be channels first (NCHW) or channels last (NHWC). + save_path (str): Path to directory where results will be saved. If left + as None, results will not be saved. f_distance (func): Function used to compute the distance between true and detected galaxies. Takes as arguments the entries corresponding to the two galaxies. By default the distance is the euclidean distance from center to center. @@ -651,6 +655,13 @@ def compute_metrics( # noqa: C901 for k in reconstruction_keys: row[k] = results["reconstruction"][k][i][j] results["galaxy_summary"].add_row(row[0]) + if save_path is not None: + if not os.path.exists(save_path): + os.mkdir(save_path) + + for key in use_metrics: + np.save(os.path.join(save_path, f"{key}_metric"), results[key]) + results["galaxy_summary"].write(os.path.join(save_path, "galaxy_summary"), format="ascii") return results @@ -665,6 +676,7 @@ def __init__( meas_band_num=0, target_meas={}, noise_threshold_factor=3, + save_path=None, f_distance=distance_center, ): """Initialize metrics generator. @@ -684,6 +696,8 @@ def __init__( applied when getting segmentations from true images. A value of 3 would correspond to a threshold of 3 sigmas (with sigma the standard deviation of the noise) + save_path (str): Path to directory where results will be saved. If left + as None, results will not be saved. f_distance (func): Function used to compute the distance between true and detected galaxies. Takes as arguments the entries corresponding to the two galaxies. By default the distance is the euclidean distance from center to center. @@ -693,6 +707,7 @@ def __init__( self.meas_band_num = meas_band_num self.target_meas = target_meas self.noise_threshold_factor = noise_threshold_factor + self.save_path = save_path self.f_distance = f_distance def __next__(self): @@ -725,6 +740,9 @@ def __next__(self): self.meas_band_num, target_meas, channels_last=self.measure_generator.channels_last, + save_path=os.path.join(self.save_path, meas_func) + if self.save_path is not None + else None, f_distance=self.f_distance, ) metrics_results[meas_func] = metrics_results_f diff --git a/btk/utils.py b/btk/utils.py new file mode 100644 index 000000000..f20f2cf99 --- /dev/null +++ b/btk/utils.py @@ -0,0 +1,133 @@ +"""Contains utility functions, including functions for loading saved results.""" +import os + +import numpy as np +from astropy.table import Table + +BLEND_RESULT_KEYS = ("blend_images", "isolated_images", "blend_list") + + +def load_blend_results(path, survey): + """Load results exported from a DrawBlendsGenerator. + + Args; + path (str): Path to the files. Should be the same as the save_path + which was provided to the DrawBlendsGenerator to save + the files. + survey (str): Name of the survey for which you want to load the files. + + Returns: + Dictionnary containing the blend images, the isolated images and the + informations about the blends. + """ + blend_images = np.load(os.path.join(path, survey, "blended.npy"), allow_pickle=True) + isolated_images = np.load(os.path.join(path, survey, "isolated.npy"), allow_pickle=True) + blend_list = [ + Table.read(os.path.join(path, survey, f"blend_info_{i}"), format="ascii") + for i in range(blend_images.shape[0]) + ] + + return { + "blend_images": blend_images, + "isolated_images": isolated_images, + "blend_list": blend_list, + } + + +def load_measure_results(path, measure_name, n_batch): + """Load results exported from a MeasureGenerator. + + Args: + path (str): Path to the files. Should be the same as the save_path + which was provided to the MeasureGenerator to save + the files. + measure_name (str): Name of the measure function for which you + want to load the files + n_batch (int): Number of blends in the batch you want to load + + Returns: + Dictionnary containing the detection catalogs, the segmentations + and the deblended images. + """ + measure_results = {} + for key in ["segmentation", "deblended_images"]: + try: + measure_results[key] = np.load( + os.path.join(path, measure_name, f"{key}.npy"), allow_pickle=True + ) + except FileNotFoundError: + print(f"No {key} found.") + + catalog = [ + Table.read( + os.path.join(path, measure_name, f"detection_catalog_{j}"), + format="ascii", + ) + for j in range(n_batch) + ] + measure_results["catalog"] = catalog + return measure_results + + +def load_metrics_results(path, measure_name): + """Load results exported from a MetricsGenerator. + + Args: + path (str): Path to the files. Should be the same as the save_path + which was provided to the MetricsGenerator to save + the files. + measure_name (str): Name of the measure function for which you + want to load the files + + Returns: + Dictionnary containing the detection catalogs, the segmentations + and the deblended images. + """ + metrics_results = {} + for key in ["detection", "segmentation", "reconstruction"]: + try: + metrics_results[key] = np.load( + os.path.join(path, measure_name, f"{key}_metric.npy"), allow_pickle=True + ) + except FileNotFoundError: + print(f"No {key} metrics found.") + + metrics_results["galaxy_summary"] = Table.read( + os.path.join(path, measure_name, "galaxy_summary"), + format="ascii", + ) + return metrics_results + + +def load_all_results(path, surveys, measure_names, n_batch, n_meas_kwargs=1): + """Load results exported from a MetricsGenerator. + + Args: + path (str): Path to the files. Should be the same as the save_path + which was provided to the MetricsGenerator to save + the files. + surveys (list): Names of the surveys for which you want to load + the files + measure_names (list): Names of the measure functions for which you + want to load the files + n_batch (int): Number of blends in the batch you want to load + + Returns: + The three dictionnaries corresponding to the results. + """ + blend_results = {} + for key in BLEND_RESULT_KEYS: + blend_results[key] = {} + measure_results = {} + metrics_results = {} + for s in surveys: + blend_results_temp = load_blend_results(path, s) + for key in BLEND_RESULT_KEYS: + blend_results[key][s] = blend_results_temp[key] + + for meas in measure_names: + for n in range(n_meas_kwargs): + measure_results[meas + str(n)] = load_measure_results(path, meas + str(n), n_batch) + metrics_results[meas + str(n)] = load_metrics_results(path, meas + str(n)) + + return blend_results, measure_results, metrics_results diff --git a/tests/test_save.py b/tests/test_save.py new file mode 100644 index 000000000..3992066c4 --- /dev/null +++ b/tests/test_save.py @@ -0,0 +1,51 @@ +import tempfile + +import numpy as np + +import btk +from btk.survey import Rubin + + +def test_save(): + output_dir = tempfile.mkdtemp() + catalog_name = "data/sample_input_catalog.fits" + stamp_size = 24.0 + batch_size = 8 + catalog = btk.catalog.CatsimCatalog.from_file(catalog_name) + sampling_function = btk.sampling_functions.DefaultSampling(stamp_size=stamp_size) + draw_blend_generator = btk.draw_blends.CatsimGenerator( + catalog, + sampling_function, + [Rubin], + batch_size=batch_size, + stamp_size=stamp_size, + save_path=output_dir, + ) + meas_generator = btk.measure.MeasureGenerator( + btk.measure.sep_measure, draw_blend_generator, save_path=output_dir + ) + metrics_generator = btk.metrics.MetricsGenerator( + meas_generator, + use_metrics=("detection", "segmentation", "reconstruction"), + target_meas={"ellipticity": btk.metrics.meas_ksb_ellipticity}, + save_path=output_dir, + ) + blend_results, measure_results, metrics_results = next(metrics_generator) + blend_results2, measure_results2, metrics_results2 = btk.utils.load_all_results( + output_dir, ["LSST"], ["sep_measure"], batch_size + ) + np.testing.assert_array_equal( + blend_results["blend_images"], blend_results2["blend_images"]["LSST"] + ) + np.testing.assert_array_equal( + measure_results["sep_measure0"]["segmentation"][0], + measure_results2["sep_measure0"]["segmentation"][0], + ) + np.testing.assert_array_equal( + measure_results["sep_measure0"]["deblended_images"][0], + measure_results2["sep_measure0"]["deblended_images"][0], + ) + np.testing.assert_array_equal( + metrics_results["sep_measure0"]["galaxy_summary"]["distance_closest_galaxy"], + metrics_results2["sep_measure0"]["galaxy_summary"]["distance_closest_galaxy"], + )