Skip to content

Commit

Permalink
Introduce pm.Data(..., mutable) kwarg
Browse files Browse the repository at this point in the history
By passing `pm.Data(mutable=False)` one can create a `TensorConstant` instead of a `SharedVariable`.
Data variables with known, fixed shape can enhance performance and compatibility in some situations.
`pm.ConstantData` or `pm.MutableData` wrappers are provided as alternative syntax.

This is the basis for solving pymc-devs#4441.
  • Loading branch information
michaelosthege committed Dec 30, 2021
1 parent cd54dfd commit 1428fc3
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 57 deletions.
158 changes: 105 additions & 53 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
import urllib.request

from copy import copy
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence, Union

import aesara
import aesara.tensor as at
import numpy as np
import pandas as pd

from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply
from aesara.tensor.type import TensorType
from aesara.tensor.var import TensorVariable
from aesara.tensor.var import TensorConstant, TensorVariable

import pymc as pm

Expand All @@ -40,6 +41,8 @@
"Minibatch",
"align_minibatches",
"Data",
"ConstantData",
"MutableData",
]
BASE_URL = "https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/data/{filename}"

Expand Down Expand Up @@ -502,9 +505,64 @@ def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict
return coords


class Data:
"""Data container class that wraps :func:`aesara.shared` and lets
the model be aware of its inputs and outputs.
def ConstantData(
name: str,
value,
*,
dims: Optional[Sequence[str]] = None,
export_index_as_coords=False,
**kwargs,
) -> TensorConstant:
"""Alias for ``pm.Data(..., mutable=False)``.
Registers the ``value`` as a ``TensorConstant`` with the model.
"""
return Data(
name,
value,
dims=dims,
export_index_as_coords=export_index_as_coords,
mutable=False,
**kwargs,
)


def MutableData(
name: str,
value,
*,
dims: Optional[Sequence[str]] = None,
export_index_as_coords=False,
**kwargs,
) -> SharedVariable:
"""Alias for ``pm.Data(..., mutable=True)``.
Registers the ``value`` as a ``SharedVariable`` with the model.
"""
return Data(
name,
value,
dims=dims,
export_index_as_coords=export_index_as_coords,
mutable=True,
**kwargs,
)


def Data(
name: str,
value,
*,
dims: Optional[Sequence[str]] = None,
export_index_as_coords=False,
mutable: bool = True,
**kwargs,
) -> Union[SharedVariable, TensorConstant]:
"""Data container that registers a data variable with the model.
Depending on the ``mutable`` setting (default: True), the variable
is registered as a ``SharedVariable``, enabling it to be altered
in value and shape, but NOT in dimensionality using ``pm.set_data()``.
Parameters
----------
Expand Down Expand Up @@ -552,52 +610,46 @@ class Data:
For more information, take a look at this example notebook
https://docs.pymc.io/notebooks/data_container.html
"""
if isinstance(value, list):
value = np.array(value)

def __new__(
self,
name: str,
value,
*,
dims: Optional[Sequence[str]] = None,
export_index_as_coords=False,
**kwargs,
):
if isinstance(value, list):
value = np.array(value)

# Add data container to the named variables of the model.
try:
model = pm.Model.get_context()
except TypeError:
raise TypeError(
"No model on context stack, which is needed to instantiate a data container. "
"Add variable inside a 'with model:' block."
)
name = model.name_for(name)

# `pandas_to_array` takes care of parameter `value` and
# transforms it to something digestible for pymc
shared_object = aesara.shared(pandas_to_array(value), name, **kwargs)

if isinstance(dims, str):
dims = (dims,)
if not (dims is None or len(dims) == shared_object.ndim):
raise pm.exceptions.ShapeError(
"Length of `dims` must match the dimensions of the dataset.",
actual=len(dims),
expected=shared_object.ndim,
)

coords = determine_coords(model, value, dims)

if export_index_as_coords:
model.add_coords(coords)
elif dims:
# Register new dimension lengths
for d, dname in enumerate(dims):
if not dname in model.dim_lengths:
model.add_coord(dname, values=None, length=shared_object.shape[d])

model.add_random_variable(shared_object, dims=dims)

return shared_object
# Add data container to the named variables of the model.
try:
model = pm.Model.get_context()
except TypeError:
raise TypeError(
"No model on context stack, which is needed to instantiate a data container. "
"Add variable inside a 'with model:' block."
)
name = model.name_for(name)

# `pandas_to_array` takes care of parameter `value` and
# transforms it to something digestible for Aesara.
arr = pandas_to_array(value)
if mutable:
x = aesara.shared(arr, name, **kwargs)
else:
x = at.as_tensor_variable(arr, name, **kwargs)

if isinstance(dims, str):
dims = (dims,)
if not (dims is None or len(dims) == x.ndim):
raise pm.exceptions.ShapeError(
"Length of `dims` must match the dimensions of the dataset.",
actual=len(dims),
expected=x.ndim,
)

coords = determine_coords(model, value, dims)

if export_index_as_coords:
model.add_coords(coords)
elif dims:
# Register new dimension lengths
for d, dname in enumerate(dims):
if not dname in model.dim_lengths:
model.add_coord(dname, values=None, length=x.shape[d])

model.add_random_variable(x, dims=dims)

return x
4 changes: 2 additions & 2 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import walk
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable
from aesara.tensor.var import TensorConstant, TensorVariable

import pymc as pm

Expand Down Expand Up @@ -133,7 +133,7 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
shape = "octagon"
style = "filled"
label = f"{var_name}\n~\nPotential"
elif isinstance(v, SharedVariable):
elif isinstance(v, (SharedVariable, TensorConstant)):
shape = "box"
style = "rounded, filled"
label = f"{var_name}\n~\nData"
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_sample(self):

def test_sample_posterior_predictive_after_set_data(self):
with pm.Model() as model:
x = pm.Data("x", [1.0, 2.0, 3.0])
y = pm.Data("y", [1.0, 2.0, 3.0])
x = pm.MutableData("x", [1.0, 2.0, 3.0])
y = pm.ConstantData("y", [1.0, 2.0, 3.0])
beta = pm.Normal("beta", 0, 10.0)
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
trace = pm.sample(
Expand Down

0 comments on commit 1428fc3

Please sign in to comment.