Skip to content

Commit

Permalink
Merge branch 'main' into cosmos-tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
thuiop authored Mar 30, 2021
2 parents e943d4e + 46d660f commit 68e1eb3
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 53 deletions.
7 changes: 1 addition & 6 deletions btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class DrawBlendsGenerator(ABC):
Batch is divided into mini batches of size blend_generator.batch_size//cpus and
each mini-batch analyzed separately. The results are then combined to output a
dict with results of entire batch. If multiprocessing is true, then each of
dict with results of entire batch. If the number of cpus is greater than one, then each of
the mini-batches are run in parallel.
"""
Expand All @@ -141,7 +141,6 @@ def __init__(
batch_size=8,
stamp_size=24,
meas_bands=("i",),
multiprocessing=False,
cpus=1,
verbose=False,
add_noise=True,
Expand All @@ -159,8 +158,6 @@ def __init__(
batch_size (int) : Number of blends generated per batch
stamp_size (float) : Size of the stamps, in arcseconds
meas_bands=("i",) : Tuple containing the bands in which the measurements are carried
multiprocessing (bool) : Indicates whether the mini batches should be ran in
parallel
cpus (int) : Number of cpus to use ; defines the number of minibatches
verbose (bool) : Indicates whether additionnal information should be printed
add_noise (bool) : Indicates if the blends should be generated with noise
Expand All @@ -178,7 +175,6 @@ def __init__(
catalog, sampling_function, batch_size, shifts, indexes, verbose
)
self.catalog = self.blend_generator.catalog
self.multiprocessing = multiprocessing
self.cpus = cpus

self.batch_size = self.blend_generator.batch_size
Expand Down Expand Up @@ -265,7 +261,6 @@ def __next__(self):
self.render_mini_batch,
input_args,
self.cpus,
self.multiprocessing,
self.verbose,
)

Expand Down
3 changes: 0 additions & 3 deletions btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def __init__(
self,
measure_functions,
draw_blend_generator,
multiprocessing=False,
cpus=1,
verbose=False,
):
Expand Down Expand Up @@ -155,7 +154,6 @@ def __init__(
ValueError("measure_functions must be a list of functions or a single function.")

self.draw_blend_generator = draw_blend_generator
self.multiprocessing = multiprocessing
self.cpus = cpus

self.batch_size = self.draw_blend_generator.batch_size
Expand Down Expand Up @@ -229,7 +227,6 @@ def __next__(self):
self.run_batch,
input_args,
self.cpus,
self.multiprocessing,
self.verbose,
)
if self.verbose:
Expand Down
4 changes: 2 additions & 2 deletions btk/multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from itertools import starmap


def multiprocess(func, input_args, cpus, multiprocessing=False, verbose=False):
if multiprocessing:
def multiprocess(func, input_args, cpus, verbose=False):
if cpus > 1:
if verbose:
print(
f"Running mini-batch of size {len(input_args)} with multiprocessing with "
Expand Down
1 change: 0 additions & 1 deletion docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ Now that we have all the objects at our disposal, we can create the DrawBlendsGe
stamp_size=stamp_size,
shifts=None,
indexes=None,
multiprocessing=False,
cpus=1,
add_noise=True,
)
Expand Down
69 changes: 46 additions & 23 deletions notebooks/intro.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/test_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_cosmos_galaxies():
[btk.survey.HST],
batch_size=batch_size,
stamp_size=stamp_size,
multiprocessing=False,
cpus=1,
add_noise=True,
verbose=True,
Expand Down
6 changes: 2 additions & 4 deletions tests/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
def get_draw_generator(
batch_size=8,
cpus=1,
multiprocessing=False,
add_noise=True,
fixed_parameters=False,
sampling_function=None,
Expand Down Expand Up @@ -46,7 +45,6 @@ def get_draw_generator(
stamp_size=stamp_size,
shifts=shifts,
indexes=indexes,
multiprocessing=multiprocessing,
cpus=cpus,
add_noise=add_noise,
verbose=True,
Expand All @@ -60,9 +58,9 @@ def test_multiprocessing(self):
b_size = 16
cpus = np.min([mp.cpu_count(), 16])

parallel_im_gen = get_draw_generator(b_size, cpus, multiprocessing=True, add_noise=False)
parallel_im_gen = get_draw_generator(b_size, cpus, add_noise=False)
parallel_im = next(parallel_im_gen)
serial_im_gen = get_draw_generator(b_size, cpus, multiprocessing=False, add_noise=False)
serial_im_gen = get_draw_generator(b_size, cpus=1, add_noise=False)
serial_im = next(serial_im_gen)
np.testing.assert_array_equal(parallel_im["blend_images"], serial_im["blend_images"])
np.testing.assert_array_equal(parallel_im["isolated_images"], serial_im["isolated_images"])
Expand Down
8 changes: 0 additions & 8 deletions tests/test_error_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def compatible_catalogs(self):
stamp_size = 24.0
batch_size = 8
cpus = 1
multiprocessing = False
add_noise = True

catalog = btk.catalog.CatsimCatalog.from_file(catalog_name)
Expand All @@ -35,7 +34,6 @@ def compatible_catalogs(self):
[Rubin],
stamp_size=stamp_size,
batch_size=batch_size,
multiprocessing=multiprocessing,
cpus=cpus,
add_noise=add_noise,
meas_bands=("i"),
Expand Down Expand Up @@ -63,7 +61,6 @@ def compatible_catalogs(self):
stamp_size = 24.0
batch_size = 8
cpus = 1
multiprocessing = False
add_noise = True

catalog = btk.catalog.CatsimCatalog.from_file(catalog_name)
Expand All @@ -74,7 +71,6 @@ def compatible_catalogs(self):
[Rubin],
stamp_size=stamp_size,
batch_size=batch_size,
multiprocessing=multiprocessing,
cpus=cpus,
add_noise=add_noise,
meas_bands=("i"),
Expand Down Expand Up @@ -104,7 +100,6 @@ def compatible_catalogs(self):
stamp_size = 24.0
batch_size = 8
cpus = 1
multiprocessing = False
add_noise = True

catalog = btk.catalog.CatsimCatalog.from_file(catalog_name)
Expand All @@ -115,7 +110,6 @@ def compatible_catalogs(self):
[Rubin],
stamp_size=stamp_size,
batch_size=batch_size,
multiprocessing=multiprocessing,
cpus=cpus,
add_noise=add_noise,
meas_bands=("i"),
Expand Down Expand Up @@ -153,7 +147,6 @@ def test_survey_not_list():
stamp_size = 24.0
batch_size = 8
cpus = 1
multiprocessing = False
add_noise = True

catalog = btk.catalog.CatsimCatalog.from_file(catalog_name)
Expand All @@ -165,7 +158,6 @@ def test_survey_not_list():
3,
stamp_size=stamp_size,
batch_size=batch_size,
multiprocessing=multiprocessing,
cpus=cpus,
add_noise=add_noise,
meas_bands=("i"),
Expand Down
5 changes: 2 additions & 3 deletions tests/test_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import btk.survey


def get_meas_generator(meas_function, multiprocessing=False, cpus=1):
def get_meas_generator(meas_function, cpus=1):
"""Returns draw generator with group sampling function"""

np.random.seed(0)
Expand Down Expand Up @@ -36,7 +36,6 @@ def get_meas_generator(meas_function, multiprocessing=False, cpus=1):
meas_generator = btk.measure.MeasureGenerator(
meas_function,
draw_blend_generator,
multiprocessing=multiprocessing,
cpus=cpus,
)
return meas_generator
Expand All @@ -62,7 +61,7 @@ def compare_sep():

def compare_sep_multiprocessing():
"""Test detection with sep"""
meas_generator = get_meas_generator(btk.measure.sep_measure, multiprocessing=True, cpus=4)
meas_generator = get_meas_generator(btk.measure.sep_measure, cpus=4)
_, results = next(meas_generator)
x_peak, y_peak = (
results[0][0]["catalog"]["x_peak"].item(),
Expand Down
2 changes: 0 additions & 2 deletions tests/test_mr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def test_multiresolution():
stamp_size = 24.0
batch_size = 8
cpus = 1
multiprocessing = False
add_noise = True

catalog = btk.catalog.CatsimCatalog.from_file(catalog_name)
Expand All @@ -20,7 +19,6 @@ def test_multiresolution():
[Rubin, HSC],
stamp_size=stamp_size,
batch_size=batch_size,
multiprocessing=multiprocessing,
cpus=cpus,
add_noise=add_noise,
meas_bands=("i", "i"),
Expand Down

0 comments on commit 68e1eb3

Please sign in to comment.