Skip to content

Commit

Permalink
Improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Sep 21, 2023
1 parent d6d6d58 commit 781322e
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
10 changes: 4 additions & 6 deletions direct/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def __init__(
compute_mask: bool = False,
kspace_context: Optional[str] = None,
) -> None:
# pylint: disable=too-many-arguments
"""Inits :class:`CMRxReconDataset`.
Parameters
Expand Down Expand Up @@ -464,7 +465,6 @@ def __init__(
will be loaded (3D data). Default: None.
"""
# pylint: disable=too-many-arguments
self.logger = logging.getLogger(type(self).__name__)

self.root = pathlib.Path(data_root)
Expand Down Expand Up @@ -576,11 +576,9 @@ def verify_extra_mat_integrity(image_fn, _, extra_mats):
for key in extra_mats:
mat_key, path = extra_mats[key]
extra_fn = path / image_fn.name
try:
with h5py.File(extra_fn, "r") as file:
_ = file[mat_key].shape
except Exception as exc:
raise ValueError(f"Reading of {extra_fn} for key {mat_key} failed: {exc}.") from exc
with h5py.File(extra_fn, "r") as file:
_ = file[mat_key].shape
return

def __len__(self):
return len(self.data)
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/vsharp/vsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
# pylint: disable=too-many-locals
super().__init__()
for extra_key in kwargs:
if extra_key != "model_name" or extra_key.startswith("image_"):
if extra_key != "model_name" and not extra_key.startswith("image_"):
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
self.num_steps = num_steps
self.num_steps_dc_gd = num_steps_dc_gd
Expand Down
12 changes: 4 additions & 8 deletions direct/nn/vsharp/vsharp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ def _do_iteration(
loss_dict = self.compute_loss_on_data(
loss_dict, loss_fns, data, output_image, None, auxiliary_loss_weights[i]
)

loss_dict = self.compute_loss_on_data(
loss_dict, loss_fns, data, None, output_kspace, auxiliary_loss_weights[i]
)
# Compute loss on k-space
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, None, output_kspace)

loss = sum(loss_dict.values()) # type: ignore

Expand Down Expand Up @@ -229,10 +227,8 @@ def _do_iteration(
loss_dict = self.compute_loss_on_data(
loss_dict, loss_fns, data, output_image, None, auxiliary_loss_weights[i]
)

loss_dict = self.compute_loss_on_data(
loss_dict, loss_fns, data, None, output_kspace, auxiliary_loss_weights[i]
)
# Compute loss on k-space
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, None, output_kspace)

loss = sum(loss_dict.values()) # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def create_test_cfg(
)
@pytest.mark.parametrize(
"loss_fns",
[["l1_loss", "ssim_loss", "l2_loss"]],
[["l1_loss", "ssim_loss", "l2_loss", "snr_loss", "psnr_loss"]],
)
@pytest.mark.parametrize(
"train_iters, val_iters, checkpointer_iters",
Expand Down
12 changes: 11 additions & 1 deletion tests/tests_nn/test_vsharp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,17 @@ def test_unet_engine(shape, loss_fns, num_steps, num_steps_dc_gd, num_filters, n
)
@pytest.mark.parametrize(
"loss_fns",
[["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]],
[
[
"l1_loss",
"hfen_l1_loss",
"hfen_l2_loss",
"hfen_l1_norm_loss",
"hfen_l2_norm_loss",
"kspace_nmse_loss",
"kspace_nmae_loss",
]
],
)
@pytest.mark.parametrize(
"num_steps, num_steps_dc_gd, num_filters, num_pool_layers",
Expand Down

0 comments on commit 781322e

Please sign in to comment.