Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR: Implement support for callbacks in colour.continuous.AbstractContinuousFunction class. #1145

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions colour/colorimetry/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,19 @@ def __init__(
self._display_name: str = self.name
self.display_name = kwargs.get("display_name", self._display_name)

self._shape: SpectralShape | None = None

def _on_domain_changed(
self, name: str, value: NDArrayFloat
) -> NDArrayFloat:
"""Invalidate *self._shape* when *self._domain* is changed."""
if name == "_domain":
self._shape = None

return value

self.register_callback("on_domain_changed", _on_domain_changed)

@property
def display_name(self) -> str:
"""
Expand Down Expand Up @@ -836,17 +849,20 @@ def shape(self) -> SpectralShape:
SpectralShape(500.0, 600.0, 10.0)
"""

wavelengths = self.wavelengths
wavelengths_interval = interval(wavelengths)
if wavelengths_interval.size != 1:
runtime_warning(
f'"{self.name}" spectral distribution is not uniform, using '
f"minimum interval!"
if self._shape is None:
wavelengths = self.wavelengths
wavelengths_interval = interval(wavelengths)
if wavelengths_interval.size != 1:
runtime_warning(
f'"{self.name}" spectral distribution is not uniform, '
"using minimum interval!"
)

self._shape = SpectralShape(
wavelengths[0], wavelengths[-1], min(wavelengths_interval)
)

return SpectralShape(
wavelengths[0], wavelengths[-1], min(wavelengths_interval)
)
return self._shape

def interpolate(
self,
Expand Down
90 changes: 63 additions & 27 deletions colour/colorimetry/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,30 +1515,35 @@ def test_interpolate(self):
SpectralDistribution.interpolate` method.
"""

shape = SpectralShape(self._sd.shape.start, self._sd.shape.end, 1)
sd = reshape_sd(self._sd, shape, "Interpolate")
np.testing.assert_array_almost_equal(
reshape_sd(
self._sd,
SpectralShape(self._sd.shape.start, self._sd.shape.end, 1),
"Interpolate",
).values,
sd.values,
DATA_SAMPLE_INTERPOLATED,
decimal=7,
)
self.assertEqual(sd.shape, shape)

shape = SpectralShape(
self._non_uniform_sd.shape.start,
self._non_uniform_sd.shape.end,
1,
)
sd = reshape_sd(self._non_uniform_sd, shape, "Interpolate")
np.testing.assert_allclose(
reshape_sd(
self._non_uniform_sd,
SpectralShape(
self._non_uniform_sd.shape.start,
self._non_uniform_sd.shape.end,
1,
),
"Interpolate",
).values,
sd.values,
DATA_SAMPLE_INTERPOLATED_NON_UNIFORM,
rtol=0.0000001,
atol=0.0000001,
)
self.assertEqual(
sd.shape,
SpectralShape(
np.ceil(self._non_uniform_sd.shape.start),
np.floor(self._non_uniform_sd.shape.end),
1,
),
)

def test_extrapolate(self):
"""
Expand All @@ -1556,8 +1561,9 @@ def test_extrapolate(self):
sd = SpectralDistribution(
np.linspace(0, 1, 10), np.linspace(25, 35, 10)
)
shape = SpectralShape(10, 50, 10)
sd.extrapolate(
SpectralShape(10, 50, 10),
shape,
extrapolator_kwargs={
"method": "Linear",
"left": None,
Expand Down Expand Up @@ -1602,6 +1608,17 @@ def test_normalise(self):
self._sd.copy().normalise(100).values, DATA_SAMPLE_NORMALISED
)

def test_callback_on_domain_changed(self):
"""
Test :class:`colour.colorimetry.spectrum.\
SpectralDistribution` *on_domain_changed* callback.
"""

sd = self._sd.copy()
self.assertEqual(sd.shape, SpectralShape(340, 820, 20))
sd[840] = 0
self.assertEqual(sd.shape, SpectralShape(340, 840, 20))


class TestMultiSpectralDistributions(unittest.TestCase):
"""
Expand Down Expand Up @@ -1760,26 +1777,25 @@ def test_interpolate(self):
"""

# pylint: disable=E1102
msds = reshape_msds(
self._sample_msds,
SpectralShape(
self._sample_msds.shape.start, self._sample_msds.shape.end, 1
),
"Interpolate",
shape = SpectralShape(
self._sample_msds.shape.start, self._sample_msds.shape.end, 1
)
msds = reshape_msds(self._sample_msds, shape, "Interpolate")
for signal in msds.signals.values():
np.testing.assert_array_almost_equal(
signal.values, DATA_SAMPLE_INTERPOLATED, decimal=7
)
self.assertEqual(msds.shape, shape)

# pylint: disable=E1102
shape = SpectralShape(
self._non_uniform_sample_msds.shape.start,
self._non_uniform_sample_msds.shape.end,
1,
)
msds = reshape_msds(
self._non_uniform_sample_msds,
SpectralShape(
self._non_uniform_sample_msds.shape.start,
self._non_uniform_sample_msds.shape.end,
1,
),
shape,
"Interpolate",
)
for signal in msds.signals.values():
Expand All @@ -1789,6 +1805,14 @@ def test_interpolate(self):
rtol=0.0000001,
atol=0.0000001,
)
self.assertEqual(
msds.shape,
SpectralShape(
np.ceil(self._non_uniform_sample_msds.shape.start),
np.floor(self._non_uniform_sample_msds.shape.end),
1,
),
)

def test_extrapolate(self):
"""
Expand Down Expand Up @@ -1853,7 +1877,7 @@ def test_trim(self):

def test_normalise(self):
"""
Test :func:`colour.colorimetry.spectrum.
Test :func:`colour.colorimetry.spectrum.
MultiSpectralDistributions.normalise` method.
"""

Expand All @@ -1875,6 +1899,18 @@ def test_to_sds(self):
self.assertEqual(sd.name, self._labels[i])
self.assertEqual(sd.display_name, self._display_labels[i])

def test_callback_on_domain_changed(self):
"""
Test underlying :class:`colour.colorimetry.spectrum.\
SpectralDistribution` *on_domain_changed* callback when used with
:class:`colour.colorimetry.spectrum.MultiSpectralDistributions` class.
"""

msds = self._msds.copy()
self.assertEqual(msds.shape, SpectralShape(380, 780, 5))
msds[785] = 0
self.assertEqual(msds.shape, SpectralShape(380, 785, 5))


class TestReshapeSd(unittest.TestCase):
"""
Expand Down
3 changes: 2 additions & 1 deletion colour/continuous/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Type,
)
from colour.utilities import (
MixinCallback,
as_float,
attest,
closest,
Expand All @@ -49,7 +50,7 @@
]


class AbstractContinuousFunction(ABC):
class AbstractContinuousFunction(ABC, MixinCallback):
"""
Define the base class for abstract continuous function.

Expand Down
6 changes: 2 additions & 4 deletions colour/continuous/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,7 @@ def _fill_domain_nan(
variable.
"""

self._domain = fill_nan(self._domain, method, default)
self._function = None # Invalidate the underlying continuous function.
self.domain = fill_nan(self.domain, method, default)

def _fill_range_nan(
self,
Expand All @@ -974,8 +973,7 @@ def _fill_range_nan(
variable.
"""

self._range = fill_nan(self._range, method, default)
self._function = None # Invalidate the underlying continuous function.
self.range = fill_nan(self.range, method, default)

def arithmetical_operation(
self,
Expand Down
8 changes: 8 additions & 0 deletions colour/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
LazyCanonicalMapping,
Node,
)
from .callback import (
Callback,
MixinCallback,
)
from .common import (
CacheRegistry,
CACHE_REGISTRY,
Expand Down Expand Up @@ -124,6 +128,10 @@
"LazyCanonicalMapping",
"Node",
]
__all__ += [
"Callback",
"MixinCallback",
]
__all__ += [
"CacheRegistry",
"CACHE_REGISTRY",
Expand Down
Loading