Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace scipy.interp2d by scipy.RegularGridInterpolator #312

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions simpa/utils/deformation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import matplotlib.pyplot as plt
from simpa.utils import Tags
from scipy.interpolate import interp2d
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import gaussian_filter
import numpy as np

Expand All @@ -24,25 +24,19 @@ def create_deformation_settings(bounds_mm, maximum_z_elevation_mm=1, filter_sigm
x_positions_vector = np.linspace(bounds_mm[0][0], bounds_mm[0][1], number_of_boundary_points[0])
y_positions_vector = np.linspace(bounds_mm[1][0], bounds_mm[1][1], number_of_boundary_points[1])

xx, yy = np.meshgrid(x_positions_vector, y_positions_vector, indexing='ij')

# Add random permutations to the y-axis of the division knots
for x_idx, x_position in enumerate(x_positions_vector):
for y_idx, y_position in enumerate(y_positions_vector):
scaling_value = (np.cos(x_position / (bounds_mm[0][1] * (cosine_scaling_factor / np.pi)) -
np.pi/(cosine_scaling_factor * 2)) ** 2 *
np.cos(y_position / (bounds_mm[1][1] * (cosine_scaling_factor / np.pi)) -
np.pi/(cosine_scaling_factor * 2)) ** 2)

surface_elevations[x_idx, y_idx] = scaling_value * surface_elevations[x_idx, y_idx]
all_scaling_value = np.multiply.outer(
np.cos(x_positions_vector / (bounds_mm[0][1] * (cosine_scaling_factor / np.pi)) - np.pi / (cosine_scaling_factor * 2)) ** 2,
np.cos(y_positions_vector / (bounds_mm[1][1] * (cosine_scaling_factor / np.pi)) - np.pi / (cosine_scaling_factor * 2)) ** 2)
surface_elevations *= all_scaling_value

# This rescales and sets the maximum to 0.
surface_elevations = surface_elevations * maximum_z_elevation_mm
de_facto_max_elevation = np.max(surface_elevations)
surface_elevations = surface_elevations - de_facto_max_elevation

deformation_settings[Tags.DEFORMATION_X_COORDINATES_MM] = xx
deformation_settings[Tags.DEFORMATION_Y_COORDINATES_MM] = yy
deformation_settings[Tags.DEFORMATION_X_COORDINATES_MM] = x_positions_vector
deformation_settings[Tags.DEFORMATION_Y_COORDINATES_MM] = y_positions_vector
deformation_settings[Tags.DEFORMATION_Z_ELEVATIONS_MM] = surface_elevations
deformation_settings[Tags.MAX_DEFORMATION_MM] = de_facto_max_elevation

Expand All @@ -66,7 +60,7 @@ def get_functional_from_deformation_settings(deformation_settings: dict):
z_elevations_mm = deformation_settings[Tags.DEFORMATION_Z_ELEVATIONS_MM]
order = "cubic"

functional_mm = interp2d(x_coordinates_mm, y_coordinates_mm, z_elevations_mm, kind=order)
functional_mm = RegularGridInterpolator(points=[x_coordinates_mm, y_coordinates_mm], values=z_elevations_mm, method=order)
return functional_mm


Expand All @@ -81,12 +75,12 @@ def get_functional_from_deformation_settings(deformation_settings: dict):
x_pos_vector = np.linspace(x_bounds[0], x_bounds[1], 100)
y_pos_vector = np.linspace(y_bounds[0], y_bounds[1], 100)

_xx, _yy = np.meshgrid(x_pos_vector, y_pos_vector, indexing='ij')
eval_points = tuple(np.meshgrid(x_pos_vector, y_pos_vector, indexing='ij'))

values = functional(x_pos_vector, y_pos_vector)
values = functional(eval_points)
max_elevation = -np.min(values)

plt3d = plt.figure().gca(projection='3d')
plt3d.plot_surface(_xx, _yy, values, cmap="viridis")
plt3d.set_zlim(-max_elevation, 0)
ax = plt.figure().add_subplot(projection='3d')
ax.plot_surface(eval_points[0], eval_points[1], values, cmap="viridis")
ax.set_zlim(-max_elevation, 0)
plt.show()
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ def get_enclosed_indices(self):

if self.do_deformation:
# the deformation functional needs mm as inputs and returns the result in reverse indexing order...
deformation_values_mm = self.deformation_functional_mm(torch.arange(self.volume_dimensions_voxels[0]) *
self.voxel_spacing,
torch.arange(self.volume_dimensions_voxels[1]) *
self.voxel_spacing).T
eval_points = torch.meshgrid(torch.arange(self.volume_dimensions_voxels[0]) * self.voxel_spacing,
torch.arange(self.volume_dimensions_voxels[1]) * self.voxel_spacing, indexing='ij')
deformation_values_mm = self.deformation_functional_mm(eval_points)
deformation_values_mm = deformation_values_mm.reshape(self.volume_dimensions_voxels[0],
self.volume_dimensions_voxels[1], 1, 1)
deformation_values_mm = torch.tile(torch.as_tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ def get_enclosed_indices(self):
target_vector_voxels = target_vector_voxels[:, :, :, 2]
if self.do_deformation:
# the deformation functional needs mm as inputs and returns the result in reverse indexing order...
deformation_values_mm = self.deformation_functional_mm(torch.arange(self.volume_dimensions_voxels[0], dtype=torch.float) *
self.voxel_spacing,
torch.arange(self.volume_dimensions_voxels[1], dtype=torch.float) *
self.voxel_spacing).T
eval_points = torch.meshgrid(torch.arange(self.volume_dimensions_voxels[0], dtype=torch.float) * self.voxel_spacing,
torch.arange(self.volume_dimensions_voxels[1], dtype=torch.float) * self.voxel_spacing, indexing='ij')
deformation_values_mm = self.deformation_functional_mm(eval_points)
target_vector_voxels = (target_vector_voxels + torch.from_numpy(deformation_values_mm.reshape(
self.volume_dimensions_voxels[0],
self.volume_dimensions_voxels[1], 1)).to(self.torch_device) / self.voxel_spacing).float()
Expand Down
4 changes: 2 additions & 2 deletions simpa/utils/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ class Tags:

DEFORMATION_X_COORDINATES_MM = "deformation_x_coordinates"
"""
Mesh that defines the x coordinates of the deformation.\n
Array that defines the x coordinates of the deformation.\n
Usage: adapter versatile_volume_creation, naming convention
"""

DEFORMATION_Y_COORDINATES_MM = "deformation_y_coordinates"
"""
Mesh that defines the y coordinates of the deformation.\n
Array that defines the y coordinates of the deformation.\n
Usage: adapter versatile_volume_creation, naming convention
"""

Expand Down
Loading