diff --git a/tardis/conftest.py b/tardis/conftest.py index 5eabacc395e..15222177777 100644 --- a/tardis/conftest.py +++ b/tardis/conftest.py @@ -12,6 +12,8 @@ from tardis.atomic import AtomData from tardis.io.config_reader import Configuration from tardis.io.util import yaml_load_config_file +from tardis.simulation import Simulation +from copy import deepcopy ### # Astropy @@ -90,17 +92,19 @@ def atomic_data_fname(): else: return os.path.expandvars(os.path.expanduser(atomic_data_fname)) - -@pytest.fixture -def kurucz_atomic_data(atomic_data_fname): +@pytest.fixture(scope="session") +def atomic_dataset(atomic_data_fname): atomic_data = AtomData.from_hdf5(atomic_data_fname) - if atomic_data.md5 != '21095dd25faa1683f4c90c911a00c3f8': pytest.skip('Need default Kurucz atomic dataset ' '(md5="21095dd25faa1683f4c90c911a00c3f8"') else: return atomic_data +@pytest.fixture +def kurucz_atomic_data(atomic_dataset): + atomic_data = deepcopy(atomic_dataset) + return atomic_data @pytest.fixture def test_data_path(): @@ -117,3 +121,26 @@ def included_he_atomic_data(test_data_path): def tardis_config_verysimple(): return yaml_load_config_file( 'tardis/io/tests/data/tardis_configv1_verysimple.yml') + +### +# HDF Fixtures +### + +@pytest.fixture(scope="session") +def hdf_file_path(tmpdir_factory): + path = tmpdir_factory.mktemp('hdf_buffer').join('test.hdf') + return str(path) + +@pytest.fixture(scope="session") +def config_verysimple(): + filename = 'tardis_configv1_verysimple.yml' + path = os.path.abspath(os.path.join('tardis/io/tests/data/', filename)) + config = Configuration.from_yaml(path) + return config + +@pytest.fixture(scope="session") +def simulation_verysimple(config_verysimple, atomic_dataset): + atomic_data = deepcopy(atomic_dataset) + sim = Simulation.from_config(config_verysimple, atom_data=atomic_data) + sim.iterate(4000) + return sim diff --git a/tardis/model/base.py b/tardis/model/base.py index c70ac7f513b..eda7363c135 100644 --- a/tardis/model/base.py +++ b/tardis/model/base.py @@ -6,13 +6,13 @@ from tardis.util import quantity_linspace, element_symbol2atomic_number from tardis.io.model_reader import read_density_file, read_abundances_file -from tardis.io.util import to_hdf +from tardis.io.util import HDFWriterMixin from density import HomologousDensity logger = logging.getLogger(__name__) -class Radial1DModel(object): +class Radial1DModel(HDFWriterMixin): """An object that hold information about the individual shells. Parameters @@ -57,6 +57,9 @@ class Radial1DModel(object): Shortcut for `t_radiative` """ + hdf_properties = ['t_inner', 'w', 't_radiative', 'v_inner', 'v_outer', 'homologous_density'] + hdf_name = 'model' + def __init__(self, velocity, homologous_density, abundance, time_explosion, t_inner, luminosity_requested=None, t_radiative=None, dilution_factor=None, v_boundary_inner=None, @@ -262,27 +265,6 @@ def v_boundary_outer_index(self): return None return self.raw_velocity.searchsorted(self.v_boundary_outer) + 1 - def to_hdf(self, path_or_buf, path=''): - """ - Store the model to an HDF structure. - - Parameters - ---------- - path_or_buf - Path or buffer to the HDF store - path : str - Path inside the HDF store to store the model - - Returns - ------- - None - - """ - model_path = os.path.join(path, 'model') - properties = ['t_inner', 'w', 't_radiative', 'v_inner', 'v_outer'] - to_hdf(path_or_buf, model_path, {name: getattr(self, name) for name - in properties}) - @classmethod def from_config(cls, config): """ diff --git a/tardis/model/density.py b/tardis/model/density.py index 1652722ad6b..a7676a6d5c6 100644 --- a/tardis/model/density.py +++ b/tardis/model/density.py @@ -1,9 +1,9 @@ import numpy as np from tardis.util import quantity_linspace +from tardis.io.util import HDFWriterMixin - -class HomologousDensity(object): +class HomologousDensity(HDFWriterMixin): """A class that holds an initial density and time Parameters @@ -12,6 +12,8 @@ class HomologousDensity(object): time_0 : astropy.units.Quantity """ + hdf_properties = ['density_0', 'time_0'] + def __init__(self, density_0, time_0): self.density_0 = density_0 self.time_0 = time_0 diff --git a/tardis/model/tests/test_base.py b/tardis/model/tests/test_base.py index 82dbbdbfaac..9053c7423de 100644 --- a/tardis/model/tests/test_base.py +++ b/tardis/model/tests/test_base.py @@ -1,4 +1,6 @@ import os +import pytest +import pandas as pd from astropy import units as u from numpy.testing import assert_almost_equal, assert_array_almost_equal @@ -202,3 +204,33 @@ def test_ascii_reader_exponential_law(): for i, mdens in enumerate(expected_densites): assert_almost_equal(model.density[i].value, mdens) assert model.density[i].unit == u.Unit(expected_unit) + +### +# Save and Load +### + +@pytest.fixture(scope="module", autouse=True) +def to_hdf_buffer(hdf_file_path, simulation_verysimple): + simulation_verysimple.model.to_hdf(hdf_file_path) + +model_scalar_attrs = ['t_inner'] + +@pytest.mark.parametrize("attr", model_scalar_attrs) +def test_hdf_model_scalars(hdf_file_path, simulation_verysimple, attr): + path = os.path.join('model', 'scalars') + expected = pd.read_hdf(hdf_file_path, path)[attr] + actual = getattr(simulation_verysimple.model, attr) + if hasattr(actual, 'cgs'): + actual = actual.cgs.value + assert_almost_equal(actual, expected) + +model_nparray_attrs = ['w', 'v_inner', 'v_outer'] + +@pytest.mark.parametrize("attr", model_nparray_attrs) +def test_hdf_model_nparray(hdf_file_path, simulation_verysimple, attr): + path = os.path.join('model', attr) + expected = pd.read_hdf(hdf_file_path, path) + actual = getattr(simulation_verysimple.model, attr) + if hasattr(actual, 'cgs'): + actual = actual.cgs.value + assert_almost_equal(actual, expected.values) diff --git a/tardis/model/tests/test_density.py b/tardis/model/tests/test_density.py new file mode 100644 index 00000000000..0f176c12236 --- /dev/null +++ b/tardis/model/tests/test_density.py @@ -0,0 +1,28 @@ +import os +import pandas as pd +import pytest +from numpy.testing import assert_almost_equal + +### +# Save and Load +### + +@pytest.fixture(scope="module", autouse=True) +def to_hdf_buffer(hdf_file_path,simulation_verysimple): + simulation_verysimple.model.homologous_density.to_hdf(hdf_file_path) + +def test_hdf_density_0(hdf_file_path, simulation_verysimple): + actual = simulation_verysimple.model.homologous_density.density_0 + if hasattr(actual, 'cgs'): + actual = actual.cgs.value + path = os.path.join('homologous_density','density_0') + expected = pd.read_hdf(hdf_file_path, path) + assert_almost_equal(actual, expected.values) + +def test_hdf_time_0(hdf_file_path, simulation_verysimple): + actual = simulation_verysimple.model.homologous_density.time_0 + if hasattr(actual, 'cgs'): + actual = actual.cgs.value + path = os.path.join('homologous_density','scalars') + expected = pd.read_hdf(hdf_file_path, path)['time_0'] + assert_almost_equal(actual, expected) \ No newline at end of file