-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add gaussian tilt bound * Add uni and multi normal * Add uni normal test
- Loading branch information
1 parent
29834a5
commit d1adfc3
Showing
3 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
def _quad_form(v, A): | ||
return v.dot(A @ v) | ||
|
||
|
||
class ForwardQCPSolver: | ||
""" | ||
Solves the minimization problem: | ||
0.5 * (q-1) * v^T cov v - log(f0) / q | ||
with respect to q >= 1. | ||
""" | ||
|
||
def __init__(self, cov): | ||
self.cov = cov | ||
|
||
def solve(self, v, f0): | ||
logf0 = jnp.log(f0) | ||
mv = _quad_form(v, self.cov) | ||
q_opt = jnp.sqrt(-2 * logf0 / mv) | ||
return jnp.maximum(q_opt, 1) | ||
|
||
|
||
class BackwardQCPSolver: | ||
""" | ||
Solves the minimization problem: | ||
0.5 * q * v^T cov v - log(alpha) * q / (q-1) | ||
with respect to q >= 1. | ||
""" | ||
|
||
def __init__(self, cov): | ||
self.cov = cov | ||
|
||
def solve(self, v, alpha): | ||
mv = _quad_form(v, self.cov) | ||
return 1 + jnp.sqrt(-2 * jnp.log(alpha) / mv) | ||
|
||
|
||
class TileForwardQCPSolver: | ||
""" | ||
Solves the minimization problem: | ||
0.5 * (q-1) * max_v v^T cov v - log(f0) / q | ||
with respect to q >= 1. | ||
""" | ||
|
||
def __init__(self, cov): | ||
self.cov = cov | ||
|
||
def solve(self, vs, f0): | ||
logf0 = jnp.log(f0) | ||
mv = jnp.max(jax.vmap(_quad_form, in_axes=(0, None))(vs, self.cov)) | ||
q_opt = jnp.sqrt(-2 * logf0 / mv) | ||
return jnp.maximum(q_opt, 1) | ||
|
||
|
||
class TileBackwardQCPSolver: | ||
""" | ||
Solves the minimization problem: | ||
0.5 * q * max_v v^T cov v - log(alpha) * q / (q-1) | ||
with respect to q >= 1. | ||
""" | ||
|
||
def __init__(self, cov): | ||
self.cov = cov | ||
|
||
def solve(self, vs, alpha): | ||
mv = jnp.max(jax.vmap(_quad_form, in_axes=(0, None))(vs, self.cov)) | ||
return 1 + jnp.sqrt(-2 * jnp.log(alpha) / mv) | ||
|
||
|
||
def tilt_bound_fwd(q, cov, v, f0): | ||
p_inv = 1 - 1 / q | ||
expo = 0.5 * (q - 1) * _quad_form(v, cov) | ||
return f0**p_inv * jnp.exp(expo) | ||
|
||
|
||
def tilt_bound_fwd_tile(q, cov, vs, f0): | ||
def _compute_expo(v): | ||
return 0.5 * (q - 1) * _quad_form(v, cov) | ||
|
||
p_inv = 1 - 1 / q | ||
max_expo = jnp.max(jax.vmap(_compute_expo, in_axes=(0,))(vs)) | ||
return f0**p_inv * jnp.exp(max_expo) | ||
|
||
|
||
def tilt_bound_bwd(q, cov, v, alpha): | ||
p = 1 / (1 - 1 / q) | ||
expo = 0.5 * (q - 1) * _quad_form(v, cov) | ||
return (alpha * jnp.exp(-expo)) ** p | ||
|
||
|
||
def tilt_bound_bwd_tile(q, cov, vs, alpha): | ||
def _compute_expo(v): | ||
return 0.5 * (q - 1) * _quad_form(v, cov) | ||
|
||
p = 1 / (1 - 1 / q) | ||
max_expo = jnp.max(jax.vmap(_compute_expo, in_axes=(0,))(vs)) | ||
return (alpha * jnp.exp(-max_expo)) ** p |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import jax.numpy as jnp | ||
|
||
|
||
class ForwardQCPSolver: | ||
""" | ||
Solves the minimization problem: | ||
0.5 * (q-1) * s_sq * v ** 2 - log(f0) / q | ||
with respect to q >= 1. | ||
""" | ||
|
||
def __init__(self, scale): | ||
self.scale = scale | ||
|
||
def solve(self, v, f0): | ||
logf0 = jnp.log(f0) | ||
mv_sqrt = self.scale * jnp.abs(v) | ||
q_opt = jnp.sqrt(-2 * logf0) / mv_sqrt | ||
return jnp.maximum(q_opt, 1) | ||
|
||
|
||
class BackwardQCPSolver: | ||
""" | ||
Solves the minimization problem: | ||
0.5 * q * s_sq * v ** 2 - log(alpha) * q / (q-1) | ||
with respect to q >= 1. | ||
""" | ||
|
||
def __init__(self, scale): | ||
self.scale = scale | ||
|
||
def solve(self, v, alpha): | ||
mv_sqrt = self.scale * jnp.abs(v) | ||
return 1 + jnp.sqrt(-2 * jnp.log(alpha)) / mv_sqrt | ||
|
||
|
||
class TileForwardQCPSolver: | ||
""" | ||
Solves the minimization problem: | ||
0.5 * (q-1) * s_sq * max_v v ** 2 - log(f0) / q | ||
with respect to q >= 1. | ||
""" | ||
|
||
def __init__(self, scale): | ||
self.scale = scale | ||
|
||
def solve(self, vs, f0): | ||
logf0 = jnp.log(f0) | ||
mv_sqrt = self.scale * jnp.max(jnp.abs(vs)) | ||
q_opt = jnp.sqrt(-2 * logf0) / mv_sqrt | ||
return jnp.maximum(q_opt, 1) | ||
|
||
|
||
class TileBackwardQCPSolver: | ||
""" | ||
Solves the minimization problem: | ||
0.5 * q * s_sq * max_v v ** 2 - log(alpha) * q / (q-1) | ||
with respect to q >= 1. | ||
""" | ||
|
||
def __init__(self, scale): | ||
self.scale = scale | ||
|
||
def solve(self, vs, alpha): | ||
mv_sqrt = self.scale * jnp.max(jnp.abs(vs)) | ||
return 1 + jnp.sqrt(-2 * jnp.log(alpha)) / mv_sqrt | ||
|
||
|
||
def tilt_bound_fwd(q, scale, v, f0): | ||
p_inv = 1 - 1 / q | ||
expo = 0.5 * (q - 1) * (scale * v) ** 2 | ||
return f0**p_inv * jnp.exp(expo) | ||
|
||
|
||
def tilt_bound_fwd_tile(q, scale, vs, f0): | ||
p_inv = 1 - 1 / q | ||
max_expo = 0.5 * (q - 1) * (scale * jnp.max(jnp.abs(vs))) ** 2 | ||
return f0**p_inv * jnp.exp(max_expo) | ||
|
||
|
||
def tilt_bound_bwd(q, scale, v, alpha): | ||
p = 1 / (1 - 1 / q) | ||
expo = 0.5 * (q - 1) * (scale * v) ** 2 | ||
return (alpha * jnp.exp(-expo)) ** p | ||
|
||
|
||
def tilt_bound_bwd_tile(q, scale, vs, alpha): | ||
p = 1 / (1 - 1 / q) | ||
max_expo = 0.5 * (q - 1) * (scale * jnp.max(jnp.abs(vs))) ** 2 | ||
return (alpha * jnp.exp(-max_expo)) ** p |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import numpy as np | ||
|
||
import confirm.mini_imprint.bound.normal as normal | ||
|
||
|
||
def fwd_qcp_derivative(q, scale, v, f0): | ||
return 0.5 * (scale * v) ** 2 + np.log(f0) / q**2 | ||
|
||
|
||
def bwd_qcp_derivative(q, scale, v, alpha): | ||
return 0.5 * (scale * v) ** 2 + np.log(alpha) / (q - 1) ** 2 | ||
|
||
|
||
def tile_fwd_qcp_derivative(q, scale, vs, f0): | ||
mv = np.max((scale * vs) ** 2) | ||
return 0.5 * mv + np.log(f0) / q**2 | ||
|
||
|
||
def tile_bwd_qcp_derivative(q, scale, vs, alpha): | ||
mv = np.max((scale * vs) ** 2) | ||
return 0.5 * mv + np.log(alpha) / (q - 1) ** 2 | ||
|
||
|
||
def test_fwd_qcp_solver(): | ||
scale = 2.0 | ||
v = -0.321 | ||
f0 = 0.025 | ||
fwd_solver = normal.ForwardQCPSolver(scale) | ||
q_opt = fwd_solver.solve(v, f0) | ||
q_opt_deriv = fwd_qcp_derivative(q_opt, scale, v, f0) | ||
np.testing.assert_almost_equal(q_opt_deriv, 0.0) | ||
|
||
|
||
def test_bwd_qcp_solver(): | ||
scale = 2.0 | ||
v = -0.321 | ||
alpha = 0.025 | ||
bwd_solver = normal.BackwardQCPSolver(scale) | ||
q_opt = bwd_solver.solve(v, alpha) | ||
q_opt_deriv = bwd_qcp_derivative(q_opt, scale, v, alpha) | ||
np.testing.assert_almost_equal(q_opt_deriv, 0.0) | ||
|
||
|
||
def test_tile_fwd_qcp_solver(): | ||
scale = 3.2 | ||
vs = np.array([-0.1, 0.2]) | ||
f0 = 0.025 | ||
fwd_solver = normal.TileForwardQCPSolver(scale) | ||
q_opt = fwd_solver.solve(vs, f0) | ||
q_opt_deriv = tile_fwd_qcp_derivative(q_opt, scale, vs, f0) | ||
print(q_opt) | ||
np.testing.assert_almost_equal(q_opt_deriv, 0.0) | ||
|
||
|
||
def test_tile_bwd_qcp_solver(): | ||
scale = 1.2 | ||
vs = np.array([-0.3, 0.1]) | ||
alpha = 0.025 | ||
bwd_solver = normal.TileBackwardQCPSolver(scale) | ||
q_opt = bwd_solver.solve(vs, alpha) | ||
q_opt_deriv = tile_bwd_qcp_derivative(q_opt, scale, vs, alpha) | ||
np.testing.assert_almost_equal(q_opt_deriv, 0.0) | ||
|
||
|
||
def test_fwd_bwd_invariance(): | ||
scale = 2.0 | ||
v = -0.321 | ||
f0 = 0.025 | ||
q = 3.2 | ||
fwd_bound = normal.tilt_bound_fwd(q, scale, v, f0) | ||
bwd_bound = normal.tilt_bound_bwd(q, scale, v, fwd_bound) | ||
np.testing.assert_almost_equal(bwd_bound, f0) | ||
|
||
|
||
def test_tile_fwd_bwd_invariance(): | ||
scale = 1.2 | ||
vs = np.array([-0.3, 0.1]) | ||
f0 = 0.025 | ||
q = 5.1 | ||
fwd_bound = normal.tilt_bound_fwd_tile(q, scale, vs, f0) | ||
bwd_bound = normal.tilt_bound_bwd_tile(q, scale, vs, fwd_bound) | ||
np.testing.assert_almost_equal(bwd_bound, f0) |