Skip to content

Commit

Permalink
More reproducibility (#237)
Browse files Browse the repository at this point in the history
* only able to provide seed not rng

* only allow integer seeds, propagate to multiprocessing correctly

* propagate changes to tests

* correct seed defaults

* fix tests bc seeds were not propagated correctly before

* larger max int

* Update btk/draw_blends.py

Co-authored-by: Alexandre Boucaud <[email protected]>

* Update btk/sampling_functions.py

Co-authored-by: Alexandre Boucaud <[email protected]>

* define default seed once

* Update btk/sampling_functions.py

Co-authored-by: Alexandre Boucaud <[email protected]>

* dont need another rng embedded in render_blend

* simplify line

* seeds are now used in config

* put defualt seed in init

* show how to use the seed

* add seed in tutorial documentation too

* ignore warning until actually throws an error

* fix tests with new seed

* why did the precision decrease again?

* no output in first cell

Co-authored-by: Alexandre Boucaud <[email protected]>
  • Loading branch information
ismael-mendoza and aboucaud authored Oct 29, 2021
1 parent 44c88e8 commit 599a181
Show file tree
Hide file tree
Showing 18 changed files with 250 additions and 205 deletions.
2 changes: 2 additions & 0 deletions btk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
__email__ = "[email protected]"
__version__ = "0.9.3"

DEFAULT_SEED = 0

from . import catalog
from . import create_blend_generator
from . import draw_blends
Expand Down
48 changes: 19 additions & 29 deletions btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
from astropy.table import Column
from astropy.table import Table

from btk import DEFAULT_SEED
from btk.create_blend_generator import BlendGenerator
from btk.multiprocess import multiprocess
from btk.survey import get_flux
from btk.survey import get_mean_sky_level
from btk.survey import make_wcs
from btk.survey import Survey

MAX_SEED_INT = 1_000_000_000


class SourceNotVisible(Exception):
"""Custom exception to indicate that a source has no visible model components."""
Expand Down Expand Up @@ -131,7 +134,7 @@ def __init__(
indexes=None,
channels_last=False,
save_path=None,
rng=None,
seed=DEFAULT_SEED,
):
"""Initializes the DrawBlendsGenerator class.
Expand All @@ -155,8 +158,7 @@ def __init__(
dimensions (default).
save_path (str): Path to a directory where results will be saved. If left
as None, results will not be saved.
rng : Controls the random number generation. Can be an integer seed,
or a numpy.random.Generator. If None, a random seed will be used.
seed (int): Integer seed for reproducible random noise realizations.
"""
self.blend_generator = BlendGenerator(
catalog, sampling_function, batch_size, shifts, indexes, verbose
Expand Down Expand Up @@ -191,17 +193,7 @@ def __init__(
self.verbose = verbose
self.channels_last = channels_last
self.save_path = save_path

if rng is None:
self.rng = np.random.default_rng()
elif isinstance(rng, int):
self.rng = np.random.default_rng(rng)
else:
try:
rng.random()
except AttributeError:
raise AttributeError("The random generator you provided is invalid.")
self.rng = rng
self.rng = np.random.default_rng(seed)

def check_compatibility(self, survey):
"""Checks that the compatibility between the survey, the catalog and the generator.
Expand Down Expand Up @@ -257,8 +249,9 @@ def __next__(self):

input_args = []
for i in range(0, self.batch_size, mini_batch_size):
noise_seed = self.rng.integers(MAX_SEED_INT) # reproducibility
cat = copy.deepcopy(blend_cat[i : i + mini_batch_size])
input_args.append((cat, psf, wcs, s))
input_args.append((cat, psf, wcs, s, noise_seed))

# multiprocess and join results
# ideally, each cpu processes a single mini_batch
Expand Down Expand Up @@ -318,7 +311,7 @@ def __next__(self):
}
return output

def render_mini_batch(self, blend_list, psf, wcs, survey, extra_data=None):
def render_mini_batch(self, blend_list, psf, wcs, survey, noise_seed, extra_data=None):
"""Returns isolated and blended images for blend catalogs in blend_list.
Function loops over blend_list and draws blend and isolated images in each
Expand Down Expand Up @@ -358,16 +351,13 @@ def render_mini_batch(self, blend_list, psf, wcs, survey, extra_data=None):
blend.add_column(y_peak)

iso_image_multi = np.zeros(
(
self.max_number,
len(survey.filters),
pix_stamp_size,
pix_stamp_size,
)
(self.max_number, len(survey.filters), pix_stamp_size, pix_stamp_size)
)
blend_image_multi = np.zeros((len(survey.filters), pix_stamp_size, pix_stamp_size))
for b, filt in enumerate(survey.filters):
single_band_output = self.render_blend(blend, psf[b], filt, survey, extra_data[i])
single_band_output = self.render_blend(
blend, psf[b], filt, survey, noise_seed, extra_data[i]
)
blend_image_multi[b, :, :] = single_band_output[0]
iso_image_multi[:, b, :, :] = single_band_output[1]

Expand All @@ -380,7 +370,7 @@ def render_mini_batch(self, blend_list, psf, wcs, survey, extra_data=None):
index += len(blend)
return outputs

def render_blend(self, blend_catalog, psf, filt, survey, extra_data):
def render_blend(self, blend_catalog, psf, filt, survey, noise_seed, extra_data):
"""Draws image of isolated galaxies along with the blend image in the single input band.
The WLDeblending package (descwl) renders galaxies corresponding to the
Expand Down Expand Up @@ -424,7 +414,7 @@ def render_blend(self, blend_catalog, psf, filt, survey, extra_data):
if self.add_noise:
if self.verbose:
print("Noise added to blend image")
generator = galsim.random.BaseDeviate(seed=self.rng.integers(100000))
generator = galsim.random.BaseDeviate(seed=noise_seed)
noise = galsim.PoissonNoise(rng=generator, sky_level=mean_sky_level)
_blend_image.addNoise(noise)

Expand Down Expand Up @@ -585,7 +575,7 @@ def __init__(
galsim_hub_model="hub:Lanusse2020",
param_names=["flux_radius", "mag_auto", "zphot"],
save_path=None,
rng=None,
seed=DEFAULT_SEED,
): # noqa: D417
"""Initializes the GalsimHubGenerator class.
Expand All @@ -610,14 +600,14 @@ def __init__(
indexes=indexes,
channels_last=channels_last,
save_path=save_path,
rng=rng,
seed=seed,
)
import galsim_hub

self.galsim_hub_model = galsim_hub.GenerativeGalaxyModel(galsim_hub_model)
self.param_names = param_names

def render_mini_batch(self, blend_list, psf, wcs, survey):
def render_mini_batch(self, blend_list, psf, wcs, survey, seed):
"""Returns isolated and blended images for blend catalogs in blend_list.
Here we generate the images for all galaxies in the batch at the same
Expand All @@ -635,7 +625,7 @@ def render_mini_batch(self, blend_list, psf, wcs, survey):
base_images_l.append(base_images[index : index + len(blend)])
index += len(blend)

return super().render_mini_batch(blend_list, psf, wcs, survey, base_images_l)
return super().render_mini_batch(blend_list, psf, wcs, survey, seed, base_images_l)

def render_single(self, entry, filt, psf, survey, extra_data):
"""Returns the Galsim Image of an isolated galaxy."""
Expand Down
44 changes: 23 additions & 21 deletions btk/sampling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import astropy.table
import numpy as np

from btk import DEFAULT_SEED
from btk.catalog import CatsimCatalog


Expand All @@ -31,14 +32,20 @@ class SamplingFunction(ABC):
galaxies chosen for the blend.
"""

def __init__(self, max_number):
def __init__(self, max_number, seed=DEFAULT_SEED):
"""Initializes the SamplingFunction.
Args:
max_number (int): maximum number of catalog entries returned from sample.
seed (int): Seed to initialize randomness for reproducibility.
"""
self.max_number = max_number

if isinstance(seed, int):
self.rng = np.random.default_rng(seed)
else:
raise AttributeError("The seed you provided is invalid, should be an int.")

@abstractmethod
def __call__(self, table, **kwargs):
"""Outputs a sample from the given astropy table.
Expand All @@ -55,20 +62,19 @@ def compatible_catalogs(self):
class DefaultSampling(SamplingFunction):
"""Default sampling function used for producing blend tables."""

def __init__(self, max_number=2, stamp_size=24.0, maxshift=None, rng=np.random.default_rng()):
def __init__(self, max_number=2, stamp_size=24.0, maxshift=None, seed=DEFAULT_SEED):
"""Initializes default sampling function.
Args:
max_number (int): Defined in parent class
stamp_size (float): Size of the desired stamp.
maxshift (float): Magnitude of maximum value of shift. If None then it
is set as one-tenth the stamp size. (in arcseconds)
rng (numpy.random.Generator) : Controls the random number generation.
seed (int): Seed to initialize randomness for reproducibility.
"""
super().__init__(max_number)
super().__init__(max_number, seed)
self.stamp_size = stamp_size
self.maxshift = maxshift if maxshift else self.stamp_size / 10.0
self.rng = rng

@property
def compatible_catalogs(self):
Expand Down Expand Up @@ -126,20 +132,19 @@ def __call__(self, table, shifts=None, indexes=None):
class DefaultSamplingGalsimHub(SamplingFunction):
"""Default sampling function used for producing blend tables, esp. for galsim_hub."""

def __init__(self, max_number=2, stamp_size=24.0, maxshift=None, rng=np.random.default_rng()):
def __init__(self, max_number=2, stamp_size=24.0, maxshift=None, seed=DEFAULT_SEED):
"""Initialize default sampling function for galsim_hub.
Args:
max_number (int): Defined in parent class
stamp_size (float): Size of the desired stamp.
maxshift (float): Magnitude of maximum value of shift. If None then it
is set as one-tenth the stamp size. (in arcseconds)
rng (numpy.random.Generator) : Controls the random number generation.
seed (int): Seed to initialize randomness for reproducibility.
"""
super().__init__(max_number)
super().__init__(max_number, seed)
self.stamp_size = stamp_size
self.maxshift = maxshift if maxshift else self.stamp_size / 10.0
self.rng = rng

@property
def compatible_catalogs(self):
Expand Down Expand Up @@ -194,20 +199,19 @@ class BasicSampling(SamplingFunction):
Includes magnitude cut, restriction on the shape, shift randomization.
"""

def __init__(self, max_number=4, stamp_size=24.0, maxshift=None, rng=np.random.default_rng()):
def __init__(self, max_number=4, stamp_size=24.0, maxshift=None, seed=DEFAULT_SEED):
"""Initializes the basic sampling function.
Args:
max_number (int): Defined in parent class
stamp_size (float): Size of the desired stamp.
maxshift (float): Magnitude of maximum value of shift. If None then it
is set as one-tenth the stamp size. (in arcseconds)
rng (numpy.random.Generator) : Controls the random number generation.
seed (int): Seed to initialize randomness for reproducibility.
"""
super().__init__(max_number)
super().__init__(max_number, seed)
self.stamp_size = stamp_size
self.maxshift = maxshift if maxshift else self.stamp_size / 10.0
self.rng = rng

@property
def compatible_catalogs(self):
Expand Down Expand Up @@ -266,7 +270,7 @@ def __init__(
pixel_scale,
shift=None,
group_id=None,
rng=np.random.default_rng(),
seed=DEFAULT_SEED,
):
"""Blends are defined from *groups* of galaxies from a CatSim-like catalog.
Expand All @@ -280,16 +284,15 @@ def __init__(
pixel_scale (float): pixel scale of the survey, in arcseconds per pixel
shift (list): List containing shifts to apply (useful to avoid randomization)
group_id (list): List containing which group_ids to analyze (avoid randomization)
rng (numpy.random.Generator) : Controls the random number generation.
seed (int): Seed to initialize randomness for reproducibility.
"""
super().__init__(max_number)
super().__init__(max_number, seed)

self.wld_catalog = CatsimCatalog.from_file(wld_catalog_name).get_raw_catalog()
self.stamp_size = stamp_size
self.pixel_scale = pixel_scale
self.shift = shift
self.group_id = group_id
self.rng = rng

@property
def compatible_catalogs(self):
Expand Down Expand Up @@ -355,7 +358,7 @@ def __init__(
pixel_scale,
shift=None,
fmt="fits",
rng=np.random.default_rng(),
seed=DEFAULT_SEED,
):
"""Blends defined from *groups* of galaxies from a catalog previously analyzed with WLD.
Expand All @@ -374,15 +377,14 @@ def __init__(
pixel_scale (float): pixel scale of the survey, in arcseconds per pixel
fmt (str): Format of input wld_catalog used to define groups.
shift (list): List of shifts to apply (usefult to avoid randomization)
rng (numpy.random.Generator) : Controls the random number generation.
seed (int): Seed to initialize randomness for reproducibility.
"""
super().__init__(max_number)
super().__init__(max_number, seed)
self.wld_catalog = astropy.table.Table.read(wld_catalog_name, format=fmt)
self.stamp_size = stamp_size
self.pixel_scale = pixel_scale
self.group_id_count = 0
self.shift = shift
self.rng = rng

@property
def compatible_catalogs(self):
Expand Down
3 changes: 3 additions & 0 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ defaults:

### all arguments ###

# reproducibility
seed: 0 # single seed used for both sampling function AND draw blend generator

# system
save_path: Null
cpus: 1
Expand Down
1 change: 1 addition & 0 deletions conf/draw_blends/catsim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ shifts: Null
indexes: Null
channels_last: ${channels_last}
save_path: ${save_path}
seed: ${seed}
1 change: 1 addition & 0 deletions conf/draw_blends/cosmos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ shifts: Null
indexes: Null
channels_last: ${channels_last}
save_path: ${save_path}
seed: ${seed}
1 change: 1 addition & 0 deletions conf/draw_blends/galsim_hub.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ channels_last: ${channels_last}
galsim_hub_model: ${galsim_hub_model}
param_names: ${param_names}
save_path: ${save_path}
seed: ${seed}
1 change: 1 addition & 0 deletions conf/sampling/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ _target_: btk.sampling_functions.DefaultSampling
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
seed: ${seed}
1 change: 1 addition & 0 deletions conf/sampling/galsim_hub.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ _target_: btk.sampling_functions.DefaultSamplingGalsimHub
max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
seed: ${seed}
1 change: 1 addition & 0 deletions conf/sampling/group_sampling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
wld_catalog_name: ${paths.data}/sample_group_catalog.fits
seed: ${seed}
shift: Null
1 change: 1 addition & 0 deletions conf/sampling/group_sampling_numbered.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ max_number: ${max_number}
stamp_size: ${stamp_size}
maxshift: ${max_shift}
wld_catalog_name: ${paths.data}/sample_group_catalog.fits
seed: ${seed}
shift: Null
16 changes: 15 additions & 1 deletion docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ Importing the relevant packages
import btk.draw_blends
import astropy.table

Reproducibility
''''''''''''''''

The following cell contains the seed use to generate reproducible random realizations of BTK output. Using the same seed in BTK for **both the sampling function and draw blend generator** (more info below) guarantees the galaxy images and parameters produced will be the same, even across different systems.

.. jupyter-execute::

from btk import DEFAULT_SEED
seed = DEFAULT_SEED

Every object in BTK that needs a seed uses `DEFAULT_SEED` implicitly. In this tutorial we explicitly show how this seed is passed in and which objects needs a seed. However, we omit it in the other tutorials.


Drawing some blends
''''''''''''''''''''

Expand Down Expand Up @@ -68,7 +81,7 @@ which galaxies are drawn, with what shifts, etc. This is achieved using the ``Sa
stamp_size = 24.0 # Size of the stamp, in arcseconds
max_number = 3 # Maximum number of galaxies in a blend
max_shift = 3.0 # Maximum shift of the galaxies, in arcseconds
sampling_function = btk.sampling_functions.DefaultSampling(max_number=max_number, stamp_size=stamp_size, maxshift=max_shift)
sampling_function = btk.sampling_functions.DefaultSampling(max_number=max_number, stamp_size=stamp_size, maxshift=max_shift, seed=seed)

As a reference, here is the code for this sampling function:

Expand Down Expand Up @@ -287,6 +300,7 @@ Now that we have all the objects at our disposal, we can create the DrawBlendsGe
indexes=None,
cpus=1,
add_noise=True,
seed=seed
)

The results from the ``next`` call are stored in the dictionnary; the keys are:
Expand Down
Loading

0 comments on commit 599a181

Please sign in to comment.