diff --git a/btk/sampling_functions.py b/btk/sampling_functions.py index f350bddd..72c1e142 100644 --- a/btk/sampling_functions.py +++ b/btk/sampling_functions.py @@ -5,6 +5,8 @@ import astropy import numpy as np from astropy.table import Table +from astropy.coordinates import SkyCoord +from fast3tree import find_friends_of_friends from btk.utils import DEFAULT_SEED @@ -429,3 +431,105 @@ def __call__(self, table: Table): blend_table["dec"] *= 3600 return blend_table + + +class FriendsOfFriendsSampling(SamplingFunction): + """Randomly selects galaxies using Friends Of Friends clustering algorithm. + + Friends of friends clustering algorithm assigns a group of object same index if one can + each member of the group is within link_distance of any other member in the group. + This sampling function explicitly uses the spatial information in the input catalog to + generate scenes of galaxies. However, blends might not always be returned as a result. + """ + + def __init__( + self, + max_number: int = 10, + min_number: int = 2, + link_distance: int = 2.5, + stamp_size: float = 24.0, + seed: int = DEFAULT_SEED, + min_mag: float = -np.inf, + max_mag: float = 25.3, + mag_name: str = "i_ab", + ): + """Initializes the FriendsOfFriendsSampling sampling function. + + Args: + max_number: Defined in parent class + min_number: Defined in parent class + link_distance: Minimum linkage distance to form a group (arcsec). + stamp_size: Size of the desired stamp (arcsec). + seed: Seed to initialize randomness for reproducibility. + min_mag: Minimum magnitude allowed in samples + max_mag: Maximum magnitude allowed in samples. + mag_name: Name of the magnitude column in the catalog. + """ + super().__init__(max_number=max_number, min_number=min_number, seed=seed) + self.stamp_size = stamp_size + self.link_distance = link_distance + self.max_number = max_number + self.min_number = min_number + self.max_mag = max_mag + self.min_mag = min_mag + self.mag_name = mag_name + + @staticmethod + def _arcsec2dist(sep, r=1.0): + """Converts arcsec separation to a cartesian distance""" + return np.sin(np.deg2rad(sep / 3600.0 / 2.0)) * 2.0 * r + + def _precompute_friends_of_friends_index(self, table: Table): + """Computes galaxy groups using friends of friends algorithm on cartesian coordinates""" + points = SkyCoord(table['ra'], table['dec'], unit=('deg', 'deg')).cartesian.xyz.value.T + group_ids = find_friends_of_friends(points, self._arcsec2dist(self.link_distance), reassign_group_indices=False) + table["friends_of_friends_id"] = group_ids + + def __call__(self, table: Table): + """Samples galaxies from input catalog to make scene. + + We assume the input catalog has `ra` and `dec` in degrees, like CATSIM does. + """ + # group galaxies by indecies + if "friends_of_friends_id" not in table.colnames: + self._precompute_friends_of_friends_index(table) + + # filter by magnitude + if self.mag_name not in table.colnames: + raise ValueError(f"Catalog must have '{self.mag_name}' column.") + cond1 = table[self.mag_name] <= self.max_mag + cond2 = table[self.mag_name] > self.min_mag + cond = cond1 & cond2 + blend_table = table[cond] + + # filter by number of galaxies + num_galaxies = blend_table.to_pandas()['friends_of_friends_id'].value_counts() + cond1 = num_galaxies >= self.min_number + cond2 = num_galaxies <= self.max_number + if (cond1 & cond2).sum() == 0: + raise ValueError( + f"No groups with number of galaxies in [{self.min_number}, {self.max_number}], found using a `link_distance` of {self.link_distance}" + ) + group_id = np.random.choice(num_galaxies[cond1 & cond2].index) + indices = blend_table['friends_of_friends_id'] == group_id + + # check if galaxies fit in the window + ra = blend_table["ra"] + dec = blend_table["dec"] + if (ra[indices].max() - ra[indices].min()) * 3600 > self.stamp_size or \ + (dec[indices].max() - dec[indices].min()) * 3600 > self.stamp_size: + raise ValueError( + "sampled galaxies don't fit in the stamp size, increase the stamp size, " + "decrease `max_number` or `link_distance`." + ) + + # recenter catalog 'ra' and 'dec' to the center of the stamp + blend_table = blend_table[indices] + blend_table["ra"] = ra[indices] - ra[indices].mean() + blend_table["dec"] = dec[indices] - dec[indices].mean() + + # finally, convert to arcsec + blend_table["ra"] *= 3600 + blend_table["dec"] *= 3600 + + return blend_table