Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for hdf5 saving (issue 442) #448

Merged
merged 6 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 109 additions & 82 deletions btk/blend_batch.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Class which stores all relevant data for blends."""
import json
import os
import pickle
from dataclasses import dataclass
from typing import List, Optional, Union

import galsim
import h5py
import numpy as np
from astropy.io.misc.hdf5 import read_table_hdf5, write_table_hdf5
from astropy.table import Table

from btk.survey import Survey, make_wcs
from btk.survey import Survey, get_surveys, make_wcs


@dataclass
Expand Down Expand Up @@ -68,70 +68,79 @@
return string

def save(self, path: str, batch_number: int = 0):
"""Save the batch to disk.
"""Save the batch to disk using hdf5 format.

Args:
path (str): Path to save the batch to.
batch_number (int): Number of the batch.
"""
os.makedirs(path, exist_ok=True)
np.save(os.path.join(path, f"blend_images_{batch_number}.npy"), self.blend_images)
np.save(os.path.join(path, f"isolated_images_{batch_number}.npy"), self.isolated_images)

with open(os.path.join(path, f"psf_{batch_number}.pickle"), "wb") as f:
pickle.dump(self.psf, f)

with open(os.path.join(path, f"catalog_list_{batch_number}.pickle"), "wb") as f:
pickle.dump(self.catalog_list, f)

# save general info about blend
with open(os.path.join(path, "blend.json"), "w", encoding="utf-8") as f:
json.dump(
{
"batch_size": self.batch_size,
"max_n_sources": self.max_n_sources,
"stamp_size": self.stamp_size,
"survey_name": self.survey.name,
},
f,
)
fpath = os.path.join(path, f"blend_{batch_number}.hdf5")

with h5py.File(fpath, "w") as f:
# save blend and isolated images
f.create_dataset("blend_images", data=self.blend_images)
f.create_dataset("isolated_images", data=self.isolated_images)

# save psfs
# first convert psfs to numpy array
psf_array = self.get_numpy_psf()
f.create_dataset("psf", data=psf_array)

# save catalog using astropy functions
# (this is faster than saving as numpy array)
for ii, catalog in enumerate(self.catalog_list):
write_table_hdf5(catalog, f, path=f"catalog_list/{ii}")

# save general info about blend
f.attrs["batch_size"] = self.batch_size
f.attrs["max_n_sources"] = self.max_n_sources
f.attrs["stamp_size"] = self.stamp_size
f.attrs["survey_name"] = self.survey.name

@classmethod
def load(cls, path: str, batch_number: int = 0):
"""Load the batch from disk.
"""Load the batch from hdf5 format.

Args:
path (str): Path to load the batch from.
batch_number (int): Number of the batch.
"""
# load general infrom about blend
with open(os.path.join(path, "blend.json"), "r", encoding="utf-8") as f:
blend_info = json.load(f)
batch_size = blend_info["batch_size"]
max_n_sources = blend_info["max_n_sources"]
stamp_size = blend_info["stamp_size"]
survey_name = blend_info["survey_name"]
# file path
fpath = os.path.join(path, f"blend_{batch_number}.hdf5")

# open file
with h5py.File(fpath, "r") as f:
# load blend and isolated images
blend_images = f["blend_images"][:]
isolated_images = f["isolated_images"][:]

blend_images = np.load(os.path.join(path, f"blend_images_{batch_number}.npy"))
isolated_images = np.load(os.path.join(path, f"isolated_images_{batch_number}.npy"))
# load psfs
psf_list = [galsim.Image(psf) for psf in f["psf"][:]]

# load psfs
with open(os.path.join(path, f"psf_{batch_number}.pickle"), "rb") as f:
psf = pickle.load(f)
# load catalog
catalog_list = []
for ii in range(f.attrs["batch_size"]):
catalog_list.append(read_table_hdf5(f, path=f"catalog_list/{ii}"))

# load catalog
with open(os.path.join(path, f"catalog_list_{batch_number}.pickle"), "rb") as f:
catalog_list = pickle.load(f)
# load general info about blend
batch_size = f.attrs["batch_size"]
max_n_sources = f.attrs["max_n_sources"]
stamp_size = f.attrs["stamp_size"]
survey_name = f.attrs["survey_name"]

# create survey
survey = get_surveys(survey_name)

# create class
return cls(
batch_size,
max_n_sources,
stamp_size,
survey_name,
blend_images,
isolated_images,
catalog_list,
psf,
batch_size=batch_size,
max_n_sources=max_n_sources,
stamp_size=stamp_size,
survey=survey,
blend_images=blend_images,
isolated_images=isolated_images,
catalog_list=catalog_list,
psf=psf_list,
)


Expand Down Expand Up @@ -388,46 +397,64 @@
return string

def save(self, path: str, batch_number: int = 0):
"""Save batch of measure results to disk."""
save_dir = os.path.join(path, str(batch_number))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
"""Save batch of measure results to disk in hdf5 format."""
fpath = os.path.join(path, f"deblend_{batch_number}.hdf5")
with h5py.File(fpath, "w") as f:
# save catalog with astropy hdf5 functions
for ii, catalog in enumerate(self.catalog_list):
write_table_hdf5(catalog, f, path=f"catalog_list/{ii}")

# save segmentation
if self.segmentation is not None:
np.save(os.path.join(save_dir, "segmentation"), self.segmentation)
f.create_dataset("segmentation", data=self.segmentation)

# save deblended images
if self.deblended_images is not None:
np.save(os.path.join(save_dir, "deblended_images"), self.deblended_images)
with open(os.path.join(save_dir, "catalog_list.pickle"), "wb") as f:
pickle.dump(self.catalog_list, f)

# save general info about class
with open(os.path.join(path, "meas.json"), "w", encoding="utf-8") as f:
json.dump(
{
"batch_size": self.batch_size,
"max_n_sources": self.max_n_sources,
"image_size": self.image_size,
"n_bands": self.n_bands,
},
f,
)
f.create_dataset("deblended_images", data=self.deblended_images)

# save general info about class
f.attrs["batch_size"] = self.batch_size
f.attrs["max_n_sources"] = self.max_n_sources
f.attrs["image_size"] = self.image_size
f.attrs["n_bands"] = self.n_bands

@classmethod
def load(cls, path: str, batch_number: int = 0):
"""Load batch of measure results from disk."""
load_dir = os.path.join(path, str(batch_number))
with open(os.path.join(path, "meas.json"), "r", encoding="utf-8") as f:
meas_config = json.load(f)

with open(os.path.join(load_dir, "catalog_list.pickle"), "rb") as f:
catalog_list = pickle.load(f)
segmentation, deblended_images = None, None
if os.path.exists(os.path.join(load_dir, "segmentation.npy")):
segmentation = np.load(os.path.join(load_dir, "segmentation.npy"))
if os.path.exists(os.path.join(load_dir, "deblended_images.npy")):
deblended_images = np.load(os.path.join(load_dir, "deblended_images.npy"))
"""Load batch of measure results from hdf5 file in disk."""
fpath = os.path.join(path, f"deblend_{batch_number}.hdf5")

# open file
with h5py.File(fpath, "r") as f:
# load catalog with astropy hdf5 functions
catalog_list = []
for ii in range(f.attrs["batch_size"]):
catalog_list.append(read_table_hdf5(f, path=f"catalog_list/{ii}"))

# load segmentation
if "segmentation" in f.keys():
segmentation = f["segmentation"][:]
else:
segmentation = None

Check warning on line 437 in btk/blend_batch.py

View check run for this annotation

Codecov / codecov/patch

btk/blend_batch.py#L437

Added line #L437 was not covered by tests

# load deblended images
if "deblended_images" in f.keys():
deblended_images = f["deblended_images"][:]
else:
deblended_images = None

Check warning on line 443 in btk/blend_batch.py

View check run for this annotation

Codecov / codecov/patch

btk/blend_batch.py#L443

Added line #L443 was not covered by tests

# load general info about blend
batch_size = f.attrs["batch_size"]
max_n_sources = f.attrs["max_n_sources"]
image_size = f.attrs["image_size"]
n_bands = f.attrs["n_bands"]

# create class
return cls(
batch_size=batch_size,
max_n_sources=max_n_sources,
catalog_list=catalog_list,
n_bands=n_bands,
image_size=image_size,
segmentation=segmentation,
deblended_images=deblended_images,
**meas_config,
)
35 changes: 34 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ sep = ">=1.2.1"
galsim = ">=2.4.9"
python = "^3.8.1,<3.12"
pre-commit = "^3.3.3"
h5py = "^3.9.0"

[tool.poetry.dev-dependencies]
black = ">=23.3.0"
Expand Down
16 changes: 16 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test pipeline as a whole at a high-level."""

import tempfile

import numpy as np

import btk
Expand Down Expand Up @@ -98,6 +100,20 @@ def test_pipeline(data_dir):
iso_images_matched2 = matching.match_pred_arrays(iso_images2)
mse(iso_images_matched1, iso_images_matched2)

# test saving
with tempfile.TemporaryDirectory() as tmpdirname:
blend_batch.save(tmpdirname, 0)
blend_batch2 = btk.blend_batch.BlendBatch.load(tmpdirname, 0)

deblend_batch.save(tmpdirname, 0)
deblend_batch2 = btk.blend_batch.DeblendBatch.load(tmpdirname, 0)

assert blend_batch.batch_size == blend_batch2.batch_size
assert blend_batch.stamp_size == blend_batch2.stamp_size

assert deblend_batch.batch_size == deblend_batch2.batch_size
assert deblend_batch.image_size == deblend_batch2.image_size


def test_sep(data_dir):
"""Check we always detect single bright objects."""
Expand Down
Loading