diff --git a/docs/api/_pyswarms.utils.rst b/docs/api/_pyswarms.utils.rst index 131289e2..70bf7da8 100644 --- a/docs/api/_pyswarms.utils.rst +++ b/docs/api/_pyswarms.utils.rst @@ -7,7 +7,9 @@ functionalities. .. toctree:: + pyswarms.utils.decorators pyswarms.utils.functions - pyswarms.utils.search pyswarms.utils.plotters pyswarms.utils.reporter + pyswarms.utils.search + \ No newline at end of file diff --git a/docs/api/pyswarms.utils.decorators.rst b/docs/api/pyswarms.utils.decorators.rst new file mode 100644 index 00000000..01a5a295 --- /dev/null +++ b/docs/api/pyswarms.utils.decorators.rst @@ -0,0 +1,7 @@ +pyswarms.utils.decorators package +================================= + +.. automodule:: pyswarms.utils.decorators + :members: + :undoc-members: + :show-inheritance: diff --git a/pyswarms/__init__.py b/pyswarms/__init__.py index edba6707..5bf3f4b5 100644 --- a/pyswarms/__init__.py +++ b/pyswarms/__init__.py @@ -16,5 +16,6 @@ from .single import global_best, local_best, general_optimizer from .discrete import binary +from .utils.decorators import cost -__all__ = ["global_best", "local_best", "general_optimizer", "binary"] +__all__ = ["global_best", "local_best", "general_optimizer", "binary", "cost"] diff --git a/pyswarms/utils/decorators/__init__.py b/pyswarms/utils/decorators/__init__.py new file mode 100644 index 00000000..4ef5c8b3 --- /dev/null +++ b/pyswarms/utils/decorators/__init__.py @@ -0,0 +1,9 @@ +""" +The :mod:`pyswarms.decorators` module implements a decorator that +can be used to simplify the task of writing the cost function for +an optimization run. The decorator can be directly called by using +:code:`@pyswarms.cost`. +""" +from .decorators import cost + +__all__ = ["cost"] diff --git a/pyswarms/utils/decorators/decorators.py b/pyswarms/utils/decorators/decorators.py new file mode 100644 index 00000000..f723c9aa --- /dev/null +++ b/pyswarms/utils/decorators/decorators.py @@ -0,0 +1,47 @@ +import numpy as np + + +def cost(cost_func): + """A decorator for the cost function + + This decorator allows the creation of much simpler cost functions. Instead of + writing a cost function that returns a shape of :code:`(n_particles, 0)` it enables + the usage of shorter and simpler cost functions that directly return the cost. + A simple example might be: + + .. code-block:: python + import pyswarms + import numpy as np + + @pyswarms.cost + def cost_func(x): + cost = np.abs(np.sum(x)) + return cost + + The decorator expects your cost function to use a d-dimensional array (where + d is the number of dimensions for the optimization) as and argument. + + .. note:: + Some :code:`numpy` functions return a :code:`np.ndarray` with single values in it. + Be aware of the fact that without unpacking the value the optimizer will raise + an exception. + + Parameters + ---------- + + cost_func : callable + A callable object that can be used as cost function in the optimization + (must return a :code:`float` or an :code:`int`). + + Returns + ------- + + cost_dec : callable + The vectorized output for all particles as defined by :code:`cost_func` + """ + def cost_dec(particles, **kwargs): + n_particles = particles.shape[0] + vector = np.array([cost_func(particles[i], **kwargs) for i in range(n_particles)]) + assert vector.shape == (n_particles, ), "The cost function should return a single value." + return vector + return cost_dec diff --git a/tests/utils/decorators/__init__.py b/tests/utils/decorators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/decorators/conftest.py b/tests/utils/decorators/conftest.py new file mode 100644 index 00000000..7d8d3fa1 --- /dev/null +++ b/tests/utils/decorators/conftest.py @@ -0,0 +1,10 @@ +import pytest +import numpy as np + + +@pytest.fixture() +def particles(): + shape = (np.random.randint(10, 20), np.random.randint(2, 6)) + particles_ = np.random.uniform(0, 10, shape) + print(particles_) + return particles_ diff --git a/tests/utils/decorators/test_decorators.py b/tests/utils/decorators/test_decorators.py new file mode 100644 index 00000000..03efc361 --- /dev/null +++ b/tests/utils/decorators/test_decorators.py @@ -0,0 +1,30 @@ +# Import modules +import pytest +import numpy as np + +# Import from package +from pyswarms.utils.decorators import cost + + +@pytest.mark.parametrize( + "objective_func", + [np.sum, np.prod] +) +def test_cost_decorator(objective_func, particles): + n_particles = particles.shape[0] + + def cost_func_without_decorator(x): + n_particles_in_func = x.shape[0] + cost = np.array([objective_func(x[i]) for i in range(n_particles_in_func)]) + return cost + + @cost + def cost_func_with_decorator(x): + cost = objective_func(x) + return cost + + undecorated = cost_func_without_decorator(particles) + decorated = cost_func_with_decorator(particles) + + assert np.array_equal(decorated, undecorated) + assert decorated.shape == (n_particles, )