diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 2e693b4e..a00fe57b 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -44,7 +44,7 @@ jobs: path: coverage.xml repo_token: ${{ secrets.GITHUB_TOKEN }} pull_request_number: ${{ steps.get-pr.outputs.PR }} - minimum_coverage: 79 + minimum_coverage: 80 show_missing: True fail_below_threshold: True link_missing_lines: True diff --git a/pyproject.toml b/pyproject.toml index 835a4bee..00b01df2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pytom-match-pick" -version = "0.6.1" +version = "0.7.0" description = "PyTOM's GPU template matching module as an independent package" readme = "README.md" license = {file = "LICENSE"} diff --git a/src/pytom_tm/matching.py b/src/pytom_tm/matching.py index c5104cbe..a41baf07 100644 --- a/src/pytom_tm/matching.py +++ b/src/pytom_tm/matching.py @@ -36,32 +36,34 @@ def __init__( reduced form, with dimensions (sx, sx, sx // 2 + 1) """ # Search volume + and fft transform plan for the volume - self.volume = cp.asarray(volume, dtype=cp.float32, order='C') - self.volume_rft = rfftn(self.volume) + volume_shape = volume.shape + cp_vol = cp.asarray(volume, dtype=cp.float32, order='C') + self.volume_rft_conj = rfftn(cp_vol).conj() + self.volume_sq_rft_conj = rfftn(cp_vol ** 2).conj() # Explicit fft plan is no longer necessary as cupy generates a plan behind the scene which leads to # comparable timings # Array for storing local standard deviations - self.std_volume = cp.zeros_like(volume, dtype=cp.float32) + self.std_volume = cp.zeros(volume_shape, dtype=cp.float32) # Data for the mask self.mask = cp.asarray(mask, dtype=cp.float32, order='C') self.mask_texture = vt.StaticVolume(self.mask, interpolation='filt_bspline', device=f'gpu:{device_id}') - self.mask_padded = cp.zeros_like(self.volume).astype(cp.float32) + self.mask_padded = cp.zeros(volume_shape, dtype=cp.float32) self.mask_weight = self.mask.sum() # weight of the mask # Init template data self.template = cp.asarray(template, dtype=cp.float32, order='C') self.template_texture = vt.StaticVolume(self.template, interpolation='filt_bspline', device=f'gpu:{device_id}') - self.template_padded = cp.zeros_like(self.volume) + self.template_padded = cp.zeros(volume_shape, dtype=cp.float32) # fourier binary wedge weight for the template self.wedge = cp.asarray(wedge, order='C', dtype=cp.float32) if wedge is not None else None # Initialize result volumes - self.ccc_map = cp.zeros_like(self.volume) - self.scores = cp.ones_like(self.volume)*-1000 - self.angles = cp.ones_like(self.volume)*-1000 + self.ccc_map = cp.zeros(volume_shape, dtype=cp.float32) + self.scores = cp.ones(volume_shape, dtype=cp.float32)*-1000 + self.angles = cp.ones(volume_shape, dtype=cp.float32)*-1000 # wait for stream to complete the work cp.cuda.stream.get_current_stream().synchronize() @@ -69,9 +71,11 @@ def __init__( def clean(self) -> None: """Remove all stored cupy arrays from the GPU's memory pool.""" gpu_memory_pool = cp.get_default_memory_pool() - del self.volume, self.volume_rft, self.mask, self.mask_texture, self.mask_padded, self.template, ( - self.template_texture), self.template_padded, self.wedge, self.ccc_map, self.scores, self.angles, ( - self.std_volume) + del ( + self.volume_rft_conj, self.volume_sq_rft_conj, self.mask, self.mask_texture, self.mask_padded, + self.template, self.template_texture, self.template_padded, self.wedge, self.ccc_map, self.scores, + self.angles, self.std_volume + ) gc.collect() gpu_memory_pool.free_all_blocks() @@ -92,6 +96,13 @@ def __init__( ): """Initialize a template matching run. + For other great implementations see: + - STOPGAP: https://github.com/wan-lab-vanderbilt/STOPGAP + - pyTME: https://github.com/KosinskiLab/pyTME + + The precalculation of conjugated FTs of the tomo was (AFAIK) introduced + by STOPGAP! + Parameters ---------- job_id: str @@ -162,19 +173,28 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: sxv, syv, szv = self.plan.template_padded.shape cxv, cyv, czv = sxv // 2, syv // 2, szv // 2 - # calculate roi size - roi_size = self.plan.volume[self.stats_roi].size + # create slice for padding + pad_index = ( + slice(cxv - cxt, cxv + cxt + mx), + slice(cyv - cyt, cyv + cyt + my), + slice(czv - czt, czv + czt + mz), + ) + + # calculate roi mask + shift = cp.floor(cp.array(self.plan.scores.shape) / 2).astype(int) + 1 + roi_mask = cp.zeros(self.plan.scores.shape, dtype=bool) + roi_mask[self.stats_roi] = True + roi_mask = cp.flip(cp.roll(roi_mask, -shift, (0, 1, 2))) + roi_size = self.plan.scores[roi_mask].size if self.mask_is_spherical: # Then we only need to calculate std volume once - self.plan.mask_padded[cxv - cxt:cxv + cxt + mx, - cyv - cyt:cyv + cyt + my, - czv - czt:czv + czt + mz] = self.plan.mask + self.plan.mask_padded[pad_index] = self.plan.mask self.plan.std_volume = std_under_mask_convolution( - self.plan.volume, + self.plan.volume_rft_conj, + self.plan.volume_sq_rft_conj, self.plan.mask_padded, self.plan.mask_weight, - volume_rft=self.plan.volume_rft - ) + ) * self.plan.mask_weight # Track iterations with a tqdm progress bar for i in tqdm(range(len(self.angle_ids))): @@ -189,16 +209,14 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: output=self.plan.mask, rotation_units='rad' ) - self.plan.mask_padded[cxv - cxt:cxv + cxt + mx, - cyv - cyt:cyv + cyt + my, - czv - czt:czv + czt + mz] = self.plan.mask + self.plan.mask_padded[pad_index] = self.plan.mask # Std volume needs to be recalculated for every rotation of the mask, expensive step self.plan.std_volume = std_under_mask_convolution( - self.plan.volume, + self.plan.volume_rft_conj, + self.plan.volume_sq_rft_conj, self.plan.mask_padded, self.plan.mask_weight, - volume_rft=self.plan.volume_rft, - ) + ) * self.plan.mask_weight # Rotate template self.plan.template_texture.transform( @@ -213,7 +231,7 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: self.plan.template = irfftn( rfftn(self.plan.template) * self.plan.wedge, s=self.plan.template.shape - ).real + ) # Normalize and mask template mean = mean_under_mask(self.plan.template, self.plan.mask, mask_weight=self.plan.mask_weight) @@ -221,16 +239,14 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: self.plan.template = ((self.plan.template - mean) / std) * self.plan.mask # Paste in center - self.plan.template_padded[cxv - cxt:cxv + cxt + mx, - cyv - cyt:cyv + cyt + my, - czv - czt:czv + czt + mz] = self.plan.template + self.plan.template_padded[pad_index] = self.plan.template # Fast local correlation function between volume and template, norm is the standard deviation at each # point in the volume in the masked area - self.plan.ccc_map = fftshift( - irfftn(self.plan.volume_rft * rfftn(self.plan.template_padded).conj(), - s=self.plan.template_padded.shape).real - / (self.plan.mask_weight * self.plan.std_volume) + self.plan.ccc_map = ( + irfftn(self.plan.volume_rft_conj * rfftn(self.plan.template_padded), + s=self.plan.template_padded.shape) + / self.plan.std_volume ) # Update the scores and angle_lists @@ -243,11 +259,18 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: ) self.stats['variance'] += ( - square_sum_kernel( - self.plan.ccc_map[self.stats_roi] - ) / roi_size + square_sum_kernel(self.plan.ccc_map * roi_mask) / roi_size ) + # Get correct orientation back! + # Use same method as William Wan's STOPGAP + # (https://doi.org/10.1107/S205979832400295X): the search volume is Fourier + # transformed and conjugated before the iterations this means the eventual + # score map needs to be flipped back. The map is also rolled due to the ftshift + # effect of a Fourier space correlation function. + self.plan.scores = cp.roll(cp.flip(self.plan.scores), shift, axis=(0, 1, 2)) + self.plan.angles = cp.roll(cp.flip(self.plan.angles), shift, axis=(0, 1, 2)) + self.stats['search_space'] = int(roi_size * len(self.angle_ids)) self.stats['variance'] = float(self.stats['variance'] / len(self.angle_ids)) self.stats['std'] = float(cp.sqrt(self.stats['variance'])) @@ -262,66 +285,40 @@ def run(self) -> tuple[npt.NDArray[float], npt.NDArray[float], dict]: def std_under_mask_convolution( - volume: cpt.NDArray[float], + volume_rft_conj: cpt.NDArray[float], + volume_sq_rft_conj: cpt.NDArray[float], padded_mask: cpt.NDArray[float], mask_weight: float, - volume_rft: Optional[cpt.NDArray[complex]] = None ) -> cpt.NDArray[float]: """Calculate the local standard deviation under the mask for each position in the volume. Calculation is done in Fourier space as this is a convolution between volume and mask. Parameters ---------- - volume: cpt.NDArray[float] - cupy array to calculate local std in + volume_rft_conj: cpt.NDArray[float] + complex conjugate of the rft of the search volume + volume_sq_rft_conj: cpt.NDArray[float] + complex conjugate of the rft of the squared search volume padded_mask: cpt.NDArray[float] template mask that has been padded to dimensions of volume mask_weight: float weight of the mask, usually calculated as mask.sum() - volume_rft: Optional[cpt.NDArray[float]], default None - optionally provide a precalculated reduced Fourier transform of volume to save computation Returns ------- std_v: cpt.NDArray[float] array with local standard deviations in volume """ - volume_rft = rfftn(volume) if volume_rft is None else volume_rft + padded_mask_rft = rfftn(padded_mask) std_v = ( - mean_under_mask_convolution(rfftn(volume ** 2), padded_mask, mask_weight) - - mean_under_mask_convolution(volume_rft, padded_mask, mask_weight) ** 2 + irfftn(volume_sq_rft_conj * padded_mask_rft, s=padded_mask.shape) / mask_weight - + (irfftn(volume_rft_conj * padded_mask_rft, s=padded_mask.shape) / mask_weight) ** 2 ) std_v[std_v <= cp.float32(1e-18)] = 1 # prevent potential sqrt of negative value and division by zero std_v = cp.sqrt(std_v) return std_v -def mean_under_mask_convolution( - volume_rft: cpt.NDArray[complex], - mask: cpt.NDArray[float], - mask_weight: float -) -> cpt.NDArray[float]: - """Calculate local mean in volume under the masked region. - - Parameters - ---------- - volume_rft: cpt.NDArray[complex] - array containing the rfftn of the volume - mask: cpt.NDArray[float] - mask to calculate the mean under - mask_weight: float - weight of the mask, usually calculated as mask.sum() - - Returns - ------- - mean: cpt.NDArray[float] - array with local means under the mask - """ - return irfftn( - volume_rft * rfftn(mask).conj(), s=mask.shape - ).real / mask_weight - - """Update scores and angles if a new maximum is found.""" update_results_kernel = cp.ElementwiseKernel( 'float32 scores, float32 ccc_map, float32 angle_id', diff --git a/tests/test_template_matching.py b/tests/test_template_matching.py index 9573619e..d6960949 100644 --- a/tests/test_template_matching.py +++ b/tests/test_template_matching.py @@ -17,7 +17,7 @@ def setUp(self): self.gpu_id = 'gpu:0' self.angles = load_angle_list(files('pytom_tm.angle_lists').joinpath('angles_38.53_256.txt')) - def test_search(self): + def test_search_spherical_mask(self): angle_id = 100 rotation = self.angles[angle_id] loc = (77, 26, 40) @@ -31,8 +31,15 @@ def test_search(self): device='cpu' ) - tm = TemplateMatchingGPU(0, 0, self.volume, self.template, self.mask, self.angles, list(range(len( - self.angles)))) + tm = TemplateMatchingGPU( + 0, + 0, + self.volume, + self.template, + self.mask, + self.angles, + list(range(len(self.angles))), + ) score_volume, angle_volume, stats = tm.run() ind = np.unravel_index(score_volume.argmax(), self.volume.shape) @@ -42,3 +49,37 @@ def test_search(self): self.assertEqual(stats['search_space'], 256000000, msg='Search space should exactly equal this value') self.assertAlmostEqual(stats['std'], 0.005175, places=5, msg='Standard deviation of the search should be almost equal') + + def test_search_non_spherical_mask(self): + angle_id = 100 + rotation = self.angles[angle_id] + loc = (77, 26, 40) + self.volume[loc[0] - self.t_size // 2: loc[0] + self.t_size // 2, + loc[1] - self.t_size // 2: loc[1] + self.t_size // 2, + loc[2] - self.t_size // 2: loc[2] + self.t_size // 2] = vt.transform( + self.template, + rotation=rotation, + rotation_units='rad', + rotation_order='rzxz', + device='cpu' + ) + + tm = TemplateMatchingGPU( + 0, + 0, + self.volume, + self.template, + self.mask, + self.angles, + list(range(len(self.angles))), + mask_is_spherical=False, + ) + score_volume, angle_volume, stats = tm.run() + + ind = np.unravel_index(score_volume.argmax(), self.volume.shape) + self.assertTrue(score_volume.max() > 0.99, msg='lcc max value lower than expected') + self.assertEqual(angle_id, angle_volume[ind]) + self.assertSequenceEqual(loc, ind) + self.assertEqual(stats['search_space'], 256000000, msg='Search space should exactly equal this value') + self.assertAlmostEqual(stats['std'], 0.005175, places=4, + msg='Standard deviation of the search should be almost equal')