diff --git a/btk/sampling_functions.py b/btk/sampling_functions.py index f350bddd..62b48276 100644 --- a/btk/sampling_functions.py +++ b/btk/sampling_functions.py @@ -4,7 +4,9 @@ import astropy import numpy as np +from astropy.coordinates import SkyCoord from astropy.table import Table +from fast3tree import find_friends_of_friends from btk.utils import DEFAULT_SEED @@ -429,3 +431,110 @@ def __call__(self, table: Table): blend_table["dec"] *= 3600 return blend_table + + +class FriendsOfFriendsSampling(SamplingFunction): + """Randomly selects galaxies using Friends Of Friends clustering algorithm. + + Friends of friends clustering algorithm assigns a group of object same index if one can + each member of the group is within link_distance of any other member in the group. + This sampling function explicitly uses the spatial information in the input catalog to + generate scenes of galaxies. However, blends might not always be returned as a result + if the provided `link_distance` is too small. + """ + + def __init__( + self, + max_number: int = 10, + min_number: int = 2, + link_distance: int = 2.5, + stamp_size: float = 24.0, + seed: int = DEFAULT_SEED, + min_mag: float = -np.inf, + max_mag: float = 25.3, + mag_name: str = "i_ab", + ): + """Initializes the FriendsOfFriendsSampling sampling function. + + Args: + max_number: Defined in parent class + min_number: Defined in parent class + link_distance: Minimum linkage distance to form a group (arcsec). + stamp_size: Size of the desired stamp (arcsec). + seed: Seed to initialize randomness for reproducibility. + min_mag: Minimum magnitude allowed in samples + max_mag: Maximum magnitude allowed in samples. + mag_name: Name of the magnitude column in the catalog. + """ + super().__init__(max_number=max_number, min_number=min_number, seed=seed) + self.stamp_size = stamp_size + self.link_distance = link_distance + self.max_number = max_number + self.min_number = min_number + self.max_mag = max_mag + self.min_mag = min_mag + self.mag_name = mag_name + + @staticmethod + def _arcsec2dist(sep, r=1.0): + """Converts arcsec separation to a cartesian distance.""" + return np.sin(np.deg2rad(sep / 3600.0 / 2.0)) * 2.0 * r + + def _precompute_friends_of_friends_index(self, table: Table): + """Computes galaxy groups using friends of friends algorithm on cartesian coordinates.""" + points = SkyCoord(table["ra"], table["dec"], unit=("deg", "deg")).cartesian.xyz.value.T + group_ids = find_friends_of_friends( + points, self._arcsec2dist(self.link_distance), reassign_group_indices=False + ) + table["friends_of_friends_id"] = group_ids + + def __call__(self, table: Table): + """Samples galaxies from input catalog to make scene. + + We assume the input catalog has `ra` and `dec` in degrees, like CATSIM does. + """ + # group galaxies by indices + if "friends_of_friends_id" not in table.colnames: + self._precompute_friends_of_friends_index(table) + + # filter by magnitude + if self.mag_name not in table.colnames: + raise ValueError(f"Catalog must have '{self.mag_name}' column.") + cond1 = table[self.mag_name] <= self.max_mag + cond2 = table[self.mag_name] > self.min_mag + cond = cond1 & cond2 + blend_table = table[cond] + + # filter by number of galaxies + num_galaxies = blend_table.to_pandas()["friends_of_friends_id"].value_counts() + cond1 = num_galaxies >= self.min_number + cond2 = num_galaxies <= self.max_number + if (cond1 & cond2).sum() == 0: + raise ValueError( + f"No groups with number of galaxies in [{self.min_number}, {self.max_number}], " + "found using a `link_distance` of {self.link_distance}" + ) + group_id = np.random.choice(num_galaxies[cond1 & cond2].index) + indices = blend_table["friends_of_friends_id"] == group_id + + # check if galaxies fit in the window + ra = blend_table["ra"] + dec = blend_table["dec"] + ra_out_of_bound = (ra[indices].max() - ra[indices].min()) * 3600 > self.stamp_size + dec_out_of_bounds = (dec[indices].max() - dec[indices].min()) * 3600 > self.stamp_size + if ra_out_of_bound or dec_out_of_bounds: + raise ValueError( + "sampled galaxies don't fit in the stamp size, increase the stamp size, " + "decrease `max_number` or `link_distance`." + ) + + # recenter catalog 'ra' and 'dec' to the center of the stamp + blend_table = blend_table[indices] + blend_table["ra"] = ra[indices] - ra[indices].mean() + blend_table["dec"] = dec[indices] - dec[indices].mean() + + # finally, convert to arcsec + blend_table["ra"] *= 3600 + blend_table["dec"] *= 3600 + + return blend_table diff --git a/poetry.lock b/poetry.lock index f499f043..dc19f214 100644 --- a/poetry.lock +++ b/poetry.lock @@ -850,6 +850,19 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "fast3tree" +version = "0.4.1" +description = "A Python wrapper of Peter Behroozi's fast3tree code." +optional = false +python-versions = "*" +files = [ + {file = "fast3tree-0.4.1.tar.gz", hash = "sha256:4d4abe9af5965949b551241d46edeaa278bfbe960dabdc8828caaac98d080b25"}, +] + +[package.dependencies] +numpy = ">=1.0" + [[package]] name = "fastjsonschema" version = "2.19.0" @@ -1143,20 +1156,20 @@ files = [ [[package]] name = "importlib-metadata" -version = "6.8.0" +version = "6.9.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, - {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, + {file = "importlib_metadata-6.9.0-py3-none-any.whl", hash = "sha256:1c8dc6839ddc9771412596926f24cb5a553bbd40624ee2c7e55e531542bed3b8"}, + {file = "importlib_metadata-6.9.0.tar.gz", hash = "sha256:e8acb523c335a91822674e149b46c0399ec4d328c4d1f6e49c273da5ff0201b9"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] @@ -3935,4 +3948,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1,<3.12" -content-hash = "c24334d1dcdb16024c97e25ddfb9c8c2bb9adcf1c72b8b9fa67311e1b121ba49" +content-hash = "713787cf6872bdfb418e56926822a64a5408b658a19fe252338d014728ba91ac" diff --git a/pyproject.toml b/pyproject.toml index 51ee82d6..fa464fe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ galsim = ">=2.4.9" python = "^3.8.1,<3.12" pre-commit = "^3.3.3" h5py = "^3.9.0" +fast3tree = "^0.4.1" [tool.poetry.dev-dependencies] black = ">=23.3.0"