diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 2e693b4e..a00fe57b 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -44,7 +44,7 @@ jobs: path: coverage.xml repo_token: ${{ secrets.GITHUB_TOKEN }} pull_request_number: ${{ steps.get-pr.outputs.PR }} - minimum_coverage: 79 + minimum_coverage: 80 show_missing: True fail_below_threshold: True link_missing_lines: True diff --git a/README.md b/README.md index e5480da2..aceb7916 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,10 @@ ![test-badge](https://github.com/SBC-Utrecht/pytom-match-pick/actions/workflows/unit-tests.yml/badge.svg?branch=main) -# GPU template matching for cryo-ET +# pytom-match-pick: GPU template matching for cryo-ET GPU template matching, originally developed in [PyTom](https://github.com/SBC-Utrecht/PyTom), as a standalone python package that can be run from the command line. +![cover_image](images/tomo200528_100_illustration.png) + ## Requires ``` @@ -80,6 +82,13 @@ The following scripts are available to run with `--help` to see parameters: - estimate an ROC curve from a job file (.json): `pytom_estimate_roc.py --help` - merge multiple star files to a single starfile: `pytom_merge_stars.py --help` +## Usage questions, ideas and solutions, engagement, etc +Please use our [github discussions](https://github.com/SBC-Utrecht/pytom-match-pick/discussions) for: + - Asking questions about bottlenecks. + - Share ideas and solutions. + - Engage with other community members about processing strategies. + - etc... + ## Developer install If you want the most up-to-date version of the code you can get install it from this repository via: diff --git a/images/tomo200528_100_illustration.png b/images/tomo200528_100_illustration.png new file mode 100644 index 00000000..16b9032e Binary files /dev/null and b/images/tomo200528_100_illustration.png differ diff --git a/pyproject.toml b/pyproject.toml index 7d2534f8..00b01df2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pytom-match-pick" -version = "0.5.0" +version = "0.7.0" description = "PyTOM's GPU template matching module as an independent package" readme = "README.md" license = {file = "LICENSE"} @@ -26,7 +26,7 @@ classifiers = [ requires-python = ">= 3.9" dependencies = [ "numpy", - "cupy", + "cupy!=13.0.*", #see https://github.com/SBC-Utrecht/pytom-match-pick/issues/106 "voltools", "tqdm", "mrcfile", diff --git a/src/pytom_tm/entry_points.py b/src/pytom_tm/entry_points.py index 91db4441..4b5e45f9 100644 --- a/src/pytom_tm/entry_points.py +++ b/src/pytom_tm/entry_points.py @@ -212,6 +212,13 @@ def pytom_create_template(argv=None): default=2.7, help="Spherical aberration in mm.", ) + parser.add_argument( + "--phase-shift", + type=float, + required=False, + default=.0, + help="Phase shift (in degrees) for the CTF to model phase plates.", + ) parser.add_argument( "--cut-after-first-zero", action="store_true", @@ -324,6 +331,7 @@ def pytom_create_template(argv=None): "spherical_aberration": args.Cs * 1e-3, "cut_after_first_zero": args.cut_after_first_zero, "flip_phase": args.flip_phase, + "phase_shift_deg": args.phase_shift, } template = generate_template_from_map( @@ -433,7 +441,8 @@ def estimate_roc(argv=None): template_matching_job = load_json_to_tmjob(args.job_file) # Set cut off to -1 to ensure the number of particles gets extracted _, lcc_max_values = extract_particles( - template_matching_job, args.radius_px, args.number_of_particles, cut_off=0 + template_matching_job, args.radius_px, args.number_of_particles, cut_off=0, + create_plot=False ) score_volume = read_mrc( @@ -581,7 +590,8 @@ def match_template(argv=None): parser = argparse.ArgumentParser( description="Run template matching. -- Marten Chaillet (@McHaillet)" ) - parser.add_argument( + io_group = parser.add_argument_group("Template, search volume, and output") + io_group.add_argument( "-t", "--template", type=pathlib.Path, @@ -589,22 +599,7 @@ def match_template(argv=None): action=CheckFileExists, help="Template; MRC file.", ) - parser.add_argument( - "-m", - "--mask", - type=pathlib.Path, - required=True, - action=CheckFileExists, - help="Mask with same box size as template; MRC file.", - ) - parser.add_argument( - "--non-spherical-mask", - action="store_true", - required=False, - help="Flag to set when the mask is not spherical. It adds the required " - "computations for non-spherical masks and roughly doubles computation time.", - ) - parser.add_argument( + io_group.add_argument( "-v", "--tomogram", type=pathlib.Path, @@ -612,7 +607,7 @@ def match_template(argv=None): action=CheckFileExists, help="Tomographic volume; MRC file.", ) - parser.add_argument( + io_group.add_argument( "-d", "--destination", type=pathlib.Path, @@ -621,52 +616,46 @@ def match_template(argv=None): action=CheckDirExists, help="Folder to store the files produced by template matching.", ) - parser.add_argument( - "-a", - "--tilt-angles", - nargs="+", - type=str, + mask_group = parser.add_argument_group("Mask") + mask_group.add_argument( + "-m", + "--mask", + type=pathlib.Path, required=True, - action=ParseTiltAngles, - help="Tilt angles of the tilt-series, either the minimum and maximum values of " - "the tilts (e.g. --tilt-angles -59.1 60.1) or a .rawtlt/.tlt file with all the " - "angles (e.g. --tilt-angles tomo101.rawtlt). In case all the tilt angles are " - "provided a more elaborate Fourier space constraint can be used", + action=CheckFileExists, + help="Mask with same box size as template; MRC file.", ) - parser.add_argument( - "--per-tilt-weighting", + mask_group.add_argument( + "--non-spherical-mask", action="store_true", - default=False, required=False, - help="Flag to activate per-tilt-weighting, only makes sense if a file with all " - "tilt angles have been provided. In case not set, while a tilt angle file is " - "provided, the minimum and maximum tilt angle are used to create a binary " - "wedge. The base functionality creates a fanned wedge where each tilt is " - "weighted by cos(tilt_angle). If dose accumulation and CTF parameters are " - "provided these will all be incorporated in the tilt-weighting.", + help="Flag to set when the mask is not spherical. It adds the required " + "computations for non-spherical masks and roughly doubles computation time.", ) - parser.add_argument( + rotation_group = parser.add_argument_group('Angular search') + rotation_group.add_argument( "--angular-search", type=str, required=True, help="Options are: [7.00, 35.76, 19.95, 90.00, 18.00, " - "12.85, 38.53, 11.00, 17.86, 25.25, 50.00, 3.00].\n" - "Alternatively, a .txt file can be provided with three Euler angles " - "(in radians) per line that define the angular search. " - "Angle format is ZXZ anti-clockwise (see: " - "https://www.ccpem.ac.uk/user_help/rotation_conventions.php).", + "12.85, 38.53, 11.00, 17.86, 25.25, 50.00, 3.00].\n" + "Alternatively, a .txt file can be provided with three Euler angles " + "(in radians) per line that define the angular search. " + "Angle format is ZXZ anti-clockwise (see: " + "https://www.ccpem.ac.uk/user_help/rotation_conventions.php).", ) - parser.add_argument( + rotation_group.add_argument( "--z-axis-rotational-symmetry", type=int, required=False, action=LargerThanZero, default=1, help="Integer value indicating the rotational symmetry of the template around " - "the z-axis. The length of the rotation search will be shortened through " - "division by this value. Only works for template symmetry around the z-axis.", + "the z-axis. The length of the rotation search will be shortened through " + "division by this value. Only works for template symmetry around the z-axis.", ) - parser.add_argument( + volume_group = parser.add_argument_group('Volume control') + volume_group.add_argument( "-s", "--volume-split", nargs=3, @@ -677,7 +666,7 @@ def match_template(argv=None): "can be relevant if the volume does not fit into GPU memory. " "Format is x y z, e.g. --volume-split 1 2 1", ) - parser.add_argument( + volume_group.add_argument( "--search-x", nargs=2, type=int, @@ -686,7 +675,7 @@ def match_template(argv=None): help="Start and end indices of the search along the x-axis, " "e.g. --search-x 10 490 ", ) - parser.add_argument( + volume_group.add_argument( "--search-y", nargs=2, type=int, @@ -695,7 +684,7 @@ def match_template(argv=None): help="Start and end indices of the search along the y-axis, " "e.g. --search-x 10 490 ", ) - parser.add_argument( + volume_group.add_argument( "--search-z", nargs=2, type=int, @@ -704,7 +693,32 @@ def match_template(argv=None): help="Start and end indices of the search along the z-axis, " "e.g. --search-x 30 230 ", ) - parser.add_argument( + filter_group = parser.add_argument_group('Filter control') + filter_group.add_argument( + "-a", + "--tilt-angles", + nargs="+", + type=str, + required=True, + action=ParseTiltAngles, + help="Tilt angles of the tilt-series, either the minimum and maximum values of " + "the tilts (e.g. --tilt-angles -59.1 60.1) or a .rawtlt/.tlt file with all the " + "angles (e.g. --tilt-angles tomo101.rawtlt). In case all the tilt angles are " + "provided a more elaborate Fourier space constraint can be used", + ) + filter_group.add_argument( + "--per-tilt-weighting", + action="store_true", + default=False, + required=False, + help="Flag to activate per-tilt-weighting, only makes sense if a file with all " + "tilt angles have been provided. In case not set, while a tilt angle file is " + "provided, the minimum and maximum tilt angle are used to create a binary " + "wedge. The base functionality creates a fanned wedge where each tilt is " + "weighted by cos(tilt_angle). If dose accumulation and CTF parameters are " + "provided these will all be incorporated in the tilt-weighting.", + ) + filter_group.add_argument( "--voxel-size-angstrom", type=float, required=False, @@ -713,7 +727,7 @@ def match_template(argv=None): "try to read from the MRC files. Argument is important for band-pass " "filtering!", ) - parser.add_argument( + filter_group.add_argument( "--low-pass", type=float, required=False, @@ -722,7 +736,7 @@ def match_template(argv=None): "if the template was already filtered to a certain resolution. " "Value is the resolution in A.", ) - parser.add_argument( + filter_group.add_argument( "--high-pass", type=float, required=False, @@ -732,7 +746,7 @@ def match_template(argv=None): "e.g. 500 could be appropriate as the CTF is often incorrectly modelled " "up to 50nm.", ) - parser.add_argument( + filter_group.add_argument( "--dose-accumulation", type=str, required=False, @@ -741,7 +755,7 @@ def match_template(argv=None): "tilt angle, assuming the same ordering of tilts as the tilt angle file. " "Format should be a .txt file with on each line a dose value in e-/A2.", ) - parser.add_argument( + filter_group.add_argument( "--defocus-file", type=str, required=False, @@ -755,28 +769,36 @@ def match_template(argv=None): "same ordering as tilt angle list. The .txt file should contain a single " "defocus value (in nm) per line.", ) - parser.add_argument( + filter_group.add_argument( "--amplitude-contrast", type=float, required=False, action=BetweenZeroAndOne, help="Amplitude contrast fraction for CTF.", ) - parser.add_argument( + filter_group.add_argument( "--spherical-abberation", type=float, required=False, action=LargerThanZero, help="Spherical abberation for CTF in mm.", ) - parser.add_argument( + filter_group.add_argument( "--voltage", type=float, required=False, action=LargerThanZero, help="Voltage for CTF in keV.", ) - parser.add_argument( + filter_group.add_argument( + "--phase-shift", + type=float, + required=False, + default=.0, + action=LargerThanZero, + help="Phase shift (in degrees) for the CTF to model phase plates.", + ) + filter_group.add_argument( "--spectral-whitening", action="store_true", default=False, @@ -785,7 +807,8 @@ def match_template(argv=None): "apply it to the tomogram patch and template. Effectively puts more weight on " "high resolution features and sharpens the correlation peaks.", ) - parser.add_argument( + device_group = parser.add_argument_group('Device control') + device_group.add_argument( "-g", "--gpu-ids", nargs="+", @@ -793,7 +816,8 @@ def match_template(argv=None): required=True, help="GPU indices to run the program on.", ) - parser.add_argument( + debug_group = parser.add_argument_group('Logging/debugging') + debug_group.add_argument( "--log", type=str, required=False, @@ -823,6 +847,7 @@ def match_template(argv=None): "amplitude": args.amplitude_contrast, "voltage": args.voltage, "cs": args.spherical_abberation, + "phase_shift_deg": args.phase_shift, } for defocus in args.defocus_file ] diff --git a/src/pytom_tm/extract.py b/src/pytom_tm/extract.py index cf8cbd29..30cf4a03 100644 --- a/src/pytom_tm/extract.py +++ b/src/pytom_tm/extract.py @@ -267,17 +267,17 @@ def extract_particles( ] *= cut_mask output = pd.DataFrame(data, columns=[ - 'ptmCoordinateX', - 'ptmCoordinateY', - 'ptmCoordinateZ', - 'ptmAngleRot', - 'ptmAngleTilt', - 'ptmAnglePsi', - 'ptmLCCmax', - 'ptmCutOff', - 'ptmSearchStd', - 'ptmDetectorPixelSize', - 'ptmMicrographName', + 'rlnCoordinateX', + 'rlnCoordinateY', + 'rlnCoordinateZ', + 'rlnAngleRot', + 'rlnAngleTilt', + 'rlnAnglePsi', + 'rlnLCCmax', + 'rlnCutOff', + 'rlnSearchStd', + 'rlnDetectorPixelSize', + 'rlnMicrographName', ]), scores if plotting_available and create_plot: diff --git a/src/pytom_tm/io.py b/src/pytom_tm/io.py index b580df32..951a5dbd 100644 --- a/src/pytom_tm/io.py +++ b/src/pytom_tm/io.py @@ -4,6 +4,7 @@ import logging import numpy.typing as npt import numpy as np +from contextlib import contextmanager from operator import attrgetter from typing import Optional, Union @@ -130,7 +131,29 @@ def write_angle_list(data: npt.NDArray[float], file_name: pathlib.Path, order: t fstream.write(' '.join([str(x) for x in [data[j, i] for j in order]]) + '\n') -def read_mrc_meta_data(file_name: pathlib.Path, permissive: bool = True) -> dict: +@contextmanager +def _wrap_mrcfile_readers(func, *args, **kwargs): + """Try to autorecover broken mrcfiles, assumes 'permissive' is a kwarg and not an arg""" + try: + mrc = func(*args, **kwargs) + except ValueError as err: + # see if permissive can safe this + logging.debug(f"mrcfile raised the following error: {err}, will try to recover") + kwargs['permissive']=True + mrc = func(*args, **kwargs) + if mrc.data is not None: + logging.warning(f"Loading {args[0]} in strict mode gave an error. " + "However, loading with 'permissive=True' did generate data, make sure this is correct!") + else: + logging.debug(f"Could not reasonably recover") + raise ValueError( + f"{args[0]} header or data is too corrupt to recover, please fix the header or data" + ) from err + yield mrc + # this should only be called after the context exists + mrc.close() + +def read_mrc_meta_data(file_name: pathlib.Path) -> dict: """Read the metadata of provided MRC file path (using mrcfile) and return as dict. If the voxel size along the x,y,and z dimensions differs a lot (not within 3 decimals) the function will raise an @@ -140,8 +163,6 @@ def read_mrc_meta_data(file_name: pathlib.Path, permissive: bool = True) -> dict ---------- file_name: pathlib.Path path to an MRC file - permissive: bool - True (default) reads file in permissive mode, setting to False will be more strict with bad MRC headers Returns ------- @@ -150,7 +171,7 @@ def read_mrc_meta_data(file_name: pathlib.Path, permissive: bool = True) -> dict 'voxel_size' containing the voxel size along x,y,z and dimensions in A units """ meta_data = {} - with mrcfile.mmap(file_name, permissive=permissive) as mrc: + with _wrap_mrcfile_readers(mrcfile.mmap, file_name) as mrc: meta_data['shape'] = tuple(map(int, attrgetter('nx', 'ny', 'nz')(mrc.header))) # allow small numerical inconsistencies in voxel size of MRC headers, sometimes seen in Warp if not all( @@ -201,7 +222,6 @@ def write_mrc( def read_mrc( file_name: pathlib.Path, - permissive: bool = True, transpose: bool = True ) -> npt.NDArray[float]: """Read an MRC file from disk. Data is transposed after reading as pytom internally uses xyz ordering and MRCs @@ -211,8 +231,6 @@ def read_mrc( ---------- file_name: pathlib.Path path to file on disk - permissive: bool, default True - True (default) reads file in permissive mode, setting to False will be more strict with bad headers transpose: bool, default True True (default) transposes the volume after reading, setting to False prevents transpose but probably not a good idea when using the functions from this module @@ -222,7 +240,7 @@ def read_mrc( data: npt.NDArray[float] returns the MRC data as a numpy array """ - with mrcfile.open(file_name, permissive=permissive) as mrc: + with _wrap_mrcfile_readers(mrcfile.open, file_name) as mrc: data = np.ascontiguousarray(mrc.data.T) if transpose else mrc.data return data diff --git a/src/pytom_tm/matching.py b/src/pytom_tm/matching.py index ea2c83a6..a41baf07 100644 --- a/src/pytom_tm/matching.py +++ b/src/pytom_tm/matching.py @@ -36,32 +36,34 @@ def __init__( reduced form, with dimensions (sx, sx, sx // 2 + 1) """ # Search volume + and fft transform plan for the volume - self.volume = cp.asarray(volume, dtype=cp.float32, order='C') - self.volume_rft = rfftn(self.volume) + volume_shape = volume.shape + cp_vol = cp.asarray(volume, dtype=cp.float32, order='C') + self.volume_rft_conj = rfftn(cp_vol).conj() + self.volume_sq_rft_conj = rfftn(cp_vol ** 2).conj() # Explicit fft plan is no longer necessary as cupy generates a plan behind the scene which leads to # comparable timings # Array for storing local standard deviations - self.std_volume = cp.zeros_like(volume, dtype=cp.float32) + self.std_volume = cp.zeros(volume_shape, dtype=cp.float32) # Data for the mask self.mask = cp.asarray(mask, dtype=cp.float32, order='C') self.mask_texture = vt.StaticVolume(self.mask, interpolation='filt_bspline', device=f'gpu:{device_id}') - self.mask_padded = cp.zeros_like(self.volume).astype(cp.float32) + self.mask_padded = cp.zeros(volume_shape, dtype=cp.float32) self.mask_weight = self.mask.sum() # weight of the mask # Init template data self.template = cp.asarray(template, dtype=cp.float32, order='C') self.template_texture = vt.StaticVolume(self.template, interpolation='filt_bspline', device=f'gpu:{device_id}') - self.template_padded = cp.zeros_like(self.volume) + self.template_padded = cp.zeros(volume_shape, dtype=cp.float32) # fourier binary wedge weight for the template self.wedge = cp.asarray(wedge, order='C', dtype=cp.float32) if wedge is not None else None # Initialize result volumes - self.ccc_map = cp.zeros_like(self.volume) - self.scores = cp.ones_like(self.volume)*-1000 - self.angles = cp.ones_like(self.volume)*-1000 + self.ccc_map = cp.zeros(volume_shape, dtype=cp.float32) + self.scores = cp.ones(volume_shape, dtype=cp.float32)*-1000 + self.angles = cp.ones(volume_shape, dtype=cp.float32)*-1000 # wait for stream to complete the work cp.cuda.stream.get_current_stream().synchronize() @@ -69,9 +71,11 @@ def __init__( def clean(self) -> None: """Remove all stored cupy arrays from the GPU's memory pool.""" gpu_memory_pool = cp.get_default_memory_pool() - del self.volume, self.volume_rft, self.mask, self.mask_texture, self.mask_padded, self.template, ( - self.template_texture), self.template_padded, self.wedge, self.ccc_map, self.scores, self.angles, ( - self.std_volume) + del ( + self.volume_rft_conj, self.volume_sq_rft_conj, self.mask, self.mask_texture, self.mask_padded, + self.template, self.template_texture, self.template_padded, self.wedge, self.ccc_map, self.scores, + self.angles, self.std_volume + ) gc.collect() gpu_memory_pool.free_all_blocks() @@ -92,6 +96,13 @@ def __init__( ): """Initialize a template matching run. + For other great implementations see: + - STOPGAP: https://github.com/wan-lab-vanderbilt/STOPGAP + - pyTME: https://github.com/KosinskiLab/pyTME + + The precalculation of conjugated FTs of the tomo was (AFAIK) introduced + by STOPGAP! + Parameters ---------- job_id: str @@ -162,19 +173,28 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: sxv, syv, szv = self.plan.template_padded.shape cxv, cyv, czv = sxv // 2, syv // 2, szv // 2 - # calculate roi size - roi_size = self.plan.volume[self.stats_roi].size + # create slice for padding + pad_index = ( + slice(cxv - cxt, cxv + cxt + mx), + slice(cyv - cyt, cyv + cyt + my), + slice(czv - czt, czv + czt + mz), + ) + + # calculate roi mask + shift = cp.floor(cp.array(self.plan.scores.shape) / 2).astype(int) + 1 + roi_mask = cp.zeros(self.plan.scores.shape, dtype=bool) + roi_mask[self.stats_roi] = True + roi_mask = cp.flip(cp.roll(roi_mask, -shift, (0, 1, 2))) + roi_size = self.plan.scores[roi_mask].size if self.mask_is_spherical: # Then we only need to calculate std volume once - self.plan.mask_padded[cxv - cxt:cxv + cxt + mx, - cyv - cyt:cyv + cyt + my, - czv - czt:czv + czt + mz] = self.plan.mask + self.plan.mask_padded[pad_index] = self.plan.mask self.plan.std_volume = std_under_mask_convolution( - self.plan.volume, + self.plan.volume_rft_conj, + self.plan.volume_sq_rft_conj, self.plan.mask_padded, self.plan.mask_weight, - volume_rft=self.plan.volume_rft - ) + ) * self.plan.mask_weight # Track iterations with a tqdm progress bar for i in tqdm(range(len(self.angle_ids))): @@ -189,16 +209,14 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: output=self.plan.mask, rotation_units='rad' ) - self.plan.mask_padded[cxv - cxt:cxv + cxt + mx, - cyv - cyt:cyv + cyt + my, - czv - czt:czv + czt + mz] = self.plan.mask + self.plan.mask_padded[pad_index] = self.plan.mask # Std volume needs to be recalculated for every rotation of the mask, expensive step self.plan.std_volume = std_under_mask_convolution( - self.plan.volume, + self.plan.volume_rft_conj, + self.plan.volume_sq_rft_conj, self.plan.mask_padded, self.plan.mask_weight, - volume_rft=self.plan.volume_rft, - ) + ) * self.plan.mask_weight # Rotate template self.plan.template_texture.transform( @@ -213,7 +231,7 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: self.plan.template = irfftn( rfftn(self.plan.template) * self.plan.wedge, s=self.plan.template.shape - ).real + ) # Normalize and mask template mean = mean_under_mask(self.plan.template, self.plan.mask, mask_weight=self.plan.mask_weight) @@ -221,16 +239,14 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: self.plan.template = ((self.plan.template - mean) / std) * self.plan.mask # Paste in center - self.plan.template_padded[cxv - cxt:cxv + cxt + mx, - cyv - cyt:cyv + cyt + my, - czv - czt:czv + czt + mz] = self.plan.template + self.plan.template_padded[pad_index] = self.plan.template # Fast local correlation function between volume and template, norm is the standard deviation at each # point in the volume in the masked area - self.plan.ccc_map = fftshift( - irfftn(self.plan.volume_rft * rfftn(self.plan.template_padded).conj(), - s=self.plan.template_padded.shape).real - / (self.plan.mask_weight * self.plan.std_volume) + self.plan.ccc_map = ( + irfftn(self.plan.volume_rft_conj * rfftn(self.plan.template_padded), + s=self.plan.template_padded.shape) + / self.plan.std_volume ) # Update the scores and angle_lists @@ -243,11 +259,18 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: ) self.stats['variance'] += ( - square_sum_kernel( - self.plan.ccc_map[self.stats_roi] - ) / roi_size + square_sum_kernel(self.plan.ccc_map * roi_mask) / roi_size ) + # Get correct orientation back! + # Use same method as William Wan's STOPGAP + # (https://doi.org/10.1107/S205979832400295X): the search volume is Fourier + # transformed and conjugated before the iterations this means the eventual + # score map needs to be flipped back. The map is also rolled due to the ftshift + # effect of a Fourier space correlation function. + self.plan.scores = cp.roll(cp.flip(self.plan.scores), shift, axis=(0, 1, 2)) + self.plan.angles = cp.roll(cp.flip(self.plan.angles), shift, axis=(0, 1, 2)) + self.stats['search_space'] = int(roi_size * len(self.angle_ids)) self.stats['variance'] = float(self.stats['variance'] / len(self.angle_ids)) self.stats['std'] = float(cp.sqrt(self.stats['variance'])) @@ -262,66 +285,40 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: def std_under_mask_convolution( - volume: cpt.NDArray[float], + volume_rft_conj: cpt.NDArray[float], + volume_sq_rft_conj: cpt.NDArray[float], padded_mask: cpt.NDArray[float], mask_weight: float, - volume_rft: Optional[cpt.NDArray[complex]] = None ) -> cpt.NDArray[float]: """Calculate the local standard deviation under the mask for each position in the volume. Calculation is done in Fourier space as this is a convolution between volume and mask. Parameters ---------- - volume: cpt.NDArray[float] - cupy array to calculate local std in + volume_rft_conj: cpt.NDArray[float] + complex conjugate of the rft of the search volume + volume_sq_rft_conj: cpt.NDArray[float] + complex conjugate of the rft of the squared search volume padded_mask: cpt.NDArray[float] template mask that has been padded to dimensions of volume mask_weight: float weight of the mask, usually calculated as mask.sum() - volume_rft: Optional[cpt.NDArray[float]], default None - optionally provide a precalculated reduced Fourier transform of volume to save computation Returns ------- std_v: cpt.NDArray[float] array with local standard deviations in volume """ - volume_rft = rfftn(volume) if volume_rft is None else volume_rft + padded_mask_rft = rfftn(padded_mask) std_v = ( - mean_under_mask_convolution(rfftn(volume ** 2), padded_mask, mask_weight) - - mean_under_mask_convolution(volume_rft, padded_mask, mask_weight) ** 2 + irfftn(volume_sq_rft_conj * padded_mask_rft, s=padded_mask.shape) / mask_weight - + (irfftn(volume_rft_conj * padded_mask_rft, s=padded_mask.shape) / mask_weight) ** 2 ) std_v[std_v <= cp.float32(1e-18)] = 1 # prevent potential sqrt of negative value and division by zero std_v = cp.sqrt(std_v) return std_v -def mean_under_mask_convolution( - volume_rft: cpt.NDArray[complex], - mask: cpt.NDArray[float], - mask_weight: float -) -> cpt.NDArray[float]: - """Calculate local mean in volume under the masked region. - - Parameters - ---------- - volume_rft: cpt.NDArray[complex] - array containing the rfftn of the volume - mask: cpt.NDArray[float] - mask to calculate the mean under - mask_weight: float - weight of the mask, usually calculated as mask.sum() - - Returns - ------- - mean: cpt.NDArray[float] - array with local means under the mask - """ - return irfftn( - volume_rft * rfftn(mask).conj(), s=mask.shape - ).real / mask_weight - - """Update scores and angles if a new maximum is found.""" update_results_kernel = cp.ElementwiseKernel( 'float32 scores, float32 ccc_map, float32 angle_id', @@ -331,18 +328,13 @@ def mean_under_mask_convolution( ) -# Temporary workaround for ReductionKernel issue in cupy 13.0.0 (see: https://github.com/cupy/cupy/issues/8184) -if version.parse(cp.__version__) == version.parse('13.0.0'): - def square_sum_kernel(x): - return (x ** 2).sum() -else: - """Calculate the sum of squares in a volume. Mean is assumed to be 0 which makes this operation a lot faster.""" - square_sum_kernel = cp.ReductionKernel( - 'T x', # input params - 'T y', # output params - 'x * x', # pre-processing expression - 'a + b', # reduction operation - 'y = a', # post-reduction output processing - '0', # identity value - 'variance' # kernel name - ) +"""Calculate the sum of squares in a volume. Mean is assumed to be 0 which makes this operation a lot faster.""" +square_sum_kernel = cp.ReductionKernel( + 'T x', # input params + 'T y', # output params + 'x * x', # pre-processing expression + 'a + b', # reduction operation + 'y = a', # post-reduction output processing + '0', # identity value + 'variance' # kernel name +) diff --git a/src/pytom_tm/tmjob.py b/src/pytom_tm/tmjob.py index e4a2334c..112edbeb 100644 --- a/src/pytom_tm/tmjob.py +++ b/src/pytom_tm/tmjob.py @@ -61,6 +61,13 @@ def load_json_to_tmjob(file_name: pathlib.Path, load_for_extraction: bool = True pytom_tm_version_number=data.get('pytom_tm_version_number', '0.3.0'), job_loaded_for_extraction=load_for_extraction, ) + # if the file originates from an old version set the phase shift for compatibility + if ( + version.parse(job.pytom_tm_version_number) < version.parse('0.6.1') and + job.ctf_data is not None + ): + for tilt in job.ctf_data: + tilt['phase_shift_deg'] = .0 job.rotation_file = pathlib.Path(data['rotation_file']) job.whole_start = data['whole_start'] job.sub_start = data['sub_start'] diff --git a/src/pytom_tm/weights.py b/src/pytom_tm/weights.py index bae83b52..c1d2ab5b 100644 --- a/src/pytom_tm/weights.py +++ b/src/pytom_tm/weights.py @@ -469,6 +469,7 @@ def _create_tilt_weighted_wedge( - 'amplitude'; fraction of amplitude contrast between 0 and 1 - 'voltage'; in keV - 'cs'; spherical abberation in mm + - 'phase_shift_deg'; phase shift for phase plates in deg Returns ------- @@ -514,7 +515,8 @@ def _create_tilt_weighted_wedge( ctf_params_per_tilt[i]['amplitude'], ctf_params_per_tilt[i]['voltage'] * 1e3, ctf_params_per_tilt[i]['cs'] * 1e-3, - flip_phase=True # creating a per tilt ctf is hard if the phase is not flipped + flip_phase=True, # creating per tilt ctf requires phase flip atm + phase_shift_deg=ctf_params_per_tilt[i]['phase_shift_deg'], ), axes=0, ) tilt[:, :, image_size // 2] = np.concatenate( @@ -571,7 +573,8 @@ def create_ctf( voltage: float, spherical_aberration: float, cut_after_first_zero: bool = False, - flip_phase: bool = False + flip_phase: bool = False, + phase_shift_deg: float = .0, ) -> npt.NDArray[float]: """Create a CTF in a 3D volume in reduced format. @@ -593,6 +596,11 @@ def create_ctf( whether to cut ctf after first zero crossing flip_phase: bool, default False make ctf fully positive/negative to imitate ctf correction by phase flipping + phase_shift_deg: float, default .0 + additional phase shift to model phase plates, similar to + `https://github.com/dtegunov/tom_deconv` except the ctf defintion in tom + produces the inverse curve of what we have here + Returns ------- @@ -609,7 +617,7 @@ def create_ctf( tan_term = np.arctan(amplitude_contrast / np.sqrt(1 - amplitude_contrast ** 2)) # determine the ctf - ctf = - np.sin(chi + tan_term) + ctf = - np.sin(chi + tan_term + np.deg2rad(phase_shift_deg)) if cut_after_first_zero: # find frequency to cut first zero def chi_1d(q): diff --git a/tests/Data/header_only.mrc b/tests/Data/header_only.mrc new file mode 100644 index 00000000..0cd033e2 Binary files /dev/null and b/tests/Data/header_only.mrc differ diff --git a/tests/Data/human_ribo_mask_32_8_5.mrc b/tests/Data/human_ribo_mask_32_8_5.mrc new file mode 100644 index 00000000..64bad422 Binary files /dev/null and b/tests/Data/human_ribo_mask_32_8_5.mrc differ diff --git a/tests/test00_parallel.py b/tests/test00_parallel.py index 340838fe..42d15318 100644 --- a/tests/test00_parallel.py +++ b/tests/test00_parallel.py @@ -76,7 +76,7 @@ def test_parallel_breaking(self): _ = run_job_parallel(self.job, volume_splits=(1, 2, 1), gpu_ids=[0, -1], unittest_mute=True) except RuntimeError: # sleep a second to make sure all children are cleaned - time.sleep(1) + time.sleep(2) self.assertEqual(len(multiprocessing.active_children()), 0, msg='a process was still lingering after a parallel job with partially invalid resources ' 'was started') diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 00000000..010da720 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,50 @@ +import unittest +from pytom_tm.io import read_mrc, read_mrc_meta_data +import pathlib +import warnings +import contextlib + +FAILING_MRC = pathlib.Path(__file__).parent.joinpath(pathlib.Path('Data/human_ribo_mask_32_8_5.mrc')) +# The below file was made with head -c 1024 human_ribo_mask_32_8_5.mrc > header_only.mrc +CORRUPT_MRC = pathlib.Path(__file__).parent.joinpath(pathlib.Path('Data/header_only.mrc')) + + +class TestBrokenMRC(unittest.TestCase): + def setUp(self): + # Mute the RuntimeWarnings comming from other code-base inside these tests + # following this SO answer: https://stackoverflow.com/a/45809502 + stack = contextlib.ExitStack() + _ = stack.enter_context(warnings.catch_warnings()) + warnings.simplefilter('ignore') + # The follwing line is better, but only works in python >= 3.11 + #_ = stack.enter_context(warnings.catch_warnings(action="ignore")) + + self.addCleanup(stack.close) + + + def test_read_mrc_minor_broken(self): + # Test if this mrc can be read and if the approriate logs are printed + with self.assertLogs(level='WARNING') as cm: + mrc = read_mrc(FAILING_MRC) + self.assertIsNotNone(mrc) + self.assertEqual(len(cm.output), 1) + self.assertIn(FAILING_MRC.name, cm.output[0]) + self.assertIn("make sure this is correct", cm.output[0]) + + def test_read_mrc_too_broken(self): + # Test if this mrc raises an error as expected + with self.assertRaises(ValueError) as err: + mrc = read_mrc(CORRUPT_MRC) + self.assertIn(CORRUPT_MRC.name, str(err.exception)) + self.assertIn("too corrupt", str(err.exception)) + + def test_read_mrc_meta_data(self): + # Test if this mrc can be read and if the approriate logs are printed + with self.assertLogs(level='WARNING') as cm: + mrc = read_mrc_meta_data(FAILING_MRC) + self.assertIsNotNone(mrc) + self.assertEqual(len(cm.output), 1) + self.assertIn(FAILING_MRC.name, cm.output[0]) + self.assertIn("make sure this is correct", cm.output[0]) + + diff --git a/tests/test_template_matching.py b/tests/test_template_matching.py index 9573619e..d6960949 100644 --- a/tests/test_template_matching.py +++ b/tests/test_template_matching.py @@ -17,7 +17,7 @@ def setUp(self): self.gpu_id = 'gpu:0' self.angles = load_angle_list(files('pytom_tm.angle_lists').joinpath('angles_38.53_256.txt')) - def test_search(self): + def test_search_spherical_mask(self): angle_id = 100 rotation = self.angles[angle_id] loc = (77, 26, 40) @@ -31,8 +31,15 @@ def test_search(self): device='cpu' ) - tm = TemplateMatchingGPU(0, 0, self.volume, self.template, self.mask, self.angles, list(range(len( - self.angles)))) + tm = TemplateMatchingGPU( + 0, + 0, + self.volume, + self.template, + self.mask, + self.angles, + list(range(len(self.angles))), + ) score_volume, angle_volume, stats = tm.run() ind = np.unravel_index(score_volume.argmax(), self.volume.shape) @@ -42,3 +49,37 @@ def test_search(self): self.assertEqual(stats['search_space'], 256000000, msg='Search space should exactly equal this value') self.assertAlmostEqual(stats['std'], 0.005175, places=5, msg='Standard deviation of the search should be almost equal') + + def test_search_non_spherical_mask(self): + angle_id = 100 + rotation = self.angles[angle_id] + loc = (77, 26, 40) + self.volume[loc[0] - self.t_size // 2: loc[0] + self.t_size // 2, + loc[1] - self.t_size // 2: loc[1] + self.t_size // 2, + loc[2] - self.t_size // 2: loc[2] + self.t_size // 2] = vt.transform( + self.template, + rotation=rotation, + rotation_units='rad', + rotation_order='rzxz', + device='cpu' + ) + + tm = TemplateMatchingGPU( + 0, + 0, + self.volume, + self.template, + self.mask, + self.angles, + list(range(len(self.angles))), + mask_is_spherical=False, + ) + score_volume, angle_volume, stats = tm.run() + + ind = np.unravel_index(score_volume.argmax(), self.volume.shape) + self.assertTrue(score_volume.max() > 0.99, msg='lcc max value lower than expected') + self.assertEqual(angle_id, angle_volume[ind]) + self.assertSequenceEqual(loc, ind) + self.assertEqual(stats['search_space'], 256000000, msg='Search space should exactly equal this value') + self.assertAlmostEqual(stats['std'], 0.005175, places=4, + msg='Standard deviation of the search should be almost equal') diff --git a/tests/test_tmjob.py b/tests/test_tmjob.py index dfdca903..6f056b2e 100644 --- a/tests/test_tmjob.py +++ b/tests/test_tmjob.py @@ -31,6 +31,7 @@ TEST_WHITENING_FILTER = TEST_DATA_DIR.joinpath('tomogram_whitening_filter.npy') TEST_JOB_JSON = TEST_DATA_DIR.joinpath('tomogram_job.json') TEST_JOB_JSON_WHITENING = TEST_DATA_DIR.joinpath('tomogram_job_whitening.json') +TEST_JOB_OLD_VERSION = TEST_DATA_DIR.joinpath('tomogram_job_old_version.json') class TestTMJob(unittest.TestCase): @@ -121,6 +122,7 @@ def tearDownClass(cls) -> None: TEST_WHITENING_FILTER.unlink(missing_ok=True) TEST_JOB_JSON.unlink() TEST_JOB_JSON_WHITENING.unlink() + TEST_JOB_OLD_VERSION.unlink() TEST_DATA_DIR.rmdir() def setUp(self): @@ -220,6 +222,18 @@ def test_load_json_to_tmjob(self): self.assertEqual(len(cm.output), 1) self.assertIn('Estimating whitening filter...', cm.output[0]) + # turn current job into 0.6.0 job with ctf params + job.pytom_tm_version_number = '0.6.0' + job.ctf_data = [] + for ctf in CTF_PARAMS: + job.ctf_data.append(ctf.copy()) + del job.ctf_data[-1]['phase_shift_deg'] + job.write_to_json(TEST_JOB_OLD_VERSION) + + # test backward compatibility with the update to 0.6.1 + job = load_json_to_tmjob(TEST_JOB_OLD_VERSION) + self.assertEqual(job.ctf_data[0]['phase_shift_deg'], .0) + def test_custom_angular_search(self): job = TMJob('0', 10, TEST_TOMOGRAM, TEST_TEMPLATE, TEST_MASK, TEST_DATA_DIR, angle_increment=TEST_CUSTOM_ANGULAR_SEARCH, voxel_size=1.) diff --git a/tests/test_weights.py b/tests/test_weights.py index 5e197694..2141579a 100644 --- a/tests/test_weights.py +++ b/tests/test_weights.py @@ -51,7 +51,8 @@ 'defocus': d, 'amplitude': AMP, 'voltage': VOL, - 'cs': CS + 'cs': CS, + 'phase_shift_deg': .0, }) DOSE_FILE = '''60.165