diff --git a/direct/engine.py b/direct/engine.py index c599a206..6b31ae0f 100644 --- a/direct/engine.py +++ b/direct/engine.py @@ -460,6 +460,14 @@ def validation_loop( curr_data_loader, loss_fns, ) + if isinstance(visualize_slices, tuple): + (visualize_slices, visualize_registration_slices) = visualize_slices + else: + visualize_registration_slices = None + if isinstance(visualize_target, tuple): + (visualize_target, visualize_registration_target) = visualize_target + else: + visualize_registration_target = None if experiment_directory: json_output_fn = experiment_directory / f"metrics_val_{curr_dataset_name}_{iter_idx}.json" @@ -487,6 +495,12 @@ def validation_loop( visualize_slices = self.process_slices_for_visualization(visualize_slices, visualize_target) storage.add_image(f"{key_prefix}prediction", visualize_slices) + if visualize_registration_slices is not None: + visualize_registration_slices = self.process_slices_for_visualization( + visualize_registration_slices, visualize_registration_target + ) + storage.add_image(f"{key_prefix}registration_prediction", visualize_registration_slices) + if visualize_mask is not None: visualize_mask = make_grid( crop_to_largest([normalize_image(image) for image in visualize_mask], pad_value=0), @@ -504,6 +518,16 @@ def validation_loop( ) storage.add_image(f"{key_prefix}target", visualize_target) + if visualize_registration_target is not None: + visualize_registration_target = make_grid( + crop_to_largest( + [normalize_image(image) for image in visualize_registration_target], pad_value=0 + ), + nrow=self.cfg.logging.tensorboard.num_images, # type: ignore + scale_each=True, + ) + storage.add_image(f"{key_prefix}registration_target", visualize_registration_target) + self.logger.info("Done evaluation of %s at iteration %s.", str(curr_dataset_name), str(iter_idx)) self.model.train() diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py index 9f5eea53..85e151b4 100644 --- a/direct/nn/mri_models.py +++ b/direct/nn/mri_models.py @@ -900,6 +900,9 @@ def reconstruct_volumes( # type: ignore last_filename = None # At the start of evaluation, there are no filenames. curr_volume = None curr_target = None + if "registration_model" in self.models: + curr_registration_volume = None + curr_registration_target = None instance_counter = 0 filenames_seen = 0 @@ -930,6 +933,9 @@ def reconstruct_volumes( # type: ignore # Compute output iteration_output = self._do_iteration(data, loss_fns=loss_fns, regularizer_fns=regularizer_fns) output = iteration_output.output_image + if "registration_model" in self.models: + output, registered_output = output + sampling_mask = iteration_output.sampling_mask if sampling_mask is not None: sampling_mask = sampling_mask.squeeze(-1).float() # Last dimension is 1 (complex dim) @@ -942,6 +948,13 @@ def reconstruct_volumes( # type: ignore resolution=resolution, complex_axis=self._complex_dim, ) + if "registration_model" in self.models: + registered_output_abs = _process_output( + registered_output, + scaling_factors, + resolution=resolution, + complex_axis=self._complex_dim, + ) if add_target: target_abs = _process_output( @@ -951,9 +964,23 @@ def reconstruct_volumes( # type: ignore complex_axis=self._complex_dim, ) + if "registration_model" in self.models: + registration_target_abs = _process_output( + data["reference_image"], + scaling_factors, + resolution=resolution, + complex_axis=self._complex_dim, + ) + if curr_volume is None: volume_size = len(data_loader.batch_sampler.sampler.volume_indices[filename]) # type: ignore curr_volume = torch.zeros(*(volume_size, *output_abs.shape[1:]), dtype=output_abs.dtype) + + if "registration_model" in self.models: + curr_registration_volume = torch.zeros( + *(volume_size, *registered_output_abs.shape[1:]), dtype=registered_output_abs.dtype + ) + curr_mask = ( torch.zeros(*(volume_size, *sampling_mask.shape[1:]), dtype=sampling_mask.dtype) if sampling_mask is not None @@ -962,12 +989,22 @@ def reconstruct_volumes( # type: ignore loss_dict_list.append(loss_dict) if add_target: curr_target = curr_volume.clone() + if "registration_model" in self.models: + curr_registration_target = curr_registration_volume.clone()[:, :, 0] curr_volume[instance_counter : instance_counter + output_abs.shape[0], ...] = output_abs.cpu() + if "registration_model" in self.models: + curr_registration_volume[instance_counter : instance_counter + output_abs.shape[0], ...] = ( + registered_output_abs.cpu() + ) if sampling_mask is not None: curr_mask[instance_counter : instance_counter + output_abs.shape[0], ...] = sampling_mask.cpu() if add_target: curr_target[instance_counter : instance_counter + output_abs.shape[0], ...] = target_abs.cpu() # type: ignore + if "registration_model" in self.models: + curr_registration_target[instance_counter : instance_counter + output_abs.shape[0], ...] = ( + registration_target_abs.cpu() + ) instance_counter += output_abs.shape[0] @@ -985,6 +1022,13 @@ def reconstruct_volumes( # type: ignore ) # Maybe not needed. del data + + if "registration_model" in self.models: + curr_volume = (curr_volume, curr_registration_volume) + + if add_target and "registration_model" in self.models: + curr_target = (curr_target, curr_registration_target) + yield ( (curr_volume, curr_target, curr_mask, reduce_list_of_dicts(loss_dict_list), filename) if add_target @@ -1014,19 +1058,51 @@ def reconstruct_and_evaluate( # type: ignore ) ): volume, target, mask, volume_loss_dict, filename = output + if isinstance(volume, tuple): + volume, registration_volume = volume + else: + registration_volume = None + if isinstance(target, tuple): + target, registration_target = target + else: + registration_target = None if self.ndim == 3: # Put slice and time data together sc, c, z, x, y = volume.shape volume_for_eval = volume.clone().transpose(1, 2).reshape(sc * z, c, x, y) + if registration_volume is not None: + registration_volume_for_eval = registration_volume.clone().transpose(1, 2).reshape(sc * z, c, x, y) target_for_eval = target.clone().transpose(1, 2).reshape(sc * z, c, x, y) + if registration_target is not None: + registration_target_for_eval = ( + registration_target.clone() + .unsqueeze(2) + .transpose(1, 2) + .tile(1, z, 1, 1, 1) + .reshape(sc * z, c, x, y) + ) else: volume_for_eval = volume.clone() target_for_eval = target.clone() + if registration_volume is not None or registration_target is not None: + raise NotImplementedError("Registration not implemented for 2D data.") curr_metrics = { metric_name: metric_fn(target_for_eval, volume_for_eval).clone().item() for metric_name, metric_fn in inf_metrics.items() } + + if registration_volume is not None and registration_target is not None: + curr_metrics.update( + { + "registration_" + + metric_name: metric_fn(registration_target_for_eval, registration_volume_for_eval) + .clone() + .item() + for metric_name, metric_fn in inf_metrics.items() + } + ) + del target, target_for_eval curr_metrics_string = ", ".join([f"{x}: {float(y)}" for x, y in curr_metrics.items()]) @@ -1083,6 +1159,8 @@ def evaluate( # type: ignore visualize_slices: list[np.ndarray] = [] visualize_mask: list[np.ndarray] = [] visualize_target: list[np.ndarray] = [] + visualize_registration_slices: list[np.ndarray] | None = None + visualize_registration_target: list[np.ndarray] | None = None for _, output in enumerate( self.reconstruct_volumes( @@ -1090,14 +1168,35 @@ def evaluate( # type: ignore ) ): volume, target, mask, volume_loss_dict, filename = output + if isinstance(volume, tuple): + volume, registration_volume = volume + else: + registration_volume = None + if isinstance(target, tuple): + target, registration_target = target + else: + registration_target = None + if self.ndim == 3: # Put slice and time data together sc, c, z, x, y = volume.shape volume_for_eval = volume.clone().transpose(1, 2).reshape(sc * z, c, x, y) target_for_eval = target.clone().transpose(1, 2).reshape(sc * z, c, x, y) + + if registration_volume is not None: + registration_volume_for_eval = registration_volume.clone().transpose(1, 2).reshape(sc * z, c, x, y) + registration_target_for_eval = ( + registration_target.clone() + .unsqueeze(2) + .transpose(1, 2) + .tile(1, z, 1, 1, 1) + .reshape(sc * z, c, x, y) + ) else: volume_for_eval = volume.clone() target_for_eval = target.clone() + if registration_volume is not None or registration_target is not None: + raise NotImplementedError("Registration not implemented for 2D data.") curr_metrics = { metric_name: metric_fn(target_for_eval, volume_for_eval).clone() @@ -1105,6 +1204,17 @@ def evaluate( # type: ignore } del volume_for_eval, target_for_eval + # Calculate image metrics for registered images + if registration_volume is not None and registration_target is not None: + curr_metrics.update( + { + "registration_" + + metric_name: metric_fn(registration_target_for_eval, registration_volume_for_eval).clone() + for metric_name, metric_fn in volume_metrics.items() + } + ) + del registration_volume_for_eval, registration_target_for_eval + curr_metrics_string = ", ".join([f"{x}: {float(y)}" for x, y in curr_metrics.items()]) self.logger.info("Metrics for %s: %s", filename, curr_metrics_string) # TODO: Path can be tricky if it is not unique (e.g. image.h5) @@ -1119,10 +1229,23 @@ def evaluate( # type: ignore target = torch.cat([target[:, :, _] for _ in range(0, z)], dim=2) mask = torch.cat([mask[:, :, _] for _ in range(0, mask.shape[2])], dim=2) + # Also visualize registration items + if registration_volume is not None: + if visualize_registration_slices is None: + visualize_registration_slices = [] + visualize_registration_target = [] + registration_target = torch.cat([registration_target] * registration_volume.shape[2], dim=2) + registration_volume = torch.cat( + [registration_volume[:, :, _] for _ in range(0, registration_volume.shape[2])], dim=2 + ) + visualize_slices.append(volume[volume.shape[0] // 2]) if mask is not None: visualize_mask.append(mask[mask.shape[0] // 2]) visualize_target.append(target[target.shape[0] // 2]) + if registration_volume is not None: + visualize_registration_slices.append(registration_volume[registration_volume.shape[0] // 2]) + visualize_registration_target.append(registration_target[registration_target.shape[0] // 2]) # Average loss dict loss_dict = reduce_list_of_dicts(val_losses) @@ -1136,6 +1259,9 @@ def evaluate( # type: ignore if len(visualize_mask) == 0: visualize_mask = None + if visualize_registration_slices is not None: + visualize_slices = (visualize_slices, visualize_registration_slices) + visualize_target = (visualize_target, visualize_registration_target) return loss_dict, all_gathered_metrics, visualize_slices, visualize_mask, visualize_target def compute_model_per_coil(self, model_name: str, data: torch.Tensor) -> torch.Tensor: