Skip to content

Commit

Permalink
tests: add tests for image and shift
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed Apr 5, 2024
1 parent 181dbbb commit a0b6e06
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/erlab/analysis/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def gradient_magnitude(
mode: str = "nearest",
cval: float = 0.0,
) -> npt.NDArray[np.float64]:
"""Calculate the gradient magnitude of an input array.
r"""Calculate the gradient magnitude of an input array.
The gradient magnitude is calculated as defined in Ref. :cite:p:`He2017`, using
given :math:`\Delta x` and :math:`\Delta y` values.
Expand Down
32 changes: 18 additions & 14 deletions src/erlab/analysis/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def shift(
darr
The array to shift.
shift
The amount of shift to be applied along the specified dimension. If `shift` is a
DataArray, different shifts can be applied to different coordinates. The
dimensions of `shift` must be a subset of the dimensions of `darr`. For more
information, see the note below. If `shift` is a `float`, the same shift is
applied to all values along dimension `along`. This is equivalent to providing a
0-dimensional DataArray.
The amount of shift to be applied along the specified dimension. If
:code:`shift` is a DataArray, different shifts can be applied to different
coordinates. The dimensions of :code:`shift` must be a subset of the dimensions
of `darr`. For more information, see the note below. If :code:`shift` is a
`float`, the same shift is applied to all values along dimension `along`. This
is equivalent to providing a 0-dimensional DataArray.
along
Name of the dimension along which the shift is applied.
shift_coords
Expand Down Expand Up @@ -149,28 +149,32 @@ def correct_with_edge(
**shift_kwargs,
):
"""
Corrects the given data array `darr` using the edge correction method.
Corrects the given data array `darr` with the given values or fit result.
Parameters
----------
darr
The input data array to be corrected.
modelresult
The model result that contains the fermi edge information. It can be an instance
of `lmfit.model.ModelResult`, a numpy array, or a callable function that takes
an array of angles and returns the corresponding energy value.
of `lmfit.model.ModelResult`, a numpy array containing the edge position at each
angle, or a callable function that takes an array of angles and returns the
corresponding energy value.
shift_coords
Whether to shift the coordinates of the data array. Defaults to True.
If `True`, the coordinates of the output data will be changed so that the output
contains all the values of the original data. If `False`, the coordinates and
shape of the original data will be retained, and only the data will be shifted.
Defaults to `False`.
plot
Whether to plot the original and corrected data arrays. Defaults to False.
Whether to plot the original and corrected data arrays. Defaults to `False`.
plot_kw
Additional keyword arguments for the plot. Defaults to None.
Additional keyword arguments for the plot. Defaults to `None`.
**shift_kwargs
Additional keyword arguments to `shift`.
Additional keyword arguments to `erlab.analysis.utilities.shift`.
Returns
-------
xarray.DataArray
corrected : xarray.DataArray
The edge corrected data.
"""
if plot_kw is None:
Expand Down
176 changes: 176 additions & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import numpy as np
import xarray as xr
import erlab.analysis as era


def test_gaussian_filter():
# Create a test input DataArray
darr = xr.DataArray(np.arange(50, step=2).reshape((5, 5)), dims=["x", "y"])

# Define the expected output
expected_output = xr.DataArray(
np.array(
[
[3, 5, 7, 8, 10],
[10, 12, 14, 15, 17],
[20, 22, 24, 25, 27],
[29, 31, 33, 34, 36],
[36, 38, 40, 41, 43],
]
),
dims=["x", "y"],
)

# Apply the gaussian_filter function
result = era.image.gaussian_filter(darr, sigma={"x": 1.0, "y": 1.0})

# Check if the result matches the expected output
assert np.allclose(result, expected_output)


def test_gaussian_laplace():
# Create a test input DataArray
darr = xr.DataArray(np.arange(50, step=2).reshape((5, 5)), dims=["x", "y"])

# Define the expected output
expected_output = xr.DataArray(
np.array(
[
[4, 4, 4, 4, 4],
[2, 2, 2, 2, 2],
[0, 0, 0, 0, 0],
[-2, -2, -2, -2, -2],
[-4, -4, -4, -4, -4],
]
),
dims=["x", "y"],
)

# Apply the gaussian_laplace function
result = era.image.gaussian_laplace(darr, sigma={"x": 1.0, "y": 1.0})

# Check if the result matches the expected output
assert np.allclose(result, expected_output)

# Additional test case
# Define the expected output
expected_output2 = xr.DataArray(
np.array(
[
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 0, 0, 0, 0],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1],
]
),
dims=["x", "y"],
)

# Apply the gaussian_laplace function
result2 = era.image.gaussian_laplace(darr, sigma=2.0)

# Check if the result matches the expected output
assert np.allclose(result2, expected_output2)


def test_laplace():
# Create a test input DataArray
darr = xr.DataArray(np.arange(50, step=2).reshape((5, 5)), dims=["x", "y"])

# Define the expected output
expected_output = xr.DataArray(
np.array(
[
[12, 10, 10, 10, 8],
[2, 0, 0, 0, -2],
[2, 0, 0, 0, -2],
[2, 0, 0, 0, -2],
[-8, -10, -10, -10, -12],
]
),
dims=["x", "y"],
)

# Apply the laplace function
result = era.image.laplace(darr)

# Check if the result matches the expected output
assert np.allclose(result, expected_output)


def test_minimum_gradient():
# Create a test input DataArray
darr = xr.DataArray(np.arange(50, step=2).reshape((5, 5)), dims=["x", "y"])

# Define the expected output
expected_output = xr.DataArray(
np.array(
[
[0.0, 0.13608276, 0.27216553, 0.40824829, 0.58345997],
[0.49507377, 0.58834841, 0.68640647, 0.78446454, 0.89113279],
[0.99014754, 1.07863874, 1.17669681, 1.27475488, 1.38620656],
[1.48522131, 1.56892908, 1.66698715, 1.76504522, 1.88128033],
[2.91729983, 2.85773803, 2.9938208, 3.12990356, 3.17887766],
]
),
dims=["x", "y"],
)

# Apply the minimum_gradient function
result = era.image.minimum_gradient(darr).astype(np.float32)

# Check if the result matches the expected output
assert np.allclose(result, expected_output)


def test_scaled_laplace():
# Create a test input DataArray
darr = xr.DataArray(
np.arange(50, step=2).reshape((5, 5)).astype(float), dims=["x", "y"]
)

# Define the expected output
expected_output = xr.DataArray(
np.array(
[
[12.0, 10.0, 10.0, 10.0, 8.0],
[2.0, 0.0, 0.0, 0.0, -2.0],
[2.0, 0.0, 0.0, 0.0, -2.0],
[2.0, 0.0, 0.0, 0.0, -2.0],
[-8.0, -10.0, -10.0, -10.0, -12.0],
]
),
dims=["x", "y"],
)

# Apply the scaled_laplace function
result = era.image.scaled_laplace(darr)

# Check if the result matches the expected output
assert np.allclose(result, expected_output)


def test_curvature():
# Create a test input DataArray
darr = xr.DataArray(np.arange(50, step=2).reshape((5, 5)), dims=["x", "y"]) ** 2

# Define the expected output
expected_output = xr.DataArray(
np.array(
[
[0.11852942, 0.11855069, 0.11778077, 0.11184719, 0.10558571],
[0.16448492, 0.16107288, 0.15683689, 0.14772279, 0.1385244],
[0.17403091, 0.16649966, 0.15876051, 0.14719689, 0.1361662],
[0.09486956, 0.09038468, 0.08598027, 0.0783563, 0.07130584],
[0.05139264, 0.04942051, 0.04746361, 0.04241876, 0.03781145],
]
),
dims=["x", "y"],
)

# Apply the curvature function
result = era.image.curvature(darr).astype(np.float32)

# Check if the result matches the expected output
assert np.allclose(result, expected_output)
25 changes: 25 additions & 0 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import xarray as xr
from erlab.analysis.utilities import shift


def test_shift():
# Create a test input DataArray
darr = xr.DataArray(
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(float), dims=["x", "y"]
)

# Create a test shift DataArray
shift_arr = xr.DataArray([1, 0, 2], dims=["x"])

# Perform the shift operation
shifted = shift(darr, shift_arr, along="y")

# Define the expected result
expected = xr.DataArray(
np.array([[np.nan, 1.0, 2.0], [4.0, 5.0, 6.0], [np.nan, np.nan, 7.0]]),
dims=["x", "y"],
)

# Check if the shifted array matches the expected result
assert np.allclose(shifted, expected, equal_nan=True)

0 comments on commit a0b6e06

Please sign in to comment.