Skip to content

Commit

Permalink
Visualize registration items
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Aug 12, 2024
1 parent f468869 commit 8742423
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 0 deletions.
24 changes: 24 additions & 0 deletions direct/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,14 @@ def validation_loop(
curr_data_loader,
loss_fns,
)
if isinstance(visualize_slices, tuple):
(visualize_slices, visualize_registration_slices) = visualize_slices
else:
visualize_registration_slices = None
if isinstance(visualize_target, tuple):
(visualize_target, visualize_registration_target) = visualize_target
else:
visualize_registration_target = None

if experiment_directory:
json_output_fn = experiment_directory / f"metrics_val_{curr_dataset_name}_{iter_idx}.json"
Expand Down Expand Up @@ -487,6 +495,12 @@ def validation_loop(
visualize_slices = self.process_slices_for_visualization(visualize_slices, visualize_target)
storage.add_image(f"{key_prefix}prediction", visualize_slices)

if visualize_registration_slices is not None:
visualize_registration_slices = self.process_slices_for_visualization(
visualize_registration_slices, visualize_registration_target
)
storage.add_image(f"{key_prefix}registration_prediction", visualize_registration_slices)

if visualize_mask is not None:
visualize_mask = make_grid(
crop_to_largest([normalize_image(image) for image in visualize_mask], pad_value=0),
Expand All @@ -504,6 +518,16 @@ def validation_loop(
)
storage.add_image(f"{key_prefix}target", visualize_target)

if visualize_registration_target is not None:
visualize_registration_target = make_grid(
crop_to_largest(
[normalize_image(image) for image in visualize_registration_target], pad_value=0
),
nrow=self.cfg.logging.tensorboard.num_images, # type: ignore
scale_each=True,
)
storage.add_image(f"{key_prefix}registration_target", visualize_registration_target)

self.logger.info("Done evaluation of %s at iteration %s.", str(curr_dataset_name), str(iter_idx))
self.model.train()

Expand Down
126 changes: 126 additions & 0 deletions direct/nn/mri_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,9 @@ def reconstruct_volumes( # type: ignore
last_filename = None # At the start of evaluation, there are no filenames.
curr_volume = None
curr_target = None
if "registration_model" in self.models:
curr_registration_volume = None
curr_registration_target = None
instance_counter = 0
filenames_seen = 0

Expand Down Expand Up @@ -930,6 +933,9 @@ def reconstruct_volumes( # type: ignore
# Compute output
iteration_output = self._do_iteration(data, loss_fns=loss_fns, regularizer_fns=regularizer_fns)
output = iteration_output.output_image
if "registration_model" in self.models:
output, registered_output = output

sampling_mask = iteration_output.sampling_mask
if sampling_mask is not None:
sampling_mask = sampling_mask.squeeze(-1).float() # Last dimension is 1 (complex dim)
Expand All @@ -942,6 +948,13 @@ def reconstruct_volumes( # type: ignore
resolution=resolution,
complex_axis=self._complex_dim,
)
if "registration_model" in self.models:
registered_output_abs = _process_output(
registered_output,
scaling_factors,
resolution=resolution,
complex_axis=self._complex_dim,
)

if add_target:
target_abs = _process_output(
Expand All @@ -951,9 +964,23 @@ def reconstruct_volumes( # type: ignore
complex_axis=self._complex_dim,
)

if "registration_model" in self.models:
registration_target_abs = _process_output(
data["reference_image"],
scaling_factors,
resolution=resolution,
complex_axis=self._complex_dim,
)

if curr_volume is None:
volume_size = len(data_loader.batch_sampler.sampler.volume_indices[filename]) # type: ignore
curr_volume = torch.zeros(*(volume_size, *output_abs.shape[1:]), dtype=output_abs.dtype)

if "registration_model" in self.models:
curr_registration_volume = torch.zeros(
*(volume_size, *registered_output_abs.shape[1:]), dtype=registered_output_abs.dtype
)

curr_mask = (
torch.zeros(*(volume_size, *sampling_mask.shape[1:]), dtype=sampling_mask.dtype)
if sampling_mask is not None
Expand All @@ -962,12 +989,22 @@ def reconstruct_volumes( # type: ignore
loss_dict_list.append(loss_dict)
if add_target:
curr_target = curr_volume.clone()
if "registration_model" in self.models:
curr_registration_target = curr_registration_volume.clone()[:, :, 0]

curr_volume[instance_counter : instance_counter + output_abs.shape[0], ...] = output_abs.cpu()
if "registration_model" in self.models:
curr_registration_volume[instance_counter : instance_counter + output_abs.shape[0], ...] = (
registered_output_abs.cpu()
)
if sampling_mask is not None:
curr_mask[instance_counter : instance_counter + output_abs.shape[0], ...] = sampling_mask.cpu()
if add_target:
curr_target[instance_counter : instance_counter + output_abs.shape[0], ...] = target_abs.cpu() # type: ignore
if "registration_model" in self.models:
curr_registration_target[instance_counter : instance_counter + output_abs.shape[0], ...] = (
registration_target_abs.cpu()
)

instance_counter += output_abs.shape[0]

Expand All @@ -985,6 +1022,13 @@ def reconstruct_volumes( # type: ignore
)
# Maybe not needed.
del data

if "registration_model" in self.models:
curr_volume = (curr_volume, curr_registration_volume)

if add_target and "registration_model" in self.models:
curr_target = (curr_target, curr_registration_target)

yield (
(curr_volume, curr_target, curr_mask, reduce_list_of_dicts(loss_dict_list), filename)
if add_target
Expand Down Expand Up @@ -1014,19 +1058,51 @@ def reconstruct_and_evaluate( # type: ignore
)
):
volume, target, mask, volume_loss_dict, filename = output
if isinstance(volume, tuple):
volume, registration_volume = volume
else:
registration_volume = None
if isinstance(target, tuple):
target, registration_target = target
else:
registration_target = None
if self.ndim == 3:
# Put slice and time data together
sc, c, z, x, y = volume.shape
volume_for_eval = volume.clone().transpose(1, 2).reshape(sc * z, c, x, y)
if registration_volume is not None:
registration_volume_for_eval = registration_volume.clone().transpose(1, 2).reshape(sc * z, c, x, y)
target_for_eval = target.clone().transpose(1, 2).reshape(sc * z, c, x, y)
if registration_target is not None:
registration_target_for_eval = (
registration_target.clone()
.unsqueeze(2)
.transpose(1, 2)
.tile(1, z, 1, 1, 1)
.reshape(sc * z, c, x, y)
)
else:
volume_for_eval = volume.clone()
target_for_eval = target.clone()
if registration_volume is not None or registration_target is not None:
raise NotImplementedError("Registration not implemented for 2D data.")

curr_metrics = {
metric_name: metric_fn(target_for_eval, volume_for_eval).clone().item()
for metric_name, metric_fn in inf_metrics.items()
}

if registration_volume is not None and registration_target is not None:
curr_metrics.update(
{
"registration_"
+ metric_name: metric_fn(registration_target_for_eval, registration_volume_for_eval)
.clone()
.item()
for metric_name, metric_fn in inf_metrics.items()
}
)

del target, target_for_eval

curr_metrics_string = ", ".join([f"{x}: {float(y)}" for x, y in curr_metrics.items()])
Expand Down Expand Up @@ -1083,28 +1159,62 @@ def evaluate( # type: ignore
visualize_slices: list[np.ndarray] = []
visualize_mask: list[np.ndarray] = []
visualize_target: list[np.ndarray] = []
visualize_registration_slices: list[np.ndarray] | None = None
visualize_registration_target: list[np.ndarray] | None = None

for _, output in enumerate(
self.reconstruct_volumes(
data_loader, loss_fns=loss_fns, add_target=True, crop=self.cfg.validation.crop # type: ignore
)
):
volume, target, mask, volume_loss_dict, filename = output
if isinstance(volume, tuple):
volume, registration_volume = volume
else:
registration_volume = None
if isinstance(target, tuple):
target, registration_target = target
else:
registration_target = None

if self.ndim == 3:
# Put slice and time data together
sc, c, z, x, y = volume.shape
volume_for_eval = volume.clone().transpose(1, 2).reshape(sc * z, c, x, y)
target_for_eval = target.clone().transpose(1, 2).reshape(sc * z, c, x, y)

if registration_volume is not None:
registration_volume_for_eval = registration_volume.clone().transpose(1, 2).reshape(sc * z, c, x, y)
registration_target_for_eval = (
registration_target.clone()
.unsqueeze(2)
.transpose(1, 2)
.tile(1, z, 1, 1, 1)
.reshape(sc * z, c, x, y)
)
else:
volume_for_eval = volume.clone()
target_for_eval = target.clone()
if registration_volume is not None or registration_target is not None:
raise NotImplementedError("Registration not implemented for 2D data.")

curr_metrics = {
metric_name: metric_fn(target_for_eval, volume_for_eval).clone()
for metric_name, metric_fn in volume_metrics.items()
}
del volume_for_eval, target_for_eval

# Calculate image metrics for registered images
if registration_volume is not None and registration_target is not None:
curr_metrics.update(
{
"registration_"
+ metric_name: metric_fn(registration_target_for_eval, registration_volume_for_eval).clone()
for metric_name, metric_fn in volume_metrics.items()
}
)
del registration_volume_for_eval, registration_target_for_eval

curr_metrics_string = ", ".join([f"{x}: {float(y)}" for x, y in curr_metrics.items()])
self.logger.info("Metrics for %s: %s", filename, curr_metrics_string)
# TODO: Path can be tricky if it is not unique (e.g. image.h5)
Expand All @@ -1119,10 +1229,23 @@ def evaluate( # type: ignore
target = torch.cat([target[:, :, _] for _ in range(0, z)], dim=2)
mask = torch.cat([mask[:, :, _] for _ in range(0, mask.shape[2])], dim=2)

# Also visualize registration items
if registration_volume is not None:
if visualize_registration_slices is None:
visualize_registration_slices = []
visualize_registration_target = []
registration_target = torch.cat([registration_target] * registration_volume.shape[2], dim=2)
registration_volume = torch.cat(
[registration_volume[:, :, _] for _ in range(0, registration_volume.shape[2])], dim=2
)

visualize_slices.append(volume[volume.shape[0] // 2])
if mask is not None:
visualize_mask.append(mask[mask.shape[0] // 2])
visualize_target.append(target[target.shape[0] // 2])
if registration_volume is not None:
visualize_registration_slices.append(registration_volume[registration_volume.shape[0] // 2])
visualize_registration_target.append(registration_target[registration_target.shape[0] // 2])

# Average loss dict
loss_dict = reduce_list_of_dicts(val_losses)
Expand All @@ -1136,6 +1259,9 @@ def evaluate( # type: ignore

if len(visualize_mask) == 0:
visualize_mask = None
if visualize_registration_slices is not None:
visualize_slices = (visualize_slices, visualize_registration_slices)
visualize_target = (visualize_target, visualize_registration_target)
return loss_dict, all_gathered_metrics, visualize_slices, visualize_mask, visualize_target

def compute_model_per_coil(self, model_name: str, data: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit 8742423

Please sign in to comment.