Skip to content

Commit

Permalink
Merge pull request #22 from LSSTDESC/#16
Browse files Browse the repository at this point in the history
  • Loading branch information
sowmyakth authored Oct 1, 2019
2 parents 403f708 + 5956655 commit f36dfa4
Show file tree
Hide file tree
Showing 18 changed files with 1,556 additions and 16,593 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
[![Documentation Status](https://readthedocs.org/projects/blendingtoolkit/badge/?version=latest)](https://blendingtoolkit.readthedocs.io/en/latest/?badge=latest)

# BlendingToolKit
Tools to create blend catalogs, produce training samples and implement blending metrics.
Framework for fast generation and analysis of galaxy blends catalogs. This toolkit is a convenient way of
producing multi-band postage stamp images of blend scenes.

Documentation can be found at https://blendingtoolkit.readthedocs.io/en/latest/

Expand Down
28 changes: 14 additions & 14 deletions btk/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@


class Metrics_params(object):
def __init__(self, meas_generator, param):
def __init__(self, meas_generator, sim_param):
"""Class describing functions to return results of
detection/deblending/measurement algorithm in meas_generator. Each
blend results yielded by the meas_generator for a batch.
"""
self.meas_generator = meas_generator
self.param = param
self.sim_param = sim_param

def get_detections(self):
"""
Expand All @@ -31,9 +31,9 @@ def get_detections(self):
and 'dy' respectively, in pixels from bottom left corner as (0, 0).
"""
# Astropy table with entries corresponding to true sources
true_tables = [astropy.table.Table()] * self.param.batch_size
true_tables = [astropy.table.Table()] * self.config.batch_size
# Astropy table with entries corresponding to detections
detected_tables = [astropy.table.Table()] * self.param.batch_size
detected_tables = [astropy.table.Table()] * self.config.batch_size
return true_tables, detected_tables

def get_segmentation(self):
Expand Down Expand Up @@ -150,19 +150,19 @@ def get_detection_match(true_table, detected_table):
norm_dist = dist/norm_size[:, np.newaxis]
detected_table['dSigma_min'] = np.min(norm_dist, axis=0)
detected_table['d_min'] = np.min(dist, axis=0)
detection_threshold1 = 0.5
detection_threshold1 = 5
condlist1 = [
np.min(norm_dist, axis=0) <= detection_threshold1,
np.min(norm_dist, axis=0) > detection_threshold1]
choicelist1 = [np.argmin(norm_dist, axis=0), -1]
np.min(dist, axis=0) <= detection_threshold1,
np.min(dist, axis=0) > detection_threshold1]
choicelist1 = [np.argmin(dist, axis=0), -1]
match_id1 = np.select(condlist1, choicelist1)
detected_table['match_true_id1'] = match_id1
detected_table['match_galtileid1'] = true_table['galtileid'][match_id1]
detection_threshold2 = 5
detection_threshold2 = 0.5
condlist2 = [
np.min(dist, axis=0) <= detection_threshold2,
np.min(dist, axis=0) > detection_threshold2]
choicelist2 = [np.argmin(dist, axis=0), -1]
np.min(norm_dist, axis=0) <= detection_threshold2,
np.min(norm_dist, axis=0) > detection_threshold2]
choicelist2 = [np.argmin(norm_dist, axis=0), -1]
match_id2 = np.select(condlist2, choicelist2)
detected_table['match_true_id2'] = match_id2
detected_table['match_galtileid2'] = true_table['galtileid'][match_id2]
Expand Down Expand Up @@ -317,13 +317,13 @@ def run(Metrics_params, test_size=1000, dSigma_detection=True):
batch_detection_result = Metrics_params.get_detections()
if (
len(batch_detection_result[0]) != len(batch_detection_result[1]) or
len(batch_detection_result[0]) != Metrics_params.param.batch_size
len(batch_detection_result[0]) != Metrics_params.sim_param.batch_size
):
raise ValueError("Metrics_params.get_detections output must be "
"two lists of astropy table of length batch size."
f" Found {len(batch_detection_result[0])}, "
f"{len(batch_detection_result[1])}, "
f"{ Metrics_params.param.batch_size}")
f"{ Metrics_params.sim_param.batch_size}")
true_table, detected_table, detection_summary = evaluate_detection(
batch_detection_result[0], batch_detection_result[1],
batch_index=i)
Expand Down
1 change: 1 addition & 0 deletions btk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Simulation_params(object):
in the sky background during a full exposure.
verbose: If true, prints description at multiple steps.
"""

def __init__(self, catalog_name, max_number=2,
batch_size=8, stamp_size=24,
survey_name="LSST",
Expand Down
3 changes: 2 additions & 1 deletion btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def draw_isolated(Args, galaxy, iso_obs):
no_margin=False,
verbose_render=False)
iso_render_engine.render_galaxy(
galaxy, no_partials=True, calculate_bias=False, no_analysis=True)
galaxy, variations_x=None, variations_s=None, variations_g=None,
no_fisher=True, calculate_bias=False, no_analysis=True)
return iso_obs


Expand Down
24 changes: 13 additions & 11 deletions btk/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_rgb(image, min_val=None, max_val=None):
if max_val is None:
max_val = image.max(axis=-1).max(axis=-1)
new_image = np.transpose(image, axes=(1, 2, 0))
new_image = (new_image - min_val) / (max_val - min_val)*255
new_image = (new_image - min_val) / (max_val - min_val) * 255
new_image[new_image < 0] = 0
new_image[new_image > 255] = 255
return new_image.astype(np.uint8)
Expand All @@ -47,8 +47,8 @@ def get_rgb_image(image, normalize_with_image=None):
uint8 array [height, width, bands] of the input image.
"""
try:
import scarlet
if normalize_with_image:
import scarlet.display
if normalize_with_image is not None:
norm = scarlet.display.Asinh(img=normalize_with_image, Q=20)
else:
norm = None
Expand Down Expand Up @@ -94,7 +94,7 @@ def plot_blends(blend_images, blend_list, detected_centers=None,
raise ValueError(f"band_indices must be a list with 3 entries, not \
{band_indices}")
if detected_centers is None:
detected_centers = [[]]*batch_size
detected_centers = [[]] * batch_size
if (len(detected_centers) != batch_size or
blend_images.shape[0] != batch_size):
raise ValueError(f"Length of detected_centers and length of blend_list\
Expand Down Expand Up @@ -163,15 +163,17 @@ def plot_with_isolated(blend_images, isolated_images, blend_list,
raise ValueError(f"band_indices must be a list with 3 entries, not \
{band_indices}")
if detected_centers is None:
detected_centers = [[]]*b_size
detected_centers = [[]] * b_size
if (len(detected_centers) != b_size or len(isolated_images) != b_size or
blend_images.shape[0] != b_size):
raise ValueError(f"Length of detected_centers and length of blend_list\
must be equal to first dimension of blend_images, found \
{len(detected_centers), len(blend_list), len(blend_images)}")
for i in range(len(blend_list)):
images = np.transpose(blend_images[i], axes=(2, 0, 1))
blend_img_rgb = get_rgb_image(images[band_indices])
blend_img_rgb = get_rgb_image(
images[band_indices],
normalize_with_image=images[band_indices])
plt.figure(figsize=(2, 2))
plt.imshow(blend_img_rgb)
plt.title(f"{len(blend_list[i])} objects")
Expand All @@ -184,7 +186,7 @@ def plot_with_isolated(blend_images, isolated_images, blend_list,
plt.show()
iso_blend = isolated_images[i]
num = iso_blend.shape[0]
plt.figure(figsize=(2*num, 2))
plt.figure(figsize=(2 * num, 2))
for j in range(num):
iso_images = np.transpose(iso_blend[j], axes=(2, 0, 1))
iso_img_rgb = get_rgb_image(
Expand Down Expand Up @@ -256,10 +258,10 @@ def plot_metrics_summary(summary, num, ax=None):
ax.imshow(results_table, origin='left', cmap=plt.cm.Blues)
ax.set_xlabel("# true objects")
# Don't print zero'th column
ax.set_xlim([0.5, num+0.5])
ax.set_xlim([0.5, num + 0.5])
ax.set_ylabel("# correctly detected objects")
ax.set_xticks(np.arange(1, num+1, 1.0))
ax.set_yticks(np.arange(0, num+2, 1.0))
ax.set_xticks(np.arange(1, num + 1, 1.0))
ax.set_yticks(np.arange(0, num + 2, 1.0))
for (j, i), label in np.ndenumerate(results_table):
if i == 0:
# Don't print efficiency for zero'th column
Expand All @@ -270,7 +272,7 @@ def plot_metrics_summary(summary, num, ax=None):
ax.text(i, j, f"{label:.1f}%",
ha='center', va='center', color=color)
if i == j:
rect = patches.Rectangle((i-0.5, j-0.5), 1, 1, linewidth=2,
rect = patches.Rectangle((i - 0.5, j - 0.5), 1, 1, linewidth=2,
edgecolor='mediumpurple',
facecolor='none')
ax.add_patch(rect)
Loading

0 comments on commit f36dfa4

Please sign in to comment.