diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 68b3bd53..d6ae3858 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -44,7 +44,7 @@ jobs: path: coverage.xml repo_token: ${{ secrets.GITHUB_TOKEN }} pull_request_number: ${{ steps.get-pr.outputs.PR }} - minimum_coverage: 87 + minimum_coverage: 88 show_missing: True fail_below_threshold: True link_missing_lines: True diff --git a/pyproject.toml b/pyproject.toml index 679f8767..7993bee7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pytom-match-pick" -version = "0.7.2" +version = "0.7.3" description = "PyTOM's GPU template matching module as an independent package" readme = "README.md" license = {file = "LICENSE"} diff --git a/src/pytom_tm/entry_points.py b/src/pytom_tm/entry_points.py index 1b49e893..cfada298 100644 --- a/src/pytom_tm/entry_points.py +++ b/src/pytom_tm/entry_points.py @@ -533,6 +533,23 @@ def extract_candidates(argv=None): action=ParseLogging, help="Can be set to `info` or `debug`", ) + parser.add_argument( + "--tophat-bins", + type=int, + required=False, + default=50, + action=LargerThanZero, + help="Number of bins to use in the histogram of occurences in the " + "tophat transform code (for both the estimation and the plotting).", + ) + parser.add_argument( + "--plot-bins", + type=int, + required=False, + default=20, + action=LargerThanZero, + help="Number of bins to use for the occurences vs LCC_max plot.", + ) # ---8<--- [end:extract_candidates_usage] @@ -552,6 +569,8 @@ def extract_candidates(argv=None): tophat_connectivity=args.tophat_connectivity, relion5_compat=args.relion5_compat, ignore_tomogram_mask=args.ignore_tomogram_mask, + tophat_bins=args.tophat_bins, + plot_bins=args.plot_bins, ) # write out as a RELION type starfile diff --git a/src/pytom_tm/extract.py b/src/pytom_tm/extract.py index 21375e74..c8ef50c3 100644 --- a/src/pytom_tm/extract.py +++ b/src/pytom_tm/extract.py @@ -32,6 +32,7 @@ def predict_tophat_mask( n_false_positives: float = 1.0, create_plot: bool = True, tophat_connectivity: int = 1, + bins: int = 50, ) -> npt.NDArray[bool]: """This function gets as input a score map and returns a peak mask as determined with a tophat transform. @@ -61,7 +62,8 @@ def predict_tophat_mask( whether to plot the gaussian fit and cut-off estimation tophat_connectivity: int, default 1 connectivity of binary structure - + bins: int, default 50 + number of bins to use for the historgram for estimation and plotting Returns ------- peak_mask: npt.NDArray[bool] @@ -73,7 +75,7 @@ def predict_tophat_mask( rank=3, connectivity=tophat_connectivity ), ) - y, bins = np.histogram(tophat.flatten(), bins=50) + y, bins = np.histogram(tophat.flatten(), bins=bins) bin_centers = (bins[:-1] + bins[1:]) / 2 x_raw, y_raw = ( bin_centers[2:], @@ -154,6 +156,8 @@ def extract_particles( tophat_connectivity: int = 1, relion5_compat: bool = False, ignore_tomogram_mask: bool = False, + tophat_bins: int = 50, + plot_bins: int = 20, ) -> tuple[pd.DataFrame, list[float, ...]]: """ Parameters @@ -185,6 +189,10 @@ def extract_particles( Debug option to force the code to ignore job.tomogram_mask and input mask. Allows for re-exctraction without rerunning the TM job (assuming the scores volume seems reasonable) + tophat_bins: int, default 50 + The numbers of bins to use in the tophat histogram + plot_bins: int, default 20 + The numbers of bins to use for the plot histogram of occurences Returns ------- @@ -208,6 +216,7 @@ def extract_particles( n_false_positives=n_false_positives, create_plot=create_plot, tophat_connectivity=tophat_connectivity, + bins=tophat_bins, ) score_volume *= ( predicted_peaks # multiply with predicted peaks to keep only those @@ -351,7 +360,7 @@ def extract_particles( output = output.rename(columns=column_change) if plotting_available and create_plot: - y, bins = np.histogram(scores, bins=20) + y, bins = np.histogram(scores, bins=plot_bins) x = (bins[1:] + bins[:-1]) / 2 hist_step = bins[1] - bins[0] # add more starting values for background Gaussian diff --git a/tests/test_tmjob.py b/tests/test_tmjob.py index c5fcf9a2..ad0ab175 100644 --- a/tests/test_tmjob.py +++ b/tests/test_tmjob.py @@ -17,7 +17,8 @@ LOCATION = (77, 26, 40) ANGLE_ID = 100 ANGULAR_SEARCH = "38.53" -TEST_DATA_DIR = pathlib.Path(__file__).parent.joinpath("test_data") +TEMP_DIR = TemporaryDirectory() +TEST_DATA_DIR = pathlib.Path(TEMP_DIR.name) TEST_TOMOGRAM = TEST_DATA_DIR.joinpath("tomogram.mrc") TEST_BROKEN_TOMOGRAM_MASK = TEST_DATA_DIR.joinpath("broken_tomogram_mask.mrc") TEST_WRONG_SIZE_TOMO_MASK = TEST_DATA_DIR.joinpath("wrong_size_tomogram_mask.mrc") @@ -133,26 +134,7 @@ def setUpClass(cls) -> None: @classmethod def tearDownClass(cls) -> None: - TEST_MASK.unlink() - TEST_BROKEN_TOMOGRAM_MASK.unlink() - TEST_WRONG_SIZE_TOMO_MASK.unlink() - TEST_EXTRACTION_MASK_OUTSIDE.unlink() - TEST_EXTRACTION_MASK_INSIDE.unlink() - TEST_TEMPLATE.unlink() - TEST_TEMPLATE_UNEQUAL_SPACING.unlink() - TEST_TEMPLATE_WRONG_VOXEL_SIZE.unlink() - TEST_TOMOGRAM.unlink() - TEST_SCORES.unlink() - TEST_ANGLES.unlink() - TEST_CUSTOM_ANGULAR_SEARCH.unlink() - # the whitening filter might not exist if the job with spectrum whitening - # failed, so the unlinking needs to allow this (with missing_ok=True) to ensure - # clean up of the test directory - TEST_WHITENING_FILTER.unlink(missing_ok=True) - TEST_JOB_JSON.unlink() - TEST_JOB_JSON_WHITENING.unlink() - TEST_JOB_OLD_VERSION.unlink() - TEST_DATA_DIR.rmdir() + TEMP_DIR.cleanup() def setUp(self): self.job = TMJob( @@ -630,7 +612,7 @@ def test_tm_job_half_precision(self): self.assertEqual(s.dtype, np.float16) self.assertEqual(a.dtype, np.float32) - def test_extraction(self): + def test_extractions(self): _ = self.job.start_job(0, return_volumes=True) # extract particles after running the job @@ -755,6 +737,23 @@ def test_extraction(self): create_plot=False, ) + # Test exraction with tophat filter and plotting + df, scores = extract_particles( + job, + 5, + 100, + tomogram_mask_path=TEST_EXTRACTION_MASK_INSIDE, + create_plot=True, + tophat_filter=True, + ) + self.assertNotEqual( + len(scores), + 0, + msg="We expected a detected particle with a extraction mask that " + "covers the object.", + ) + # We don't look for the plots, they might be skipped if no plotting is available + def test_get_defocus_offsets(self): tilt_angles = list(range(-51, 54, 3)) x_offset_um = 200 * 13.79 * 1e-4 diff --git a/tests/test_broken_imports.py b/tests/testzz_broken_imports.py similarity index 97% rename from tests/test_broken_imports.py rename to tests/testzz_broken_imports.py index 23a6c826..0bde9499 100644 --- a/tests/test_broken_imports.py +++ b/tests/testzz_broken_imports.py @@ -1,3 +1,5 @@ +# This file is named testzz_* as it should run last, +# because it permanently destroys the imports # No imports of pytom_tm outside of the methods import unittest from importlib import reload