Skip to content

Commit

Permalink
Make kt masks more efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jun 30, 2024
1 parent fdcfb38 commit 1ed8960
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,15 +2457,15 @@ class KtRadialMaskFunc(KtBaseMaskFunc):
uniform_range : bool, optional
If True then an acceleration will be uniformly sampled between the two values, by default False.
crop_corner : bool, optional
If True, the mask is cropped to the corners. Default: True.
If True, the mask is cropped to the corners. Default: False.
"""

def __init__(
self,
accelerations: Union[list[Number], tuple[Number, ...]],
center_fractions: Union[list[float], tuple[float, ...]],
uniform_range: bool = False,
crop_corner: bool = True,
crop_corner: bool = False,
) -> None:
"""Inits :class:`KtRadialMaskFunc`.
Expand All @@ -2478,7 +2478,7 @@ def __init__(
uniform_range : bool, optional
If True then an acceleration will be uniformly sampled between the two values, by default False.
crop_corner : bool, optional
If True, the mask is cropped to the corners. Default: True.
If True, the mask is cropped to the corners. Default: False.
"""
super().__init__(
accelerations=accelerations,
Expand Down Expand Up @@ -2517,40 +2517,40 @@ def mask_func(
(nt, num_rows, num_cols) = shape[-4:-1]

with temp_seed(self.rng, seed):

center_fraction, acceleration = self.choose_acceleration()
num_low_freqs = int(round(num_cols * center_fraction))

offset_angle = self.rng.uniform(0, 360)

acs_mask = self.zero_pad_to_center(np.ones((nt, num_low_freqs, num_low_freqs)), [nt, num_rows, num_cols])
acs_mask = centered_disk_mask((num_rows, num_cols), center_fraction)
num_low_freqs = acs_mask.sum()
acs_mask = np.tile(acs_mask, (nt, 1, 1))

if return_acs:
return torch.from_numpy(acs_mask.astype(bool)[np.newaxis, ..., np.newaxis])

adjusted_acceleration = (acceleration * (num_low_freqs**2 - num_rows * num_cols)) / (
num_low_freqs**2 * acceleration - num_rows * num_cols
adjusted_acceleration = (acceleration * (num_low_freqs - num_rows * num_cols)) / (
num_low_freqs * acceleration - num_rows * num_cols
)

rate = 1 / adjusted_acceleration
beams = int(rate * np.mean([num_rows, num_cols])) # beams is the number of angles
num_beams = int(rate * np.mean([num_rows, num_cols])) # num_beams is the number of angles

if self.crop_corner:
temp_size = max(num_rows, num_cols)
else:
temp_size = int(np.sqrt(num_rows**2 + num_cols**2))
temp_size = int(np.sqrt(2) * max(num_rows, num_cols))

aux = np.zeros((temp_size, temp_size))
aux[int(temp_size / 2), :] = 1

mask = np.zeros((nt, num_rows, num_cols))
for i in range(nt):
angles = np.linspace(0 + offset_angle * i, 180 + offset_angle * i + 1, beams)
mask_t = np.zeros((num_rows, num_cols))
for ang in angles:
temp = self.crop_center(rotate(aux, ang, reshape=False, order=0), num_rows, num_cols)
mask_t += temp
mask[i] = mask_t
base_mask = np.sum(
[rotate(aux, angle, reshape=False, order=0) for angle in np.linspace(0, 180, num_beams)], axis=0
)
mask = [self.crop_center(base_mask, num_rows, num_cols)]

nt_angles = np.linspace(offset_angle, offset_angle + 180, nt)
for angle in nt_angles[:-1]:
mask.append(self.crop_center(rotate(base_mask, angle, reshape=False, order=0), num_rows, num_cols))
mask = np.stack(mask, 0)

mask = mask + acs_mask
mask = mask > 0
Expand Down Expand Up @@ -2628,13 +2628,14 @@ def mask_func(
center_fraction, acceleration = self.choose_acceleration()
num_low_freqs = int(round(num_cols * center_fraction))

# Fully sampled rectangle region
acs_mask = self.zero_pad_to_center(np.ones((nt, num_rows, num_low_freqs)), [nt, num_rows, num_cols])

if return_acs:
return torch.from_numpy(acs_mask.astype(bool)[np.newaxis, ..., np.newaxis])

adjusted_acceleration = int(
(acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols)
adjusted_acceleration = (acceleration * (num_low_freqs - num_cols)) / (
num_low_freqs * acceleration - num_cols
)

ptmp = np.zeros(num_cols)
Expand All @@ -2657,9 +2658,9 @@ def mask_func(

ph, ti = self.resolve_duplicates_on_kt_grid(ph, ti, num_cols, nt)
samp = np.zeros((num_cols, nt), dtype=int)
indices = num_cols * (ti + (nt // 2)) + (ph + (num_cols // 2))
indices[indices <= 0] = 1 # Ensure indices are within bounds
samp.ravel()[indices.astype(int)] = 1
inds = np.round(num_cols * (ti + nt // 2) + (ph + num_cols // 2)).astype(int)
inds[inds <= 0] = 1 # Ensure indices are within bounds
samp.ravel()[inds] = 1

mask = np.tile(samp, (num_rows, 1, 1)).transpose(2, 0, 1)
mask = mask + acs_mask
Expand Down Expand Up @@ -2752,13 +2753,14 @@ def mask_func(
center_fraction, acceleration = self.choose_acceleration()
num_low_freqs = int(round(num_cols * center_fraction))

# Fully sampled rectangle region
acs_mask = self.zero_pad_to_center(np.ones((nt, num_rows, num_low_freqs)), [nt, num_rows, num_cols])

if return_acs:
return torch.from_numpy(acs_mask.astype(bool)[np.newaxis, ..., np.newaxis])

adjusted_acceleration = int(
(acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols)
adjusted_acceleration = (acceleration * (num_low_freqs - num_cols)) / (
num_low_freqs * acceleration - num_cols
)

p1 = np.arange(-num_cols // 2, num_cols // 2)
Expand All @@ -2770,7 +2772,7 @@ def mask_func(

sigma = num_cols / self.std_scale # Std of the Gaussian envelope for sampling density

prob = 0.1 + self.alpha / (1 - self.alpha + 1e-10) * np.exp(-((p1) ** 2) / (sigma**2))
prob = 0.1 + self.alpha / (1 - self.alpha + 1e-10) * np.exp(-(p1**2) / (sigma**2))

ind = 0
for i in range(-nt // 2, nt // 2):
Expand All @@ -2784,10 +2786,10 @@ def mask_func(
ind += n_tmp

ph, ti = self.resolve_duplicates_on_kt_grid(ph, ti, num_cols, nt)
samp = np.zeros((num_cols, nt), dtype=int)
inds = num_cols * (ti + nt // 2) + (ph + num_cols // 2)
inds = inds.astype(int)
samp = np.zeros((nt, num_cols), dtype=int)
inds = np.round(num_cols * (ti + nt // 2) + (ph + num_cols // 2)).astype(int)
samp.ravel()[inds] = 1
samp = samp.T

mask = np.tile(samp, (num_rows, 1, 1)).transpose(2, 0, 1)
mask = mask + acs_mask
Expand Down

0 comments on commit 1ed8960

Please sign in to comment.