Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added possibility to save results #148

Merged
merged 18 commits into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions btk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from . import plot_utils
from . import sampling_functions
from . import survey
from . import utils
21 changes: 20 additions & 1 deletion btk/draw_blends.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module for generating batches of drawn blended images."""
import copy
import os
from abc import ABC
from abc import abstractmethod
from itertools import chain
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
shifts=None,
indexes=None,
channels_last=False,
save_path=None,
):
"""Initializes the DrawBlendsGenerator class.

Expand All @@ -170,6 +172,8 @@ def __init__(
channels_last (bool): Whether to return images as numpy arrays with the channel
(band) dimension as the last dimension or before the pixels
dimensions (default).
save_path (str): Path to save the results, ending by the file name root. If left
as None, results will not be saved.
thuiop marked this conversation as resolved.
Show resolved Hide resolved
"""
self.blend_generator = BlendGenerator(
catalog, sampling_function, batch_size, shifts, indexes, verbose
Expand All @@ -191,8 +195,8 @@ def __init__(

self.add_noise = add_noise
self.verbose = verbose

self.channels_last = channels_last
self.save_path = save_path

def __iter__(self):
"""Returns iterable which is the object itself."""
Expand Down Expand Up @@ -274,6 +278,20 @@ def __next__(self):
blend_images[s.name][i] = batch_results[i][0]
isolated_images[s.name][i] = batch_results[i][1]
batch_blend_cat[s.name].append(batch_results[i][2])

if self.save_path is not None:
if not os.path.exists(os.path.join(self.save_path, s.name)):
os.mkdir(os.path.join(self.save_path, s.name))
thuiop marked this conversation as resolved.
Show resolved Hide resolved

np.save(os.path.join(self.save_path, s.name, "blended"), blend_images[s.name])
np.save(os.path.join(self.save_path, s.name, "isolated"), isolated_images[s.name])
for i in range(len(batch_results)):
batch_blend_cat[s.name][i].write(
os.path.join(self.save_path, s.name, f"blend_info_{i}"),
format="ascii",
overwrite=True,
)

if len(self.surveys) > 1:
output = {
"blend_images": blend_images,
Expand All @@ -291,6 +309,7 @@ def __next__(self):
"psf": psfs[survey_name],
"wcs": wcss[survey_name],
}

return output

def render_mini_batch(self, blend_list, psf, wcs, survey, extra_data=None):
Expand Down
20 changes: 19 additions & 1 deletion btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def measure_function(batch, idx, **kwargs):
Omitted keys in the returned dictionary are automatically assigned a `None` value (except for
`catalog` which is a mandatory entry).
"""
import os
from itertools import repeat

import astropy.table
Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
cpus=1,
verbose=False,
measure_kwargs: dict = None,
save_path=None,
):
"""Initialize measurement generator.

Expand All @@ -163,6 +165,8 @@ def __init__(
verbose (bool): Whether to print information about measurement.
measure_kwargs (dict): Dictionary containing keyword arguments to be passed
in to each of the `measure_functions`.
save_path (str): Path to save the results, ending by the file name root. If left
thuiop marked this conversation as resolved.
Show resolved Hide resolved
ton None, results will not be saved.
thuiop marked this conversation as resolved.
Show resolved Hide resolved
"""
# setup and verify measure_functions.
if callable(measure_functions):
Expand All @@ -183,6 +187,7 @@ def __init__(
self.batch_size = self.draw_blend_generator.batch_size
self.channels_last = self.draw_blend_generator.channels_last
self.verbose = verbose
self.save_path = save_path

# initialize measure_kwargs dictionary.
self.measure_kwargs = {} if measure_kwargs is None else measure_kwargs
Expand Down Expand Up @@ -268,11 +273,24 @@ def __next__(self):
measure_results = {}
for i, f in enumerate(self.measure_functions):
measure_dic = {}
thuiop marked this conversation as resolved.
Show resolved Hide resolved
for key in ["catalog", "deblended_images", "segmentation"]:
for key in self.measure_params:
if measure_output[0][i][key] is not None:
measure_dic[key] = [
measure_output[j][i][key] for j in range(len(measure_output))
]
measure_results[f.__name__] = measure_dic
if self.save_path is not None:
if not os.path.exists(os.path.join(self.save_path, f.__name__)):
os.mkdir(os.path.join(self.save_path, f.__name__))

for key in ["segmentation", "deblended_images"]:
if key in measure_dic.keys():
np.save(os.path.join(self.save_path, f.__name__, key), measure_dic[key])
thuiop marked this conversation as resolved.
Show resolved Hide resolved
for j, cat in enumerate(measure_dic["catalog"]):
cat.write(
os.path.join(self.save_path, f.__name__, f"detection_catalog_{j}"),
format="ascii",
overwrite=True,
)

return blend_output, measure_results
19 changes: 19 additions & 0 deletions btk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
is the standard scalar product on vectors.

"""
import os

import astropy.table
import galsim
import numpy as np
Expand Down Expand Up @@ -509,6 +511,7 @@ def compute_metrics( # noqa: C901
meas_band_num=0,
target_meas={},
channels_last=False,
save_path=None,
):
"""Computes all requested metrics given information in a single batch from measure_generator.

Expand Down Expand Up @@ -542,6 +545,8 @@ def compute_metrics( # noqa: C901
be returned for both isolated and deblended images to compare.
channels_last (bool) : Indicates whether the images should be channels first (NCHW)
or channels last (NHWC).
save_path (str): Path to save the results, ending by the file name root. If left
ton None, results will not be saved.
thuiop marked this conversation as resolved.
Show resolved Hide resolved

Returns:
results (dict) : Contains all the computed metrics. Entries are :
Expand Down Expand Up @@ -626,6 +631,13 @@ def compute_metrics( # noqa: C901
for k in reconstruction_keys:
row[k] = results["reconstruction"][k][i][j]
results["galaxy_summary"].add_row(row[0])
if save_path is not None:
if not os.path.exists(save_path):
os.mkdir(save_path)

for key in use_metrics:
np.save(os.path.join(save_path, f"{key}_metric"), results[key])
results["galaxy_summary"].write(os.path.join(save_path, "galaxy_summary"), format="ascii")

return results

Expand All @@ -640,6 +652,7 @@ def __init__(
meas_band_num=0,
target_meas={},
noise_threshold_factor=3,
save_path=None,
):
"""Initialize metrics generator.

Expand All @@ -658,12 +671,15 @@ def __init__(
applied when getting segmentations from true images. A value of 3 would
correspond to a threshold of 3 sigmas (with sigma the standard deviation of
the noise)
save_path (str): Path to save the results, ending by the file name root. If left
ton None, results will not be saved.
thuiop marked this conversation as resolved.
Show resolved Hide resolved
"""
self.measure_generator: MeasureGenerator = measure_generator
self.use_metrics = use_metrics
self.meas_band_num = meas_band_num
self.target_meas = target_meas
self.noise_threshold_factor = noise_threshold_factor
self.save_path = save_path

def __next__(self):
"""Returns metric results calculated on one batch."""
Expand Down Expand Up @@ -695,6 +711,9 @@ def __next__(self):
self.meas_band_num,
target_meas,
channels_last=self.measure_generator.channels_last,
save_path=os.path.join(self.save_path, meas_func)
if self.save_path is not None
else None,
)
metrics_results[meas_func] = metrics_results_f

Expand Down
132 changes: 132 additions & 0 deletions btk/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Contains utility functions, including functions for loading saved results."""
thuiop marked this conversation as resolved.
Show resolved Hide resolved
import os

import numpy as np
from astropy.table import Table

BLEND_RESULT_KEYS = ("blend_images", "isolated_images", "blend_list")


def load_blend_results(path, survey):
"""Load results exported from a DrawBlendsGenerator.

Args;
path (str): Path to the files. Should be the same as the save_path
which was provided to the DrawBlendsGenerator to save
the files.
survey (str): Name of the survey for which you want to load the files.

Returns:
Dictionnary containing the blend images, the isolated images and the
informations about the blends.
"""
blend_images = np.load(os.path.join(path, survey, "blended.npy"), allow_pickle=True)
isolated_images = np.load(os.path.join(path, survey, "isolated.npy"), allow_pickle=True)
blend_list = [
Table.read(os.path.join(path, survey, f"blend_info_{i}"), format="ascii")
for i in range(blend_images.shape[0])
]

return {
"blend_images": blend_images,
"isolated_images": isolated_images,
"blend_list": blend_list,
}


def load_measure_results(path, measure_name, n_batch):
"""Load results exported from a MeasureGenerator.

Args:
path (str): Path to the files. Should be the same as the save_path
which was provided to the MeasureGenerator to save
the files.
measure_name (str): Name of the measure function for which you
want to load the files
n_batch (int): Number of blends in the batch you want to load

Returns:
Dictionnary containing the detection catalogs, the segmentations
and the deblended images.
"""
measure_results = {}
for key in ["segmentation", "deblended_images"]:
try:
measure_results[key] = np.load(
os.path.join(path, measure_name, f"{key}.npy"), allow_pickle=True
)
except FileNotFoundError:
print(f"No {key} found.")

catalog = [
Table.read(
os.path.join(path, measure_name, f"detection_catalog_{j}"),
format="ascii",
)
for j in range(n_batch)
]
measure_results["catalog"] = catalog
return measure_results


def load_metrics_results(path, measure_name):
"""Load results exported from a MetricsGenerator.

Args:
path (str): Path to the files. Should be the same as the save_path
which was provided to the MetricsGenerator to save
the files.
measure_name (str): Name of the measure function for which you
want to load the files

Returns:
Dictionnary containing the detection catalogs, the segmentations
and the deblended images.
"""
metrics_results = {}
for key in ["detection", "segmentation", "reconstruction"]:
try:
metrics_results[key] = np.load(
os.path.join(path, measure_name, f"{key}_metric.npy"), allow_pickle=True
)
except FileNotFoundError:
print(f"No {key} metrics found.")

metrics_results["galaxy_summary"] = Table.read(
os.path.join(path, measure_name, "galaxy_summary"),
format="ascii",
)
return metrics_results


def load_all_results(path, surveys, measure_names, n_batch):
"""Load results exported from a MetricsGenerator.

Args:
path (str): Path to the files. Should be the same as the save_path
which was provided to the MetricsGenerator to save
the files.
surveys (list): Names of the surveys for which you want to load
the files
measure_names (list): Names of the measure functions for which you
want to load the files
n_batch (int): Number of blends in the batch you want to load

Returns:
The three dictionnaries corresponding to the results.
"""
blend_results = {}
for key in BLEND_RESULT_KEYS:
blend_results[key] = {}
measure_results = {}
metrics_results = {}
for s in surveys:
blend_results_temp = load_blend_results(path, s)
for key in BLEND_RESULT_KEYS:
blend_results[key][s] = blend_results_temp[key]

for meas in measure_names:
measure_results[meas] = load_measure_results(path, meas, n_batch)
metrics_results[meas] = load_metrics_results(path, meas)

return blend_results, measure_results, metrics_results
51 changes: 51 additions & 0 deletions tests/test_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tempfile

import numpy as np

import btk
from btk.survey import Rubin


def test_save():
output_dir = tempfile.mkdtemp()
catalog_name = "data/sample_input_catalog.fits"
stamp_size = 24.0
batch_size = 8
catalog = btk.catalog.CatsimCatalog.from_file(catalog_name)
sampling_function = btk.sampling_functions.DefaultSampling(stamp_size=stamp_size)
draw_blend_generator = btk.draw_blends.CatsimGenerator(
catalog,
sampling_function,
[Rubin],
batch_size=batch_size,
stamp_size=stamp_size,
save_path=output_dir,
)
meas_generator = btk.measure.MeasureGenerator(
btk.measure.sep_measure, draw_blend_generator, save_path=output_dir
)
metrics_generator = btk.metrics.MetricsGenerator(
meas_generator,
use_metrics=("detection", "segmentation", "reconstruction"),
target_meas={"ellipticity": btk.metrics.meas_ksb_ellipticity},
save_path=output_dir,
)
blend_results, measure_results, metrics_results = next(metrics_generator)
blend_results2, measure_results2, metrics_results2 = btk.utils.load_all_results(
output_dir, ["LSST"], ["sep_measure"], batch_size
)
np.testing.assert_array_equal(
blend_results["blend_images"], blend_results2["blend_images"]["LSST"]
)
np.testing.assert_array_equal(
measure_results["sep_measure"]["segmentation"][0],
measure_results2["sep_measure"]["segmentation"][0],
)
np.testing.assert_array_equal(
measure_results["sep_measure"]["deblended_images"][0],
measure_results2["sep_measure"]["deblended_images"][0],
)
np.testing.assert_array_equal(
metrics_results["sep_measure"]["galaxy_summary"]["distance_closest_galaxy"],
metrics_results2["sep_measure"]["galaxy_summary"]["distance_closest_galaxy"],
)