Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An attempt at a pytorch fast_inla implementation. #97

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda-forge
dependencies:
# essentials
- python
- python=3.9
- setuptools
- jupyterlab
- numpy
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[pytest]
addopts = -s --tb=short
norecursedirs = __pycache__ build bazel-*
env =
PYTORCH_ENABLE_MPS_FALLBACK=1
38 changes: 38 additions & 0 deletions research/berry/berrylib/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# flake8: noqa
import sys
import time

sys.path.append("./research/berry")
import berrylib.fast_inla as fast_inla
import test_berry
import torch


def benchmark_m1_torch():
for device in ["cpu", "mps"]:
print("")
start = time.time()
N = 5000
A = torch.rand(N, N, dtype=torch.float).to(device)
B = torch.rand(N, N, dtype=torch.float).to(device)
print("create", time.time() - start)

start = time.time()
for i in range(100):
A = torch.mm(A, B) * (2 / N)
cpu_arr = A.to("cpu")
print(cpu_arr[0, 0], cpu_arr[-1, -1])
print("matvec", time.time() - start)
start = time.time()


N = 10000
it = 4
print("jax")
test_berry.test_fast_inla("jax", N, it)
# print("cpp")
# test_berry.test_fast_inla("cpp", N, it)
# print("numpy")
# test_berry.test_fast_inla("numpy", N, it)
print("pytorch")
test_berry.test_fast_inla("pytorch", N, it)
178 changes: 176 additions & 2 deletions research/berry/berrylib/fast_inla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import numpy as np
import scipy.linalg
import scipy.stats
import torch
from jax.config import config
from scipy.special import logit

# This line is critical for enabling 64-bit floats.
config.update("jax_enable_x64", True)


def fast_invert(S_in, d):
def np_fast_invert(S_in, d):
S = np.tile(S_in, (d.shape[0], 1, 1, 1))
for k in range(d.shape[-1]):
outer = np.einsum("...i,...j->...ij", S[..., k, :], S[..., :, k])
Expand All @@ -20,6 +21,13 @@ def fast_invert(S_in, d):
return S


def pytorch_fast_invert(S, d):
for k in range(d.shape[-1]):
offset = d[..., k] / (1 + d[..., k] * S[..., k, k])
S -= (offset[..., None, None] * S[..., k, None, :]) * S[..., :, None, k]
return S


class FastINLA:
def __init__(
self,
Expand All @@ -28,6 +36,8 @@ def __init__(
sigma2_bounds=(1e-6, 1e3),
p1=0.3,
critical_value=0.85,
torch_dtype=torch.float64,
torch_device="cpu",
):
self.n_arms = n_arms
self.mu_0 = -1.34
Expand Down Expand Up @@ -69,13 +79,38 @@ def __init__(
)
)

# For Torch impl:
self.torch_dtype = torch_dtype
self.torch_device = torch_device
self.sigma2_pts_torch = torch.tensor(
self.sigma2_rule.pts, dtype=self.torch_dtype
).to(self.torch_device)
self.sigma2_wts_torch = torch.tensor(
self.sigma2_rule.wts, dtype=self.torch_dtype
).to(self.torch_device)
self.cov_torch = torch.tensor(self.cov, dtype=self.torch_dtype).to(
self.torch_device
)
self.neg_precQ_torch = torch.tensor(self.neg_precQ, dtype=self.torch_dtype).to(
self.torch_device
)
self.logprecQdet_torch = torch.tensor(
self.logprecQdet, dtype=self.torch_dtype
).to(self.torch_device)
self.log_prior_torch = torch.tensor(self.log_prior, dtype=self.torch_dtype).to(
self.torch_device
)

def rejection_inference(self, y, n, method="jax"):
_, exceedance, _, _ = self.inference(y, n, method)
return exceedance > self.critical_value

def inference(self, y, n, method="jax"):
fncs = dict(
numpy=self.numpy_inference, jax=self.jax_inference, cpp=self.cpp_inference
numpy=self.numpy_inference,
jax=self.jax_inference,
cpp=self.cpp_inference,
pytorch=self.pytorch_inference,
)
return fncs[method](y, n)[:4]

Expand Down Expand Up @@ -300,6 +335,145 @@ def cpp_inference(self, y, n):
)
return sigma2_post, exceedances, theta_max, theta_sigma

def pytorch_inference(self, y, n, thresh_theta=None):
if thresh_theta is None:
thresh_theta = self.thresh_theta
thresh_theta = torch.tensor(thresh_theta, dtype=self.torch_dtype).to(
self.torch_device
)

torch_y = torch.tensor(y, dtype=self.torch_dtype).to(self.torch_device)
torch_n = torch.tensor(n, dtype=self.torch_dtype).to(self.torch_device)

theta_max, hess = pytorch_optimize_mode(
torch_y,
torch_n,
self.cov_torch,
self.neg_precQ_torch,
self.mu_0,
self.logit_p1,
self.tol,
)
logjoint = self.pytorch_log_joint(torch_y, torch_n, theta_max)
log_sigma2_post = logjoint - 0.5 * torch.log(torch.linalg.det(-hess))
# log_sigma2_post = logjoint + 0.5 * torch.log(torch.linalg.det(-hess_inv))
sigma2_post = torch.exp(log_sigma2_post)
sigma2_post /= torch.sum(sigma2_post * self.sigma2_wts_torch, axis=1)[:, None]

theta_sigma = torch.empty(
(y.shape[0], self.sigma2_n, self.n_arms),
dtype=self.torch_dtype,
device=self.torch_device,
)
for i in range(self.n_arms):
rhs = [0] * self.n_arms
rhs[i] = 1.0
rhs_torch = torch.tensor(
rhs, dtype=self.torch_dtype, device=self.torch_device
)
# PERFORMANCE NOTE: Using torch.linalg.solve is much more stable
# than inverting the hessian directly in 32-bit. It's likely that
# there would be a faster way of doing this.
hess_inv_row = torch.linalg.solve(hess, rhs_torch)
theta_sigma[:, :, i] = torch.sqrt(-hess_inv_row[:, :, i])
# Version for float64...
# theta_sigma = torch.sqrt(torch.diagonal(-hess_inv, dim1=2, dim2=3))
theta_mu = theta_max

exceedances = []
dist = torch.distributions.normal.Normal(0, 1)
for i in range(self.n_arms):
exc_sigma2 = 1.0 - dist.cdf(
(thresh_theta[..., None, i] - theta_mu[..., i]) / theta_sigma[..., i]
)
exc = torch.sum(
exc_sigma2 * sigma2_post * self.sigma2_wts_torch[None, :], axis=1
)
exceedances.append(exc)
return (
sigma2_post.cpu().numpy(),
torch.stack(exceedances, axis=-1).cpu().numpy(),
theta_max.cpu().numpy(),
theta_sigma.cpu().numpy(),
)

def pytorch_log_joint(self, y, n, theta):
"""
theta is expected to have shape (N, n_sigma2, n_arms)
"""
theta_m0 = theta - self.mu_0
theta_adj = theta + self.logit_p1
exp_theta_adj = torch.exp(theta_adj)
MM = theta_m0[:, :, :, None].reshape((-1, self.n_arms, 1))
NN = torch.tile(self.neg_precQ_torch[None, ...], (y.shape[0], 1, 1, 1)).reshape(
(-1, self.n_arms, self.n_arms)
)
quad_term = torch.sum(torch.bmm(NN, MM) * MM, axis=(-2, -1)).reshape(
(y.shape[0], -1)
)
return (
0.5 * quad_term
+ self.logprecQdet_torch
+ torch.sum(
theta_adj * y[:, None] - n[:, None] * torch.log(exp_theta_adj + 1),
axis=-1,
)
+ self.log_prior_torch
)


def pytorch_optimize_mode(y, n, cov, neg_precQ, mu_0, logit_p1, tol):
sigma2_n = neg_precQ.shape[0]
N, n_arms = y.shape
theta_max = torch.zeros((N, sigma2_n, n_arms), dtype=y.dtype, device=y.device)
na = torch.arange(4)

converged = False
for i in range(100):
theta_m0 = theta_max - mu_0
exp_theta_adj = torch.exp(theta_max + logit_p1)
C = 1.0 / (exp_theta_adj + 1)
nCeta = n[:, None] * C * exp_theta_adj

grad = (
torch.matmul(neg_precQ[None], theta_m0[:, :, :, None])[..., 0]
+ y[:, None]
- nCeta
)

# diag = nCeta * C
# Version that only works in float64.
# hess_inv = pytorch_fast_invert(
# -torch.tile(cov[None, ...], (diag.shape[0], 1, 1, 1)),
# -diag,
# )
# step = -torch.matmul(hess_inv, grad[..., None])[..., 0]

hess = torch.tile(neg_precQ[None, ...], (N, 1, 1, 1))
hess[..., na, na] -= nCeta * C

# Take the full Newton step. The negative sign comes here because we
# are finding a maximum, not a minimum.

# PERFORMANCE NOTE: Using pytorch_fast_invert is faster but has some
# instability with float32 and thus is unsuitable to a GPU
# implementation.
step = -torch.linalg.solve(hess, grad)

theta_max += step

# We use a step size convergence criterion. This seems empirically
# sufficient. But, it would be possible to also check gradient norms
# or other common convergence criteria.
if torch.max(torch.sum(step**2, dim=-1)) < tol**2:
converged = True
break

if not converged:
raise RuntimeError("Failed to identify the mode of the joint density.")

return theta_max, hess


def jax_opt(y, n, cov, neg_precQ, sigma2, logit_p1, mu_0, tol):
def step(args):
Expand Down
29 changes: 25 additions & 4 deletions research/berry/berrylib/test_berry.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,15 @@ def test_inla_properties(method):
np.testing.assert_allclose(sigma2_integral, 1.0)


@pytest.mark.parametrize("method", ["jax", "numpy", "cpp"])
@pytest.mark.parametrize("method", ["pytorch"])
def test_fast_inla(method, N=10, iterations=1):
n_i = np.tile(np.array([20, 20, 35, 35]), (N, 1))
y_i = np.tile(np.array([0, 1, 9, 10], dtype=np.float64), (N, 1))
inla_model = fast_inla.FastINLA()

# import torch
# inla_model = fast_inla.FastINLA(torch_dtype=torch.float32, torch_device='mps')

runtimes = []
for i in range(iterations):
start = time.time()
Expand All @@ -187,10 +190,28 @@ def test_fast_inla(method, N=10, iterations=1):

sigma2_post, exceedances, theta_max, theta_sigma = out

np.testing.assert_allclose(
theta_max[0, 12],
correct_theta_max = [
[-0.65720195, -0.65720081, -0.65719484, -0.65719371],
[-0.65720546, -0.65720355, -0.65719345, -0.65719153],
[-0.65721851, -0.65721369, -0.65718827, -0.65718345],
[-0.65727494, -0.65725755, -0.65716589, -0.65714851],
[-0.65757963, -0.65749441, -0.65704505, -0.65695985],
[-0.65958489, -0.65905323, -0.65625057, -0.65571955],
[-0.67465157, -0.67076447, -0.65032674, -0.64647392],
[-0.78705011, -0.75796462, -0.60851132, -0.58150985],
[-1.34091076, -1.17062088, -0.44985596, -0.34985839],
[-2.47957009, -1.79535068, -0.28588712, -0.14714995],
[-3.7880899, -2.05159198, -0.23025555, -0.08632545],
[-5.02345932, -2.09158471, -0.21765884, -0.07316777],
[-6.04682818, -2.09586893, -0.21474981, -0.07019088],
rtol=1e-3,
[-6.789137, -2.09644905, -0.21401509, -0.06944441],
[-7.21806318, -2.096633, -0.21382083, -0.06924673],
]
print(theta_max[0])
np.testing.assert_allclose(
theta_max[0],
correct_theta_max,
rtol=1e-2,
)
correct = np.array(
[
Expand Down