Skip to content

Commit

Permalink
fix deprecation messages
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Oct 27, 2021
1 parent 9e7e8aa commit 2fcb64b
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,29 +58,35 @@ def save_trace(trace: MultiTrace, directory: Optional[str] = None, overwrite=Fal
warnings.warn(
"The `save_trace` function will soon be removed."
"Instead, use `arviz.to_netcdf` to save traces.",
DeprecationWarning,
FutureWarning,
)

if directory is None:
directory = ".pymc_{}.trace"
idx = 1
while os.path.exists(directory.format(idx)):
idx += 1
directory = directory.format(idx)

if os.path.isdir(directory):
if overwrite:
shutil.rmtree(directory)
else:
raise OSError(
"Cautiously refusing to overwrite the already existing {}! Please supply "
"a different directory, or set `overwrite=True`".format(directory)
)
os.makedirs(directory)

for chain, ndarray in trace._straces.items():
SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
return directory
if isinstance(trace, MultiTrace):
if directory is None:
directory = ".pymc_{}.trace"
idx = 1
while os.path.exists(directory.format(idx)):
idx += 1
directory = directory.format(idx)

if os.path.isdir(directory):
if overwrite:
shutil.rmtree(directory)
else:
raise OSError(
"Cautiously refusing to overwrite the already existing {}! Please supply "
"a different directory, or set `overwrite=True`".format(directory)
)
os.makedirs(directory)

for chain, ndarray in trace._straces.items():
SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
return directory
else:
raise TypeError(
f"You are attempting to save an InferenceData object but this function "
"works only for MultiTrace objects. Use `arviz.to_netcdf` instead"
)


def load_trace(directory: str, model=None) -> MultiTrace:
Expand All @@ -103,7 +109,7 @@ def load_trace(directory: str, model=None) -> MultiTrace:
warnings.warn(
"The `load_trace` function will soon be removed."
"Instead, use `arviz.from_netcdf` to load traces.",
DeprecationWarning,
FutureWarning,
)
straces = []
for subdir in glob.glob(os.path.join(directory, "*")):
Expand All @@ -125,7 +131,7 @@ def __init__(self, directory: str):
warnings.warn(
"The `SerializeNDArray` class will soon be removed. "
"Instead, use ArviZ to save/load traces.",
DeprecationWarning,
FutureWarning,
)
self.directory = directory
self.metadata_path = os.path.join(self.directory, self.metadata_file)
Expand Down

0 comments on commit 2fcb64b

Please sign in to comment.