-
Notifications
You must be signed in to change notification settings - Fork 9
/
tmjob.py
910 lines (829 loc) · 36.1 KB
/
tmjob.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
from __future__ import annotations
from packaging import version
import pathlib
import warnings
import copy
import itertools as itt
import numpy as np
import numpy.typing as npt
import json
import logging
from typing import Optional, Union
from scipy.fft import next_fast_len, rfftn, irfftn
from pytom_tm.angles import get_angle_list
from pytom_tm.matching import TemplateMatchingGPU
from pytom_tm.weights import (
create_wedge,
power_spectrum_profile,
profile_to_weighting,
create_gaussian_band_pass,
)
from pytom_tm.io import read_mrc_meta_data, read_mrc, write_mrc, UnequalSpacingError
from pytom_tm import __version__ as PYTOM_TM_VERSION
def load_json_to_tmjob(
file_name: pathlib.Path, load_for_extraction: bool = True
) -> TMJob:
"""Load a previous job that was stored with TMJob.write_to_json().
Parameters
----------
file_name: pathlib.Path
path to TMJob json file
load_for_extraction: bool, default True
whether a finished job is loaded form disk for extraction, default is True as
this function is currently only called for pytom_extract_candidates and
pytom_estimate_roc which run on previously finished jobs
Returns
-------
job: TMJob
initialized TMJob
"""
with open(file_name, "r") as fstream:
data = json.load(fstream)
job = TMJob(
data["job_key"],
data["log_level"],
pathlib.Path(data["tomogram"]),
pathlib.Path(data["template"]),
pathlib.Path(data["mask"]),
pathlib.Path(data["output_dir"]),
angle_increment=data.get("angle_increment", data["rotation_file"]),
mask_is_spherical=data["mask_is_spherical"],
tilt_angles=data["tilt_angles"],
tilt_weighting=data["tilt_weighting"],
search_x=data["search_x"],
search_y=data["search_y"],
search_z=data["search_z"],
# Use 'get' for backwards compatibility
tomogram_mask=data.get("tomogram_mask", None),
voxel_size=data["voxel_size"],
low_pass=data["low_pass"],
# Use 'get' for backwards compatibility
high_pass=data.get("high_pass", None),
dose_accumulation=data.get("dose_accumulation", None),
ctf_data=data.get("ctf_data", None),
whiten_spectrum=data.get("whiten_spectrum", False),
rotational_symmetry=data.get("rotational_symmetry", 1),
# if version number is not in the .json, it must be 0.3.0 or older
pytom_tm_version_number=data.get("pytom_tm_version_number", "0.3.0"),
job_loaded_for_extraction=load_for_extraction,
particle_diameter=data.get("particle_diameter", None),
random_phase_correction=data.get("random_phase_correction", False),
rng_seed=data.get("rng_seed", 321),
)
# if the file originates from an old version set the phase shift for compatibility
if (
version.parse(job.pytom_tm_version_number) < version.parse("0.6.1")
and job.ctf_data is not None
):
for tilt in job.ctf_data:
tilt["phase_shift_deg"] = 0.0
job.whole_start = data["whole_start"]
job.sub_start = data["sub_start"]
job.sub_step = data["sub_step"]
job.n_rotations = data["n_rotations"]
job.start_slice = data["start_slice"]
job.steps_slice = data["steps_slice"]
job.job_stats = data["job_stats"]
return job
def _determine_1D_fft_splits(
length: int, splits: int, overhang: int = 0
) -> list[tuple[tuple[int, int], tuple[int, int]]]:
"""Split a 1D length into FFT optimal sizes taking into account overhangs
Parameters
----------
length: int
Total 1D length to split
splits: int
Number of splits to make
overhang: int, default 0
Minimal overhang/overlap to consider between splits
Returns
-------
output: list[tuple[tuple[int, int], tuple[int,int]]]:
A list splits where every split gets two tuples meaning:
[start, end) of the tomogram data in this split
[start, end) of the unique datapoints in this split
If a datapoint exists in 2 splits, we add it as unique to
either the split with the most data or the left one if both
splits have the same size
"""
# Everything in this code assumes default slices of [x,y) so including x but
# excluding y
data_slices = []
valid_data_slices = []
sub_len = []
# if single split return early
if splits == 1:
return [((0, length), (0, length))]
if splits > length:
warnings.warn(
"More splits than pixels where asked," " will default to 1 split per pixel",
RuntimeWarning,
)
splits = length
# Ceil to guarantee that we map the whole length with enough buffer
min_len = int(np.ceil(length / splits)) + overhang
min_unique_len = min_len - overhang
no_overhang_left = 0
while True:
if no_overhang_left == 0:
# Treat first split specially, only right overhang
split_length = next_fast_len(min_len)
data_slices.append((0, split_length))
valid_data_slices.append((0, split_length - overhang))
no_overhang_left = split_length - overhang
sub_len.append(split_length)
elif no_overhang_left + min_unique_len >= length:
# Last slice, only overhang to the left
split_length = next_fast_len(min_len)
data_slices.append((length - split_length, length))
valid_data_slices.append((length - split_length + overhang, length))
sub_len.append(split_length)
break
else:
# Any other slice
split_length = next_fast_len(min_len + overhang)
left_overhang = (split_length - min_unique_len) // 2
temp_left = no_overhang_left - left_overhang
temp_right = temp_left + split_length
data_slices.append((temp_left, temp_right))
valid_data_slices.append((temp_left + overhang, temp_right - overhang))
sub_len.append(split_length)
no_overhang_left = temp_right - overhang
if split_length <= 0 or no_overhang_left <= 0:
raise RuntimeError(
f"Cannot generate legal splits for {length=}, {splits=}, {overhang=}"
)
# Now generate the best unique data point,
# we always pick the bigest data subset or the left one
unique_data = []
unique_left = 0
for i, (len1, len2) in enumerate(itt.pairwise(sub_len)):
if len1 >= len2:
right = valid_data_slices[i][1]
else:
right = valid_data_slices[i + 1][0]
unique_data.append((unique_left, right))
unique_left = right
# Add final part
if unique_left != length:
unique_data.append((unique_left, length))
# Make sure unique slices are unique and within valid data
last_right = 0
for (vd_left, vd_right), (ud_left, ud_right) in zip(valid_data_slices, unique_data):
if (
ud_left < vd_left
or ud_right > vd_right
or ud_right > length
or ud_left != last_right
): # pragma: no cover
raise RuntimeError(
f"We produced inconsistent slices for {length=}, {splits=}, {overhang=}"
)
last_right = ud_right
return list(zip(data_slices, unique_data))
class TMJobError(Exception):
"""TMJob Exception with provided message."""
def __init__(self, message):
# Call the base class constructor with the parameters it needs
super().__init__(message)
class TMJob:
def __init__(
self,
job_key: str,
log_level: int,
tomogram: pathlib.Path,
template: pathlib.Path,
mask: pathlib.Path,
output_dir: pathlib.Path,
angle_increment: Optional[Union[str, float]] = None,
mask_is_spherical: bool = True,
tilt_angles: Optional[list[float, ...]] = None,
tilt_weighting: bool = False,
search_x: Optional[list[int, int]] = None,
search_y: Optional[list[int, int]] = None,
search_z: Optional[list[int, int]] = None,
tomogram_mask: Optional[pathlib.Path] = None,
voxel_size: Optional[float] = None,
low_pass: Optional[float] = None,
high_pass: Optional[float] = None,
dose_accumulation: Optional[list[float, ...]] = None,
ctf_data: Optional[list[dict, ...]] = None,
whiten_spectrum: bool = False,
rotational_symmetry: int = 1,
pytom_tm_version_number: str = PYTOM_TM_VERSION,
job_loaded_for_extraction: bool = False,
particle_diameter: Optional[float] = None,
random_phase_correction: bool = False,
rng_seed: int = 321,
):
"""
Parameters
----------
job_key: str
job identifier
log_level: int
log level for logging module
tomogram: pathlib.Path
path to tomogram MRC
template: pathlib.Path
path to template MRC
mask: pathlib.Path
path to mask MRC
output_dir: pathlib.Path
path to output directory
angle_increment: Union[str, float]; default 7.00
angular increment of template search
mask_is_spherical: bool, default True
whether template mask is spherical, reduces computation complexity
tilt_angles: Optional[list[float, ...]], default None
tilt angles of tilt-series used to reconstruct tomogram, if only two floats
will be used to generate a continuous wedge model
tilt_weighting: bool, default False
use advanced tilt weighting options, can be supplemented with CTF parameters
and accumulated dose
search_x: Optional[list[int, int]], default None
restrict tomogram search region along the x-axis
search_y: Optional[list[int, int]], default None
restrict tomogram search region along the y-axis
search_z: Optional[list[int, int]], default None
restrict tomogram search region along the z-axis
tomogram_mask: Optional[pathlib.Path], default None
when volume splitting tomograms, only subjobs where any(mask > 0) will be
generated
voxel_size: Optional[float], default None
voxel size of tomogram and template (in A) if not provided will be read from
template/tomogram MRCs
low_pass: Optional[float], default None
optional low-pass filter (resolution in A) to apply to tomogram and template
high_pass: Optional[float], default None
optional high-pass filter (resolution in A) to apply to tomogram and
template
dose_accumulation: Optional[list[float, ...]], default None
list with dose accumulation per tilt image
ctf_data: Optional[list[dict, ...]], default None
list of dictionaries with CTF parameters per tilt image, see
pytom_tm.weight.create_ctf() for parameter definition
whiten_spectrum: bool, default False
whether to apply spectrum whitening
rotational_symmetry: int, default 1
specify a rotational symmetry around the z-axis, is only valid if the
symmetry axis of the template is aligned with the z-axis
pytom_tm_version_number: str, default current version
a string with the version number of pytom_tm for backward compatibility
job_loaded_for_extraction: bool, default False
flag to set for finished template matching jobs that are loaded back for
extraction, it prevents recomputation of the whitening filter
particle_diameter: Optional[float], default None
particle diameter (in Angstrom) to calculate angular search
random_phase_correction: bool, default False,
run matching with a phase randomized version of the template to correct
scores for noise
rng_seed: int, default 321
set a seed for the rng for phase randomization
"""
self.mask = mask
self.mask_is_spherical = mask_is_spherical
self.output_dir = output_dir
self.tomogram = tomogram
self.template = template
self.tomo_id = self.tomogram.stem
try:
meta_data_tomo = read_mrc_meta_data(self.tomogram)
except UnequalSpacingError: # add information that the problem is the tomogram
raise UnequalSpacingError(
"Input tomogram voxel spacing is not equal in each dimension!"
)
try:
meta_data_template = read_mrc_meta_data(self.template)
except UnequalSpacingError: # add information that the problem is the template
raise UnequalSpacingError(
"Input template voxel spacing is not equal in each dimension!"
)
self.tomo_shape = meta_data_tomo["shape"]
self.template_shape = meta_data_template["shape"]
if voxel_size is not None:
if voxel_size <= 0:
raise ValueError(
"Invalid voxel size provided, smaller or equal to zero."
)
self.voxel_size = voxel_size
if ( # allow tiny numerical differences that are not relevant for
# template matching
round(self.voxel_size, 3) != round(meta_data_tomo["voxel_size"], 3)
or round(self.voxel_size, 3)
!= round(meta_data_template["voxel_size"], 3)
):
logging.debug(
f"provided {self.voxel_size} tomogram "
f"{meta_data_tomo['voxel_size']} "
f"template {meta_data_template['voxel_size']}"
)
print(
"WARNING: Provided voxel size does not match voxel size annotated "
"in tomogram/template mrc."
)
elif (
round(meta_data_tomo["voxel_size"], 3)
== round(meta_data_template["voxel_size"], 3)
and meta_data_tomo["voxel_size"] > 0
):
self.voxel_size = round(meta_data_tomo["voxel_size"], 3)
else:
raise ValueError(
"Voxel size could not be assigned, either a mismatch between tomogram "
"and template or annotated as 0."
)
search_origin = [
x[0] if x is not None else 0 for x in (search_x, search_y, search_z)
]
# Check if tomogram origin is valid
if all([0 <= x < y for x, y in zip(search_origin, self.tomo_shape)]):
self.search_origin = search_origin
else:
raise ValueError("Invalid input provided for search origin of tomogram.")
# if end not valid raise and error
search_end = []
for x, s in zip([search_x, search_y, search_z], self.tomo_shape):
if x is not None:
if not x[1] <= s:
raise ValueError(
"One of search end indices is larger than the tomogram "
"dimension."
)
search_end.append(x[1])
else:
search_end.append(s)
self.search_size = [
end - start for end, start in zip(search_end, self.search_origin)
]
logging.debug(f"origin, size = {self.search_origin}, {self.search_size}")
self.tomogram_mask = tomogram_mask
if tomogram_mask is not None:
temp = read_mrc(tomogram_mask)
if temp.shape != self.tomo_shape:
raise ValueError(
"Tomogram mask does not have the same number of pixels as the "
"tomogram.\n"
f"Tomogram mask shape: {temp.shape}, "
f"tomogram shape: {self.tomo_shape}"
)
if np.all(temp <= 0):
raise ValueError(
"No values larger than 0 found in the tomogram mask: "
f"{tomogram_mask}"
)
self.whole_start = None
# For the main job these are always [0,0,0] and self.search_size, for sub_jobs
# these will differ from self.search_origin and self.search_size. The main job
# only uses them to calculate the search_volume_roi for statistics. Sub jobs
# also use these to extract and place back the relevant region in the master
# job.
self.sub_start, self.sub_step = [0, 0, 0], self.search_size.copy()
# Rotation parameters
self.start_slice = 0
self.steps_slice = 1
self.rotational_symmetry = rotational_symmetry
self.particle_diameter = particle_diameter
# calculate increment from particle diameter
if angle_increment is None:
if particle_diameter is not None:
max_res = max(
2 * self.voxel_size, low_pass if low_pass is not None else 0
)
angle_increment = np.rad2deg(max_res / particle_diameter)
else:
angle_increment = 7.0
self.rotation_file = angle_increment
if job_loaded_for_extraction:
log_level = "DEBUG"
else:
log_level = "INFO"
try:
angle_list = get_angle_list(
angle_increment,
sort_angles=False,
symmetry=rotational_symmetry,
log_level=log_level,
)
except ValueError:
raise TMJobError("Invalid angular search provided.")
self.n_rotations = len(angle_list)
# missing wedge
self.tilt_angles = tilt_angles
self.tilt_weighting = tilt_weighting
# set the band-pass resolution shells
self.low_pass = low_pass
self.high_pass = high_pass
# set dose and ctf
self.dose_accumulation = dose_accumulation
self.ctf_data = ctf_data
self.whiten_spectrum = whiten_spectrum
self.whitening_filter = self.output_dir.joinpath(
f"{self.tomo_id}_whitening_filter.npy"
)
if self.whiten_spectrum and not job_loaded_for_extraction:
logging.info("Estimating whitening filter...")
weights = 1 / np.sqrt(
power_spectrum_profile(
read_mrc(self.tomogram)[
self.search_origin[0] : self.search_origin[0]
+ self.search_size[0],
self.search_origin[1] : self.search_origin[1]
+ self.search_size[1],
self.search_origin[2] : self.search_origin[2]
+ self.search_size[2],
]
)
)
weights /= weights.max() # scale to 1
np.save(self.whitening_filter, weights)
# phase randomization options
self.random_phase_correction = random_phase_correction
self.rng_seed = rng_seed
# Job details
self.job_key = job_key
self.leader = None # the job that spawned this job
self.sub_jobs = [] # if this job had no sub jobs it should be executed
# dict to keep track of job statistics
self.job_stats = None
self.log_level = log_level
# version number of the job
self.pytom_tm_version_number = pytom_tm_version_number
def copy(self) -> TMJob:
"""Create a copy of the TMJob
Returns
-------
job: TMJob
copied TMJob instance
"""
return copy.deepcopy(self)
def write_to_json(self, file_name: pathlib.Path) -> None:
"""Write job to .json file.
Parameters
----------
file_name: pathlib.Path
path to the output file
"""
d = self.__dict__.copy()
d.pop("sub_jobs")
d.pop("search_origin")
d.pop("search_size")
d["search_x"] = [
self.search_origin[0],
self.search_origin[0] + self.search_size[0],
]
d["search_y"] = [
self.search_origin[1],
self.search_origin[1] + self.search_size[1],
]
d["search_z"] = [
self.search_origin[2],
self.search_origin[2] + self.search_size[2],
]
for key, value in d.items():
if isinstance(value, pathlib.Path):
d[key] = str(value)
with open(file_name, "w") as fstream:
json.dump(d, fstream, indent=4)
def split_rotation_search(self, n: int) -> list[TMJob, ...]:
"""Split the search into sub_jobs by dividing the rotations. Sub jobs will
obtain the key self.job_key + str(i) when looping over range(n).
Parameters
----------
n: int
number of times to split the angular search
Returns
-------
sub_jobs: list[TMJob, ...]
a list of TMJobs that were split from self, the jobs are also assigned as
the TMJob.sub_jobs attribute
"""
if len(self.sub_jobs) > 0:
raise TMJobError(
"Could not further split this job as it already has subjobs assigned!"
)
sub_jobs = []
for i in range(n):
new_job = self.copy()
new_job.start_slice = i
new_job.steps_slice = n
new_job.leader = self.job_key
new_job.job_key = self.job_key + str(i)
sub_jobs.append(new_job)
self.sub_jobs = sub_jobs
return self.sub_jobs
def split_volume_search(self, split: tuple[int, int, int]) -> list[TMJob, ...]:
"""Split the search into sub_jobs by dividing into subvolumes. Final number of
subvolumes is obtained by multiplying all the split together, e.g. (2, 2, 1)
results in 4 subvolumes. Sub jobs will obtain the key self.job_key + str(i) when
looping over range(n).
The sub jobs search area of the full tomogram is defined by:
new_job.search_origin and new_job.search_size.
They are used when loading the search volume from the full tomogram.
The attribute new_job.whole_start defines how the volume maps back to the score
volume of the parent job (which can be a different size than the tomogram when
the search is restricted along x, y or z).
Finally, new_job.sub_start and new_job.sub_step, extract the score and angle map
without the template overhang from the subvolume.
If self.tomogram_mask is set, we will skip subjobs where all(mask <= 0).
Parameters
----------
split: tuple[int, int, int]
tuple that defines how many times the search volume should be split into
subvolumes along each axis
Returns
-------
sub_jobs: list[TMJob, ...]
a list of TMJobs that were split from self, the jobs are also assigned as
the TMJob.sub_jobs attribute
"""
if len(self.sub_jobs) > 0:
raise TMJobError(
"Could not further split this job as it already has subjobs assigned!"
)
search_size = self.search_size
if self.tomogram_mask is not None:
# This should have some positve values after the check in the __init__
tomogram_mask = read_mrc(self.tomogram_mask)
else:
tomogram_mask = None
# shape of template for overhang
overhang = self.template_shape
# use overhang//2 (+1 for odd sizes)
overhang = tuple(sum(divmod(o, 2)) for o in overhang)
x_splits = _determine_1D_fft_splits(search_size[0], split[0], overhang[0])
y_splits = _determine_1D_fft_splits(search_size[1], split[1], overhang[1])
z_splits = _determine_1D_fft_splits(search_size[2], split[2], overhang[2])
sub_jobs = []
for i, data_3D in enumerate(itt.product(x_splits, y_splits, z_splits)):
# each data point for each dim is slice(left, right) of the search space
# and slice(left,right) of the unique data point in the search space
# Look at the comments in the new_job.attribute for the meaning of each
# attribute
search_origin = tuple(
data_3D[d][0][0] + self.search_origin[d] for d in range(3)
)
search_size = tuple(dim_data[0][1] - dim_data[0][0] for dim_data in data_3D)
whole_start = tuple(dim_data[1][0] for dim_data in data_3D)
sub_start = tuple(dim_data[1][0] - dim_data[0][0] for dim_data in data_3D)
sub_step = tuple(dim_data[1][1] - dim_data[1][0] for dim_data in data_3D)
# check if this contains any of the unique data points are where
# tomo_mask > 0
if tomogram_mask is not None:
slices = [
slice(origin, origin + step)
for origin, step in zip(whole_start, sub_step)
]
if np.all(tomogram_mask[*slices] <= 0):
# No non-masked unique data-points, skipping
continue
new_job = self.copy()
new_job.leader = self.job_key
new_job.job_key = self.job_key + str(i)
# search origin with respect to the complete tomogram
new_job.search_origin = search_origin
# search size TODO: should be combined with the origin into slices
new_job.search_size = search_size
# whole start is the start of the unique data within the complete searched
# array
new_job.whole_start = whole_start
# sub_start is where the unique data starts inside the split array
new_job.sub_start = sub_start
# sub_step is the step of unique data inside the split array.
# TODO: should be slices instead
new_job.sub_step = sub_step
sub_jobs.append(new_job)
self.sub_jobs = sub_jobs
return self.sub_jobs
def merge_sub_jobs(
self, stats: Optional[list[dict, ...]] = None
) -> tuple[npt.NDArray[float], npt.NDArray[float]]:
"""Merge the sub jobs present in self.sub_jobs together to create the final
output score and angle maps.
Parameters
----------
stats: Optional[list[dict, ...]], default None
optional list of sub job statistics to merge together
Returns
-------
output: tuple[npt.NDArray[float], npt.NDArray[float]]
the merged score and angle maps from the subjobs
"""
if len(self.sub_jobs) == 0:
# read the volumes, remove them and return them
score_file, angle_file = (
self.output_dir.joinpath(f"{self.tomo_id}_scores_{self.job_key}.mrc"),
self.output_dir.joinpath(f"{self.tomo_id}_angles_{self.job_key}.mrc"),
)
result = (read_mrc(score_file), read_mrc(angle_file))
(score_file.unlink(), angle_file.unlink())
return result
if stats is not None:
search_space = sum([s["search_space"] for s in stats])
variance = sum([s["variance"] for s in stats]) / len(stats)
self.job_stats = {
"search_space": search_space,
"variance": variance,
"std": np.sqrt(variance),
}
is_subvolume_split = np.all(
np.array([x.start_slice for x in self.sub_jobs]) == 0
)
score_volumes, angle_volumes = [], []
for x in self.sub_jobs:
result = x.merge_sub_jobs()
score_volumes.append(result[0])
angle_volumes.append(result[1])
if not is_subvolume_split:
scores, angles = (
np.zeros_like(score_volumes[0]) - 1.0,
np.zeros_like(angle_volumes[0]) - 1.0,
)
for s, a in zip(score_volumes, angle_volumes):
angles = np.where(s > scores, a, angles)
# prevents race condition due to slicing
angles = np.where(s == scores, np.minimum(a, angles), angles)
scores = np.where(s > scores, s, scores)
else:
scores, angles = (
np.zeros(self.search_size, dtype=np.float32),
np.zeros(self.search_size, dtype=np.float32),
)
for job, s, a in zip(self.sub_jobs, score_volumes, angle_volumes):
sub_scores = s[
job.sub_start[0] : job.sub_start[0] + job.sub_step[0],
job.sub_start[1] : job.sub_start[1] + job.sub_step[1],
job.sub_start[2] : job.sub_start[2] + job.sub_step[2],
]
sub_angles = a[
job.sub_start[0] : job.sub_start[0] + job.sub_step[0],
job.sub_start[1] : job.sub_start[1] + job.sub_step[1],
job.sub_start[2] : job.sub_start[2] + job.sub_step[2],
]
# Then the corrected sub part needs to be placed back into the full
# volume
scores[
job.whole_start[0] : job.whole_start[0] + sub_scores.shape[0],
job.whole_start[1] : job.whole_start[1] + sub_scores.shape[1],
job.whole_start[2] : job.whole_start[2] + sub_scores.shape[2],
] = sub_scores
angles[
job.whole_start[0] : job.whole_start[0] + sub_scores.shape[0],
job.whole_start[1] : job.whole_start[1] + sub_scores.shape[1],
job.whole_start[2] : job.whole_start[2] + sub_scores.shape[2],
] = sub_angles
return scores, angles
def start_job(
self, gpu_id: int, return_volumes: bool = False
) -> Union[tuple[npt.NDArray[float], npt.NDArray[float]], dict]:
"""Run this template matching job on the specified GPU. Search statistics of the
job will always be assigned to the self.job_stats.
Parameters
----------
gpu_id: int
index of the GPU to run the job on
return_volumes: bool, default False
False (default) does not return volumes but instead writes them to disk, set
to True to instead directly return the score and angle volumes
Returns
-------
output: Union[tuple[npt.NDArray[float], npt.NDArray[float]], dict]
when volumes are returned the output consists of two numpy arrays (score and
angle map), when no volumes are returned the output consists of a dictionary
with search statistics
"""
# next fast fft len
logging.debug(
"Next fast fft shape: "
f"{tuple([next_fast_len(s, real=True) for s in self.search_size])}"
)
search_volume = np.zeros(
tuple([next_fast_len(s, real=True) for s in self.search_size]),
dtype=np.float32,
)
# load the (sub)volume
search_volume[
: self.search_size[0], : self.search_size[1], : self.search_size[2]
] = np.ascontiguousarray(
read_mrc(self.tomogram)[
self.search_origin[0] : self.search_origin[0] + self.search_size[0],
self.search_origin[1] : self.search_origin[1] + self.search_size[1],
self.search_origin[2] : self.search_origin[2] + self.search_size[2],
]
)
# load template and mask
template, mask = (read_mrc(self.template), read_mrc(self.mask))
# apply mask directly to prevent any wedge convolution with weird edges
template *= mask
# init tomogram and template weighting
tomo_filter, template_wedge = 1, 1
# first generate bandpass filters
if not (self.low_pass is None and self.high_pass is None):
tomo_filter *= create_gaussian_band_pass(
search_volume.shape, self.voxel_size, self.low_pass, self.high_pass
).astype(np.float32)
template_wedge *= create_gaussian_band_pass(
self.template_shape, self.voxel_size, self.low_pass, self.high_pass
).astype(np.float32)
# then multiply with optional whitening filters
if self.whiten_spectrum:
tomo_filter *= profile_to_weighting(
np.load(self.whitening_filter), search_volume.shape
).astype(np.float32)
template_wedge *= profile_to_weighting(
np.load(self.whitening_filter), self.template_shape
).astype(np.float32)
# if tilt angles are provided we can create wedge filters
if self.tilt_angles is not None:
# for the tomogram a binary wedge is generated to explicitly set the
# missing wedge region to 0
tomo_filter *= create_wedge(
search_volume.shape,
self.tilt_angles,
self.voxel_size,
cut_off_radius=1.0,
angles_in_degrees=True,
tilt_weighting=False,
).astype(np.float32)
# for the template a binary or per-tilt-weighted wedge is generated
# depending on the options
template_wedge *= create_wedge(
self.template_shape,
self.tilt_angles,
self.voxel_size,
cut_off_radius=1.0,
angles_in_degrees=True,
tilt_weighting=self.tilt_weighting,
accumulated_dose_per_tilt=self.dose_accumulation,
ctf_params_per_tilt=self.ctf_data,
).astype(np.float32)
if logging.DEBUG >= logging.root.level:
write_mrc(
self.output_dir.joinpath("template_psf.mrc"),
template_wedge,
self.voxel_size,
)
write_mrc(
self.output_dir.joinpath("template_convolved.mrc"),
irfftn(rfftn(template) * template_wedge, s=template.shape),
self.voxel_size,
)
# apply the optional band pass and whitening filter to the search region
search_volume = np.real(
irfftn(rfftn(search_volume) * tomo_filter, s=search_volume.shape)
)
# load rotation search
angle_ids = list(range(self.start_slice, self.n_rotations, self.steps_slice))
angle_list = get_angle_list(
self.rotation_file,
sort_angles=version.parse(self.pytom_tm_version_number)
> version.parse("0.3.0"),
symmetry=self.rotational_symmetry,
)
angle_list = angle_list[
slice(self.start_slice, self.n_rotations, self.steps_slice)
]
# slices for relevant part for job statistics
search_volume_roi = (
slice(self.sub_start[0], self.sub_start[0] + self.sub_step[0]),
slice(self.sub_start[1], self.sub_start[1] + self.sub_step[1]),
slice(self.sub_start[2], self.sub_start[2] + self.sub_step[2]),
)
tm = TemplateMatchingGPU(
job_id=self.job_key,
device_id=gpu_id,
volume=search_volume,
template=template,
mask=mask,
angle_list=angle_list,
angle_ids=angle_ids,
mask_is_spherical=self.mask_is_spherical,
wedge=template_wedge,
stats_roi=search_volume_roi,
noise_correction=self.random_phase_correction,
rng_seed=self.rng_seed,
)
results = tm.run()
score_volume = results[0][
: self.search_size[0], : self.search_size[1], : self.search_size[2]
]
angle_volume = results[1][
: self.search_size[0], : self.search_size[1], : self.search_size[2]
]
self.job_stats = results[2]
del tm # delete the template matching plan
if return_volumes:
return score_volume, angle_volume
else: # otherwise write them out with job_key
write_mrc(
self.output_dir.joinpath(f"{self.tomo_id}_scores_{self.job_key}.mrc"),
score_volume,
self.voxel_size,
)
write_mrc(
self.output_dir.joinpath(f"{self.tomo_id}_angles_{self.job_key}.mrc"),
angle_volume,
self.voxel_size,
)
return self.job_stats