diff --git a/btk/draw_blends.py b/btk/draw_blends.py index a6c4f9710..1b698219d 100644 --- a/btk/draw_blends.py +++ b/btk/draw_blends.py @@ -133,6 +133,7 @@ def __init__( save_path=None, seed=DEFAULT_SEED, apply_shear=False, + augment_data=False, ): """Initializes the DrawBlendsGenerator class. @@ -164,6 +165,10 @@ def __init__( as None, results will not be saved. seed (int): Integer seed for reproducible random noise realizations. apply_shear (float): Whether to apply the shear specified in catalogs to galaxies. + If set to True, sampling function must add 'g1', 'g2' columns. + augment_data (float): If set to True, augment data by adding a random rotation to every + galaxy drawn. Rotation added is recorded via `btk_rotation` column + output. """ self.blend_generator = BlendGenerator( catalog, sampling_function, batch_size, shifts, indexes, verbose @@ -174,6 +179,7 @@ def __init__( self.batch_size = self.blend_generator.batch_size self.max_number = self.blend_generator.max_number self.apply_shear = apply_shear + self.augment_data = augment_data if isinstance(surveys, Survey): self.surveys = [surveys] @@ -369,6 +375,12 @@ def render_mini_batch(self, blend_list, psf, wcs, survey, seedseq_minibatch, ext blend.add_column(x_peak) blend.add_column(y_peak) + # add rotation, if requested + if self.augment_data: + rng = np.random.default_rng(seedseq_minibatch.generate_state(1)) + theta = rng.uniform(0, 360, size=len(blend)) + blend.add_column(Column(theta), name="btk_rotation") + n_bands = len(survey.available_filters) iso_image_multi = np.zeros((self.max_number, n_bands, pix_stamp_size, pix_stamp_size)) blend_image_multi = np.zeros((n_bands, pix_stamp_size, pix_stamp_size)) @@ -506,6 +518,8 @@ def render_single(self, entry, filt, psf, survey, extra_data): pix_stamp_size = int(self.stamp_size / survey.pixel_scale.to_value("arcsec")) try: gal = get_catsim_galaxy(entry, filt, survey) + if self.augment_data: + gal.rotate(galsim.Angle(entry["btk_rotation"], unit=galsim.degrees)) if self.apply_shear: if "g1" in entry.keys() and "g2" in entry.keys(): gal = gal.shear(g1=entry["g1"], g2=entry["g2"]) @@ -628,6 +642,8 @@ def render_single(self, entry, filt, psf, survey, extra_data): gal = galsim_catalog.makeGalaxy( entry["btk_index"], gal_type=self.gal_type, noise_pad_size=0 ).withFlux(gal_flux) + if self.augment_data: + gal.rotate(galsim.Angle(entry["btk_rotation"], unit=galsim.degrees)) if self.apply_shear: if "g1" in entry.keys() and "g2" in entry.keys(): gal = gal.shear(g1=entry["g1"], g2=entry["g2"])