Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model API cleanup #6309

Merged
merged 6 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,11 @@ def __init__(
}

self.dims = {} if dims is None else dims
if hasattr(self.model, "RV_dims"):
model_dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in self.model.RV_dims.items()
}
self.dims = {**model_dims, **self.dims}
model_dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in self.model.named_vars_to_dims.items()
}
self.dims = {**model_dims, **self.dims}
if sample_dims is None:
sample_dims = ["chain", "draw"]
self.sample_dims = sample_dims
Expand Down
2 changes: 1 addition & 1 deletion pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,6 @@ def Data(
length=xshape[d],
)

model.add_random_variable(x, dims=dims)
model.add_named_variable(x, dims=dims)

return x
24 changes: 3 additions & 21 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,11 @@ def convert_str_to_rv_dict(
return initvals


def filter_rvs_to_jitter(step) -> Set[TensorVariable]:
"""Find the set of RVs for which the responsible step methods ask for
the addition of jitter to the initial point.

Parameters
----------
step : BlockedStep or CompoundStep
One or many step methods that were assigned model variables.

Returns
-------
rvs_to_jitter : set
The random variables for which jitter should be added.
"""
# TODO: implement this
return set()


def make_initial_point_fns_per_chain(
*,
model,
overrides: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
jitter_rvs: Set[TensorVariable],
jitter_rvs: Optional[Set[TensorVariable]] = None,
chains: int,
) -> List[Callable]:
"""Create an initial point function for each chain, as defined by initvals
Expand All @@ -87,7 +69,7 @@ def make_initial_point_fns_per_chain(
overrides : optional, list or dict
Initial value strategy overrides that should take precedence over the defaults from the model.
A sequence of None or dicts will be treated as chain-wise strategies and must have the same length as `seeds`.
jitter_rvs : set
jitter_rvs : set, optional
Random variable tensors for which U(-1, 1) jitter shall be applied.
(To the transformed space if applicable.)

Expand Down Expand Up @@ -151,7 +133,7 @@ def make_initial_point_fn(

sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
**model.initial_values,
**model.rvs_to_initial_values,
**sdict_overrides,
}

Expand Down
90 changes: 49 additions & 41 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,35 +550,33 @@ def __init__(
self.name = self._validate_name(name)
self.check_bounds = check_bounds

self._initial_values: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]] = {}

if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
self.named_vars_to_dims = treedict(parent=self.parent.named_vars_to_dims)
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
self.free_RVs = treelist(parent=self.parent.free_RVs)
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
self.deterministics = treelist(parent=self.parent.deterministics)
self.potentials = treelist(parent=self.parent.potentials)
self._coords = self.parent._coords
self._RV_dims = treedict(parent=self.parent._RV_dims)
self._dim_lengths = self.parent._dim_lengths
else:
self.named_vars = treedict()
self.named_vars_to_dims = treedict()
self.values_to_rvs = treedict()
self.rvs_to_values = treedict()
self.rvs_to_transforms = treedict()
self.rvs_to_total_sizes = treedict()
self.rvs_to_initial_values = treedict()
self.free_RVs = treelist()
self.observed_RVs = treelist()
self.auto_deterministics = treelist()
self.deterministics = treelist()
self.potentials = treelist()
self._coords = {}
self._RV_dims = treedict()
self._dim_lengths = {}
self.add_coords(coords)

Expand Down Expand Up @@ -972,7 +970,11 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:

Entries in the tuples may be ``None``, if the RV dimension was not given a name.
"""
return self._RV_dims
warnings.warn(
"Model.RV_dims is deprecated. User Model.named_vars_to_dims instead.",
FutureWarning,
)
return self.named_vars_to_dims

@property
def coords(self) -> Dict[str, Union[Tuple, None]]:
Expand Down Expand Up @@ -1124,15 +1126,18 @@ def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Vari
Keys are the random variables (as returned by e.g. ``pm.Uniform()``) and
values are the numeric/symbolic initial values, strings denoting the strategy to get them, or None.
"""
return self._initial_values
warnings.warn(
"Model.initial_values is deprecated. Use Model.rvs_to_initial_values instead."
)
return self.rvs_to_initial_values

def set_initval(self, rv_var, initval):
"""Sets an initial value (strategy) for a random variable."""
if initval is not None and not isinstance(initval, (Variable, str)):
# Convert scalars or array-like inputs to ndarrays
initval = rv_var.type.filter(initval)

self.initial_values[rv_var] = initval
self.rvs_to_initial_values[rv_var] = initval

def set_data(
self,
Expand Down Expand Up @@ -1167,7 +1172,7 @@ def set_data(
if isinstance(values, list):
values = np.array(values)
values = convert_observed_data(values)
dims = self.RV_dims.get(name, None) or ()
dims = self.named_vars_to_dims.get(name, None) or ()
coords = coords or {}

if values.ndim != shared_object.ndim:
Expand Down Expand Up @@ -1257,7 +1262,7 @@ def set_data(
shared_object.set_value(values)

def register_rv(
self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET, initval=None
self, rv_var, name, observed=None, total_size=None, dims=None, transform=UNSET, initval=None
):
"""Register an (un)observed random variable with the model.

Expand All @@ -1266,9 +1271,8 @@ def register_rv(
rv_var: TensorVariable
name: str
Intended name for the model variable.
data: array_like (optional)
If data is provided, the variable is observed. If None,
the variable is unobserved.
observed: array_like (optional)
Data values for observed variables.
total_size: scalar
upscales logp of variable with ``coef = total_size/var.shape[0]``
dims: tuple
Expand All @@ -1295,31 +1299,31 @@ def register_rv(
if dname not in self.dim_lengths:
self.add_coord(dname, values=None, length=rv_var.shape[d])

if data is None:
if observed is None:
self.free_RVs.append(rv_var)
self.create_value_var(rv_var, transform)
self.add_random_variable(rv_var, dims)
self.add_named_variable(rv_var, dims)
self.set_initval(rv_var, initval)
else:
if (
isinstance(data, Variable)
and not isinstance(data, (GenTensorVariable, Minibatch))
and data.owner is not None
isinstance(observed, Variable)
and not isinstance(observed, (GenTensorVariable, Minibatch))
and observed.owner is not None
# The only Aesara operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
and not (
isinstance(data.owner.op, Elemwise)
and isinstance(data.owner.op.scalar_op, Cast)
isinstance(observed.owner.op, Elemwise)
and isinstance(observed.owner.op.scalar_op, Cast)
)
):
raise TypeError(
"Variables that depend on other nodes cannot be used for observed data."
f"The data variable was: {data}"
f"The data variable was: {observed}"
)

# `rv_var` is potentially changed by `make_obs_var`,
# for example into a new graph for imputation of missing data.
rv_var = self.make_obs_var(rv_var, data, dims, transform)
rv_var = self.make_obs_var(rv_var, observed, dims, transform)

return rv_var

Expand Down Expand Up @@ -1425,14 +1429,15 @@ def make_obs_var(
observed_rv_var.tag.observations = nonmissing_data

self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
self.add_random_variable(observed_rv_var)
self.add_named_variable(observed_rv_var)
self.observed_RVs.append(observed_rv_var)

# Create deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
rv_var = at.zeros(data.shape)
rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
rv_var = Deterministic(name, rv_var, self, dims, auto=True)
rv_var = Deterministic(name, rv_var, self, dims)

else:
if sps.issparse(data):
Expand All @@ -1441,7 +1446,7 @@ def make_obs_var(
data = at.as_tensor_variable(data, name=name)
rv_var.tag.observations = data
self.create_value_var(rv_var, transform=None, value_var=data)
self.add_random_variable(rv_var, dims)
self.add_named_variable(rv_var, dims)
self.observed_RVs.append(rv_var)

return rv_var
Expand Down Expand Up @@ -1481,15 +1486,18 @@ def create_value_var(
value_var.tag.test_value = transform.forward(
value_var, *rv_var.owner.inputs
).tag.test_value
self.named_vars[value_var.name] = value_var
self.rvs_to_transforms[rv_var] = transform
self.rvs_to_values[rv_var] = value_var
self.values_to_rvs[value_var] = rv_var

return value_var

def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = None):
"""Add a random variable to the named variables of the model."""
def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = None):
"""Add a random graph variable to the named variables of the model.

This can include several types of variables such basic_RVs, Data, Deterministics,
and Potentials.
"""
if self.named_vars.tree_contains(var.name):
raise ValueError(f"Variable name {var.name} already exists.")

Expand All @@ -1501,7 +1509,7 @@ def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]]
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
if any(var.name == dim for dim in dims):
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
self._RV_dims[var.name] = dims
self.named_vars_to_dims[var.name] = dims

self.named_vars[var.name] = var
if not hasattr(self, self.name_of(var.name)):
Expand Down Expand Up @@ -1705,14 +1713,17 @@ def check_start_vals(self, start):
None
"""
start_points = [start] if isinstance(start, dict) else start

value_names_to_dtypes = {value.name: value.dtype for value in self.value_vars}
value_names_set = set(value_names_to_dtypes.keys())
for elem in start_points:

for k, v in elem.items():
elem[k] = np.asarray(v, dtype=self[k].dtype)
elem[k] = np.asarray(v, dtype=value_names_to_dtypes[k])

if not set(elem.keys()).issubset(self.named_vars.keys()):
extra_keys = ", ".join(set(elem.keys()) - set(self.named_vars.keys()))
valid_keys = ", ".join(self.named_vars.keys())
if not set(elem.keys()).issubset(value_names_set):
extra_keys = ", ".join(set(elem.keys()) - value_names_set)
valid_keys = ", ".join(value_names_set)
raise KeyError(
"Some start parameters do not appear in the model!\n"
f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
Expand Down Expand Up @@ -1899,7 +1910,7 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
}


def Deterministic(name, var, model=None, dims=None, auto=False):
def Deterministic(name, var, model=None, dims=None):
"""Create a named deterministic variable.

Deterministic nodes are only deterministic given all of their inputs, i.e.
Expand Down Expand Up @@ -1962,11 +1973,8 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
"""
model = modelcontext(model)
var = var.copy(model.name_for(name))
if auto:
model.auto_deterministics.append(var)
else:
model.deterministics.append(var)
model.add_random_variable(var, dims)
model.deterministics.append(var)
model.add_named_variable(var, dims)

from pymc.printing import str_for_potential_or_deterministic

Expand Down Expand Up @@ -1998,7 +2006,7 @@ def Potential(name, var, model=None):
model = modelcontext(model)
var.name = model.name_for(name)
model.potentials.append(var)
model.add_random_variable(var)
model.add_named_variable(var)

from pymc.printing import str_for_potential_or_deterministic

Expand Down
4 changes: 2 additions & 2 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,

for var_name in self.vars_to_plot(var_names):
v = self.model[var_name]
if var_name in self.model.RV_dims:
if var_name in self.model.named_vars_to_dims:
plate_label = " x ".join(
f"{d} ({self._eval(self.model.dim_lengths[d])})"
for d in self.model.RV_dims[var_name]
for d in self.model.named_vars_to_dims[var_name]
)
else:
plate_label = " x ".join(map(str, self._eval(v.shape)))
Expand Down
Loading