diff --git a/direct/nn/transformers/config.py b/direct/nn/transformers/config.py index 91eb01d1..4cc6a34d 100644 --- a/direct/nn/transformers/config.py +++ b/direct/nn/transformers/config.py @@ -42,9 +42,11 @@ class ImageDomainMRIViT3DConfig(MRIViTConfig): class KSpaceDomainMRIViT2DConfig(MRIViTConfig): average_size: tuple[int, int] = (320, 320) patch_size: tuple[int, int] = (16, 16) + compute_per_coil: bool = True @dataclass class KSpaceDomainMRIViT3DConfig(MRIViTConfig): average_size: tuple[int, int] = (320, 320, 320) patch_size: tuple[int, int] = (16, 16, 16) + compute_per_coil: bool = True