Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

min_number argument in sampling func #374

Merged
merged 4 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion btk/create_blend_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
)

self.max_number = self.sampling_function.max_number
self.min_number = self.sampling_function.min_number

def __iter__(self):
"""Returns an iterable which is the object itself."""
Expand All @@ -55,7 +56,7 @@ def __iter__(self):
def __next__(self):
"""Generates a list of blend tables of len batch_size.

Each blend table has entries numbered between 1 and max_number, corresponding
Each blend table has entries numbered between min_number and max_number, corresponding
to overlapping objects in the blend.

Returns:
Expand All @@ -77,6 +78,11 @@ def __next__(self):
f"Number of objects per blend must be "
f"less than max_number: {len(blend_table)} <= {self.max_number}"
)
if len(blend_table) < self.min_number:
raise ValueError(
f"Number of objects per blend must be "
f"greater than min_number: {len(blend_table)} >= {self.min_number}"
)
blend_tables.append(blend_table)
return blend_tables

Expand Down
22 changes: 15 additions & 7 deletions btk/sampling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ class SamplingFunction(ABC):
galaxies chosen for the blend.
"""

def __init__(self, max_number, seed=DEFAULT_SEED):
def __init__(self, max_number, min_number=1, seed=DEFAULT_SEED):
"""Initializes the SamplingFunction.

Args:
max_number (int): maximum number of catalog entries returned from sample.
min_number (int): minimum number of catalog entries returned from sample.
seed (int): Seed to initialize randomness for reproducibility.
"""
self.max_number = max_number
self.min_number = min_number

if isinstance(seed, int):
self.rng = np.random.default_rng(seed)
Expand All @@ -60,17 +62,20 @@ def compatible_catalogs(self):
class DefaultSampling(SamplingFunction):
"""Default sampling function used for producing blend tables."""

def __init__(self, max_number=2, stamp_size=24.0, max_shift=None, seed=DEFAULT_SEED):
def __init__(
self, max_number=2, min_number=1, stamp_size=24.0, max_shift=None, seed=DEFAULT_SEED
):
"""Initializes default sampling function.

Args:
max_number (int): Defined in parent class
min_number (int): Defined in parent class
stamp_size (float): Size of the desired stamp.
max_shift (float): Magnitude of maximum value of shift. If None then it
is set as one-tenth the stamp size. (in arcseconds)
seed (int): Seed to initialize randomness for reproducibility.
"""
super().__init__(max_number, seed)
super().__init__(max_number=max_number, min_number=min_number, seed=seed)
self.stamp_size = stamp_size
self.max_shift = max_shift if max_shift else self.stamp_size / 10.0

Expand Down Expand Up @@ -104,7 +109,7 @@ def __call__(self, table, shifts=None, indexes=None):
Returns:
Astropy.table with entries corresponding to one blend.
"""
number_of_objects = self.rng.integers(1, self.max_number + 1)
number_of_objects = self.rng.integers(self.min_number, self.max_number)
(q,) = np.where(table["ref_mag"] <= 25.3)

if indexes is None:
Expand Down Expand Up @@ -133,17 +138,20 @@ class BasicSampling(SamplingFunction):
Includes magnitude cut, restriction on the shape, shift randomization.
"""

def __init__(self, max_number=4, stamp_size=24.0, max_shift=None, seed=DEFAULT_SEED):
def __init__(
self, max_number=4, min_number=1, stamp_size=24.0, max_shift=None, seed=DEFAULT_SEED
):
"""Initializes the basic sampling function.

Args:
max_number (int): Defined in parent class
min_number (int): Defined in parent class
stamp_size (float): Size of the desired stamp.
max_shift (float): Magnitude of maximum value of shift. If None then it
is set as one-tenth the stamp size. (in arcseconds)
seed (int): Seed to initialize randomness for reproducibility.
"""
super().__init__(max_number, seed)
super().__init__(max_number=max_number, min_number=min_number, seed=seed)
self.stamp_size = stamp_size
self.max_shift = max_shift if max_shift else self.stamp_size / 10.0

Expand All @@ -169,7 +177,7 @@ def __call__(self, table, **kwargs):
Returns:
Table with entries corresponding to one blend.
"""
number_of_objects = self.rng.integers(0, self.max_number)
number_of_objects = self.rng.integers(self.min_number, self.max_number)
a = np.hypot(table["a_d"], table["a_b"])
cond = (a <= 2) & (a > 0.2)
(q_bright,) = np.where(cond & (table["ref_mag"] <= 24))
Expand Down