From c41756ef90939a4733fe2278a4ccf95c16d6d29c Mon Sep 17 00:00:00 2001 From: James Yang Date: Mon, 7 Nov 2022 12:34:30 -0800 Subject: [PATCH] James.yang/gaussian bound (#145) * Add gaussian tilt bound * Add uni and multi normal * Add uni normal test --- .../mini_imprint/bound/multivariate_normal.py | 100 ++++++++++++++++++ confirm/mini_imprint/bound/normal.py | 89 ++++++++++++++++ tests/mini_imprint/bound/test_normal.py | 82 ++++++++++++++ 3 files changed, 271 insertions(+) create mode 100644 confirm/mini_imprint/bound/multivariate_normal.py create mode 100644 confirm/mini_imprint/bound/normal.py create mode 100644 tests/mini_imprint/bound/test_normal.py diff --git a/confirm/mini_imprint/bound/multivariate_normal.py b/confirm/mini_imprint/bound/multivariate_normal.py new file mode 100644 index 00000000..dbca3063 --- /dev/null +++ b/confirm/mini_imprint/bound/multivariate_normal.py @@ -0,0 +1,100 @@ +import jax +import jax.numpy as jnp + + +def _quad_form(v, A): + return v.dot(A @ v) + + +class ForwardQCPSolver: + """ + Solves the minimization problem: + 0.5 * (q-1) * v^T cov v - log(f0) / q + with respect to q >= 1. + """ + + def __init__(self, cov): + self.cov = cov + + def solve(self, v, f0): + logf0 = jnp.log(f0) + mv = _quad_form(v, self.cov) + q_opt = jnp.sqrt(-2 * logf0 / mv) + return jnp.maximum(q_opt, 1) + + +class BackwardQCPSolver: + """ + Solves the minimization problem: + 0.5 * q * v^T cov v - log(alpha) * q / (q-1) + with respect to q >= 1. + """ + + def __init__(self, cov): + self.cov = cov + + def solve(self, v, alpha): + mv = _quad_form(v, self.cov) + return 1 + jnp.sqrt(-2 * jnp.log(alpha) / mv) + + +class TileForwardQCPSolver: + """ + Solves the minimization problem: + 0.5 * (q-1) * max_v v^T cov v - log(f0) / q + with respect to q >= 1. + """ + + def __init__(self, cov): + self.cov = cov + + def solve(self, vs, f0): + logf0 = jnp.log(f0) + mv = jnp.max(jax.vmap(_quad_form, in_axes=(0, None))(vs, self.cov)) + q_opt = jnp.sqrt(-2 * logf0 / mv) + return jnp.maximum(q_opt, 1) + + +class TileBackwardQCPSolver: + """ + Solves the minimization problem: + 0.5 * q * max_v v^T cov v - log(alpha) * q / (q-1) + with respect to q >= 1. + """ + + def __init__(self, cov): + self.cov = cov + + def solve(self, vs, alpha): + mv = jnp.max(jax.vmap(_quad_form, in_axes=(0, None))(vs, self.cov)) + return 1 + jnp.sqrt(-2 * jnp.log(alpha) / mv) + + +def tilt_bound_fwd(q, cov, v, f0): + p_inv = 1 - 1 / q + expo = 0.5 * (q - 1) * _quad_form(v, cov) + return f0**p_inv * jnp.exp(expo) + + +def tilt_bound_fwd_tile(q, cov, vs, f0): + def _compute_expo(v): + return 0.5 * (q - 1) * _quad_form(v, cov) + + p_inv = 1 - 1 / q + max_expo = jnp.max(jax.vmap(_compute_expo, in_axes=(0,))(vs)) + return f0**p_inv * jnp.exp(max_expo) + + +def tilt_bound_bwd(q, cov, v, alpha): + p = 1 / (1 - 1 / q) + expo = 0.5 * (q - 1) * _quad_form(v, cov) + return (alpha * jnp.exp(-expo)) ** p + + +def tilt_bound_bwd_tile(q, cov, vs, alpha): + def _compute_expo(v): + return 0.5 * (q - 1) * _quad_form(v, cov) + + p = 1 / (1 - 1 / q) + max_expo = jnp.max(jax.vmap(_compute_expo, in_axes=(0,))(vs)) + return (alpha * jnp.exp(-max_expo)) ** p diff --git a/confirm/mini_imprint/bound/normal.py b/confirm/mini_imprint/bound/normal.py new file mode 100644 index 00000000..c1b11a76 --- /dev/null +++ b/confirm/mini_imprint/bound/normal.py @@ -0,0 +1,89 @@ +import jax.numpy as jnp + + +class ForwardQCPSolver: + """ + Solves the minimization problem: + 0.5 * (q-1) * s_sq * v ** 2 - log(f0) / q + with respect to q >= 1. + """ + + def __init__(self, scale): + self.scale = scale + + def solve(self, v, f0): + logf0 = jnp.log(f0) + mv_sqrt = self.scale * jnp.abs(v) + q_opt = jnp.sqrt(-2 * logf0) / mv_sqrt + return jnp.maximum(q_opt, 1) + + +class BackwardQCPSolver: + """ + Solves the minimization problem: + 0.5 * q * s_sq * v ** 2 - log(alpha) * q / (q-1) + with respect to q >= 1. + """ + + def __init__(self, scale): + self.scale = scale + + def solve(self, v, alpha): + mv_sqrt = self.scale * jnp.abs(v) + return 1 + jnp.sqrt(-2 * jnp.log(alpha)) / mv_sqrt + + +class TileForwardQCPSolver: + """ + Solves the minimization problem: + 0.5 * (q-1) * s_sq * max_v v ** 2 - log(f0) / q + with respect to q >= 1. + """ + + def __init__(self, scale): + self.scale = scale + + def solve(self, vs, f0): + logf0 = jnp.log(f0) + mv_sqrt = self.scale * jnp.max(jnp.abs(vs)) + q_opt = jnp.sqrt(-2 * logf0) / mv_sqrt + return jnp.maximum(q_opt, 1) + + +class TileBackwardQCPSolver: + """ + Solves the minimization problem: + 0.5 * q * s_sq * max_v v ** 2 - log(alpha) * q / (q-1) + with respect to q >= 1. + """ + + def __init__(self, scale): + self.scale = scale + + def solve(self, vs, alpha): + mv_sqrt = self.scale * jnp.max(jnp.abs(vs)) + return 1 + jnp.sqrt(-2 * jnp.log(alpha)) / mv_sqrt + + +def tilt_bound_fwd(q, scale, v, f0): + p_inv = 1 - 1 / q + expo = 0.5 * (q - 1) * (scale * v) ** 2 + return f0**p_inv * jnp.exp(expo) + + +def tilt_bound_fwd_tile(q, scale, vs, f0): + p_inv = 1 - 1 / q + max_expo = 0.5 * (q - 1) * (scale * jnp.max(jnp.abs(vs))) ** 2 + return f0**p_inv * jnp.exp(max_expo) + + +def tilt_bound_bwd(q, scale, v, alpha): + p = 1 / (1 - 1 / q) + expo = 0.5 * (q - 1) * (scale * v) ** 2 + return (alpha * jnp.exp(-expo)) ** p + + +def tilt_bound_bwd_tile(q, scale, vs, alpha): + p = 1 / (1 - 1 / q) + max_expo = 0.5 * (q - 1) * (scale * jnp.max(jnp.abs(vs))) ** 2 + return (alpha * jnp.exp(-max_expo)) ** p diff --git a/tests/mini_imprint/bound/test_normal.py b/tests/mini_imprint/bound/test_normal.py new file mode 100644 index 00000000..ce974cca --- /dev/null +++ b/tests/mini_imprint/bound/test_normal.py @@ -0,0 +1,82 @@ +import numpy as np + +import confirm.mini_imprint.bound.normal as normal + + +def fwd_qcp_derivative(q, scale, v, f0): + return 0.5 * (scale * v) ** 2 + np.log(f0) / q**2 + + +def bwd_qcp_derivative(q, scale, v, alpha): + return 0.5 * (scale * v) ** 2 + np.log(alpha) / (q - 1) ** 2 + + +def tile_fwd_qcp_derivative(q, scale, vs, f0): + mv = np.max((scale * vs) ** 2) + return 0.5 * mv + np.log(f0) / q**2 + + +def tile_bwd_qcp_derivative(q, scale, vs, alpha): + mv = np.max((scale * vs) ** 2) + return 0.5 * mv + np.log(alpha) / (q - 1) ** 2 + + +def test_fwd_qcp_solver(): + scale = 2.0 + v = -0.321 + f0 = 0.025 + fwd_solver = normal.ForwardQCPSolver(scale) + q_opt = fwd_solver.solve(v, f0) + q_opt_deriv = fwd_qcp_derivative(q_opt, scale, v, f0) + np.testing.assert_almost_equal(q_opt_deriv, 0.0) + + +def test_bwd_qcp_solver(): + scale = 2.0 + v = -0.321 + alpha = 0.025 + bwd_solver = normal.BackwardQCPSolver(scale) + q_opt = bwd_solver.solve(v, alpha) + q_opt_deriv = bwd_qcp_derivative(q_opt, scale, v, alpha) + np.testing.assert_almost_equal(q_opt_deriv, 0.0) + + +def test_tile_fwd_qcp_solver(): + scale = 3.2 + vs = np.array([-0.1, 0.2]) + f0 = 0.025 + fwd_solver = normal.TileForwardQCPSolver(scale) + q_opt = fwd_solver.solve(vs, f0) + q_opt_deriv = tile_fwd_qcp_derivative(q_opt, scale, vs, f0) + print(q_opt) + np.testing.assert_almost_equal(q_opt_deriv, 0.0) + + +def test_tile_bwd_qcp_solver(): + scale = 1.2 + vs = np.array([-0.3, 0.1]) + alpha = 0.025 + bwd_solver = normal.TileBackwardQCPSolver(scale) + q_opt = bwd_solver.solve(vs, alpha) + q_opt_deriv = tile_bwd_qcp_derivative(q_opt, scale, vs, alpha) + np.testing.assert_almost_equal(q_opt_deriv, 0.0) + + +def test_fwd_bwd_invariance(): + scale = 2.0 + v = -0.321 + f0 = 0.025 + q = 3.2 + fwd_bound = normal.tilt_bound_fwd(q, scale, v, f0) + bwd_bound = normal.tilt_bound_bwd(q, scale, v, fwd_bound) + np.testing.assert_almost_equal(bwd_bound, f0) + + +def test_tile_fwd_bwd_invariance(): + scale = 1.2 + vs = np.array([-0.3, 0.1]) + f0 = 0.025 + q = 5.1 + fwd_bound = normal.tilt_bound_fwd_tile(q, scale, vs, f0) + bwd_bound = normal.tilt_bound_bwd_tile(q, scale, vs, fwd_bound) + np.testing.assert_almost_equal(bwd_bound, f0)