Skip to content

Commit

Permalink
Merge pull request #42 from tdmorello/feat/pytest-benchmark
Browse files Browse the repository at this point in the history
Add pytest-benchmark
  • Loading branch information
yfukai authored Mar 24, 2022
2 parents a7b7336 + 65956de commit 76ff9b3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def tests(session: Session) -> None:
"""Run the test suite."""
session.install(".")
session.install("pytest", "pytest-cov", "xdoctest")
session.install("pytest", "pytest-benchmark", "pytest-cov", "xdoctest")
session.install("opencv-python")
session.run("pytest")

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ dev =
opencv-python
pre-commit
pytest
pytest-benchmark
pytest-cov
scipy
xdoctest
21 changes: 17 additions & 4 deletions tests/test_tools_dct2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
scipy.fft.idct(arr_input.T, norm="ortho").T, norm="ortho"
)

backends = ["JAX", "OPENCV", "SCIPY"]
# backends = ["JAX", "OPENCV", "SCIPY"]
backends = ["OPENCV", "SCIPY"]


@pytest.mark.parametrize("backend", backends)
def test_dct_backends(backend):
if backend == "JAX":
return

dct2d = DCT_BACKENDS[backend].dct2d
idct2d = DCT_BACKENDS[backend].idct2d

Expand All @@ -43,6 +41,8 @@ def test_dct_backends(backend):
def test_dct_backend_import(monkeypatch, backend):
import basicpy.tools.dct2d_tools

idct2d = DCT_BACKENDS[backend].idct2d

monkeypatch.setenv("BASIC_DCT_BACKEND", backend)
importlib.reload(basicpy.tools.dct2d_tools)

Expand All @@ -69,3 +69,16 @@ def test_unrecognized_backend(monkeypatch):
def test_backend_not_installed(monkeypatch, backend):
# TODO mimic package not installed by removing from path?
...


### BENCHMARKING ###
@pytest.mark.parametrize("backend", backends)
def test_dct_backends_benchmark_dct2d(backend, benchmark):
dct2d = DCT_BACKENDS[backend].dct2d
benchmark(dct2d, arr_input)


@pytest.mark.parametrize("backend", backends)
def test_dct_backends_benchmark_idct2d(backend, benchmark):
idct2d = DCT_BACKENDS[backend].idct2d
benchmark(idct2d, arr_input)

0 comments on commit 76ff9b3

Please sign in to comment.