diff --git a/tardis/conftest.py b/tardis/conftest.py index 5eabacc395e..13f6cd84027 100644 --- a/tardis/conftest.py +++ b/tardis/conftest.py @@ -12,6 +12,8 @@ from tardis.atomic import AtomData from tardis.io.config_reader import Configuration from tardis.io.util import yaml_load_config_file +from tardis.model import Radial1DModel +from tardis.model.density import HomologousDensity ### # Astropy @@ -117,3 +119,29 @@ def included_he_atomic_data(test_data_path): def tardis_config_verysimple(): return yaml_load_config_file( 'tardis/io/tests/data/tardis_configv1_verysimple.yml') + +### +# HDF Fixtures +### + +@pytest.fixture(scope="session") +def hdf_file_path(tmpdir_factory): + path = tmpdir_factory.mktemp('hdf_buffer').join('test.hdf') + return str(path) + +@pytest.fixture(scope="session") +def hdf_config(): + filename = 'tardis_configv1_verysimple.yml' + path = os.path.abspath(os.path.join('tardis/io/tests/data/', filename)) + config = Configuration.from_yaml(path) + return config + +@pytest.fixture(scope="session") +def model(hdf_config): + model = Radial1DModel.from_config(hdf_config) + return model + +@pytest.fixture(scope="session") +def homologous_density(hdf_config): + density = HomologousDensity.from_config(hdf_config) + return density \ No newline at end of file diff --git a/tardis/model/base.py b/tardis/model/base.py index 03ed6401184..f35306e53a7 100644 --- a/tardis/model/base.py +++ b/tardis/model/base.py @@ -6,13 +6,13 @@ from tardis.util import quantity_linspace, element_symbol2atomic_number from tardis.io.model_reader import read_density_file, read_abundances_file -from tardis.io.util import HDFWriter +from tardis.io.util import HDFWriter as HDFWriterMixin from density import HomologousDensity logger = logging.getLogger(__name__) -class Radial1DModel(HDFWriter, object): +class Radial1DModel(HDFWriterMixin): """An object that hold information about the individual shells. Parameters diff --git a/tardis/model/density.py b/tardis/model/density.py index 4f172c9eea3..4bad6e5abeb 100644 --- a/tardis/model/density.py +++ b/tardis/model/density.py @@ -1,9 +1,9 @@ import numpy as np from tardis.util import quantity_linspace -from tardis.io.util import HDFWriter +from tardis.io.util import HDFWriter as HDFWriterMixin -class HomologousDensity(HDFWriter, object): +class HomologousDensity(HDFWriterMixin): """A class that holds an initial density and time Parameters diff --git a/tardis/model/tests/test_base.py b/tardis/model/tests/test_base.py index 155609f9552..83b67ec674b 100644 --- a/tardis/model/tests/test_base.py +++ b/tardis/model/tests/test_base.py @@ -209,32 +209,17 @@ def test_ascii_reader_exponential_law(): # Save and Load ### - -@pytest.fixture(scope="module") -def hdf_file_path(tmpdir_factory): - path = tmpdir_factory.mktemp('hdf_buffer').join('model.hdf') - return str(path) - - -@pytest.fixture(scope="module") -def actual_model(): - filename = 'tardis_configv1_verysimple.yml' - config = Configuration.from_yaml(data_path(filename)) - model = Radial1DModel.from_config(config) - return model - - @pytest.fixture(scope="module", autouse=True) -def to_hdf_buffer(hdf_file_path, actual_model): - actual_model.to_hdf(hdf_file_path) +def to_hdf_buffer(hdf_file_path, model): + model.to_hdf(hdf_file_path) model_scalar_attrs = ['t_inner'] @pytest.mark.parametrize("attr", model_scalar_attrs) -def test_hdf_model_scalars(hdf_file_path, actual_model, attr): +def test_hdf_model_scalars(hdf_file_path, model, attr): path = os.path.join('model', 'scalars') expected = pd.read_hdf(hdf_file_path, path)[attr] - actual = getattr(actual_model, attr) + actual = getattr(model, attr) if hasattr(actual, 'cgs'): actual = actual.cgs.value assert_almost_equal(actual, expected) @@ -242,10 +227,10 @@ def test_hdf_model_scalars(hdf_file_path, actual_model, attr): model_nparray_attrs = ['w', 'v_inner', 'v_outer'] @pytest.mark.parametrize("attr", model_nparray_attrs) -def test_hdf_model_nparray(hdf_file_path, actual_model, attr): +def test_hdf_model_nparray(hdf_file_path, model, attr): path = os.path.join('model', attr) expected = pd.read_hdf(hdf_file_path, path) - actual = getattr(actual_model, attr) + actual = getattr(model, attr) if hasattr(actual, 'cgs'): actual = actual.cgs.value assert_almost_equal(actual, expected.values) diff --git a/tardis/model/tests/test_density.py b/tardis/model/tests/test_density.py index f30dab81418..c1721cb514d 100644 --- a/tardis/model/tests/test_density.py +++ b/tardis/model/tests/test_density.py @@ -3,30 +3,11 @@ import pytest from numpy.testing import assert_almost_equal -from tardis.io.config_reader import Configuration -from tardis.model.density import HomologousDensity - - ### # Save and Load ### -def data_path(filename): - return os.path.join('tardis/io/tests/data/', filename) - -@pytest.fixture(scope="module") -def hdf_file_path(tmpdir_factory): - path = tmpdir_factory.mktemp('hdf_buffer').join('density.hdf') - return str(path) - -@pytest.fixture(scope="module") -def homologous_density(): - filename = 'tardis_configv1_verysimple.yml' - config = Configuration.from_yaml(data_path(filename)) - density = HomologousDensity.from_config(config) - return density - -@pytest.fixture(scope="module",autouse=True) +@pytest.fixture(scope="module", autouse=True) def to_hdf_buffer(hdf_file_path,homologous_density): homologous_density.to_hdf(hdf_file_path)