From e51b9d3d7dbfef713ec39414c6ae7671ba4aa817 Mon Sep 17 00:00:00 2001 From: Osvaldo Martin Date: Sat, 14 Nov 2020 14:46:27 -0300 Subject: [PATCH] Add Bayesian Additive Regression Trees (BARTs) (#4183) * update from master * black * minor fix * clean code * blackify * fix error residuals * use a low number of max_stages for the first iteration, remove not necessary errors * use Rockova prior, refactor prior leaf prob computaion * clean code add docstring * reduce code * speed-up by fitting a subset of trees per step * choose max * improve docstrings * refactor and clean code * clean docstrings * add tests and minor fixes. Co-authored-by: aloctavodia Co-authored-by: jmloyola * remove space. Co-authored-by: aloctavodia Co-authored-by: jmloyola * add variable importance report * use ValueError * wip return mean and std variable importance * update variable importance report * update release notes, remove vi hdi report * test variable importance * fix test Co-authored-by: jmloyola --- RELEASE-NOTES.md | 4 +- pymc3/distributions/__init__.py | 4 + pymc3/distributions/bart.py | 252 ++++++++++++++++++++++++++++ pymc3/distributions/tree.py | 182 +++++++++++++++++++++ pymc3/model.py | 8 +- pymc3/sampling.py | 6 + pymc3/step_methods/__init__.py | 2 + pymc3/step_methods/hmc/nuts.py | 4 +- pymc3/step_methods/pgbart.py | 280 ++++++++++++++++++++++++++++++++ pymc3/tests/test_sampling.py | 14 ++ 10 files changed, 750 insertions(+), 6 deletions(-) create mode 100644 pymc3/distributions/bart.py create mode 100644 pymc3/distributions/tree.py create mode 100644 pymc3/step_methods/pgbart.py diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 9199eaa9313..f238735f754 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -16,11 +16,13 @@ - `sample_posterior_predictive_w` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#4042](https://github.com/pymc-devs/pymc3/pull/4042)) - Added `pymc3.gp.cov.Circular` kernel for Gaussian Processes on circular domains, e.g. the unit circle (see [#4082](https://github.com/pymc-devs/pymc3/pull/4082)). - Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926)) -- Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/3926)) +- Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/4115)) - Add alternative parametrization to NegativeBinomial distribution in terms of n and p (see [#4126](https://github.com/pymc-devs/pymc3/issues/4126)) +- Add Bayesian Additive Regression Trees (BARTs) [#4183](https://github.com/pymc-devs/pymc3/pull/4183)) - Added a new `MixtureSameFamily` distribution to handle mixtures of arbitrary dimensions in vectorized form (see [#4185](https://github.com/pymc-devs/pymc3/issues/4185)). + ## PyMC3 3.9.3 (11 August 2020) ### Maintenance diff --git a/pymc3/distributions/__init__.py b/pymc3/distributions/__init__.py index fce98766f05..2a671d3293b 100644 --- a/pymc3/distributions/__init__.py +++ b/pymc3/distributions/__init__.py @@ -99,8 +99,11 @@ from .timeseries import MvGaussianRandomWalk from .timeseries import MvStudentTRandomWalk +from .bart import BART + from .bound import Bound + __all__ = [ "Uniform", "Flat", @@ -177,4 +180,5 @@ "Moyal", "Simulator", "fast_sample_posterior_predictive", + "BART", ] diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py new file mode 100644 index 00000000000..bba94f8d22b --- /dev/null +++ b/pymc3/distributions/bart.py @@ -0,0 +1,252 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from .distribution import NoDistribution +from .tree import Tree, SplitNode, LeafNode + +__all__ = ["BART"] + + +class BaseBART(NoDistribution): + def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs): + self.X = X + self.Y = Y + super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs) + + if self.X.ndim != 2: + raise ValueError("The design matrix X must have two dimensions") + + if self.Y.ndim != 1: + raise ValueError("The response matrix Y must have one dimension") + if self.X.shape[0] != self.Y.shape[0]: + raise ValueError( + "The design matrix X and the response matrix Y must have the same number of elements" + ) + if not isinstance(m, int): + raise ValueError("The number of trees m type must be int") + if m < 1: + raise ValueError("The number of trees m must be greater than zero") + + if alpha <= 0 or 1 <= alpha: + raise ValueError( + "The value for the alpha parameter for the tree structure " + "must be in the interval (0, 1)" + ) + + self.num_observations = X.shape[0] + self.num_variates = X.shape[1] + self.m = m + self.alpha = alpha + self.trees = self.init_list_of_trees() + self.mean = fast_mean() + self.prior_prob_leaf_node = compute_prior_probability(alpha) + + def init_list_of_trees(self): + initial_value_leaf_nodes = self.Y.mean() / self.m + initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32") + list_of_trees = [] + for i in range(self.m): + new_tree = Tree.init_tree( + tree_id=i, + leaf_node_value=initial_value_leaf_nodes, + idx_data_points=initial_idx_data_points_leaf_nodes, + ) + list_of_trees.append(new_tree) + # Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J. + # bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013 + # The sum_trees_output will contain the sum of the predicted output for all trees. + # When R_j is needed we subtract the current predicted output for tree T_j. + self.sum_trees_output = np.full_like(self.Y, self.Y.mean()) + + return list_of_trees + + def __iter__(self): + return iter(self.trees) + + def __repr_latex(self): + raise NotImplementedError + + def get_available_predictors(self, idx_data_points_split_node): + possible_splitting_variables = [] + for j in range(self.num_variates): + x_j = self.X[idx_data_points_split_node, j] + x_j = x_j[~np.isnan(x_j)] + for i in range(1, len(x_j)): + if x_j[i - 1] != x_j[i]: + possible_splitting_variables.append(j) + break + return possible_splitting_variables + + def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable): + x_j = self.X[idx_data_points_split_node, idx_split_variable] + x_j = x_j[~np.isnan(x_j)] + values, indices = np.unique(x_j, return_index=True) + # The last value is not consider since if we choose it as the value of + # the splitting rule assignment, it would leave the right subtree empty. + return values[:-1], indices[:-1] + + def grow_tree(self, tree, index_leaf_node): + # This can be unsuccessful when there are not available predictors + current_node = tree.get_node(index_leaf_node) + + available_predictors = self.get_available_predictors(current_node.idx_data_points) + + if not available_predictors: + return False, None + + index_selected_predictor = discrete_uniform_sampler(len(available_predictors)) + selected_predictor = available_predictors[index_selected_predictor] + available_splitting_rules, _ = self.get_available_splitting_rules( + current_node.idx_data_points, selected_predictor + ) + index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules)) + selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule] + new_split_node = SplitNode( + index=index_leaf_node, + idx_split_variable=selected_predictor, + split_value=selected_splitting_rule, + ) + + left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points( + new_split_node, current_node.idx_data_points + ) + + left_node_value = self.draw_leaf_value(left_node_idx_data_points) + right_node_value = self.draw_leaf_value(right_node_idx_data_points) + + new_left_node = LeafNode( + index=current_node.get_idx_left_child(), + value=left_node_value, + idx_data_points=left_node_idx_data_points, + ) + new_right_node = LeafNode( + index=current_node.get_idx_right_child(), + value=right_node_value, + idx_data_points=right_node_idx_data_points, + ) + tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) + + return True, index_selected_predictor + + def get_new_idx_data_points(self, current_split_node, idx_data_points): + idx_split_variable = current_split_node.idx_split_variable + split_value = current_split_node.split_value + + left_idx = self.X[idx_data_points, idx_split_variable] <= split_value + left_node_idx_data_points = idx_data_points[left_idx] + right_node_idx_data_points = idx_data_points[~left_idx] + + return left_node_idx_data_points, right_node_idx_data_points + + def get_residuals(self): + """Compute the residuals.""" + R_j = self.Y - self.sum_trees_output + return R_j + + def get_residuals_loo(self, tree): + """Compute the residuals without leaving the passed tree out.""" + R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations)) + return R_j + + def draw_leaf_value(self, idx_data_points): + """ Draw the residual mean.""" + R_j = self.get_residuals()[idx_data_points] + draw = self.mean(R_j) + return draw + + +def compute_prior_probability(alpha): + """ + Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). + Taken from equation 19 in [Rockova2018]. + + Parameters + ---------- + alpha : float + + Returns + ------- + list with probabilities for leaf nodes + + References + ---------- + .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. + arXiv, `link `__ + """ + prior_leaf_prob = [0] + depth = 1 + while prior_leaf_prob[-1] < 1: + prior_leaf_prob.append(1 - alpha ** depth) + depth += 1 + return prior_leaf_prob + + +def fast_mean(): + """If available use Numba to speed up the computation of the mean.""" + try: + from numba import jit + except ImportError: + return np.mean + + @jit + def mean(a): + count = a.shape[0] + suma = 0 + for i in range(count): + suma += a[i] + return suma / count + + return mean + + +def discrete_uniform_sampler(upper_value): + """Draw from the uniform distribution with bounds [0, upper_value).""" + return int(np.random.random() * upper_value) + + +class BART(BaseBART): + """ + BART distribution. + + Distribution representing a sum over trees + + Parameters + ---------- + X : + The design matrix. + Y : + The response vector. + m : int + Number of trees + alpha : float + Control the prior probability over the depth of the trees. Must be in the interval (0, 1), + altought it is recomenned to be in the interval (0, 0.5]. + """ + + def __init__(self, X, Y, m=200, alpha=0.25): + super().__init__(X, Y, m, alpha) + + def _str_repr(self, name=None, dist=None, formatting="plain"): + if dist is None: + dist = self + X = (type(self.X),) + Y = (type(self.Y),) + alpha = self.alpha + m = self.m + + if formatting == "latex": + return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$" + else: + return f"{name} ~ BART(alpha = {alpha}, m = {m})" diff --git a/pymc3/distributions/tree.py b/pymc3/distributions/tree.py new file mode 100644 index 00000000000..3b2c098e896 --- /dev/null +++ b/pymc3/distributions/tree.py @@ -0,0 +1,182 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math +from copy import deepcopy + + +class Tree: + """Full binary tree + A full binary tree is a tree where each node has exactly zero or two children. + This structure is used as the basic component of the Bayesian Additive Regression Tree (BART) + Attributes + ---------- + tree_structure : dict + A dictionary that represents the nodes stored in breadth-first order, based in the array method + for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). + The dictionary's keys are integers that represent the nodes position. + The dictionary's values are objects of type SplitNode or LeafNode that represent the nodes of the tree itself. + num_nodes : int + Total number of nodes. + idx_leaf_nodes : list + List with the index of the leaf nodes of the tree. + idx_prunable_split_nodes : list + List with the index of the prunable splitting nodes of the tree. A splitting node is prunable if both + its children are leaf nodes. + tree_id : int + Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART. + + Parameters + ---------- + tree_id : int, optional + """ + + def __init__(self, tree_id=0): + self.tree_structure = {} + self.num_nodes = 0 + self.idx_leaf_nodes = [] + self.idx_prunable_split_nodes = [] + self.tree_id = tree_id + + def __getitem__(self, index): + return self.get_node(index) + + def __setitem__(self, index, node): + self.set_node(index, node) + + def copy(self): + return deepcopy(self) + + def get_node(self, index): + return self.tree_structure[index] + + def set_node(self, index, node): + self.tree_structure[index] = node + self.num_nodes += 1 + if isinstance(node, LeafNode): + self.idx_leaf_nodes.append(index) + + def delete_node(self, index): + current_node = self.get_node(index) + if isinstance(current_node, LeafNode): + self.idx_leaf_nodes.remove(index) + del self.tree_structure[index] + self.num_nodes -= 1 + + def predict_output(self, num_observations): + output = np.zeros(num_observations) + for node_index in self.idx_leaf_nodes: + current_node = self.get_node(node_index) + output[current_node.idx_data_points] = current_node.value + return output + + def _traverse_tree(self, x, node_index=0): + """ + Traverse the tree starting from a particular node given an unobserved point. + + Parameters + ---------- + x : np.ndarray + node_index : int + + Returns + ------- + LeafNode + """ + current_node = self.get_node(node_index) + if isinstance(current_node, SplitNode): + if x is not np.NaN: + left_child = current_node.get_idx_left_child() + final_node = self._traverse_tree(x, left_child) + else: + right_child = current_node.get_idx_right_child() + final_node = self._traverse_tree(x, right_child) + else: + final_node = current_node + return final_node + + def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node): + """ + Grow the tree from a particular node. + + Parameters + ---------- + index_leaf_node : int + new_split_node : SplitNode + new_left_node : LeafNode + new_right_node : LeafNode + """ + current_node = self.get_node(index_leaf_node) + + self.delete_node(index_leaf_node) + self.set_node(index_leaf_node, new_split_node) + self.set_node(new_left_node.index, new_left_node) + self.set_node(new_right_node.index, new_right_node) + + # The new SplitNode is a prunable node since it has both children. + self.idx_prunable_split_nodes.append(index_leaf_node) + # If the parent of the node from which the tree is growing was a prunable node, + # remove from the list since one of its children is a SplitNode now + parent_index = current_node.get_idx_parent_node() + if parent_index in self.idx_prunable_split_nodes: + self.idx_prunable_split_nodes.remove(parent_index) + + @staticmethod + def init_tree(tree_id, leaf_node_value, idx_data_points): + """ + + Parameters + ---------- + tree_id + leaf_node_value + idx_data_points + + Returns + ------- + + """ + new_tree = Tree(tree_id) + new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points) + return new_tree + + +class BaseNode: + def __init__(self, index): + self.index = index + self.depth = int(math.floor(math.log(index + 1, 2))) + + def get_idx_parent_node(self): + return (self.index - 1) // 2 + + def get_idx_left_child(self): + return self.index * 2 + 1 + + def get_idx_right_child(self): + return self.get_idx_left_child() + 1 + + +class SplitNode(BaseNode): + def __init__(self, index, idx_split_variable, split_value): + super().__init__(index) + + self.idx_split_variable = idx_split_variable + self.split_value = split_value + + +class LeafNode(BaseNode): + def __init__(self, index, value, idx_data_points): + super().__init__(index) + self.value = value + self.idx_data_points = idx_data_points diff --git a/pymc3/model.py b/pymc3/model.py index 126c7537427..a4fbc462466 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -1997,10 +1997,12 @@ def as_iterargs(data): def all_continuous(vars): - """Check that vars not include discrete variables, excepting - ObservedRVs.""" + """Check that vars not include discrete variables or BART variables, excepting ObservedRVs.""" + vars_ = [var for var in vars if not isinstance(var, pm.model.ObservedRV)] - if any([var.dtype in pm.discrete_types for var in vars_]): + if any( + [(var.dtype in pm.discrete_types or isinstance(var.distribution, pm.BART)) for var in vars_] + ): return False else: return True diff --git a/pymc3/sampling.py b/pymc3/sampling.py index d6d40992931..128270e7f37 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -51,6 +51,7 @@ Slice, CompoundStep, arraystep, + PGBART, ) from .util import ( update_start_vals, @@ -90,6 +91,7 @@ BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis, + PGBART, ) ArrayLike = Union[np.ndarray, List[float]] @@ -604,6 +606,10 @@ def sample( trace.report._n_draws = n_draws trace.report._t_sampling = t_sampling + if "variable_inclusion" in trace.stat_names: + variable_inclusion = np.stack(trace.get_sampler_stats("variable_inclusion")).mean(0) + trace.report.variable_importance = variable_inclusion / variable_inclusion.sum() + n_chains = len(trace.chains) _log.info( f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations ' diff --git a/pymc3/step_methods/__init__.py b/pymc3/step_methods/__init__.py index d41a8d5af3d..a4afd48a3d1 100644 --- a/pymc3/step_methods/__init__.py +++ b/pymc3/step_methods/__init__.py @@ -34,3 +34,5 @@ from .slicer import Slice from .elliptical_slice import EllipticalSlice + +from .pgbart import PGBART diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index 782a350654c..7ffbf106d41 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -23,7 +23,7 @@ from pymc3.math import logbern, logdiffexp_numpy from pymc3.theanof import floatX from pymc3.vartypes import continuous_types - +from ...distributions import BART __all__ = ["NUTS"] @@ -196,7 +196,7 @@ def _hamiltonian_step(self, start, p0, step_size): @staticmethod def competence(var, has_grad): """Check how appropriate this class is for sampling a random variable.""" - if var.dtype in continuous_types and has_grad: + if var.dtype in continuous_types and has_grad and not isinstance(var.distribution, BART): return Competence.IDEAL return Competence.INCOMPATIBLE diff --git a/pymc3/step_methods/pgbart.py b/pymc3/step_methods/pgbart.py new file mode 100644 index 00000000000..1d6a503c4c7 --- /dev/null +++ b/pymc3/step_methods/pgbart.py @@ -0,0 +1,280 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +from theano import function as theano_function + +from .arraystep import ArrayStepShared, Competence +from ..distributions import BART +from ..distributions.tree import Tree +from ..model import modelcontext +from ..theanof import inputvars, make_shared_replacements, join_nonshared_inputs + +_log = logging.getLogger("pymc3") + + +class PGBART(ArrayStepShared): + """ + Particle Gibss BART sampling step + + Parameters + ---------- + vars: list + List of variables for sampler + num_particles : int + Number of particles for the conditional SMC sampler. Defaults to 10 + max_stages : int + Maximum number of iterations of the conditional SMC sampler. Defaults to 100. + chunk = int + Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees. + model: PyMC Model + Optional model for sampling step. Defaults to None (taken from context). + + References + ---------- + .. [Lakshminarayanan2015] Lakshminarayanan, B. and Roy, D.M. and Teh, Y. W., (2015), + Particle Gibbs for Bayesian Additive Regression Trees. + ArviX, `link `__ + """ + + name = "bartsampler" + default_blocked = False + generates_stats = True + stats_dtypes = [{"variable_inclusion": np.ndarray}] + + def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", model=None): + _log.warning("The BART model is experimental. Use with caution.") + model = modelcontext(model) + vars = inputvars(vars) + self.bart = vars[0].distribution + + self.tune = True + self.idx = 0 + if chunk == "auto": + self.chunk = max(1, int(self.bart.m * 0.1)) + self.num_particles = num_particles + self.log_num_particles = np.log(num_particles) + self.indices = list(range(1, num_particles)) + self.max_stages = max_stages + self.old_trees_particles_list = [] + for i in range(self.bart.m): + p = ParticleTree(self.bart.trees[i], self.bart.prior_prob_leaf_node) + self.old_trees_particles_list.append(p) + + shared = make_shared_replacements(vars, model) + self.likelihood_logp = logp([model.datalogpt], vars, shared) + super().__init__(vars, shared) + + def astep(self, _): + bart = self.bart + num_observations = bart.num_observations + variable_inclusion = np.zeros(bart.num_variates, dtype="int") + + # For the tunning phase we restrict max_stages to a low number, otherwise it is almost sure + # we will reach max_stages given that our first set of m trees is not good at all. + # Can set max_stages as a function of the number of variables/dimensions? + if self.tune: + max_stages = 5 + else: + max_stages = self.max_stages + + if self.idx == bart.m: + self.idx = 0 + + for idx in range(self.idx, self.idx + self.chunk): + if idx > bart.m: + break + self.idx += 1 + tree = bart.trees[idx] + R_j = bart.get_residuals_loo(tree) + # Generate an initial set of SMC particles + # at the end of the algorithm we return one of these particles as the new tree + particles = self.init_particles(tree.tree_id, R_j, bart.num_observations) + + for t in range(1, max_stages): + # Get old particle at stage t + particles[0] = self.get_old_tree_particle(tree.tree_id, t) + # sample each particle (try to grow each tree) + for c in range(1, self.num_particles): + particles[c].sample_tree_sequential(bart) + # Update weights. Since the prior is used as the proposal,the weights + # are updated additively as the ratio of the new and old log_likelihoods + for p_idx, p in enumerate(particles): + new_likelihood = self.likelihood_logp(p.tree.predict_output(num_observations)) + p.log_weight += new_likelihood - p.old_likelihood_logp + p.old_likelihood_logp = new_likelihood + + # Normalize weights + W, normalized_weights = self.normalize(particles) + + # Resample all but first particle + re_n_w = normalized_weights[1:] / normalized_weights[1:].sum() + new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w) + particles[1:] = particles[new_indices] + + # Set the new weights + w_t = W - self.log_num_particles + for p in particles: + p.log_weight = w_t + + # Check if particles can keep growing, otherwise stop iterating + non_available_nodes_for_expansion = np.ones(self.num_particles - 1) + for c in range(1, self.num_particles): + if len(particles[c].expansion_nodes) != 0: + non_available_nodes_for_expansion[c - 1] = 0 + if np.all(non_available_nodes_for_expansion): + break + + # Get the new tree and update + new_tree = np.random.choice(particles, p=normalized_weights) + self.old_trees_particles_list[tree.tree_id] = new_tree + bart.trees[idx] = new_tree.tree + new_prediction = new_tree.tree.predict_output(num_observations) + bart.sum_trees_output = bart.Y - R_j + new_prediction + + if not self.tune: + for index in new_tree.used_variates: + variable_inclusion[index] += 1 + + stats = {"variable_inclusion": variable_inclusion} + + return bart.sum_trees_output, [stats] + + @staticmethod + def competence(var, has_grad): + """ + PGBART is only suitable for BART distributions + """ + if isinstance(var.distribution, BART): + return Competence.IDEAL + return Competence.INCOMPATIBLE + + def normalize(self, particles): + """ + Use logsumexp trick to get W and softmax to get normalized_weights + """ + log_w = np.array([p.log_weight for p in particles]) + log_w_max = log_w.max() + log_w_ = log_w - log_w_max + w_ = np.exp(log_w_) + w_sum = w_.sum() + W = log_w_max + np.log(w_sum) + normalized_weights = w_ / w_sum + # stabilize weights to avoid assigning exactly zero probability to a particle + normalized_weights += 1e-12 + + return W, normalized_weights + + def get_old_tree_particle(self, tree_id, t): + old_tree_particle = self.old_trees_particles_list[tree_id] + old_tree_particle.set_particle_to_step(t) + return old_tree_particle + + def init_particles(self, tree_id, R_j, num_observations): + """ + Initialize particles + """ + # The first particle is from the tree we are trying to replace + prev_tree = self.get_old_tree_particle(tree_id, 0) + likelihood = self.likelihood_logp(prev_tree.tree.predict_output(num_observations)) + prev_tree.old_likelihood_logp = likelihood + prev_tree.log_weight = likelihood - self.log_num_particles + particles = [prev_tree] + + # The rest of the particles are identically initialized + initial_value_leaf_nodes = R_j.mean() + initial_idx_data_points_leaf_nodes = np.arange(num_observations, dtype="int32") + new_tree = Tree.init_tree( + tree_id=tree_id, + leaf_node_value=initial_value_leaf_nodes, + idx_data_points=initial_idx_data_points_leaf_nodes, + ) + likelihood_logp = self.likelihood_logp(new_tree.predict_output(num_observations)) + log_weight = likelihood_logp - self.log_num_particles + for i in range(1, self.num_particles): + particles.append( + ParticleTree(new_tree, self.bart.prior_prob_leaf_node, log_weight, likelihood_logp) + ) + + return np.array(particles) + + def resample(self, particles, weights): + """ + resample a set of particles given its weights + """ + particles = np.random.choice(particles, size=len(particles), p=weights) + return particles + + +class ParticleTree: + """ + Particle tree + """ + + def __init__(self, tree, prior_prob_leaf_node, log_weight=0, likelihood=0): + self.tree = tree.copy() # keeps the tree that we care at the moment + self.expansion_nodes = tree.idx_leaf_nodes.copy() # This should be the array [0] + self.tree_history = [self.tree] + self.expansion_nodes_history = [self.expansion_nodes] + self.log_weight = 0 + self.prior_prob_leaf_node = prior_prob_leaf_node + self.old_likelihood_logp = likelihood + self.used_variates = [] + + def sample_tree_sequential(self, bart): + if self.expansion_nodes: + index_leaf_node = self.expansion_nodes.pop(0) + # Probability that this node will remain a leaf node + prob_leaf = self.prior_prob_leaf_node[self.tree[index_leaf_node].depth] + + if prob_leaf < np.random.random(): + grow_successful, index_selected_predictor = bart.grow_tree( + self.tree, index_leaf_node + ) + if grow_successful: + # Add new leaf nodes indexes + new_indexes = self.tree.idx_leaf_nodes[-2:] + self.expansion_nodes.extend(new_indexes) + self.used_variates.append(index_selected_predictor) + + self.tree_history.append(self.tree) + self.expansion_nodes_history.append(self.expansion_nodes) + + def set_particle_to_step(self, t): + if len(self.tree_history) <= t: + self.tree = self.tree_history[-1] + self.expansion_nodes = self.expansion_nodes_history[-1] + else: + self.tree = self.tree_history[t] + self.expansion_nodes = self.expansion_nodes_history[t] + + +def logp(out_vars, vars, shared): + """Compile Theano function of the model and the input and output variables. + + Parameters + ---------- + out_vars: List + containing :class:`pymc3.Distribution` for the output variables + vars: List + containing :class:`pymc3.Distribution` for the input variables + shared: List + containing :class:`theano.tensor.Tensor` for depended shared data + """ + out_list, inarray0 = join_nonshared_inputs(out_vars, vars, shared) + f = theano_function([inarray0], out_list[0]) + f.trust_input = True + return f diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 389b3eb4520..a415f648ea4 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -166,6 +166,20 @@ def test_trace_report(self, step_cls, discard): assert isinstance(trace.report.t_sampling, float) pass + def test_trace_report_bart(self): + X = np.random.normal(0, 1, size=(3, 250)).T + Y = np.random.normal(0, 1, size=250) + X[:, 0] = np.random.normal(Y, 0.1) + + with pm.Model() as model: + mu = pm.BART("mu", X, Y, m=20) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=Y) + trace = pm.sample(500, tune=100, random_seed=3415) + var_imp = trace.report.variable_importance + assert var_imp[0] > var_imp[1:].sum() + npt.assert_almost_equal(var_imp.sum(), 1) + def test_return_inferencedata(self): with self.model: kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())