Skip to content

Commit

Permalink
implements friends of friends sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
atorchylo committed Sep 19, 2023
1 parent 9c12c98 commit 7fd7663
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions btk/sampling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import astropy
import numpy as np
from astropy.table import Table
from astropy.coordinates import SkyCoord
from fast3tree import find_friends_of_friends

from btk.utils import DEFAULT_SEED

Expand Down Expand Up @@ -429,3 +431,105 @@ def __call__(self, table: Table):
blend_table["dec"] *= 3600

return blend_table


class FriendsOfFriendsSampling(SamplingFunction):
"""Randomly selects galaxies using Friends Of Friends clustering algorithm.
Friends of friends clustering algorithm assigns a group of object same index if one can
each member of the group is within link_distance of any other member in the group.
This sampling function explicitly uses the spatial information in the input catalog to
generate scenes of galaxies. However, blends might not always be returned as a result.
"""

def __init__(
self,
max_number: int = 10,
min_number: int = 2,
link_distance: int = 2.5,
stamp_size: float = 24.0,
seed: int = DEFAULT_SEED,
min_mag: float = -np.inf,
max_mag: float = 25.3,
mag_name: str = "i_ab",
):
"""Initializes the FriendsOfFriendsSampling sampling function.
Args:
max_number: Defined in parent class
min_number: Defined in parent class
link_distance: Minimum linkage distance to form a group (arcsec).
stamp_size: Size of the desired stamp (arcsec).
seed: Seed to initialize randomness for reproducibility.
min_mag: Minimum magnitude allowed in samples
max_mag: Maximum magnitude allowed in samples.
mag_name: Name of the magnitude column in the catalog.
"""
super().__init__(max_number=max_number, min_number=min_number, seed=seed)
self.stamp_size = stamp_size
self.link_distance = link_distance
self.max_number = max_number
self.min_number = min_number
self.max_mag = max_mag
self.min_mag = min_mag
self.mag_name = mag_name

@staticmethod
def _arcsec2dist(sep, r=1.0):
"""Converts arcsec separation to a cartesian distance"""
return np.sin(np.deg2rad(sep / 3600.0 / 2.0)) * 2.0 * r

def _precompute_friends_of_friends_index(self, table: Table):
"""Computes galaxy groups using friends of friends algorithm on cartesian coordinates"""
points = SkyCoord(table['ra'], table['dec'], unit=('deg', 'deg')).cartesian.xyz.value.T
group_ids = find_friends_of_friends(points, self._arcsec2dist(self.link_distance), reassign_group_indices=False)
table["friends_of_friends_id"] = group_ids

def __call__(self, table: Table):
"""Samples galaxies from input catalog to make scene.
We assume the input catalog has `ra` and `dec` in degrees, like CATSIM does.
"""
# group galaxies by indecies
if "friends_of_friends_id" not in table.colnames:
self._precompute_friends_of_friends_index(table)

# filter by magnitude
if self.mag_name not in table.colnames:
raise ValueError(f"Catalog must have '{self.mag_name}' column.")
cond1 = table[self.mag_name] <= self.max_mag
cond2 = table[self.mag_name] > self.min_mag
cond = cond1 & cond2
blend_table = table[cond]

# filter by number of galaxies
num_galaxies = blend_table.to_pandas()['friends_of_friends_id'].value_counts()
cond1 = num_galaxies >= self.min_number
cond2 = num_galaxies <= self.max_number
if (cond1 & cond2).sum() == 0:
raise ValueError(
f"No groups with number of galaxies in [{self.min_number}, {self.max_number}], found using a `link_distance` of {self.link_distance}"
)
group_id = np.random.choice(num_galaxies[cond1 & cond2].index)
indices = blend_table['friends_of_friends_id'] == group_id

# check if galaxies fit in the window
ra = blend_table["ra"]
dec = blend_table["dec"]
if (ra[indices].max() - ra[indices].min()) * 3600 > self.stamp_size or \
(dec[indices].max() - dec[indices].min()) * 3600 > self.stamp_size:
raise ValueError(
"sampled galaxies don't fit in the stamp size, increase the stamp size, "
"decrease `max_number` or `link_distance`."
)

# recenter catalog 'ra' and 'dec' to the center of the stamp
blend_table = blend_table[indices]
blend_table["ra"] = ra[indices] - ra[indices].mean()
blend_table["dec"] = dec[indices] - dec[indices].mean()

# finally, convert to arcsec
blend_table["ra"] *= 3600
blend_table["dec"] *= 3600

return blend_table

0 comments on commit 7fd7663

Please sign in to comment.