Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove save_trace and load_trace function #5123

Merged
merged 2 commits into from
Oct 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,7 @@ def __set_compiler_flags():

from pymc import gp, ode, sampling
from pymc.aesaraf import *
from pymc.backends import (
load_trace,
predictions_to_inference_data,
save_trace,
to_inference_data,
)
from pymc.backends import predictions_to_inference_data, to_inference_data
from pymc.backends.tracetab import *
from pymc.bart import *
from pymc.blocking import *
Expand Down
7 changes: 1 addition & 6 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,4 @@

"""
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.ndarray import (
NDArray,
load_trace,
point_list_to_multitrace,
save_trace,
)
from pymc.backends.ndarray import NDArray, point_list_to_multitrace
88 changes: 0 additions & 88 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,94 +32,6 @@
from pymc.model import Model, modelcontext


def save_trace(trace: MultiTrace, directory: Optional[str] = None, overwrite=False) -> str:
"""Save multitrace to file.

TODO: Also save warnings.

This is a custom data format for PyMC traces. Each chain goes inside
a directory, and each directory contains a metadata json file, and a
numpy compressed file. See https://docs.scipy.org/doc/numpy/neps/npy-format.html
for more information about this format.

Parameters
----------
trace: pm.MultiTrace
trace to save to disk
directory: str (optional)
path to a directory to save the trace
overwrite: bool (default False)
whether to overwrite an existing directory.

Returns
-------
str, path to the directory where the trace was saved
"""
warnings.warn(
"The `save_trace` function will soon be removed."
"Instead, use `arviz.to_netcdf` to save traces.",
FutureWarning,
)

if isinstance(trace, MultiTrace):
if directory is None:
directory = ".pymc_{}.trace"
idx = 1
while os.path.exists(directory.format(idx)):
idx += 1
directory = directory.format(idx)

if os.path.isdir(directory):
if overwrite:
shutil.rmtree(directory)
else:
raise OSError(
"Cautiously refusing to overwrite the already existing {}! Please supply "
"a different directory, or set `overwrite=True`".format(directory)
)
os.makedirs(directory)

for chain, ndarray in trace._straces.items():
SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
return directory
else:
raise TypeError(
f"You are attempting to save an InferenceData object but this function "
"works only for MultiTrace objects. Use `arviz.to_netcdf` instead"
)


def load_trace(directory: str, model=None) -> MultiTrace:
"""Loads a multitrace that has been written to file.

A the model used for the trace must be passed in, or the command
must be run in a model context.

Parameters
----------
directory: str
Path to a pymc serialized trace
model: pm.Model (optional)
Model used to create the trace. Can also be inferred from context

Returns
-------
pm.Multitrace that was saved in the directory
"""
warnings.warn(
"The `load_trace` function will soon be removed."
"Instead, use `arviz.from_netcdf` to load traces.",
FutureWarning,
)
straces = []
for subdir in glob.glob(os.path.join(directory, "*")):
if os.path.isdir(subdir):
straces.append(SerializeNDArray(subdir).load(model))
if not straces:
raise TraceDirectoryError("%s is not a PyMC saved chain directory." % directory)
return base.MultiTrace(straces)


class SerializeNDArray:
metadata_file = "metadata.json"
samples_file = "samples.npz"
Expand Down
75 changes: 0 additions & 75 deletions pymc/tests/test_ndarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,78 +205,3 @@ def test_combine_true_squeeze_true(self):
expected = np.concatenate([self.x, self.y])
result = base._squeeze_cat([self.x, self.y], True, True)
npt.assert_equal(result, expected)


class TestSaveLoad:
@staticmethod
def model(rng_seeder=None):
with pm.Model(rng_seeder=rng_seeder) as model:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, 1, observed=2)
z = pm.Normal("z", x + y, 1)
return model

@classmethod
def setup_class(cls):
with TestSaveLoad.model():
cls.trace = pm.sample(return_inferencedata=False)

def test_save_new_model(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp("data"))
save_dir = pm.save_trace(self.trace, directory, overwrite=True)

assert save_dir == directory
with pm.Model() as model:
w = pm.Normal("w", 0, 1)
new_trace = pm.sample(return_inferencedata=False)

with pytest.raises(OSError):
_ = pm.save_trace(new_trace, directory)

_ = pm.save_trace(new_trace, directory, overwrite=True)
with model:
new_trace_copy = pm.load_trace(directory)

assert (new_trace["w"] == new_trace_copy["w"]).all()

def test_save_and_load(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp("data"))
save_dir = pm.save_trace(self.trace, directory, overwrite=True)

assert save_dir == directory

trace2 = pm.load_trace(directory, model=TestSaveLoad.model())

for var in ("x", "z"):
assert (self.trace[var] == trace2[var]).all()

assert self.trace.stat_names == trace2.stat_names
for stat in self.trace.stat_names:
assert all(self.trace[stat] == trace2[stat]), (
"Restored value of statistic %s does not match stored value" % stat
)

def test_bad_load(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp("data"))
with pytest.raises(pm.TraceDirectoryError):
pm.load_trace(directory, model=TestSaveLoad.model())

def test_sample_posterior_predictive(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp("data"))
save_dir = pm.save_trace(self.trace, directory, overwrite=True)

assert save_dir == directory

rng = np.random.RandomState(10)

with TestSaveLoad.model(rng_seeder=rng):
ppc = pm.sample_posterior_predictive(self.trace)

rng = np.random.RandomState(10)

with TestSaveLoad.model(rng_seeder=rng):
trace2 = pm.load_trace(directory)
ppc2 = pm.sample_posterior_predictive(trace2)

for key, value in ppc.items():
assert (value == ppc2[key]).all()