Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more tests for linearize and make_empirical_fisher_vp #405

Merged
merged 16 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/robust/__init__.py
Empty file.
223 changes: 223 additions & 0 deletions tests/robust/robust_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import math
from typing import Callable, Optional, Tuple, TypedDict, TypeVar

import pyro
import pyro.distributions as dist
import torch
from pyro.nn import PyroModule

from chirho.observational.handlers import condition
from chirho.robust.internals.utils import ParamDict
from chirho.robust.ops import Point

pyro.settings.set(module_local_params=True)
T = TypeVar("T")


class SimpleModel(PyroModule):
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
def forward(self):
a = pyro.sample("a", dist.Normal(0, 1))
with pyro.plate("data", 3, dim=-1):
b = pyro.sample("b", dist.Normal(a, 1))
return pyro.sample("y", dist.Normal(a + b, 1))


class SimpleGuide(torch.nn.Module):
def __init__(self):
super().__init__()
self.loc_a = torch.nn.Parameter(torch.rand(()))
self.loc_b = torch.nn.Parameter(torch.rand((3,)))

def forward(self):
a = pyro.sample("a", dist.Normal(self.loc_a, 1))
with pyro.plate("data", 3, dim=-1):
b = pyro.sample("b", dist.Normal(self.loc_b, 1))
return {"a": a, "b": b}


class GaussianModel(PyroModule):
def __init__(self, cov_mat: torch.Tensor):
super().__init__()
self.register_buffer("cov_mat", cov_mat)

def forward(self, loc):
pyro.sample(
"x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat)
)


# Note: `gaussian_log_prob` is separate from the GaussianModel above because of upstream obstacles
# in the interaction between `pyro.nn.PyroModule` and `torch.func`.
# See https://github.com/BasisResearch/chirho/issues/393
def gaussian_log_prob(params: ParamDict, data_point: Point[T], cov_mat) -> T:
with pyro.validation_enabled(False):
return dist.MultivariateNormal(
loc=params["loc"], covariance_matrix=cov_mat
).log_prob(data_point["x"])


class DataConditionedModel(PyroModule):
r"""
Helper class for conditioning on data.
"""

def __init__(self, model: PyroModule):
super().__init__()
self.model = model

def forward(self, D: Point[torch.Tensor]):
with condition(data=D):
# Assume first dimension corresponds to # of datapoints
N = D[next(iter(D))].shape[0]
return self.model.forward(N=N)


class HighDimLinearModel(pyro.nn.PyroModule):
def __init__(
self,
p: int,
link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),
prior_scale: Optional[float] = None,
):
super().__init__()
self.p = p
self.link_fn = link_fn
if prior_scale is None:
self.prior_scale = 1 / math.sqrt(self.p)
else:
self.prior_scale = prior_scale

def sample_outcome_weights(self):
return pyro.sample(
"outcome_weights",
dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1),
)

def sample_intercept(self):
return pyro.sample("intercept", dist.Normal(0.0, 1.0))

def sample_propensity_weights(self):
return pyro.sample(
"propensity_weights",
dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1),
)

def sample_treatment_weight(self):
return pyro.sample("treatment_weight", dist.Normal(0.0, 1.0))

def sample_covariate_loc_scale(self):
loc = pyro.sample(
"covariate_loc", dist.Normal(0.0, 1.0).expand((self.p,)).to_event(1)
)
scale = pyro.sample(
"covariate_scale", dist.LogNormal(0, 1).expand((self.p,)).to_event(1)
)
return loc, scale

def forward(self, N: int = 1):
intercept = self.sample_intercept()
outcome_weights = self.sample_outcome_weights()
propensity_weights = self.sample_propensity_weights()
tau = self.sample_treatment_weight()
x_loc, x_scale = self.sample_covariate_loc_scale()
with pyro.plate("obs", N, dim=-1):
X = pyro.sample("X", dist.Normal(x_loc, x_scale).to_event(1))
A = pyro.sample(
"A",
dist.Bernoulli(
logits=torch.einsum("...np,...p->...n", X, propensity_weights)
),
)
return pyro.sample(
"Y",
self.link_fn(
torch.einsum("...np,...p->...n", X, outcome_weights)
+ A * tau
+ intercept
),
)


class KnownCovariateDistModel(HighDimLinearModel):
def sample_covariate_loc_scale(self):
return torch.zeros(self.p), torch.ones(self.p)


class BenchmarkLinearModel(HighDimLinearModel):
def __init__(
self,
p: int,
link_fn: Callable[..., dist.Distribution],
alpha: int,
beta: int,
treatment_weight: float = 0.0,
):
super().__init__(p, link_fn)
self.alpha = alpha # sparsity of propensity weights
self.beta = beta # sparisty of outcome weights
self.treatment_weight = treatment_weight

def sample_outcome_weights(self):
outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p)
outcome_weights[self.beta :] = 0.0
return outcome_weights

def sample_treatment_null_weight(self):
return torch.tensor(0.0)

def sample_propensity_weights(self):
propensity_weights = 1 / math.sqrt(self.alpha) * torch.ones(self.p)
propensity_weights[self.alpha :] = 0.0
return propensity_weights

def sample_treatment_weight(self):
return torch.tensor(self.treatment_weight)

def sample_intercept(self):
return torch.tensor(0.0)

def sample_covariate_loc_scale(self):
return torch.zeros(self.p), torch.ones(self.p)


class MLEGuide(torch.nn.Module):
def __init__(self, mle_est: ParamDict):
super().__init__()
self.names = list(mle_est.keys())
for name, value in mle_est.items():
setattr(self, name + "_param", torch.nn.Parameter(value))

def forward(self, *args, **kwargs):
for name in self.names:
value = getattr(self, name + "_param")
pyro.sample(name, dist.Delta(value))


class ATETestPoint(TypedDict):
X: torch.Tensor
A: torch.Tensor
Y: torch.Tensor


class ATEParamDict(TypedDict):
propensity_weights: torch.Tensor
outcome_weights: torch.Tensor
treatment_weight: torch.Tensor
intercept: torch.Tensor


def closed_form_ate_correction(
X_test: ATETestPoint, theta: ATEParamDict
) -> Tuple[torch.Tensor, torch.Tensor]:
X = X_test["X"]
A = X_test["A"]
Y = X_test["Y"]
pi_X = torch.sigmoid(X.mv(theta["propensity_weights"]))
mu_X = (
X.mv(theta["outcome_weights"])
+ A * theta["treatment_weight"]
+ theta["intercept"]
)
analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)
analytic_correction = analytic_eif_at_test_pts.mean()
return analytic_correction, analytic_eif_at_test_pts
61 changes: 61 additions & 0 deletions tests/robust/test_internals_compositions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import functools
import warnings

import pyro
import torch
from pyro.poutine.seed_messenger import SeedMessenger

from chirho.robust.internals.linearize import (
conjugate_gradient_solve,
make_empirical_fisher_vp,
)
from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood
from chirho.robust.internals.utils import make_functional_call

from .robust_fixtures import SimpleGuide, SimpleModel

pyro.settings.set(module_local_params=True)


def test_empirical_fisher_vp_nmclikelihood_cg_composition():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the test you expect to fail without #408, can you mark it with pytest.mark.xfail(reason="fails without fix in https://github.com/BasisResearch/chirho/pull/408")?

Copy link
Contributor Author

@agrawalraj agrawalraj Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the change in #408 now just uses SeedMessenger in the linearize body, I can fix these errors by doing the same in this test. I'll push up these changes now.

Copy link
Contributor Author

@agrawalraj agrawalraj Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 just made this change!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, but remember to remove the SeedMessenger here in #408, since you want this test to exercise the changes made to NMCLogPredictiveLikelihood there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok in #408, I followed your suggestion here #408 (review), and did not actually modify NMCLogPredictiveLikelihood. I just checked if

log_prob = NMCLogPredictiveLikelihood(
        model, guide, num_samples=3, max_plate_nesting=3
    )
log_prob_params, func_log_prob = make_functional_call(log_prob)
func_log_prob = SeedMessenger(123)(func_log_prob)

works in the test_nmc_likelihood_seeded test in #408. I think we can move this discussion to #408 if you have other suggestions on this point!

model = SimpleModel()
guide = SimpleGuide()
model(), guide() # initialize
log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=100)
log_prob_params, func_log_prob = make_functional_call(log_prob)
func_log_prob = SeedMessenger(123)(func_log_prob)

predictive = pyro.infer.Predictive(
model, guide=guide, num_samples=1000, parallel=True, return_sites=["y"]
)
predictive_params, func_predictive = make_functional_call(predictive)

cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=10)

with torch.no_grad():
data = func_predictive(predictive_params)
fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data)

v = {k: torch.ones_like(v) for k, v in log_prob_params.items()}

assert fvp(v)["guide.loc_a"].abs().max() > 0 # sanity check for non-zero fvp

solve_one = cg_solver(fvp, v)
solve_two = cg_solver(fvp, v)

if solve_one["guide.loc_a"].abs().max() > 1e6:
warnings.warn(
"solve_one['guide.loc_a'] is large (max entry={}).".format(
solve_one["guide.loc_a"].abs().max()
)
)

if solve_one["guide.loc_b"].abs().max() > 1e6:
warnings.warn(
"solve_one['guide.loc_b'] is large (max entry={}).".format(
solve_one["guide.loc_b"].abs().max()
)
)

assert torch.allclose(solve_one["guide.loc_a"], solve_two["guide.loc_a"])
assert torch.allclose(solve_one["guide.loc_b"], solve_two["guide.loc_b"])
Loading
Loading