Skip to content

Commit

Permalink
Skip tests for optional/extra dependencies when not installed (#1113)
Browse files Browse the repository at this point in the history
* Skip tests for optional/extra dependencies when not installed

* Update changelog

* Fix pystan_version()

* Fix pylint issues

* Relocate numba specific tests in test_diagnostics (and skip if not installed)

* Relocate numba specific tests in test_utils (and skip if not installed)

* Relocate numba specific tests in test_stats (and skip if not installed)

* Add CI environment variable

* Add custom importorskip for CI

* Use internal `importorskip`

* Displayed skipped files in pytest output (due to importorskip)

* Move CI env variable detection into `importorskip`

* Test `importorskip` for local/ci machines with `monkeypatch`

* Ignore vscode config

* Properly test for local machine with monkeypatch deleting CI env

* Add back in some parts of `pytest.importorskip` to our `importorskip` to fix errors

* Clarify reason text when test skipped for lack of pystan/pystan3

* Refactor helper function to test running on CI machine

* Ensure individual tests for external requirements only skip locally

* Use `importlib.import_module` in `importorskip`

* Attempt to fix failing CI imports

* Revert import method

* Correct skip logic

* Breakup import statements

* Correct `pydocstyle` issues

* Correct pystan skip logic
  • Loading branch information
hectormz authored Mar 31, 2020
1 parent f76da26 commit 9909aab
Show file tree
Hide file tree
Showing 23 changed files with 452 additions and 264 deletions.
2 changes: 2 additions & 0 deletions .azure-pipelines/azure-pipelines-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ jobs:
variables:
- name: NUMBA_DISABLE_JIT
value: 1
- name: ARVIZ_CI_MACHINE
value: 1
timeoutInMinutes: 360
strategy:
matrix:
Expand Down
2 changes: 2 additions & 0 deletions .azure-pipelines/azure-pipelines-external.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ jobs:
variables:
- name: NUMBA_DISABLE_JIT
value: 1
- name: ARVIZ_CI_MACHINE
value: 1
timeoutInMinutes: 360
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ target/

# IDE configs
.idea/
.vscode/

saved_animations/

Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Add `num_chains` and `pred_dims` arguments to io_pyro #1090
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
* Allow xarray.Dataarray input for plots.(#1120)
* Skip test for optional/extra dependencies when not installed (#1113)
### Maintenance and fixes
* Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115)
* Fixed hist kind of `plot_dist` with multidimensional input (#1115)
Expand Down Expand Up @@ -212,4 +213,3 @@
## v0.3.0 (2018 Dec 14)

* First Beta Release

100 changes: 9 additions & 91 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
"""Test Diagnostic methods"""
# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
import os

import numpy as np
from numpy.testing import assert_almost_equal, assert_array_almost_equal
import pandas as pd
import pytest
from numpy.testing import assert_almost_equal, assert_array_almost_equal

from ...data import load_arviz_data, from_cmdstan
from ...data import from_cmdstan, load_arviz_data
from ...plots.plot_utils import xarray_var_iter
from ...stats import bfmi, rhat, ess, mcse, geweke
from ...rcparams import rcParams
from ...stats import bfmi, ess, geweke, mcse, rhat
from ...stats.diagnostics import (
ks_summary,
_conv_quantile,
_ess,
_ess_quantile,
_multichain_statistics,
_mc_error,
_multichain_statistics,
_rhat,
_rhat_rank,
_z_scale,
_conv_quantile,
_split_chains,
_z_scale,
ks_summary,
)
from ...utils import Numba
from ...rcparams import rcParams

# For tests only, recommended value should be closer to 1.01-1.05
# See discussion in https://github.com/stan-dev/rstan/pull/618
Expand Down Expand Up @@ -536,85 +536,3 @@ def test_split_chain_dims(self, chains, draws):
if chains is None:
chains = 1
assert split_data.shape == (chains * 2, draws // 2)


def test_numba_bfmi():
"""Numba test for bfmi."""
state = Numba.numba_flag
school = load_arviz_data("centered_eight")
data_md = np.random.rand(100, 100, 10)
Numba.disable_numba()
non_numba = bfmi(school.posterior["mu"].values)
non_numba_md = bfmi(data_md)
Numba.enable_numba()
with_numba = bfmi(school.posterior["mu"].values)
with_numba_md = bfmi(data_md)
assert np.allclose(non_numba_md, with_numba_md)
assert np.allclose(with_numba, non_numba)
assert state == Numba.numba_flag


@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
def test_numba_rhat(method):
"""Numba test for mcse."""
state = Numba.numba_flag
school = np.random.rand(100, 100)
Numba.disable_numba()
non_numba = rhat(school, method=method)
Numba.enable_numba()
with_numba = rhat(school, method=method)
assert np.allclose(with_numba, non_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("method", ("mean", "sd", "quantile"))
def test_numba_mcse(method, prob=None):
"""Numba test for mcse."""
state = Numba.numba_flag
school = np.random.rand(100, 100)
if method == "quantile":
prob = 0.80
Numba.disable_numba()
non_numba = mcse(school, method=method, prob=prob)
Numba.enable_numba()
with_numba = mcse(school, method=method, prob=prob)
assert np.allclose(with_numba, non_numba)
assert Numba.numba_flag == state


def test_ks_summary_numba():
"""Numba test for ks_summary."""
state = Numba.numba_flag
data = np.random.randn(100, 100)
Numba.disable_numba()
non_numba = (ks_summary(data)["Count"]).values
Numba.enable_numba()
with_numba = (ks_summary(data)["Count"]).values
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


def test_geweke_numba():
"""Numba test for geweke."""
state = Numba.numba_flag
data = np.random.randn(100)
Numba.disable_numba()
non_numba = geweke(data)
Numba.enable_numba()
with_numba = geweke(data)
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("batches", (1, 20))
@pytest.mark.parametrize("circular", (True, False))
def test_mcse_error_numba(batches, circular):
"""Numba test for mcse_error."""
data = np.random.randn(100, 100)
state = Numba.numba_flag
Numba.disable_numba()
non_numba = _mc_error(data, batches=batches, circular=circular)
Numba.enable_numba()
with_numba = _mc_error(data, batches=batches, circular=circular)
assert np.allclose(non_numba, with_numba)
assert state == Numba.numba_flag
104 changes: 104 additions & 0 deletions arviz/tests/base_tests/test_diagnostics_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Test Diagnostic methods"""
import importlib

# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
import numpy as np
import pytest

from ...data import load_arviz_data
from ..helpers import running_on_ci
from ...rcparams import rcParams
from ...stats import bfmi, geweke, mcse, rhat
from ...stats.diagnostics import _mc_error, ks_summary
from ...utils import Numba
from .test_diagnostics import data # pylint: disable=unused-import


pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name
(importlib.util.find_spec("numba") is None) & ~running_on_ci(),
reason="test requires numba which is not installed",
)

rcParams["data.load"] = "eager"


def test_numba_bfmi():
"""Numba test for bfmi."""
state = Numba.numba_flag
school = load_arviz_data("centered_eight")
data_md = np.random.rand(100, 100, 10)
Numba.disable_numba()
non_numba = bfmi(school.posterior["mu"].values)
non_numba_md = bfmi(data_md)
Numba.enable_numba()
with_numba = bfmi(school.posterior["mu"].values)
with_numba_md = bfmi(data_md)
assert np.allclose(non_numba_md, with_numba_md)
assert np.allclose(with_numba, non_numba)
assert state == Numba.numba_flag


@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
def test_numba_rhat(method):
"""Numba test for mcse."""
state = Numba.numba_flag
school = np.random.rand(100, 100)
Numba.disable_numba()
non_numba = rhat(school, method=method)
Numba.enable_numba()
with_numba = rhat(school, method=method)
assert np.allclose(with_numba, non_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("method", ("mean", "sd", "quantile"))
def test_numba_mcse(method, prob=None):
"""Numba test for mcse."""
state = Numba.numba_flag
school = np.random.rand(100, 100)
if method == "quantile":
prob = 0.80
Numba.disable_numba()
non_numba = mcse(school, method=method, prob=prob)
Numba.enable_numba()
with_numba = mcse(school, method=method, prob=prob)
assert np.allclose(with_numba, non_numba)
assert Numba.numba_flag == state


def test_ks_summary_numba():
"""Numba test for ks_summary."""
state = Numba.numba_flag
data = np.random.randn(100, 100)
Numba.disable_numba()
non_numba = (ks_summary(data)["Count"]).values
Numba.enable_numba()
with_numba = (ks_summary(data)["Count"]).values
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


def test_geweke_numba():
"""Numba test for geweke."""
state = Numba.numba_flag
data = np.random.randn(100)
Numba.disable_numba()
non_numba = geweke(data)
Numba.enable_numba()
with_numba = geweke(data)
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("batches", (1, 20))
@pytest.mark.parametrize("circular", (True, False))
def test_mcse_error_numba(batches, circular):
"""Numba test for mcse_error."""
data = np.random.randn(100, 100)
state = Numba.numba_flag
Numba.disable_numba()
non_numba = _mc_error(data, batches=batches, circular=circular)
Numba.enable_numba()
with_numba = _mc_error(data, batches=batches, circular=circular)
assert np.allclose(non_numba, with_numba)
assert state == Numba.numba_flag
18 changes: 18 additions & 0 deletions arviz/tests/base_tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from _pytest.outcomes import Skipped

from ..helpers import importorskip


def test_importorskip_local(monkeypatch):
"""Test ``importorskip`` run on local machine with non-existent module, which should skip."""
monkeypatch.delenv("ARVIZ_CI_MACHINE", raising=False)
with pytest.raises(Skipped):
importorskip("non-existent-function")


def test_importorskip_ci(monkeypatch):
"""Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
monkeypatch.setenv("ARVIZ_CI_MACHINE", 1)
with pytest.raises(ModuleNotFoundError):
importorskip("non-existent-function")
19 changes: 13 additions & 6 deletions arviz/tests/base_tests/test_plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# pylint: disable=redefined-outer-name
import importlib

import numpy as np
import xarray as xr
import pytest
import xarray as xr

from ...data import from_dict
from ..helpers import running_on_ci
from ...plots.plot_utils import (
make_2d,
xarray_to_ndarray,
xarray_var_iter,
get_bins,
get_coords,
filter_plotters_list,
format_sig_figs,
get_bins,
get_coords,
get_plotting_function,
make_2d,
matplotlib_kwarg_dealiaser,
xarray_to_ndarray,
xarray_var_iter,
)
from ...rcparams import rc_context

Expand Down Expand Up @@ -194,6 +197,10 @@ def test_filter_plotter_list_warning():
assert len(plotters_filtered) == 5


@pytest.mark.skipif(
(importlib.util.find_spec("bokeh") is None) & ~running_on_ci(),
reason="test requires bokeh which is not installed",
)
def test_bokeh_import():
"""Tests that correct method is returned on bokeh import"""
plot = get_plotting_function("plot_dist", "distplot", "bokeh")
Expand Down
27 changes: 16 additions & 11 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
"""Tests use the 'bokeh' backend."""
# pylint: disable=redefined-outer-name,too-many-lines
from copy import deepcopy
import bokeh.plotting as bkp
from pandas import DataFrame

import numpy as np
import pytest
from pandas import DataFrame # pylint: disable=wrong-import-position

from ...data import from_dict, load_arviz_data
from ..helpers import ( # pylint: disable=unused-import
from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
create_model,
create_multidimensional_model,
eight_schools_params,
importorskip,
models,
create_model,
multidim_models,
create_multidimensional_model,
)
from ...rcparams import rcParams, rc_context
from ...plots import (
from ...rcparams import rc_context, rcParams # pylint: disable=wrong-import-position
from ...plots import ( # pylint: disable=wrong-import-position
plot_autocorr,
plot_compare,
plot_density,
Expand All @@ -31,14 +32,18 @@
plot_loo_pit,
plot_mcse,
plot_pair,
plot_rank,
plot_trace,
plot_parallel,
plot_posterior,
plot_ppc,
plot_rank,
plot_trace,
plot_violin,
)
from ...stats import compare, loo, waic
from ...stats import compare, loo, waic # pylint: disable=wrong-import-position

# Skip tests if bokeh not installed
bkp = importorskip("bokeh.plotting") # pylint: disable=invalid-name


rcParams["data.load"] = "eager"

Expand Down
8 changes: 2 additions & 6 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,8 @@ def test_plot_trace(models, kwargs):
assert axes.shape


@pytest.mark.parametrize(
"compact", [True, False],
)
@pytest.mark.parametrize(
"combined", [True, False],
)
@pytest.mark.parametrize("compact", [True, False])
@pytest.mark.parametrize("combined", [True, False])
def test_plot_trace_legend(compact, combined):
idata = load_arviz_data("rugby")
axes = plot_trace(
Expand Down
Loading

0 comments on commit 9909aab

Please sign in to comment.