Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deprecation warnings for old backends #3902

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
### Maintenance
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
- Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796).
- Deprecated `sd` in version 3.7 has been replaced by `sigma` now raises `DepreciationWarning` on using `sd` in continuous, mixed and timeseries distributions. (see #3837 and #3688).
- Deprecated `sd` in version 3.7 has been replaced by `sigma` now raises `DeprecationWarning` on using `sd` in continuous, mixed and timeseries distributions. (see #3837 and #3688).
- We'll deprecate the `Text` and `SQLite` backends and the `save_trace`/`load_trace` functions, since this is now done with ArviZ. (see [#3902](https://github.com/pymc-devs/pymc3/pull/3902))
- In named models, `pm.Data` objects now get model-relative names (see [#3843](https://github.com/pymc-devs/pymc3/pull/3843)).
- `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)).
- Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)).
Expand Down
18 changes: 18 additions & 0 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import shutil
from typing import Optional, Dict, Any, List
import warnings

import numpy as np
from pymc3.backends import base
Expand Down Expand Up @@ -52,6 +53,12 @@ def save_trace(trace: MultiTrace, directory: Optional[str]=None, overwrite=False
-------
str, path to the directory where the trace was saved
"""
warnings.warn(
'The `save_trace` function will soon be removed.'
'Instead, use ArviZ to save/load traces.',
DeprecationWarning,
)

if directory is None:
directory = '.pymc_{}.trace'
idx = 1
Expand Down Expand Up @@ -89,6 +96,11 @@ def load_trace(directory: str, model=None) -> MultiTrace:
-------
pm.Multitrace that was saved in the directory
"""
warnings.warn(
'The `load_trace` function will soon be removed.'
'Instead, use ArviZ to save/load traces.',
DeprecationWarning,
)
straces = []
for subdir in glob.glob(os.path.join(directory, '*')):
if os.path.isdir(subdir):
Expand All @@ -106,6 +118,11 @@ class SerializeNDArray:

def __init__(self, directory: str):
"""Helper to save and load NDArray objects"""
warnings.warn(
'The `SerializeNDArray` class will soon be removed. '
'Instead, use ArviZ to save/load traces.',
DeprecationWarning,
)
self.directory = directory
self.metadata_path = os.path.join(self.directory, self.metadata_file)
self.samples_path = os.path.join(self.directory, self.samples_file)
Expand Down Expand Up @@ -367,6 +384,7 @@ def _slice_as_ndarray(strace, idx):

return sliced


def point_list_to_multitrace(point_list: List[Dict[str, np.ndarray]], model: Optional[Model]=None) -> MultiTrace:
'''transform point list into MultiTrace'''
_model = modelcontext(model)
Expand Down
13 changes: 13 additions & 0 deletions pymc3/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""
import numpy as np
import sqlite3
import warnings

from ..backends import base, ndarray
from . import tracetab as ttab
Expand Down Expand Up @@ -89,6 +90,12 @@ class SQLite(base.BaseTrace):
"""

def __init__(self, name, model=None, vars=None, test_point=None):
warnings.warn(
'The `SQLite` backend will soon be removed. '
'Please switch to a different backend. '
'If you have good reasons for using the SQLite backend, file an issue and tell us about them.',
DeprecationWarning,
)
super().__init__(name, model, vars, test_point)
self._var_cols = {}
self.var_inserts = {} # varname -> insert statement
Expand Down Expand Up @@ -322,6 +329,12 @@ def load(name, model=None):
-------
A MultiTrace instance
"""
warnings.warn(
'The `sqlite.load` function will soon be removed. '
'Please use ArviZ to save traces. '
'If you have good reasons for using the `load` function, file an issue and tell us about them. ',
DeprecationWarning,
)
db = _SQLiteDB(name)
db.connect()
varnames = _get_table_list(db.cursor)
Expand Down
19 changes: 19 additions & 0 deletions pymc3/backends/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import os
import re
import pandas as pd
import warnings

from ..backends import base, ndarray
from . import tracetab as ttab
Expand All @@ -57,6 +58,12 @@ class Text(base.BaseTrace):
"""

def __init__(self, name, model=None, vars=None, test_point=None):
warnings.warn(
'The `Text` backend will soon be removed. '
'Please switch to a different backend. '
'If you have good reasons for using the Text backend, file an issue and tell us about them. ',
DeprecationWarning,
)
if not os.path.exists(name):
os.mkdir(name)
super().__init__(name, model, vars, test_point)
Expand Down Expand Up @@ -185,6 +192,12 @@ def load(name, model=None):
-------
A MultiTrace instance
"""
warnings.warn(
'The `load` function will soon be removed. '
'Please use ArviZ to save traces. '
'If you have good reasons for using the `load` function, file an issue and tell us about them. ',
DeprecationWarning,
)
files = glob(os.path.join(name, 'chain-*.csv'))

if len(files) == 0:
Expand Down Expand Up @@ -224,6 +237,12 @@ def dump(name, trace, chains=None):
chains: list
Chains to dump. If None, all chains are dumped.
"""
warnings.warn(
'The `dump` function will soon be removed. '
'Please use ArviZ to save traces. '
'If you have good reasons for using the `dump` function, file an issue and tell us about them. ',
DeprecationWarning,
)
if not os.path.exists(name):
os.mkdir(name)
if chains is None:
Expand Down
7 changes: 7 additions & 0 deletions pymc3/backends/tracetab.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import pandas as pd
import warnings

from ..util import get_default_varnames

Expand All @@ -39,6 +40,12 @@ def trace_to_dataframe(trace, chains=None, varnames=None, include_transformed=Fa
If true transformed variables will be included in the resulting
DataFrame.
"""
warnings.warn(
'The `trace_to_dataframe` function will soon be removed. '
'Please use ArviZ to save traces. '
'If you have good reasons for using the `trace_to_dataframe` function, file an issue and tell us about them. ',
DeprecationWarning,
)
var_shapes = trace._straces[0].var_shapes

if varnames is None:
Expand Down