Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable 16 bit output #222

Merged
merged 18 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/pytom_tm/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,13 @@ def match_template(argv=None):
"For this method please see STOPGAP as a reference: "
"https://doi.org/10.1107/S205979832400295X .",
)
additional_group.add_argument(
"--half-precision",
action="store_true",
default=False,
required=False,
help="Return and save all output in float16 instead of the default float32",
)
additional_group.add_argument(
"--rng-seed",
type=int,
Expand Down Expand Up @@ -959,6 +966,7 @@ def match_template(argv=None):
random_phase_correction=args.random_phase_correction,
rng_seed=args.rng_seed,
defocus_handedness=args.defocus_handedness,
output_dtype=np.float16 if args.half_precision else np.float32,
)

score_volume, angle_volume = run_job_parallel(
Expand Down
5 changes: 3 additions & 2 deletions src/pytom_tm/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,10 @@ def write_mrc(
Returns
-------
"""
if data.dtype != np.float32:
if data.dtype not in [np.float32, np.float16]:
logging.warning(
"data for mrc writing is not np.float32 will convert to np.float32"
"data for mrc writing is not np.float32 or np.float16, will convert to "
"np.float32"
)
data = data.astype(np.float32)
mrcfile.write(
Expand Down
19 changes: 18 additions & 1 deletion src/pytom_tm/tmjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def load_json_to_tmjob(
with open(file_name, "r") as fstream:
data = json.load(fstream)

# wrangle dtypes
output_dtype = data.get("output_dtype", "float32")
output_dtype = np.dtype(output_dtype)

job = TMJob(
data["job_key"],
data["log_level"],
Expand Down Expand Up @@ -75,6 +79,7 @@ def load_json_to_tmjob(
random_phase_correction=data.get("random_phase_correction", False),
rng_seed=data.get("rng_seed", 321),
defocus_handedness=data.get("defocus_handedness", 0),
output_dtype=output_dtype,
)
# if the file originates from an old version set the phase shift for compatibility
if (
Expand Down Expand Up @@ -280,6 +285,7 @@ def __init__(
random_phase_correction: bool = False,
rng_seed: int = 321,
defocus_handedness: int = 0,
output_dtype: np.dtype = np.float32,
):
"""
Parameters
Expand Down Expand Up @@ -350,6 +356,8 @@ def __init__(
-1 = inverted
0 = don't correct offsets (preferred if unknown)
1 = regular (as in Pyle & Zianetti (2021))
output_dtype: np.dtype, default np.float32
output score volume dtype, options are np.float32 and np.float16
"""
self.mask = mask
self.mask_is_spherical = mask_is_spherical
Expand Down Expand Up @@ -539,6 +547,9 @@ def __init__(
# version number of the job
self.pytom_tm_version_number = pytom_tm_version_number

# output dtype
self.output_dtype = output_dtype

def copy(self) -> TMJob:
"""Create a copy of the TMJob

Expand Down Expand Up @@ -576,6 +587,8 @@ def write_to_json(self, file_name: pathlib.Path) -> None:
for key, value in d.items():
if isinstance(value, pathlib.Path):
d[key] = str(value)
# wrangle dtype conversion
d["output_dtype"] = str(np.dtype(d["output_dtype"]))
with open(file_name, "w") as fstream:
json.dump(d, fstream, indent=4)

Expand Down Expand Up @@ -794,7 +807,7 @@ def merge_sub_jobs(
job.whole_start[1] : job.whole_start[1] + sub_scores.shape[1],
job.whole_start[2] : job.whole_start[2] + sub_scores.shape[2],
] = sub_angles
return scores, angles
return scores.astype(self.output_dtype), angles

def start_job(
self, gpu_id: int, return_volumes: bool = False
Expand Down Expand Up @@ -975,6 +988,10 @@ def start_job(

del tm # delete the template matching plan

# cast to correct dtype
score_volume = score_volume.astype(self.output_dtype)
angle_volume = angle_volume

if return_volumes:
return score_volume, angle_volume
else: # otherwise write them out with job_key
Expand Down
38 changes: 37 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import unittest
from pytom_tm.io import read_mrc, read_mrc_meta_data
import pathlib
import warnings
import contextlib
from tempfile import TemporaryDirectory
import numpy as np
import mrcfile

from pytom_tm.io import read_mrc, read_mrc_meta_data, write_mrc

FAILING_MRC = pathlib.Path(__file__).parent.joinpath(
pathlib.Path("Data/human_ribo_mask_32_8_5.mrc")
Expand All @@ -25,6 +29,11 @@ def setUp(self):

self.addCleanup(stack.close)

# prep temporary directory
tempdir = TemporaryDirectory()
self.tempdirname = tempdir.name
self.addCleanup(tempdir.cleanup)

def test_read_mrc_minor_broken(self):
# Test if this mrc can be read and if the approriate logs are printed
with self.assertLogs(level="WARNING") as cm:
Expand All @@ -49,3 +58,30 @@ def test_read_mrc_meta_data(self):
self.assertEqual(len(cm.output), 1)
self.assertIn(FAILING_MRC.name, cm.output[0])
self.assertIn("make sure this is correct", cm.output[0])

def test_half_precision_read_write_cycle(self):
array = np.random.rand(27).reshape((3, 3, 3)).astype(np.float16)
fname = pathlib.Path(self.tempdirname) / "test_half.mrc"
# Make sure no warnings are raised
with self.assertNoLogs(level="WARNING"):
write_mrc(fname, array, 1.0)
# Make sure the file can be read back
# make sure mode is as expected for float16
# https://mrcfile.readthedocs.io/en/stable/source/mrcfile.html#mrcfile.utils.dtype_from_mode
mrc = mrcfile.open(fname)
self.assertEqual(mrc.header.mode, 12)
mrc.close()
# make sure dtype is expected
mrc = read_mrc(fname)
self.assertEqual(mrc.dtype, np.float16)
# make sure data is identical
np.testing.assert_equal(mrc, array)

def test_cast_warning(self):
# make sure a warning is raised when writing an integer based array
array = np.random.rand(27).reshape((3, 3, 3)).astype(np.int32)
fname = pathlib.Path(self.tempdirname) / "test_cast.mrc"
with self.assertLogs(level="WARNING") as cm:
write_mrc(fname, array, 1.0)
self.assertEqual(len(cm.output), 1)
self.assertIn("np.float32", cm.output[0])
16 changes: 16 additions & 0 deletions tests/test_tmjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,22 @@ def test_tm_job_split_angles(self):
"should be almost identical.",
)

def test_tm_job_half_precision(self):
job = TMJob(
"0",
10,
TEST_TOMOGRAM,
TEST_TEMPLATE,
TEST_MASK,
TEST_DATA_DIR,
angle_increment=ANGULAR_SEARCH,
voxel_size=1.0,
output_dtype=np.float16,
)
s, a = job.start_job(0, return_volumes=True)
self.assertEqual(s.dtype, np.float16)
self.assertEqual(a.dtype, np.float32)

def test_extraction(self):
_ = self.job.start_job(0, return_volumes=True)

Expand Down