diff --git a/btk/create_blend_generator.py b/btk/create_blend_generator.py index aacb33c9b..4b1aa23f5 100644 --- a/btk/create_blend_generator.py +++ b/btk/create_blend_generator.py @@ -47,6 +47,7 @@ def __init__( ) self.max_number = self.sampling_function.max_number + self.min_number = self.sampling_function.min_number def __iter__(self): """Returns an iterable which is the object itself.""" @@ -55,7 +56,7 @@ def __iter__(self): def __next__(self): """Generates a list of blend tables of len batch_size. - Each blend table has entries numbered between 1 and max_number, corresponding + Each blend table has entries numbered between min_number and max_number, corresponding to overlapping objects in the blend. Returns: @@ -77,6 +78,11 @@ def __next__(self): f"Number of objects per blend must be " f"less than max_number: {len(blend_table)} <= {self.max_number}" ) + if len(blend_table) < self.min_number: + raise ValueError( + f"Number of objects per blend must be " + f"greater than min_number: {len(blend_table)} >= {self.min_number}" + ) blend_tables.append(blend_table) return blend_tables diff --git a/btk/sampling_functions.py b/btk/sampling_functions.py index 0ccd1b2e6..b3a8ced64 100644 --- a/btk/sampling_functions.py +++ b/btk/sampling_functions.py @@ -30,14 +30,16 @@ class SamplingFunction(ABC): galaxies chosen for the blend. """ - def __init__(self, max_number, seed=DEFAULT_SEED): + def __init__(self, max_number, min_number=1, seed=DEFAULT_SEED): """Initializes the SamplingFunction. Args: max_number (int): maximum number of catalog entries returned from sample. + min_number (int): minimum number of catalog entries returned from sample. seed (int): Seed to initialize randomness for reproducibility. """ self.max_number = max_number + self.min_number = min_number if isinstance(seed, int): self.rng = np.random.default_rng(seed) @@ -60,17 +62,20 @@ def compatible_catalogs(self): class DefaultSampling(SamplingFunction): """Default sampling function used for producing blend tables.""" - def __init__(self, max_number=2, stamp_size=24.0, max_shift=None, seed=DEFAULT_SEED): + def __init__( + self, max_number=2, min_number=1, stamp_size=24.0, max_shift=None, seed=DEFAULT_SEED + ): """Initializes default sampling function. Args: max_number (int): Defined in parent class + min_number (int): Defined in parent class stamp_size (float): Size of the desired stamp. max_shift (float): Magnitude of maximum value of shift. If None then it is set as one-tenth the stamp size. (in arcseconds) seed (int): Seed to initialize randomness for reproducibility. """ - super().__init__(max_number, seed) + super().__init__(max_number=max_number, min_number=min_number, seed=seed) self.stamp_size = stamp_size self.max_shift = max_shift if max_shift else self.stamp_size / 10.0 @@ -104,7 +109,7 @@ def __call__(self, table, shifts=None, indexes=None): Returns: Astropy.table with entries corresponding to one blend. """ - number_of_objects = self.rng.integers(1, self.max_number + 1) + number_of_objects = self.rng.integers(self.min_number, self.max_number) (q,) = np.where(table["ref_mag"] <= 25.3) if indexes is None: @@ -133,17 +138,20 @@ class BasicSampling(SamplingFunction): Includes magnitude cut, restriction on the shape, shift randomization. """ - def __init__(self, max_number=4, stamp_size=24.0, max_shift=None, seed=DEFAULT_SEED): + def __init__( + self, max_number=4, min_number=1, stamp_size=24.0, max_shift=None, seed=DEFAULT_SEED + ): """Initializes the basic sampling function. Args: max_number (int): Defined in parent class + min_number (int): Defined in parent class stamp_size (float): Size of the desired stamp. max_shift (float): Magnitude of maximum value of shift. If None then it is set as one-tenth the stamp size. (in arcseconds) seed (int): Seed to initialize randomness for reproducibility. """ - super().__init__(max_number, seed) + super().__init__(max_number=max_number, min_number=min_number, seed=seed) self.stamp_size = stamp_size self.max_shift = max_shift if max_shift else self.stamp_size / 10.0 @@ -169,7 +177,7 @@ def __call__(self, table, **kwargs): Returns: Table with entries corresponding to one blend. """ - number_of_objects = self.rng.integers(0, self.max_number) + number_of_objects = self.rng.integers(self.min_number, self.max_number) a = np.hypot(table["a_d"], table["a_b"]) cond = (a <= 2) & (a > 0.2) (q_bright,) = np.where(cond & (table["ref_mag"] <= 24))