Skip to content

Commit

Permalink
add rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Oct 15, 2022
1 parent 2741b0b commit 1d862ee
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
save_path=None,
seed=DEFAULT_SEED,
apply_shear=False,
augment_data=False,
):
"""Initializes the DrawBlendsGenerator class.
Expand Down Expand Up @@ -164,6 +165,10 @@ def __init__(
as None, results will not be saved.
seed (int): Integer seed for reproducible random noise realizations.
apply_shear (float): Whether to apply the shear specified in catalogs to galaxies.
If set to True, sampling function must add 'g1', 'g2' columns.
augment_data (float): If set to True, augment data by adding a random rotation to every
galaxy drawn. Rotation added is recorded via `btk_rotation` column
output.
"""
self.blend_generator = BlendGenerator(
catalog, sampling_function, batch_size, shifts, indexes, verbose
Expand All @@ -174,6 +179,7 @@ def __init__(
self.batch_size = self.blend_generator.batch_size
self.max_number = self.blend_generator.max_number
self.apply_shear = apply_shear
self.augment_data = augment_data

if isinstance(surveys, Survey):
self.surveys = [surveys]
Expand Down Expand Up @@ -369,6 +375,12 @@ def render_mini_batch(self, blend_list, psf, wcs, survey, seedseq_minibatch, ext
blend.add_column(x_peak)
blend.add_column(y_peak)

# add rotation, if requested
if self.augment_data:
rng = np.random.default_rng(seedseq_minibatch.generate_state(1))
theta = rng.uniform(0, 360, size=len(blend))
blend.add_column(Column(theta), name="btk_rotation")

n_bands = len(survey.available_filters)
iso_image_multi = np.zeros((self.max_number, n_bands, pix_stamp_size, pix_stamp_size))
blend_image_multi = np.zeros((n_bands, pix_stamp_size, pix_stamp_size))
Expand Down Expand Up @@ -506,6 +518,8 @@ def render_single(self, entry, filt, psf, survey, extra_data):
pix_stamp_size = int(self.stamp_size / survey.pixel_scale.to_value("arcsec"))
try:
gal = get_catsim_galaxy(entry, filt, survey)
if self.augment_data:
gal.rotate(galsim.Angle(entry["btk_rotation"], unit=galsim.degrees))
if self.apply_shear:
if "g1" in entry.keys() and "g2" in entry.keys():
gal = gal.shear(g1=entry["g1"], g2=entry["g2"])
Expand Down Expand Up @@ -628,6 +642,8 @@ def render_single(self, entry, filt, psf, survey, extra_data):
gal = galsim_catalog.makeGalaxy(
entry["btk_index"], gal_type=self.gal_type, noise_pad_size=0
).withFlux(gal_flux)
if self.augment_data:
gal.rotate(galsim.Angle(entry["btk_rotation"], unit=galsim.degrees))
if self.apply_shear:
if "g1" in entry.keys() and "g2" in entry.keys():
gal = gal.shear(g1=entry["g1"], g2=entry["g2"])
Expand Down

0 comments on commit 1d862ee

Please sign in to comment.