Skip to content

Commit

Permalink
version 0.3.1: job annotation with module version number, symmetry pa…
Browse files Browse the repository at this point in the history
…tch, and whitening filter calculation check (#57)

* apply rotation symmetry in extraction threshold estimation

* update to 0.3.1 for bug fix

* angle sorting is crucial for symmetry

* sorting key is not neccesary + sort after checking if list is valid

* fix unittests after new sorting

* remove print line

* added note to README about rotational symmetry update

* rotational symmetry only possible around z-axis

* Update pytom_match_template.py

* remove readme warning

* added version to the job and made version loading more dynamic

* add sorting option to angles based on version number; whitening filter is now only calcualted if not detected in the output dir, otherwise neesds to be recalculated for every job init

* extract.py also needs to import the version module from packaging

* switched to importlib for detecting version

* Update src/pytom_tm/angles.py

Co-authored-by: Sander Roet <[email protected]>

---------

Co-authored-by: Sander Roet <[email protected]>
  • Loading branch information
McHaillet and sroet authored Nov 3, 2023
1 parent 0fb6ce1 commit 5dbc519
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 22 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
name='pytom-template-matching-gpu',
packages=['pytom_tm', 'pytom_tm.angle_lists'],
package_dir={'': 'src'},
version='0.3.0', # for versioning definition see https://semver.org/
version='0.3.1', # for versioning definition see https://semver.org/
description='GPU template matching from PyTOM as a lightweight pip package',
long_description=long_description,
long_description_content_type='text/markdown',
Expand All @@ -24,7 +24,8 @@
'tqdm',
'mrcfile',
'starfile',
'importlib_resources'
'importlib_resources',
'packaging',
],
extras_require={
'plotting': ['matplotlib', 'seaborn']
Expand Down
9 changes: 5 additions & 4 deletions src/bin/pytom_match_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def main():
'Alternatively, a .txt file can be provided with three Euler angles (in radians) per '
'line that define the angular search. Angle format is ZXZ anti-clockwise (see: '
'https://www.ccpem.ac.uk/user_help/rotation_conventions.php).')
parser.add_argument('--rotational-symmetry', type=int, required=False, action=LargerThanZero, default=1,
help='Integer value indicating the rotational symmetry of the template. The length of the '
'rotation search will be shortened through division by this value.')
parser.add_argument('--z-axis-rotational-symmetry', type=int, required=False, action=LargerThanZero, default=1,
help='Integer value indicating the rotational symmetry of the template around the z-axis. The '
'length of the rotation search will be shortened through division by this value. Only '
'works for template symmetry around the z-axis.')
parser.add_argument('-s', '--volume-split', nargs=3, type=int, required=False, default=[1, 1, 1],
help='Split the volume into smaller parts for the search, can be relevant if the volume does '
'not fit into GPU memory. Format is x y z, e.g. --volume-split 1 2 1')
Expand Down Expand Up @@ -125,7 +126,7 @@ def main():
dose_accumulation=args.dose_accumulation,
ctf_data=ctf_params,
whiten_spectrum=args.spectral_whitening,
rotational_symmetry=args.rotational_symmetry,
rotational_symmetry=args.z_axis_rotational_symmetry,
)

score_volume, angle_volume = run_job_parallel(job, tuple(args.volume_split), args.gpu_ids)
Expand Down
3 changes: 3 additions & 0 deletions src/pytom_tm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from importlib import metadata
__version__ = metadata.version('pytom-template-matching-gpu')

try:
import cupy
except (ModuleNotFoundError, ImportError):
Expand Down
7 changes: 4 additions & 3 deletions src/pytom_tm/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
v[0] = ANGLE_LIST_DIR.joinpath(v[0])


def load_angle_list(file_name: pathlib.Path) -> list[tuple[float, float, float]]:
def load_angle_list(file_name: pathlib.Path, sort_angles: bool = True) -> list[tuple[float, float, float]]:
with open(str(file_name)) as fstream:
lines = fstream.readlines()
angle_list = [tuple(map(float, x.strip().split(' '))) for x in lines]
if not all([len(a) == 3 for a in angle_list]):
raise ValueError('Invalid angle file provided, each line should have 3 ZXZ Euler angles!')
else:
return angle_list
if sort_angles:
angle_list.sort() # angle list needs to be sorted otherwise symmetry reduction cannot be used!
return angle_list


def convert_euler(
Expand Down
8 changes: 6 additions & 2 deletions src/pytom_tm/extract.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from packaging import version
import pandas as pd
import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -36,7 +37,10 @@ def extract_particles(

score_volume = read_mrc(job.output_dir.joinpath(f'{job.tomo_id}_scores.mrc'))
angle_volume = read_mrc(job.output_dir.joinpath(f'{job.tomo_id}_angles.mrc'))
angle_list = load_angle_list(job.rotation_file)
angle_list = load_angle_list(
job.rotation_file,
sort_angles=version.parse(job.pytom_tm_version_number) > version.parse('0.3.0')
)

# mask edges of score volume
score_volume[0: particle_radius_px, :, :] = -1
Expand All @@ -61,7 +65,7 @@ def extract_particles(
search_space = (
# wherever the score volume has not been explicitly set to -1 is the size of the search region
(score_volume > -1).sum() *
job.n_rotations
int(np.ceil(job.n_rotations / job.rotational_symmetry))
)
cut_off = erfcinv((2 * n_false_positives) / search_space) * np.sqrt(2) * sigma
logging.info(f'cut off for particle extraction: {cut_off}')
Expand Down
25 changes: 18 additions & 7 deletions src/pytom_tm/tmjob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations
from importlib import metadata
from packaging import version
import pathlib
import copy
import numpy as np
Expand Down Expand Up @@ -39,6 +41,8 @@ def load_json_to_tmjob(file_name: pathlib.Path) -> TMJob:
ctf_data=data.get('ctf_data', None),
whiten_spectrum=data.get('whiten_spectrum', False),
rotational_symmetry=data.get('rotational_symmetry', 1),
# if version number is not in the .json, it must be 0.3.0 or older
pytom_tm_version_number=data.get('pytom_tm_version_number', '0.3.0'),
)
job.rotation_file = pathlib.Path(data['rotation_file'])
job.whole_start = data['whole_start']
Expand Down Expand Up @@ -79,7 +83,8 @@ def __init__(
dose_accumulation: Optional[list[float, ...]] = None,
ctf_data: Optional[list[dict, ...]] = None,
whiten_spectrum: bool = False,
rotational_symmetry: int = 1
rotational_symmetry: int = 1,
pytom_tm_version_number: str = metadata.version('pytom-template-matching-gpu')
):
self.mask = mask
self.mask_is_spherical = mask_is_spherical
Expand Down Expand Up @@ -171,12 +176,12 @@ def __init__(
self.dose_accumulation = dose_accumulation
self.ctf_data = ctf_data
self.whiten_spectrum = whiten_spectrum

if self.whiten_spectrum:
self.whitening_filter = self.output_dir.joinpath(f'{self.tomo_id}_whitening_filter.npy')
if self.whiten_spectrum and not self.whitening_filter.exists():
logging.info('Estimating whitening filter...')
weights = 1 / np.sqrt(power_spectrum_profile(read_mrc(self.tomogram)))
weights /= weights.max() # scale to 1
np.save(self.output_dir.joinpath('whitening_filter.npy'), weights)
np.save(self.whitening_filter, weights)

# Job details
self.job_key = job_key
Expand All @@ -188,6 +193,9 @@ def __init__(

self.log_level = log_level

# version number of the job
self.pytom_tm_version_number = pytom_tm_version_number

def copy(self) -> TMJob:
return copy.deepcopy(self)

Expand Down Expand Up @@ -395,7 +403,7 @@ def start_job(
self.low_pass,
self.high_pass
) * (profile_to_weighting(
np.load(self.output_dir.joinpath('whitening_filter.npy')),
np.load(self.whitening_filter),
search_volume.shape
) if self.whiten_spectrum else 1)).astype(np.float32)

Expand All @@ -415,7 +423,7 @@ def start_job(
accumulated_dose_per_tilt=self.dose_accumulation,
ctf_params_per_tilt=self.ctf_data
) * (profile_to_weighting(
np.load(self.output_dir.joinpath('whitening_filter.npy')),
np.load(self.whitening_filter),
self.template_shape
) if self.whiten_spectrum else 1)).astype(np.float32)

Expand Down Expand Up @@ -447,7 +455,10 @@ def start_job(
int(np.ceil(self.n_rotations / self.rotational_symmetry)),
self.steps_slice
))
angle_list = load_angle_list(self.rotation_file)[slice(
angle_list = load_angle_list(
self.rotation_file,
sort_angles=version.parse(self.pytom_tm_version_number) > version.parse('0.3.0')
)[slice(
self.start_slice,
int(np.ceil(self.n_rotations / self.rotational_symmetry)),
self.steps_slice
Expand Down
2 changes: 1 addition & 1 deletion tests/test_template_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_search(self):
self.assertEqual(angle_id, angle_volume[ind])
self.assertSequenceEqual(loc, ind)
self.assertEqual(stats['search_space'], 256000000, msg='Search space should exactly equal this value')
self.assertAlmostEqual(stats['std'], 0.005175, places=6,
self.assertAlmostEqual(stats['std'], 0.005175, places=5,
msg='Standard deviation of the search should be almost equal')


Expand Down
6 changes: 3 additions & 3 deletions tests/test_tmjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_tm_job_split_volume(self):
score, angle = self.job.merge_sub_jobs()
ind = np.unravel_index(score.argmax(), score.shape)

self.assertTrue(score.max() > 0.934, msg='lcc max value lower than expected')
self.assertTrue(score.max() > 0.931, msg='lcc max value lower than expected')
self.assertEqual(ANGLE_ID, angle[ind])
self.assertSequenceEqual(LOCATION, ind)

Expand Down Expand Up @@ -204,7 +204,7 @@ def test_tm_job_split_angles(self):
score, angle = self.job.merge_sub_jobs()
ind = np.unravel_index(score.argmax(), score.shape)

self.assertTrue(score.max() > 0.934, msg='lcc max value lower than expected')
self.assertTrue(score.max() > 0.931, msg='lcc max value lower than expected')
self.assertEqual(ANGLE_ID, angle[ind])
self.assertSequenceEqual(LOCATION, ind)

Expand All @@ -217,7 +217,7 @@ def test_parallel_manager(self):
score, angle = run_job_parallel(self.job, volume_splits=(1, 3, 1), gpu_ids=[0])
ind = np.unravel_index(score.argmax(), score.shape)

self.assertTrue(score.max() > 0.934, msg='lcc max value lower than expected')
self.assertTrue(score.max() > 0.931, msg='lcc max value lower than expected')
self.assertEqual(ANGLE_ID, angle[ind])
self.assertSequenceEqual(LOCATION, ind)

Expand Down

0 comments on commit 5dbc519

Please sign in to comment.