diff --git a/btk/blend_batch.py b/btk/blend_batch.py index 4524b6c3..54827de4 100644 --- a/btk/blend_batch.py +++ b/btk/blend_batch.py @@ -1,15 +1,15 @@ """Class which stores all relevant data for blends.""" -import json import os -import pickle from dataclasses import dataclass from typing import List, Optional, Union import galsim +import h5py import numpy as np +from astropy.io.misc.hdf5 import read_table_hdf5, write_table_hdf5 from astropy.table import Table -from btk.survey import Survey, make_wcs +from btk.survey import Survey, get_surveys, make_wcs @dataclass @@ -68,70 +68,79 @@ def __repr__(self) -> str: return string def save(self, path: str, batch_number: int = 0): - """Save the batch to disk. + """Save the batch to disk using hdf5 format. Args: path (str): Path to save the batch to. batch_number (int): Number of the batch. """ - os.makedirs(path, exist_ok=True) - np.save(os.path.join(path, f"blend_images_{batch_number}.npy"), self.blend_images) - np.save(os.path.join(path, f"isolated_images_{batch_number}.npy"), self.isolated_images) - - with open(os.path.join(path, f"psf_{batch_number}.pickle"), "wb") as f: - pickle.dump(self.psf, f) - - with open(os.path.join(path, f"catalog_list_{batch_number}.pickle"), "wb") as f: - pickle.dump(self.catalog_list, f) - - # save general info about blend - with open(os.path.join(path, "blend.json"), "w", encoding="utf-8") as f: - json.dump( - { - "batch_size": self.batch_size, - "max_n_sources": self.max_n_sources, - "stamp_size": self.stamp_size, - "survey_name": self.survey.name, - }, - f, - ) + fpath = os.path.join(path, f"blend_{batch_number}.hdf5") + + with h5py.File(fpath, "w") as f: + # save blend and isolated images + f.create_dataset("blend_images", data=self.blend_images) + f.create_dataset("isolated_images", data=self.isolated_images) + + # save psfs + # first convert psfs to numpy array + psf_array = self.get_numpy_psf() + f.create_dataset("psf", data=psf_array) + + # save catalog using astropy functions + # (this is faster than saving as numpy array) + for ii, catalog in enumerate(self.catalog_list): + write_table_hdf5(catalog, f, path=f"catalog_list/{ii}") + + # save general info about blend + f.attrs["batch_size"] = self.batch_size + f.attrs["max_n_sources"] = self.max_n_sources + f.attrs["stamp_size"] = self.stamp_size + f.attrs["survey_name"] = self.survey.name @classmethod def load(cls, path: str, batch_number: int = 0): - """Load the batch from disk. + """Load the batch from hdf5 format. Args: path (str): Path to load the batch from. batch_number (int): Number of the batch. """ - # load general infrom about blend - with open(os.path.join(path, "blend.json"), "r", encoding="utf-8") as f: - blend_info = json.load(f) - batch_size = blend_info["batch_size"] - max_n_sources = blend_info["max_n_sources"] - stamp_size = blend_info["stamp_size"] - survey_name = blend_info["survey_name"] + # file path + fpath = os.path.join(path, f"blend_{batch_number}.hdf5") + + # open file + with h5py.File(fpath, "r") as f: + # load blend and isolated images + blend_images = f["blend_images"][:] + isolated_images = f["isolated_images"][:] - blend_images = np.load(os.path.join(path, f"blend_images_{batch_number}.npy")) - isolated_images = np.load(os.path.join(path, f"isolated_images_{batch_number}.npy")) + # load psfs + psf_list = [galsim.Image(psf) for psf in f["psf"][:]] - # load psfs - with open(os.path.join(path, f"psf_{batch_number}.pickle"), "rb") as f: - psf = pickle.load(f) + # load catalog + catalog_list = [] + for ii in range(f.attrs["batch_size"]): + catalog_list.append(read_table_hdf5(f, path=f"catalog_list/{ii}")) - # load catalog - with open(os.path.join(path, f"catalog_list_{batch_number}.pickle"), "rb") as f: - catalog_list = pickle.load(f) + # load general info about blend + batch_size = f.attrs["batch_size"] + max_n_sources = f.attrs["max_n_sources"] + stamp_size = f.attrs["stamp_size"] + survey_name = f.attrs["survey_name"] + # create survey + survey = get_surveys(survey_name) + + # create class return cls( - batch_size, - max_n_sources, - stamp_size, - survey_name, - blend_images, - isolated_images, - catalog_list, - psf, + batch_size=batch_size, + max_n_sources=max_n_sources, + stamp_size=stamp_size, + survey=survey, + blend_images=blend_images, + isolated_images=isolated_images, + catalog_list=catalog_list, + psf=psf_list, ) @@ -388,46 +397,64 @@ def __repr__(self) -> str: return string def save(self, path: str, batch_number: int = 0): - """Save batch of measure results to disk.""" - save_dir = os.path.join(path, str(batch_number)) - if not os.path.exists(save_dir): - os.makedirs(save_dir) + """Save batch of measure results to disk in hdf5 format.""" + fpath = os.path.join(path, f"deblend_{batch_number}.hdf5") + with h5py.File(fpath, "w") as f: + # save catalog with astropy hdf5 functions + for ii, catalog in enumerate(self.catalog_list): + write_table_hdf5(catalog, f, path=f"catalog_list/{ii}") + + # save segmentation if self.segmentation is not None: - np.save(os.path.join(save_dir, "segmentation"), self.segmentation) + f.create_dataset("segmentation", data=self.segmentation) + + # save deblended images if self.deblended_images is not None: - np.save(os.path.join(save_dir, "deblended_images"), self.deblended_images) - with open(os.path.join(save_dir, "catalog_list.pickle"), "wb") as f: - pickle.dump(self.catalog_list, f) - - # save general info about class - with open(os.path.join(path, "meas.json"), "w", encoding="utf-8") as f: - json.dump( - { - "batch_size": self.batch_size, - "max_n_sources": self.max_n_sources, - "image_size": self.image_size, - "n_bands": self.n_bands, - }, - f, - ) + f.create_dataset("deblended_images", data=self.deblended_images) + + # save general info about class + f.attrs["batch_size"] = self.batch_size + f.attrs["max_n_sources"] = self.max_n_sources + f.attrs["image_size"] = self.image_size + f.attrs["n_bands"] = self.n_bands @classmethod def load(cls, path: str, batch_number: int = 0): - """Load batch of measure results from disk.""" - load_dir = os.path.join(path, str(batch_number)) - with open(os.path.join(path, "meas.json"), "r", encoding="utf-8") as f: - meas_config = json.load(f) - - with open(os.path.join(load_dir, "catalog_list.pickle"), "rb") as f: - catalog_list = pickle.load(f) - segmentation, deblended_images = None, None - if os.path.exists(os.path.join(load_dir, "segmentation.npy")): - segmentation = np.load(os.path.join(load_dir, "segmentation.npy")) - if os.path.exists(os.path.join(load_dir, "deblended_images.npy")): - deblended_images = np.load(os.path.join(load_dir, "deblended_images.npy")) + """Load batch of measure results from hdf5 file in disk.""" + fpath = os.path.join(path, f"deblend_{batch_number}.hdf5") + + # open file + with h5py.File(fpath, "r") as f: + # load catalog with astropy hdf5 functions + catalog_list = [] + for ii in range(f.attrs["batch_size"]): + catalog_list.append(read_table_hdf5(f, path=f"catalog_list/{ii}")) + + # load segmentation + if "segmentation" in f.keys(): + segmentation = f["segmentation"][:] + else: + segmentation = None + + # load deblended images + if "deblended_images" in f.keys(): + deblended_images = f["deblended_images"][:] + else: + deblended_images = None + + # load general info about blend + batch_size = f.attrs["batch_size"] + max_n_sources = f.attrs["max_n_sources"] + image_size = f.attrs["image_size"] + n_bands = f.attrs["n_bands"] + + # create class return cls( + batch_size=batch_size, + max_n_sources=max_n_sources, catalog_list=catalog_list, + n_bands=n_bands, + image_size=image_size, segmentation=segmentation, deblended_images=deblended_images, - **meas_config, ) diff --git a/poetry.lock b/poetry.lock index 3264a525..e6a2679e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1023,6 +1023,39 @@ numpy = ">=1.17" pybind11 = ">=2.2" setuptools = ">=38" +[[package]] +name = "h5py" +version = "3.9.0" +description = "Read and write HDF5 files from Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "h5py-3.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eb7bdd5e601dd1739698af383be03f3dad0465fe67184ebd5afca770f50df9d6"}, + {file = "h5py-3.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:78e44686334cbbf2dd21d9df15823bc38663f27a3061f6a032c68a3e30c47bf7"}, + {file = "h5py-3.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f68b41efd110ce9af1cbe6fa8af9f4dcbadace6db972d30828b911949e28fadd"}, + {file = "h5py-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:12aa556d540f11a2cae53ea7cfb94017353bd271fb3962e1296b342f6550d1b8"}, + {file = "h5py-3.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:d97409e17915798029e297a84124705c8080da901307ea58f29234e09b073ddc"}, + {file = "h5py-3.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:551e358db05a874a0f827b22e95b30092f2303edc4b91bb62ad2f10e0236e1a0"}, + {file = "h5py-3.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6822a814b9d8b8363ff102f76ea8d026f0ca25850bb579d85376029ee3e73b93"}, + {file = "h5py-3.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54f01202cdea754ab4227dd27014bdbd561a4bbe4b631424fd812f7c2ce9c6ac"}, + {file = "h5py-3.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64acceaf6aff92af091a4b83f6dee3cf8d3061f924a6bb3a33eb6c4658a8348b"}, + {file = "h5py-3.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:804c7fb42a34c8ab3a3001901c977a5c24d2e9c586a0f3e7c0a389130b4276fc"}, + {file = "h5py-3.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8d9492391ff5c3c80ec30ae2fe82a3f0efd1e750833739c25b0d090e3be1b095"}, + {file = "h5py-3.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9da9e7e63376c32704e37ad4cea2dceae6964cee0d8515185b3ab9cbd6b947bc"}, + {file = "h5py-3.9.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4e20897c88759cbcbd38fb45b507adc91af3e0f67722aa302d71f02dd44d286"}, + {file = "h5py-3.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbf5225543ca35ce9f61c950b73899a82be7ba60d58340e76d0bd42bf659235a"}, + {file = "h5py-3.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:36408f8c62f50007d14e000f9f3acf77e103b9e932c114cbe52a3089e50ebf94"}, + {file = "h5py-3.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:23e74b878bbe1653ab34ca49b83cac85529cd0b36b9d625516c5830cc5ca2eac"}, + {file = "h5py-3.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f457089c5d524b7998e3649bc63240679b8fb0a3859ea53bbb06841f3d755f1"}, + {file = "h5py-3.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6284061f3214335e1eec883a6ee497dbe7a79f19e6a57fed2dd1f03acd5a8cb"}, + {file = "h5py-3.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7a745efd0d56076999b52e8da5fad5d30823bac98b59c68ae75588d09991a"}, + {file = "h5py-3.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:79bbca34696c6f9eeeb36a91776070c49a060b2879828e2c8fa6c58b8ed10dd1"}, + {file = "h5py-3.9.0.tar.gz", hash = "sha256:e604db6521c1e367c6bd7fad239c847f53cc46646f2d2651372d05ae5e95f817"}, +] + +[package.dependencies] +numpy = ">=1.17.3" + [[package]] name = "identify" version = "2.5.26" @@ -4022,4 +4055,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1,<3.12" -content-hash = "8c503c7f593e4ba2aae3be859ee2adbe34f43c1fca1c03ba18b5a4f19f31b9c4" +content-hash = "c24334d1dcdb16024c97e25ddfb9c8c2bb9adcf1c72b8b9fa67311e1b121ba49" diff --git a/pyproject.toml b/pyproject.toml index 22a08da0..6cc6c484 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ sep = ">=1.2.1" galsim = ">=2.4.9" python = "^3.8.1,<3.12" pre-commit = "^3.3.3" +h5py = "^3.9.0" [tool.poetry.dev-dependencies] black = ">=23.3.0" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e4add883..63810aec 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,5 +1,7 @@ """Test pipeline as a whole at a high-level.""" +import tempfile + import numpy as np import btk @@ -98,6 +100,20 @@ def test_pipeline(data_dir): iso_images_matched2 = matching.match_pred_arrays(iso_images2) mse(iso_images_matched1, iso_images_matched2) + # test saving + with tempfile.TemporaryDirectory() as tmpdirname: + blend_batch.save(tmpdirname, 0) + blend_batch2 = btk.blend_batch.BlendBatch.load(tmpdirname, 0) + + deblend_batch.save(tmpdirname, 0) + deblend_batch2 = btk.blend_batch.DeblendBatch.load(tmpdirname, 0) + + assert blend_batch.batch_size == blend_batch2.batch_size + assert blend_batch.stamp_size == blend_batch2.stamp_size + + assert deblend_batch.batch_size == deblend_batch2.batch_size + assert deblend_batch.image_size == deblend_batch2.image_size + def test_sep(data_dir): """Check we always detect single bright objects."""