Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

phase randomization score map correction #182

Merged
merged 22 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/pytom_tm/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BetweenZeroAndOne,
)
from pytom_tm.tmjob import load_json_to_tmjob
from os import urandom


def _parse_argv(argv=None):
Expand Down Expand Up @@ -756,6 +757,26 @@ def match_template(argv=None):
"apply it to the tomogram patch and template. Effectively puts more weight on "
"high resolution features and sharpens the correlation peaks.",
)
additional_group = parser.add_argument_group('Additional options')
additional_group.add_argument(
"-r",
"--random-phase-correction",
action="store_true",
default=False,
required=False,
help="Run template matching simultaneously with a phase randomized version of "
"the template, and subtract this 'noise' map from the final score map. "
"For this method please see STOPGAP as a reference: "
"https://doi.org/10.1107/S205979832400295X ."
)
additional_group.add_argument(
"--rng-seed",
action=LargerThanZero,
default=int.from_bytes(urandom(8)),
required=False,
help="Specify a seed for the random number generator used for phase "
"randomization for consistent results!"
)
device_group = parser.add_argument_group('Device control')
device_group.add_argument(
"-g",
Expand Down Expand Up @@ -836,6 +857,8 @@ def match_template(argv=None):
whiten_spectrum=args.spectral_whitening,
rotational_symmetry=args.z_axis_rotational_symmetry,
particle_diameter=args.particle_diameter,
random_phase_correction=args.random_phase_correction,
rng_seed=args.rng_seed,
)

score_volume, angle_volume = run_job_parallel(
Expand Down
143 changes: 112 additions & 31 deletions src/pytom_tm/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import voltools as vt
import gc
from typing import Optional
from cupyx.scipy.fft import rfftn, irfftn, fftshift
from cupyx.scipy.fft import rfftn, irfftn
from tqdm import tqdm
from pytom_tm.correlation import mean_under_mask, std_under_mask
from packaging import version
from pytom_tm.template import phase_randomize_template


class TemplateMatchingPlan:
Expand All @@ -17,7 +17,8 @@ def __init__(
template: npt.NDArray[float],
mask: npt.NDArray[float],
device_id: int,
wedge: Optional[npt.NDArray[float]] = None
wedge: Optional[npt.NDArray[float]] = None,
phase_randomized_template: Optional[npt.NDArray[float]] = None,
):
"""Initialize a template matching plan. All the necessary cupy arrays will be allocated on the GPU.

Expand All @@ -34,6 +35,8 @@ def __init__(
wedge: Optional[npt.NDArray[float]], default None
3D numpy array that contains the Fourier space weighting for the template, it should be in Fourier
reduced form, with dimensions (sx, sx, sx // 2 + 1)
phase_randomized_template: Optional[npt.NDArray[float]], default None
initialize the plan with a phase randomized version of the template for noise correction
"""
# Search volume + and fft transform plan for the volume
volume_shape = volume.shape
Expand Down Expand Up @@ -65,6 +68,16 @@ def __init__(
self.scores = cp.ones(volume_shape, dtype=cp.float32)*-1000
self.angles = cp.ones(volume_shape, dtype=cp.float32)*-1000

self.random_phase_template_texture = None
self.noise_scores = None
if phase_randomized_template is not None:
self.random_phase_template_texture = vt.StaticVolume(
cp.asarray(phase_randomized_template, dtype=cp.float32, order='C'),
interpolation='filt_bspline',
device=f'gpu:{device_id}',
)
self.noise_scores = cp.ones(volume_shape, dtype=cp.float32)*-1000

# wait for stream to complete the work
cp.cuda.stream.get_current_stream().synchronize()

Expand Down Expand Up @@ -92,7 +105,9 @@ def __init__(
angle_ids: list[int],
mask_is_spherical: bool = True,
wedge: Optional[npt.NDArray[float]] = None,
stats_roi: Optional[tuple[slice, slice, slice]] = None
stats_roi: Optional[tuple[slice, slice, slice]] = None,
noise_correction: bool = False,
rng_seed: int = 321,
):
"""Initialize a template matching run.

Expand All @@ -101,7 +116,8 @@ def __init__(
- pyTME: https://github.com/KosinskiLab/pyTME

The precalculation of conjugated FTs of the tomo was (AFAIK) introduced
by STOPGAP!
by STOPGAP! Also, they introduced simultaneous matching with a phase randomized
version of the template. https://doi.org/10.1107/S205979832400295X

Parameters
----------
Expand All @@ -127,6 +143,11 @@ def __init__(
stats_roi: Optional[tuple[slice, slice, slice]], default None
region of interest to calculate statistics on the search volume, default will just take the full search
volume
noise_correction: bool, default False
initialize template matching with a phase randomized version of the template that is used to subtract
background noise from the score map; expense is more gpu memory and computation time
rng_seed: int, default 321
seed for rng in phase randomization
"""
cp.cuda.Device(device_id).use()

Expand All @@ -146,8 +167,22 @@ def __init__(
)
else:
self.stats_roi = stats_roi
self.noise_correction = noise_correction

# create a 'random noise' version of the template
shuffled_template = (
phase_randomize_template(template, rng_seed)
if noise_correction else None
)

self.plan = TemplateMatchingPlan(volume, template, mask, device_id, wedge=wedge)
self.plan = TemplateMatchingPlan(
volume,
template,
mask,
device_id,
wedge=wedge,
phase_randomized_template=shuffled_template,
)

def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:
"""Run the template matching job.
Expand Down Expand Up @@ -226,28 +261,7 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:
rotation_units='rad'
)

if self.plan.wedge is not None:
# Add wedge to the template after rotating
self.plan.template = irfftn(
rfftn(self.plan.template) * self.plan.wedge,
s=self.plan.template.shape
)

# Normalize and mask template
mean = mean_under_mask(self.plan.template, self.plan.mask, mask_weight=self.plan.mask_weight)
std = std_under_mask(self.plan.template, self.plan.mask, mean, mask_weight=self.plan.mask_weight)
self.plan.template = ((self.plan.template - mean) / std) * self.plan.mask

# Paste in center
self.plan.template_padded[pad_index] = self.plan.template

# Fast local correlation function between volume and template, norm is the standard deviation at each
# point in the volume in the masked area
self.plan.ccc_map = (
irfftn(self.plan.volume_rft_conj * rfftn(self.plan.template_padded),
s=self.plan.template_padded.shape)
/ self.plan.std_volume
)
self.correlate(pad_index)

# Update the scores and angle_lists
update_results_kernel(
Expand All @@ -262,6 +276,31 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:
square_sum_kernel(self.plan.ccc_map * roi_mask) / roi_size
)

if self.noise_correction:
# Rotate noise template texture into template
self.plan.random_phase_template_texture.transform(
rotation=(rotation[0], rotation[1], rotation[2]),
rotation_order='rzxz',
output=self.plan.template,
rotation_units='rad'
)

self.correlate(pad_index)

# update noise scores results
update_noise_template_results_kernel(
self.plan.noise_scores,
self.plan.ccc_map,
self.plan.noise_scores,
)

# do the noise correction on the scores map: substract the noise scores first,
# and then add the noise mean to ensure stats are consistent
if self.noise_correction:
self.plan.scores = (
(self.plan.scores - self.plan.noise_scores) + self.plan.noise_scores.mean()
)

# Get correct orientation back!
# Use same method as William Wan's STOPGAP
# (https://doi.org/10.1107/S205979832400295X): the search volume is Fourier
Expand All @@ -283,6 +322,39 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]:

return results

def correlate(self, padding_index: tuple[slice, slice, slice]):
"""Correlate template and tomogram.

Parameters
----------
padding_index: tuple[slice, slice, slice]
Location to pad template after weighting and normalization
"""
if self.plan.wedge is not None:
# Add wedge to the template after rotating
self.plan.template = irfftn(
rfftn(self.plan.template) * self.plan.wedge,
s=self.plan.template.shape
)

# Normalize and mask template
mean = mean_under_mask(self.plan.template, self.plan.mask,
mask_weight=self.plan.mask_weight)
std = std_under_mask(self.plan.template, self.plan.mask, mean,
mask_weight=self.plan.mask_weight)
self.plan.template = ((self.plan.template - mean) / std) * self.plan.mask

# Paste in center
self.plan.template_padded[padding_index] = self.plan.template

# Fast local correlation function between volume and template, norm is the standard deviation at each
# point in the volume in the masked area
self.plan.ccc_map = (
irfftn(self.plan.volume_rft_conj * rfftn(self.plan.template_padded),
s=self.plan.template_padded.shape)
/ self.plan.std_volume
)


def std_under_mask_convolution(
volume_rft_conj: cpt.NDArray[float],
Expand Down Expand Up @@ -321,13 +393,22 @@ def std_under_mask_convolution(

"""Update scores and angles if a new maximum is found."""
update_results_kernel = cp.ElementwiseKernel(
'float32 scores, float32 ccc_map, float32 angle_id',
'float32 out1, float32 out2',
'if (scores < ccc_map) {out1 = ccc_map; out2 = angle_id;}',
'float32 scores, float32 ccc_map, float32 angles',
'float32 scores_out, float32 angles_out',
'if (scores < ccc_map) {scores_out = ccc_map; angles_out = angles;}',
'update_results'
)


"""Update scores for noise template"""
update_noise_template_results_kernel = cp.ElementwiseKernel(
'float32 scores, float32 ccc_map',
'float32 scores_out',
'if (scores < ccc_map) {scores_out = ccc_map;}',
'update_noise_template_results'
)


"""Calculate the sum of squares in a volume. Mean is assumed to be 0 which makes this operation a lot faster."""
square_sum_kernel = cp.ReductionKernel(
'T x', # input params
Expand Down
47 changes: 46 additions & 1 deletion src/pytom_tm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from scipy.ndimage import center_of_mass, zoom
from scipy.fft import rfftn, irfftn
from typing import Optional
from pytom_tm.weights import create_gaussian_low_pass
from pytom_tm.weights import (
create_ctf,
create_gaussian_low_pass,
radial_average,
radial_reduced_grid,
)


def generate_template_from_map(
Expand Down Expand Up @@ -103,3 +108,43 @@ def generate_template_from_map(
irfftn(rfftn(input_map) * lpf, s=input_map.shape),
input_spacing / output_spacing
)


def phase_randomize_template(
template: npt.NDArray[float],
seed: int = 321,
):
"""Create a version of the template that has its phases randomly
permuted in Fourier space.

Parameters
----------
template: npt.NDArray[float]
input structure
seed: int, default 321
seed for random number generator for phase permutation

Returns
-------
result: npt.NDArray[float]
phase randomized version of the template
"""
ft = rfftn(template)
amplitude = np.abs(ft)

# permute the phases in flattened version of the array
phase = np.angle(ft).flatten()
grid = np.fft.ifftshift(
radial_reduced_grid(template.shape), axes=(0, 1)
).flatten()
relevant_freqs = grid <= 1 # permute only up to Nyquist
noise = np.zeros_like(phase)
rng = np.random.default_rng(seed)
noise[relevant_freqs] = rng.permutation(phase[relevant_freqs])

# construct the new template
noise = np.reshape(noise, amplitude.shape)
result = irfftn(
amplitude * np.exp(1j * noise), s=template.shape
)
return result
18 changes: 15 additions & 3 deletions src/pytom_tm/tmjob.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
from importlib import metadata
from packaging import version
import pathlib
import os
Expand Down Expand Up @@ -66,6 +65,8 @@ def load_json_to_tmjob(file_name: pathlib.Path, load_for_extraction: bool = True
pytom_tm_version_number=data.get('pytom_tm_version_number', '0.3.0'),
job_loaded_for_extraction=load_for_extraction,
particle_diameter=data.get('particle_diameter', None),
random_phase_correction=data.get('random_phase_correction', False),
rng_seed=data.get('rng_seed', 321),
)
# if the file originates from an old version set the phase shift for compatibility
if (
Expand Down Expand Up @@ -212,6 +213,8 @@ def __init__(
pytom_tm_version_number: str = PYTOM_TM_VERSION,
job_loaded_for_extraction: bool = False,
particle_diameter: Optional[float] = None,
random_phase_correction: bool = False,
rng_seed: int = 321,
):
"""
Parameters
Expand Down Expand Up @@ -263,9 +266,12 @@ def __init__(
a string with the version number of pytom_tm for backward compatibility
job_loaded_for_extraction: bool, default False
flag to set for finished template matching jobs that are loaded back for extraction, it prevents
recalculation of the whitening filter which is unnecessary at this stage
particle_diameter: Optional[float], default None
particle diameter (in Angstrom) to calculate angular search
random_phase_correction: bool, default False,
run matching with a phase randomized version of the template to correct scores for noise
rng_seed: int, default 321
set a seed for the rng for phase randomization
"""
self.mask = mask
self.mask_is_spherical = mask_is_spherical
Expand Down Expand Up @@ -385,6 +391,10 @@ def __init__(
weights /= weights.max() # scale to 1
np.save(self.whitening_filter, weights)

# phase randomization options
self.random_phase_correction = random_phase_correction
self.rng_seed = rng_seed

# Job details
self.job_key = job_key
self.leader = None # the job that spawned this job
Expand Down Expand Up @@ -748,7 +758,9 @@ def start_job(
angle_ids=angle_ids,
mask_is_spherical=self.mask_is_spherical,
wedge=template_wedge,
stats_roi=search_volume_roi
stats_roi=search_volume_roi,
noise_correction=self.random_phase_correction,
rng_seed=self.rng_seed,
)
results = tm.run()
score_volume = results[0][:self.search_size[0], :self.search_size[1], :self.search_size[2]]
Expand Down
Loading