Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

small speedboost by precalculating conjugation #165

Merged
merged 15 commits into from
May 8, 2024
138 changes: 67 additions & 71 deletions src/pytom_tm/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,42 +36,45 @@ def __init__(
reduced form, with dimensions (sx, sx, sx // 2 + 1)
"""
# Search volume + and fft transform plan for the volume
self.volume = cp.asarray(volume, dtype=cp.float32, order='C')
self.volume_rft = rfftn(self.volume)
volume_shape = volume.shape
self.volume_rft_conj = rfftn(cp.asarray(volume, dtype=cp.float32, order='C')).conj()
self.volume_sq_rft_conj = rfftn(cp.asarray(volume, dtype=cp.float32, order='C') ** 2).conj()
McHaillet marked this conversation as resolved.
Show resolved Hide resolved
# Explicit fft plan is no longer necessary as cupy generates a plan behind the scene which leads to
# comparable timings

# Array for storing local standard deviations
self.std_volume = cp.zeros_like(volume, dtype=cp.float32)
self.std_volume = cp.zeros(volume_shape, dtype=cp.float32)

# Data for the mask
self.mask = cp.asarray(mask, dtype=cp.float32, order='C')
self.mask_texture = vt.StaticVolume(self.mask, interpolation='filt_bspline', device=f'gpu:{device_id}')
self.mask_padded = cp.zeros_like(self.volume).astype(cp.float32)
self.mask_padded = cp.zeros(volume_shape, dtype=cp.float32)
self.mask_weight = self.mask.sum() # weight of the mask

# Init template data
self.template = cp.asarray(template, dtype=cp.float32, order='C')
self.template_texture = vt.StaticVolume(self.template, interpolation='filt_bspline', device=f'gpu:{device_id}')
self.template_padded = cp.zeros_like(self.volume)
self.template_padded = cp.zeros(volume_shape, dtype=cp.float32)

# fourier binary wedge weight for the template
self.wedge = cp.asarray(wedge, order='C', dtype=cp.float32) if wedge is not None else None

# Initialize result volumes
self.ccc_map = cp.zeros_like(self.volume)
self.scores = cp.ones_like(self.volume)*-1000
self.angles = cp.ones_like(self.volume)*-1000
self.ccc_map = cp.zeros(volume_shape, dtype=cp.float32)
self.scores = cp.ones(volume_shape, dtype=cp.float32)*-1000
self.angles = cp.ones(volume_shape, dtype=cp.float32)*-1000

# wait for stream to complete the work
cp.cuda.stream.get_current_stream().synchronize()

def clean(self) -> None:
"""Remove all stored cupy arrays from the GPU's memory pool."""
gpu_memory_pool = cp.get_default_memory_pool()
del self.volume, self.volume_rft, self.mask, self.mask_texture, self.mask_padded, self.template, (
self.template_texture), self.template_padded, self.wedge, self.ccc_map, self.scores, self.angles, (
self.std_volume)
del (
self.volume_rft_conj, self.volume_sq_rft_conj, self.mask, self.mask_texture, self.mask_padded,
self.template, self.template_texture, self.template_padded, self.wedge, self.ccc_map, self.scores,
self.angles, self.std_volume
)
gc.collect()
gpu_memory_pool.free_all_blocks()

Expand All @@ -92,6 +95,13 @@ def __init__(
):
"""Initialize a template matching run.

For other great implementations see:
- STOPGAP: https://github.com/wan-lab-vanderbilt/STOPGAP
- pyTME: https://github.com/KosinskiLab/pyTME

The precalculation of conjugated FTs of the tomo was (AFAIK) introduced
by STOPGAP!

Parameters
----------
job_id: str
Expand Down Expand Up @@ -162,19 +172,28 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:
sxv, syv, szv = self.plan.template_padded.shape
cxv, cyv, czv = sxv // 2, syv // 2, szv // 2

# calculate roi size
roi_size = self.plan.volume[self.stats_roi].size
# create slice for padding
pad_index = (
slice(cxv - cxt, cxv + cxt + mx),
slice(cyv - cyt, cyv + cyt + my),
slice(czv - czt, czv + czt + mz),
)

# calculate roi mask
shift = cp.floor(cp.array(self.plan.scores.shape) / 2).astype(int) + 1
roi_mask = cp.zeros(self.plan.scores.shape, dtype=bool)
roi_mask[self.stats_roi] = True
roi_mask = cp.flip(cp.roll(roi_mask, -shift, (0, 1, 2)))
roi_size = self.plan.scores[roi_mask].size

if self.mask_is_spherical: # Then we only need to calculate std volume once
self.plan.mask_padded[cxv - cxt:cxv + cxt + mx,
cyv - cyt:cyv + cyt + my,
czv - czt:czv + czt + mz] = self.plan.mask
self.plan.mask_padded[pad_index] = self.plan.mask
self.plan.std_volume = std_under_mask_convolution(
self.plan.volume,
self.plan.volume_rft_conj,
self.plan.volume_sq_rft_conj,
self.plan.mask_padded,
self.plan.mask_weight,
volume_rft=self.plan.volume_rft
)
) * self.plan.mask_weight

# Track iterations with a tqdm progress bar
for i in tqdm(range(len(self.angle_ids))):
Expand All @@ -189,16 +208,14 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:
output=self.plan.mask,
rotation_units='rad'
)
self.plan.mask_padded[cxv - cxt:cxv + cxt + mx,
cyv - cyt:cyv + cyt + my,
czv - czt:czv + czt + mz] = self.plan.mask
self.plan.mask_padded[pad_index] = self.plan.mask
McHaillet marked this conversation as resolved.
Show resolved Hide resolved
# Std volume needs to be recalculated for every rotation of the mask, expensive step
self.plan.std_volume = std_under_mask_convolution(
self.plan.volume,
self.plan.volume_rft_conj,
self.plan.volume_sq_rft_conj,
self.plan.mask_padded,
self.plan.mask_weight,
volume_rft=self.plan.volume_rft,
)
) * self.plan.mask_weight

# Rotate template
self.plan.template_texture.transform(
Expand All @@ -213,24 +230,22 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:
self.plan.template = irfftn(
rfftn(self.plan.template) * self.plan.wedge,
s=self.plan.template.shape
).real
)

# Normalize and mask template
mean = mean_under_mask(self.plan.template, self.plan.mask, mask_weight=self.plan.mask_weight)
std = std_under_mask(self.plan.template, self.plan.mask, mean, mask_weight=self.plan.mask_weight)
self.plan.template = ((self.plan.template - mean) / std) * self.plan.mask

# Paste in center
self.plan.template_padded[cxv - cxt:cxv + cxt + mx,
cyv - cyt:cyv + cyt + my,
czv - czt:czv + czt + mz] = self.plan.template
self.plan.template_padded[pad_index] = self.plan.template

# Fast local correlation function between volume and template, norm is the standard deviation at each
# point in the volume in the masked area
self.plan.ccc_map = fftshift(
irfftn(self.plan.volume_rft * rfftn(self.plan.template_padded).conj(),
s=self.plan.template_padded.shape).real
/ (self.plan.mask_weight * self.plan.std_volume)
self.plan.ccc_map = (
irfftn(self.plan.volume_rft_conj * rfftn(self.plan.template_padded),
s=self.plan.template_padded.shape)
/ self.plan.std_volume
)

# Update the scores and angle_lists
Expand All @@ -243,11 +258,18 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:
)

self.stats['variance'] += (
square_sum_kernel(
self.plan.ccc_map[self.stats_roi]
) / roi_size
square_sum_kernel(self.plan.ccc_map * roi_mask) / roi_size
)

# Get correct orientation back!
# Use same method as William Wan's STOPGAP
# (https://doi.org/10.1107/S205979832400295X): the search volume is Fourier
# transformed and conjugated before the iterations this means the eventual
# score map needs to be flipped back. The map is also rolled due to the ftshift
# effect of a Fourier space correlation function.
self.plan.scores = cp.roll(cp.flip(self.plan.scores), shift, axis=(0, 1, 2))
self.plan.angles = cp.roll(cp.flip(self.plan.angles), shift, axis=(0, 1, 2))

self.stats['search_space'] = int(roi_size * len(self.angle_ids))
self.stats['variance'] = float(self.stats['variance'] / len(self.angle_ids))
self.stats['std'] = float(cp.sqrt(self.stats['variance']))
Expand All @@ -262,66 +284,40 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:


def std_under_mask_convolution(
volume: cpt.NDArray[float],
volume_rft_conj: cpt.NDArray[float],
volume_sq_rft_conj: cpt.NDArray[float],
McHaillet marked this conversation as resolved.
Show resolved Hide resolved
padded_mask: cpt.NDArray[float],
mask_weight: float,
volume_rft: Optional[cpt.NDArray[complex]] = None
) -> cpt.NDArray[float]:
"""Calculate the local standard deviation under the mask for each position in the volume. Calculation is done in
Fourier space as this is a convolution between volume and mask.

Parameters
----------
volume: cpt.NDArray[float]
cupy array to calculate local std in
volume_rft_conj: cpt.NDArray[float]
complex conjugate of the rft of the search volume
volume_sq_rft_conj: cpt.NDArray[float]
complex conjugate of the rft of the squared search volume
padded_mask: cpt.NDArray[float]
template mask that has been padded to dimensions of volume
mask_weight: float
weight of the mask, usually calculated as mask.sum()
volume_rft: Optional[cpt.NDArray[float]], default None
optionally provide a precalculated reduced Fourier transform of volume to save computation

Returns
-------
std_v: cpt.NDArray[float]
array with local standard deviations in volume
"""
volume_rft = rfftn(volume) if volume_rft is None else volume_rft
padded_mask_rft = rfftn(padded_mask)
std_v = (
mean_under_mask_convolution(rfftn(volume ** 2), padded_mask, mask_weight) -
mean_under_mask_convolution(volume_rft, padded_mask, mask_weight) ** 2
irfftn(volume_sq_rft_conj * padded_mask_rft, s=padded_mask.shape) / mask_weight -
(irfftn(volume_rft_conj * padded_mask_rft, s=padded_mask.shape) / mask_weight) ** 2
)
std_v[std_v <= cp.float32(1e-18)] = 1 # prevent potential sqrt of negative value and division by zero
std_v = cp.sqrt(std_v)
return std_v


def mean_under_mask_convolution(
volume_rft: cpt.NDArray[complex],
mask: cpt.NDArray[float],
mask_weight: float
) -> cpt.NDArray[float]:
"""Calculate local mean in volume under the masked region.

Parameters
----------
volume_rft: cpt.NDArray[complex]
array containing the rfftn of the volume
mask: cpt.NDArray[float]
mask to calculate the mean under
mask_weight: float
weight of the mask, usually calculated as mask.sum()

Returns
-------
mean: cpt.NDArray[float]
array with local means under the mask
"""
return irfftn(
volume_rft * rfftn(mask).conj(), s=mask.shape
).real / mask_weight


"""Update scores and angles if a new maximum is found."""
update_results_kernel = cp.ElementwiseKernel(
'float32 scores, float32 ccc_map, float32 angle_id',
Expand Down