Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART: add partial dependence plots and individual conditional expectation plots #5091

Merged
merged 7 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
- `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- `pm.DensityDist` now accepts distribution parameters as positional arguments. Passing them as a dictionary in the `observed` keyword argument is no longer supported and will raise an error (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- The signature of the `logp` and `random` functions that can be passed into a `pm.DensityDist` has been changed (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- Generalize BART. A BART variable can be combined with other random variables. The `inv_link` argument has been removed (see [4914](https://github.com/pymc-devs/pymc3/pull/4914)).
- Move BART to its own module (see [5058](https://github.com/pymc-devs/pymc3/pull/5058)).
- ...

### New Features
Expand All @@ -32,6 +34,8 @@
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- BART: add linear response, increase number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
- BART: add partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
- ...

### Maintenance
Expand Down
1 change: 1 addition & 0 deletions pymc/bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@

from pymc.bart.bart import BART
from pymc.bart.pgbart import PGBART
from pymc.bart.utils import plot_dependence, predict

__all__ = ["BART", "PGBART"]
32 changes: 1 addition & 31 deletions pymc/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,35 +39,7 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):

@classmethod
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
size = kwargs.pop("size", None)
X_new = kwargs.pop("X_new", None)
all_trees = cls.all_trees
if all_trees:

if size is None:
size = ()
elif isinstance(size, int):
size = [size]

flatten_size = 1
for s in size:
flatten_size *= s

idx = rng.randint(len(all_trees), size=flatten_size)

if X_new is None:
pred = np.zeros((flatten_size, all_trees[0][0].num_observations))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += tree.predict_output()
else:
pred = np.zeros((flatten_size, X_new.shape[0]))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += np.array([tree.predict_out_of_sample(x, cls.m) for x in X_new])
return pred.reshape((*size, -1))
else:
return np.full_like(cls.Y, cls.Y.mean())
return np.full_like(cls.Y, cls.Y.mean())


bart = BARTRV()
Expand Down Expand Up @@ -115,15 +87,13 @@ def __new__(
**kwargs,
):

cls.all_trees = []
X, Y = preprocess_XY(X, Y)

bart_op = type(
f"BART_{name}",
(BARTRV,),
dict(
name="BART",
all_trees=cls.all_trees,
inplace=False,
initval=Y.mean(),
X=X,
Expand Down
28 changes: 11 additions & 17 deletions pymc/bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging

from copy import copy
from typing import Any, Dict, List, Tuple

import aesara
Expand Down Expand Up @@ -121,7 +122,7 @@ class PGBART(ArrayStepShared):
name = "bartsampler"
default_blocked = False
generates_stats = True
stats_dtypes = [{"variable_inclusion": np.ndarray}]
stats_dtypes = [{"variable_inclusion": np.ndarray, "bart_trees": np.ndarray}]

def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None):
_log.warning("BART is experimental. Use with caution.")
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
tree_id=0,
leaf_node_value=self.init_mean / self.m,
idx_data_points=np.arange(self.num_observations, dtype="int32"),
m=self.m,
)
self.mean = fast_mean()
self.linear_fit = fast_linear_fit()
Expand All @@ -169,8 +171,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo

self.tune = True
self.idx = 0
self.iter = 0
self.sum_trees = []
self.batch = batch

if self.batch == "auto":
Expand All @@ -193,12 +193,12 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
self.init_likelihood,
)
self.all_particles.append(p)
self.all_trees = np.array([p.tree for p in self.all_particles])
super().__init__(vars, shared)

def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
point_map_info = q.point_map_info
sum_trees_output = q.data

variable_inclusion = np.zeros(self.num_variates, dtype="int")

if self.idx == self.m:
Expand All @@ -212,7 +212,6 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
particles = self.init_particles(tree_id)
# Compute the sum of trees without the tree we are attempting to replace
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
self.idx += 1

# The old tree is not growing so we update the weights only once.
self.update_weight(particles[0])
Expand Down Expand Up @@ -258,6 +257,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
# Get the new tree and update
new_particle = np.random.choice(particles, p=normalized_weights)
new_tree = new_particle.tree
self.all_trees[self.idx] = new_tree
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
self.all_particles[tree_id] = new_particle
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()
Expand All @@ -268,17 +268,11 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.ssv = SampleSplittingVariable(self.split_prior)
else:
self.batch = max(1, int(self.m * 0.2))
self.iter += 1
self.sum_trees.append(new_tree)
if not self.iter % self.m:
# XXX update the all_trees variable in BARTRV to be used in the rng_fn method
# this fails for chains > 1 as the variable is not shared between proccesses
self.bart.all_trees.append(self.sum_trees)
self.sum_trees = []
for index in new_particle.used_variates:
variable_inclusion[index] += 1
self.idx += 1

stats = {"variable_inclusion": variable_inclusion}
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
sum_trees_output = RaveledVars(sum_trees_output, point_map_info)
return sum_trees_output, [stats]

Expand Down Expand Up @@ -526,11 +520,11 @@ def linear_fit(X, Y):
xbar = np.sum(X) / n
ybar = np.sum(Y) / n

if np.all(X == xbar):
b = 0
den = X @ X - n * xbar ** 2
if den > 1e-10:
b = (X @ Y - n * xbar * ybar) / den
else:
b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)

b = 0
a = ybar - b * xbar
Y_fit = a + b * X
return Y_fit, [a, b, 0]
Expand Down
18 changes: 10 additions & 8 deletions pymc/bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,23 @@ class Tree:
Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART.
num_observations : int
Number of observations used to fit BART.

m : int
Number of trees

Parameters
----------
tree_id : int, optional
num_observations : int, optional
"""

def __init__(self, tree_id=0, num_observations=0):
def __init__(self, tree_id=0, num_observations=0, m=0):
self.tree_structure = {}
self.num_nodes = 0
self.idx_leaf_nodes = []
self.idx_prunable_split_nodes = []
self.tree_id = tree_id
self.num_observations = num_observations
self.m = m

def __getitem__(self, index):
return self.get_node(index)
Expand Down Expand Up @@ -94,16 +96,14 @@ def predict_output(self):

return output.astype(aesara.config.floatX)

def predict_out_of_sample(self, X, m):
def predict_out_of_sample(self, X):
"""
Predict output of tree for an unobserved point x.

Parameters
----------
X : numpy array
Unobserved point
m : int
Number of trees

Returns
-------
Expand All @@ -116,7 +116,7 @@ def predict_out_of_sample(self, X, m):
return leaf_node.value
else:
x = X[split_variable].item()
y_x = (linear_params[0] + linear_params[1] * x) / m
y_x = (linear_params[0] + linear_params[1] * x) / self.m
return y_x + linear_params[2]

def _traverse_tree(self, x, node_index=0, split_variable=None):
Expand Down Expand Up @@ -170,20 +170,22 @@ def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_no
self.idx_prunable_split_nodes.remove(parent_index)

@staticmethod
def init_tree(tree_id, leaf_node_value, idx_data_points):
def init_tree(tree_id, leaf_node_value, idx_data_points, m):
"""

Parameters
----------
tree_id
leaf_node_value
idx_data_points
m : int
number of trees in BART

Returns
-------

"""
new_tree = Tree(tree_id, len(idx_data_points))
new_tree = Tree(tree_id, len(idx_data_points), m)
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
return new_tree

Expand Down
Loading