Skip to content

Commit

Permalink
BART: clamp first particle to old full tree (#5011)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Sep 29, 2021
1 parent 1048b69 commit 8c59b41
Showing 1 changed file with 83 additions and 95 deletions.
178 changes: 83 additions & 95 deletions pymc/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,59 @@
_log = logging.getLogger("pymc")


class ParticleTree:
"""
Particle tree
"""

def __init__(self, tree, log_weight, likelihood):
self.tree = tree.copy() # keeps the tree that we care at the moment
self.expansion_nodes = [0]
self.log_weight = log_weight
self.old_likelihood_logp = likelihood
self.used_variates = []

def sample_tree_sequential(
self,
ssv,
available_predictors,
prior_prob_leaf_node,
X,
missing_data,
sum_trees_output,
mean,
m,
normal,
mu_std,
):
tree_grew = False
if self.expansion_nodes:
index_leaf_node = self.expansion_nodes.pop(0)
# Probability that this node will remain a leaf node
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]

if prob_leaf < np.random.random():
tree_grew, index_selected_predictor = grow_tree(
self.tree,
index_leaf_node,
ssv,
available_predictors,
X,
missing_data,
sum_trees_output,
mean,
m,
normal,
mu_std,
)
if tree_grew:
new_indexes = self.tree.idx_leaf_nodes[-2:]
self.expansion_nodes.extend(new_indexes)
self.used_variates.append(index_selected_predictor)

return tree_grew


class PGBART(ArrayStepShared):
"""
Particle Gibss BART sampling step
Expand Down Expand Up @@ -108,9 +161,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo

if self.chunk == "auto":
self.chunk = max(1, int(self.m * 0.1))
self.num_particles = num_particles
self.log_num_particles = np.log(num_particles)
self.indices = list(range(1, num_particles))
self.len_indices = len(self.indices)
self.max_stages = max_stages

shared = make_shared_replacements(initial_values, vars, model)
Expand All @@ -137,24 +190,22 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if self.idx == self.m:
self.idx = 0

for idx in range(self.idx, self.idx + self.chunk):
if idx >= self.m:
for tree_id in range(self.idx, self.idx + self.chunk):
if tree_id >= self.m:
break
tree = self.all_particles[idx].tree
sum_trees_output_noi = sum_trees_output - tree.predict_output()
self.idx += 1
# Generate an initial set of SMC particles
# at the end of the algorithm we return one of these particles as the new tree
particles = self.init_particles(tree.tree_id)
particles = self.init_particles(tree_id)
# Compute the sum of trees without the tree we are attempting to replace
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
self.idx += 1

# The old tree is not growing so we update the weights only once.
self.update_weight(particles[0])
for t in range(self.max_stages):
# Get old particle at stage t
if t > 0:
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
# sample each particle (try to grow each tree)
compute_logp = [True]
# Sample each particle (try to grow each tree), except for the first one.
for p in particles[1:]:
clp = p.sample_tree_sequential(
tree_grew = p.sample_tree_sequential(
self.ssv,
self.available_predictors,
self.prior_prob_leaf_node,
Expand All @@ -166,22 +217,14 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.normal,
self.mu_std,
)
compute_logp.append(clp)
# Update weights. Since the prior is used as the proposal,the weights
# are updated additively as the ratio of the new and old log_likelihoods
for clp, p in zip(compute_logp, particles):
if clp: # Compute the likelihood when p has changed from the previous iteration
new_likelihood = self.likelihood_logp(
sum_trees_output_noi + p.tree.predict_output()
)
p.log_weight += new_likelihood - p.old_likelihood_logp
p.old_likelihood_logp = new_likelihood
if tree_grew:
self.update_weight(p)
# Normalize weights
W_t, normalized_weights = self.normalize(particles)

# Resample all but first particle
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w)
particles[1:] = particles[new_indices]

# Set the new weights
Expand All @@ -200,8 +243,8 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
new_particle = np.random.choice(particles, p=normalized_weights)
new_tree = new_particle.tree
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
self.all_particles[tree.tree_id] = new_particle
sum_trees_output = sum_trees_output_noi + new_tree.predict_output()
self.all_particles[tree_id] = new_particle
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()

if self.tune:
for index in new_particle.used_variates:
Expand Down Expand Up @@ -232,7 +275,7 @@ def competence(var, has_grad):
return Competence.IDEAL
return Competence.INCOMPATIBLE

def normalize(self, particles):
def normalize(self, particles: List[ParticleTree]) -> Tuple[float, np.ndarray]:
"""
Use logsumexp trick to get W_t and softmax to get normalized_weights
"""
Expand All @@ -248,16 +291,11 @@ def normalize(self, particles):

return W_t, normalized_weights

def get_old_tree_particle(self, tree_id, t):
old_tree_particle = self.all_particles[tree_id]
old_tree_particle.set_particle_to_step(t)
return old_tree_particle

def init_particles(self, tree_id):
def init_particles(self, tree_id: int) -> np.ndarray:
"""
Initialize particles
"""
p = self.get_old_tree_particle(tree_id, 0)
p = self.all_particles[tree_id]
p.log_weight = self.init_log_weight
p.old_likelihood_logp = self.init_likelihood
particles = [p]
Expand All @@ -274,68 +312,18 @@ def init_particles(self, tree_id):

return np.array(particles)

def update_weight(self, particle: List[ParticleTree]) -> None:
"""
Update the weight of a particle
class ParticleTree:
"""
Particle tree
"""

def __init__(self, tree, log_weight, likelihood):
self.tree = tree.copy() # keeps the tree that we care at the moment
self.expansion_nodes = [0]
self.tree_history = [self.tree]
self.expansion_nodes_history = [self.expansion_nodes]
self.log_weight = log_weight
self.old_likelihood_logp = likelihood
self.used_variates = []

def sample_tree_sequential(
self,
ssv,
available_predictors,
prior_prob_leaf_node,
X,
missing_data,
sum_trees_output,
mean,
m,
normal,
mu_std,
):
clp = False
if self.expansion_nodes:
index_leaf_node = self.expansion_nodes.pop(0)
# Probability that this node will remain a leaf node
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]

if prob_leaf < np.random.random():
clp, index_selected_predictor = grow_tree(
self.tree,
index_leaf_node,
ssv,
available_predictors,
X,
missing_data,
sum_trees_output,
mean,
m,
normal,
mu_std,
)
if clp:
new_indexes = self.tree.idx_leaf_nodes[-2:]
self.expansion_nodes.extend(new_indexes)
self.used_variates.append(index_selected_predictor)

self.tree_history.append(self.tree)
self.expansion_nodes_history.append(self.expansion_nodes)
return clp

def set_particle_to_step(self, t):
if len(self.tree_history) <= t:
t = -1
self.tree = self.tree_history[t]
self.expansion_nodes = self.expansion_nodes_history[t]
Since the prior is used as the proposal,the weights are updated additively as the ratio of
the new and old log-likelihoods.
"""
new_likelihood = self.likelihood_logp(
self.sum_trees_output_noi + particle.tree.predict_output()
)
particle.log_weight += new_likelihood - particle.old_likelihood_logp
particle.old_likelihood_logp = new_likelihood


def preprocess_XY(X, Y):
Expand Down

0 comments on commit 8c59b41

Please sign in to comment.