Skip to content

Commit

Permalink
James.yang/gaussian bound (#145)
Browse files Browse the repository at this point in the history
* Add gaussian tilt bound

* Add uni and multi normal

* Add uni normal test
  • Loading branch information
JamesYang007 authored Nov 7, 2022
1 parent 29834a5 commit d1adfc3
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 0 deletions.
100 changes: 100 additions & 0 deletions confirm/mini_imprint/bound/multivariate_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import jax
import jax.numpy as jnp


def _quad_form(v, A):
return v.dot(A @ v)


class ForwardQCPSolver:
"""
Solves the minimization problem:
0.5 * (q-1) * v^T cov v - log(f0) / q
with respect to q >= 1.
"""

def __init__(self, cov):
self.cov = cov

def solve(self, v, f0):
logf0 = jnp.log(f0)
mv = _quad_form(v, self.cov)
q_opt = jnp.sqrt(-2 * logf0 / mv)
return jnp.maximum(q_opt, 1)


class BackwardQCPSolver:
"""
Solves the minimization problem:
0.5 * q * v^T cov v - log(alpha) * q / (q-1)
with respect to q >= 1.
"""

def __init__(self, cov):
self.cov = cov

def solve(self, v, alpha):
mv = _quad_form(v, self.cov)
return 1 + jnp.sqrt(-2 * jnp.log(alpha) / mv)


class TileForwardQCPSolver:
"""
Solves the minimization problem:
0.5 * (q-1) * max_v v^T cov v - log(f0) / q
with respect to q >= 1.
"""

def __init__(self, cov):
self.cov = cov

def solve(self, vs, f0):
logf0 = jnp.log(f0)
mv = jnp.max(jax.vmap(_quad_form, in_axes=(0, None))(vs, self.cov))
q_opt = jnp.sqrt(-2 * logf0 / mv)
return jnp.maximum(q_opt, 1)


class TileBackwardQCPSolver:
"""
Solves the minimization problem:
0.5 * q * max_v v^T cov v - log(alpha) * q / (q-1)
with respect to q >= 1.
"""

def __init__(self, cov):
self.cov = cov

def solve(self, vs, alpha):
mv = jnp.max(jax.vmap(_quad_form, in_axes=(0, None))(vs, self.cov))
return 1 + jnp.sqrt(-2 * jnp.log(alpha) / mv)


def tilt_bound_fwd(q, cov, v, f0):
p_inv = 1 - 1 / q
expo = 0.5 * (q - 1) * _quad_form(v, cov)
return f0**p_inv * jnp.exp(expo)


def tilt_bound_fwd_tile(q, cov, vs, f0):
def _compute_expo(v):
return 0.5 * (q - 1) * _quad_form(v, cov)

p_inv = 1 - 1 / q
max_expo = jnp.max(jax.vmap(_compute_expo, in_axes=(0,))(vs))
return f0**p_inv * jnp.exp(max_expo)


def tilt_bound_bwd(q, cov, v, alpha):
p = 1 / (1 - 1 / q)
expo = 0.5 * (q - 1) * _quad_form(v, cov)
return (alpha * jnp.exp(-expo)) ** p


def tilt_bound_bwd_tile(q, cov, vs, alpha):
def _compute_expo(v):
return 0.5 * (q - 1) * _quad_form(v, cov)

p = 1 / (1 - 1 / q)
max_expo = jnp.max(jax.vmap(_compute_expo, in_axes=(0,))(vs))
return (alpha * jnp.exp(-max_expo)) ** p
89 changes: 89 additions & 0 deletions confirm/mini_imprint/bound/normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import jax.numpy as jnp


class ForwardQCPSolver:
"""
Solves the minimization problem:
0.5 * (q-1) * s_sq * v ** 2 - log(f0) / q
with respect to q >= 1.
"""

def __init__(self, scale):
self.scale = scale

def solve(self, v, f0):
logf0 = jnp.log(f0)
mv_sqrt = self.scale * jnp.abs(v)
q_opt = jnp.sqrt(-2 * logf0) / mv_sqrt
return jnp.maximum(q_opt, 1)


class BackwardQCPSolver:
"""
Solves the minimization problem:
0.5 * q * s_sq * v ** 2 - log(alpha) * q / (q-1)
with respect to q >= 1.
"""

def __init__(self, scale):
self.scale = scale

def solve(self, v, alpha):
mv_sqrt = self.scale * jnp.abs(v)
return 1 + jnp.sqrt(-2 * jnp.log(alpha)) / mv_sqrt


class TileForwardQCPSolver:
"""
Solves the minimization problem:
0.5 * (q-1) * s_sq * max_v v ** 2 - log(f0) / q
with respect to q >= 1.
"""

def __init__(self, scale):
self.scale = scale

def solve(self, vs, f0):
logf0 = jnp.log(f0)
mv_sqrt = self.scale * jnp.max(jnp.abs(vs))
q_opt = jnp.sqrt(-2 * logf0) / mv_sqrt
return jnp.maximum(q_opt, 1)


class TileBackwardQCPSolver:
"""
Solves the minimization problem:
0.5 * q * s_sq * max_v v ** 2 - log(alpha) * q / (q-1)
with respect to q >= 1.
"""

def __init__(self, scale):
self.scale = scale

def solve(self, vs, alpha):
mv_sqrt = self.scale * jnp.max(jnp.abs(vs))
return 1 + jnp.sqrt(-2 * jnp.log(alpha)) / mv_sqrt


def tilt_bound_fwd(q, scale, v, f0):
p_inv = 1 - 1 / q
expo = 0.5 * (q - 1) * (scale * v) ** 2
return f0**p_inv * jnp.exp(expo)


def tilt_bound_fwd_tile(q, scale, vs, f0):
p_inv = 1 - 1 / q
max_expo = 0.5 * (q - 1) * (scale * jnp.max(jnp.abs(vs))) ** 2
return f0**p_inv * jnp.exp(max_expo)


def tilt_bound_bwd(q, scale, v, alpha):
p = 1 / (1 - 1 / q)
expo = 0.5 * (q - 1) * (scale * v) ** 2
return (alpha * jnp.exp(-expo)) ** p


def tilt_bound_bwd_tile(q, scale, vs, alpha):
p = 1 / (1 - 1 / q)
max_expo = 0.5 * (q - 1) * (scale * jnp.max(jnp.abs(vs))) ** 2
return (alpha * jnp.exp(-max_expo)) ** p
82 changes: 82 additions & 0 deletions tests/mini_imprint/bound/test_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np

import confirm.mini_imprint.bound.normal as normal


def fwd_qcp_derivative(q, scale, v, f0):
return 0.5 * (scale * v) ** 2 + np.log(f0) / q**2


def bwd_qcp_derivative(q, scale, v, alpha):
return 0.5 * (scale * v) ** 2 + np.log(alpha) / (q - 1) ** 2


def tile_fwd_qcp_derivative(q, scale, vs, f0):
mv = np.max((scale * vs) ** 2)
return 0.5 * mv + np.log(f0) / q**2


def tile_bwd_qcp_derivative(q, scale, vs, alpha):
mv = np.max((scale * vs) ** 2)
return 0.5 * mv + np.log(alpha) / (q - 1) ** 2


def test_fwd_qcp_solver():
scale = 2.0
v = -0.321
f0 = 0.025
fwd_solver = normal.ForwardQCPSolver(scale)
q_opt = fwd_solver.solve(v, f0)
q_opt_deriv = fwd_qcp_derivative(q_opt, scale, v, f0)
np.testing.assert_almost_equal(q_opt_deriv, 0.0)


def test_bwd_qcp_solver():
scale = 2.0
v = -0.321
alpha = 0.025
bwd_solver = normal.BackwardQCPSolver(scale)
q_opt = bwd_solver.solve(v, alpha)
q_opt_deriv = bwd_qcp_derivative(q_opt, scale, v, alpha)
np.testing.assert_almost_equal(q_opt_deriv, 0.0)


def test_tile_fwd_qcp_solver():
scale = 3.2
vs = np.array([-0.1, 0.2])
f0 = 0.025
fwd_solver = normal.TileForwardQCPSolver(scale)
q_opt = fwd_solver.solve(vs, f0)
q_opt_deriv = tile_fwd_qcp_derivative(q_opt, scale, vs, f0)
print(q_opt)
np.testing.assert_almost_equal(q_opt_deriv, 0.0)


def test_tile_bwd_qcp_solver():
scale = 1.2
vs = np.array([-0.3, 0.1])
alpha = 0.025
bwd_solver = normal.TileBackwardQCPSolver(scale)
q_opt = bwd_solver.solve(vs, alpha)
q_opt_deriv = tile_bwd_qcp_derivative(q_opt, scale, vs, alpha)
np.testing.assert_almost_equal(q_opt_deriv, 0.0)


def test_fwd_bwd_invariance():
scale = 2.0
v = -0.321
f0 = 0.025
q = 3.2
fwd_bound = normal.tilt_bound_fwd(q, scale, v, f0)
bwd_bound = normal.tilt_bound_bwd(q, scale, v, fwd_bound)
np.testing.assert_almost_equal(bwd_bound, f0)


def test_tile_fwd_bwd_invariance():
scale = 1.2
vs = np.array([-0.3, 0.1])
f0 = 0.025
q = 5.1
fwd_bound = normal.tilt_bound_fwd_tile(q, scale, vs, f0)
bwd_bound = normal.tilt_bound_bwd_tile(q, scale, vs, fwd_bound)
np.testing.assert_almost_equal(bwd_bound, f0)

0 comments on commit d1adfc3

Please sign in to comment.