From 09e79b4f5d0e0b3535eb0b29be6444ad39feeea3 Mon Sep 17 00:00:00 2001 From: Thomas Mansencal Date: Sun, 30 Apr 2023 12:05:08 +1200 Subject: [PATCH] Implement support for callbacks in `colour.continuous.AbstractContinuousFunction` class. --- colour/colorimetry/spectrum.py | 34 +++-- colour/colorimetry/tests/test_spectrum.py | 90 ++++++++---- colour/continuous/abstract.py | 3 +- colour/continuous/signal.py | 6 +- colour/utilities/__init__.py | 8 + colour/utilities/callback.py | 169 ++++++++++++++++++++++ colour/utilities/tests/test_callback.py | 109 ++++++++++++++ docs/colour.utilities.rst | 17 +++ 8 files changed, 395 insertions(+), 41 deletions(-) create mode 100644 colour/utilities/callback.py create mode 100644 colour/utilities/tests/test_callback.py diff --git a/colour/colorimetry/spectrum.py b/colour/colorimetry/spectrum.py index 3f6997f662..b471aae4dd 100644 --- a/colour/colorimetry/spectrum.py +++ b/colour/colorimetry/spectrum.py @@ -716,6 +716,19 @@ def __init__( self._display_name: str = self.name self.display_name = kwargs.get("display_name", self._display_name) + self._shape: SpectralShape | None = None + + def _on_domain_changed( + self, name: str, value: NDArrayFloat + ) -> NDArrayFloat: + """Invalidate *self._shape* when *self._domain* is changed.""" + if name == "_domain": + self._shape = None + + return value + + self.register_callback("on_domain_changed", _on_domain_changed) + @property def display_name(self) -> str: """ @@ -836,17 +849,20 @@ def shape(self) -> SpectralShape: SpectralShape(500.0, 600.0, 10.0) """ - wavelengths = self.wavelengths - wavelengths_interval = interval(wavelengths) - if wavelengths_interval.size != 1: - runtime_warning( - f'"{self.name}" spectral distribution is not uniform, using ' - f"minimum interval!" + if self._shape is None: + wavelengths = self.wavelengths + wavelengths_interval = interval(wavelengths) + if wavelengths_interval.size != 1: + runtime_warning( + f'"{self.name}" spectral distribution is not uniform, ' + "using minimum interval!" + ) + + self._shape = SpectralShape( + wavelengths[0], wavelengths[-1], min(wavelengths_interval) ) - return SpectralShape( - wavelengths[0], wavelengths[-1], min(wavelengths_interval) - ) + return self._shape def interpolate( self, diff --git a/colour/colorimetry/tests/test_spectrum.py b/colour/colorimetry/tests/test_spectrum.py index 413a49e3b8..83d76e10cd 100644 --- a/colour/colorimetry/tests/test_spectrum.py +++ b/colour/colorimetry/tests/test_spectrum.py @@ -1515,30 +1515,35 @@ def test_interpolate(self): SpectralDistribution.interpolate` method. """ + shape = SpectralShape(self._sd.shape.start, self._sd.shape.end, 1) + sd = reshape_sd(self._sd, shape, "Interpolate") np.testing.assert_array_almost_equal( - reshape_sd( - self._sd, - SpectralShape(self._sd.shape.start, self._sd.shape.end, 1), - "Interpolate", - ).values, + sd.values, DATA_SAMPLE_INTERPOLATED, decimal=7, ) + self.assertEqual(sd.shape, shape) + shape = SpectralShape( + self._non_uniform_sd.shape.start, + self._non_uniform_sd.shape.end, + 1, + ) + sd = reshape_sd(self._non_uniform_sd, shape, "Interpolate") np.testing.assert_allclose( - reshape_sd( - self._non_uniform_sd, - SpectralShape( - self._non_uniform_sd.shape.start, - self._non_uniform_sd.shape.end, - 1, - ), - "Interpolate", - ).values, + sd.values, DATA_SAMPLE_INTERPOLATED_NON_UNIFORM, rtol=0.0000001, atol=0.0000001, ) + self.assertEqual( + sd.shape, + SpectralShape( + np.ceil(self._non_uniform_sd.shape.start), + np.floor(self._non_uniform_sd.shape.end), + 1, + ), + ) def test_extrapolate(self): """ @@ -1556,8 +1561,9 @@ def test_extrapolate(self): sd = SpectralDistribution( np.linspace(0, 1, 10), np.linspace(25, 35, 10) ) + shape = SpectralShape(10, 50, 10) sd.extrapolate( - SpectralShape(10, 50, 10), + shape, extrapolator_kwargs={ "method": "Linear", "left": None, @@ -1602,6 +1608,17 @@ def test_normalise(self): self._sd.copy().normalise(100).values, DATA_SAMPLE_NORMALISED ) + def test_callback_on_domain_changed(self): + """ + Test :class:`colour.colorimetry.spectrum.\ +SpectralDistribution` *on_domain_changed* callback. + """ + + sd = self._sd.copy() + self.assertEqual(sd.shape, SpectralShape(340, 820, 20)) + sd[840] = 0 + self.assertEqual(sd.shape, SpectralShape(340, 840, 20)) + class TestMultiSpectralDistributions(unittest.TestCase): """ @@ -1760,26 +1777,25 @@ def test_interpolate(self): """ # pylint: disable=E1102 - msds = reshape_msds( - self._sample_msds, - SpectralShape( - self._sample_msds.shape.start, self._sample_msds.shape.end, 1 - ), - "Interpolate", + shape = SpectralShape( + self._sample_msds.shape.start, self._sample_msds.shape.end, 1 ) + msds = reshape_msds(self._sample_msds, shape, "Interpolate") for signal in msds.signals.values(): np.testing.assert_array_almost_equal( signal.values, DATA_SAMPLE_INTERPOLATED, decimal=7 ) + self.assertEqual(msds.shape, shape) # pylint: disable=E1102 + shape = SpectralShape( + self._non_uniform_sample_msds.shape.start, + self._non_uniform_sample_msds.shape.end, + 1, + ) msds = reshape_msds( self._non_uniform_sample_msds, - SpectralShape( - self._non_uniform_sample_msds.shape.start, - self._non_uniform_sample_msds.shape.end, - 1, - ), + shape, "Interpolate", ) for signal in msds.signals.values(): @@ -1789,6 +1805,14 @@ def test_interpolate(self): rtol=0.0000001, atol=0.0000001, ) + self.assertEqual( + msds.shape, + SpectralShape( + np.ceil(self._non_uniform_sample_msds.shape.start), + np.floor(self._non_uniform_sample_msds.shape.end), + 1, + ), + ) def test_extrapolate(self): """ @@ -1853,7 +1877,7 @@ def test_trim(self): def test_normalise(self): """ - Test :func:`colour.colorimetry.spectrum. + Test :func:`colour.colorimetry.spectrum. MultiSpectralDistributions.normalise` method. """ @@ -1875,6 +1899,18 @@ def test_to_sds(self): self.assertEqual(sd.name, self._labels[i]) self.assertEqual(sd.display_name, self._display_labels[i]) + def test_callback_on_domain_changed(self): + """ + Test underlying :class:`colour.colorimetry.spectrum.\ +SpectralDistribution` *on_domain_changed* callback when used with + :class:`colour.colorimetry.spectrum.MultiSpectralDistributions` class. + """ + + msds = self._msds.copy() + self.assertEqual(msds.shape, SpectralShape(380, 780, 5)) + msds[785] = 0 + self.assertEqual(msds.shape, SpectralShape(380, 785, 5)) + class TestReshapeSd(unittest.TestCase): """ diff --git a/colour/continuous/abstract.py b/colour/continuous/abstract.py index f62bc8b786..33f54a42cb 100644 --- a/colour/continuous/abstract.py +++ b/colour/continuous/abstract.py @@ -29,6 +29,7 @@ Type, ) from colour.utilities import ( + MixinCallback, as_float, attest, closest, @@ -49,7 +50,7 @@ ] -class AbstractContinuousFunction(ABC): +class AbstractContinuousFunction(ABC, MixinCallback): """ Define the base class for abstract continuous function. diff --git a/colour/continuous/signal.py b/colour/continuous/signal.py index 9142e23fca..eb8ea7f8a4 100644 --- a/colour/continuous/signal.py +++ b/colour/continuous/signal.py @@ -948,8 +948,7 @@ def _fill_domain_nan( variable. """ - self._domain = fill_nan(self._domain, method, default) - self._function = None # Invalidate the underlying continuous function. + self.domain = fill_nan(self.domain, method, default) def _fill_range_nan( self, @@ -974,8 +973,7 @@ def _fill_range_nan( variable. """ - self._range = fill_nan(self._range, method, default) - self._function = None # Invalidate the underlying continuous function. + self.range = fill_nan(self.range, method, default) def arithmetical_operation( self, diff --git a/colour/utilities/__init__.py b/colour/utilities/__init__.py index 820f14477c..a516926d33 100644 --- a/colour/utilities/__init__.py +++ b/colour/utilities/__init__.py @@ -11,6 +11,10 @@ LazyCanonicalMapping, Node, ) +from .callback import ( + Callback, + MixinCallback, +) from .common import ( CacheRegistry, CACHE_REGISTRY, @@ -124,6 +128,10 @@ "LazyCanonicalMapping", "Node", ] +__all__ += [ + "Callback", + "MixinCallback", +] __all__ += [ "CacheRegistry", "CACHE_REGISTRY", diff --git a/colour/utilities/callback.py b/colour/utilities/callback.py new file mode 100644 index 0000000000..999c09cf3c --- /dev/null +++ b/colour/utilities/callback.py @@ -0,0 +1,169 @@ +""" +Callback Management +=================== + +Defines the callback management objects. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from colour.hints import ( + Any, + Callable, + List, +) + +__author__ = "Colour Developers" +__copyright__ = "Copyright 2013 Colour Developers" +__license__ = "New BSD License - https://opensource.org/licenses/BSD-3-Clause" +__maintainer__ = "Colour Developers" +__email__ = "colour-developers@colour-science.org" +__status__ = "Production" + +__all__ = [ + "Callback", + "MixinCallback", +] + + +@dataclass +class Callback: + """ + Define a callback. + + Parameters + ---------- + name + Callback name. + function + Callback callable. + """ + + name: str + function: Callable + + +class MixinCallback: + """ + A mixin providing support for callbacks. + + Attributes + ---------- + - :attr:`~colour.utilities.MixinCallback.callbacks` + - :attr:`~colour.utilities.MixinCallback.__setattr__` + + Methods + ------- + - :meth:`~colour.utilities.MixinCallback.register_callback` + - :meth:`~colour.utilities.MixinCallback.unregister_callback` + + Examples + -------- + >>> class WithCallback(MixinCallback): + ... def __init__(self): + ... super().__init__() + ... self.attribute_a = "a" + ... + >>> with_callback = WithCallback() + >>> def _on_attribute_a_changed(self, name: str, value: str) -> str: + ... if name == "attribute_a": + ... value = value.upper() + ... return value + >>> with_callback.register_callback( + ... "on_attribute_a_changed", _on_attribute_a_changed + ... ) + >>> with_callback.attribute_a = "a" + >>> with_callback.attribute_a + 'A' + """ + + def __init__(self) -> None: + super().__init__() + + self._callbacks: List = [] + + @property + def callbacks(self) -> List: + """ + Getter property for the callbacks. + + Returns + ------- + :class:`list` + Callbacks. + """ + + return self._callbacks + + def __setattr__(self, name: str, value: Any) -> None: + """ + Set given value to the attribute with given name. + + Parameters + ---------- + attribute + Attribute to set the value of. + value + Value to set the attribute with. + """ + + if hasattr(self, "_callbacks"): + for callback in self._callbacks: + value = callback.function(self, name, value) + + super().__setattr__(name, value) + + def register_callback(self, name: str, function: Callable) -> None: + """ + Register the callback with given name. + + Parameters + ---------- + name + Callback name. + function + Callback callable. + + Examples + -------- + >>> class WithCallback(MixinCallback): + ... def __init__(self): + ... super().__init__() + ... + >>> with_callback = WithCallback() + >>> with_callback.register_callback("callback", lambda *args: None) + >>> with_callback.callbacks # doctest: +SKIP + [Callback(name='callback', function= at 0x10fcf3420>)] + """ + + self._callbacks.append(Callback(name, function)) + + def unregister_callback(self, name: str) -> None: + """ + Unregister the callback with given name. + + Parameters + ---------- + name + Callback name. + + Examples + -------- + >>> class WithCallback(MixinCallback): + ... def __init__(self): + ... super().__init__() + ... + >>> with_callback = WithCallback() + >>> with_callback.register_callback("callback", lambda s, n, v: v) + >>> with_callback.callbacks # doctest: +SKIP + [Callback(name='callback', function= at 0x10fcf3420>)] + >>> with_callback.unregister_callback("callback") + >>> with_callback.callbacks + [] + """ + + self._callbacks = [ + callback for callback in self._callbacks if callback.name != name + ] diff --git a/colour/utilities/tests/test_callback.py b/colour/utilities/tests/test_callback.py new file mode 100644 index 0000000000..fb48c8357d --- /dev/null +++ b/colour/utilities/tests/test_callback.py @@ -0,0 +1,109 @@ +# !/usr/bin/env python +"""Define the unit tests for the :mod:`colour.utilities.callback` module.""" + +from __future__ import annotations + +import unittest + +from colour.utilities import MixinCallback + +__author__ = "Colour Developers" +__copyright__ = "Copyright 2013 Colour Developers" +__license__ = "New BSD License - https://opensource.org/licenses/BSD-3-Clause" +__maintainer__ = "Colour Developers" +__email__ = "colour-developers@colour-science.org" +__status__ = "Production" + +__all__ = [ + "TestMixinCallback", +] + + +class TestMixinCallback(unittest.TestCase): + """ + Define :class:`colour.utilities.callback.MixinCallback` class unit + tests methods. + """ + + def setUp(self): + """Initialise the common tests attributes.""" + + class WithCallback(MixinCallback): + """Test :class:`MixinCallback` class.""" + + def __init__(self): + super().__init__() + + self.attribute_a = "a" + + self._with_callback = WithCallback() + + def _on_attribute_a_changed(self, name: str, value: str) -> str: + """Transform *self._attribute_a* to uppercase.""" + + if name == "attribute_a": + value = value.upper() + + if getattr(self, name) != "a": + raise RuntimeError( + '"self" was not able to retrieve class instance value!' + ) + + return value + + self._on_attribute_a_changed = _on_attribute_a_changed + + def test_required_attributes(self): + """Test the presence of required attributes.""" + + required_attributes = ("callbacks",) + + for attribute in required_attributes: + self.assertIn(attribute, dir(MixinCallback)) + + def test_required_methods(self): + """Test the presence of required methods.""" + + required_methods = ( + "__init__", + "register_callback", + "unregister_callback", + ) + + for method in required_methods: + self.assertIn(method, dir(MixinCallback)) + + def test_register_callback(self): + """ + Test :class:`colour.utilities.callback.MixinCallback.register_callback` + method. + """ + + self._with_callback.register_callback( + "on_attribute_a_changed", self._on_attribute_a_changed + ) + + self._with_callback.attribute_a = "a" + self.assertEqual(self._with_callback.attribute_a, "A") + self.assertEqual(len(self._with_callback.callbacks), 1) + + def test_unregister_callback(self): + """ + Test :class:`colour.utilities.callback.MixinCallback.unregister_callback` + method. + """ + + if len(self._with_callback.callbacks) == 0: + self._with_callback.register_callback( + "on_attribute_a_changed", self._on_attribute_a_changed + ) + + self.assertEqual(len(self._with_callback.callbacks), 1) + self._with_callback.unregister_callback("on_attribute_a_changed") + self.assertEqual(len(self._with_callback.callbacks), 0) + self._with_callback.attribute_a = "a" + self.assertEqual(self._with_callback.attribute_a, "a") + + +if __name__ == "__main__": + unittest.main() diff --git a/docs/colour.utilities.rst b/docs/colour.utilities.rst index 99a255cd02..36821b3eb4 100644 --- a/docs/colour.utilities.rst +++ b/docs/colour.utilities.rst @@ -1,6 +1,23 @@ Utilities ========= +Callback Management +------------------- + +``colour`` + + +``colour.utilities`` + +.. currentmodule:: colour.utilities + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + Callback + MixinCallback + Common ------