Skip to content

Commit

Permalink
Fix Wrong implementation of u_t(x|z) in Brownian Bridge (#600)
Browse files Browse the repository at this point in the history
* refactor naming of x0 x1 instead of src tgt

* change genot and otfm implementations to fit the new flow changes
  • Loading branch information
selmanozleyen authored Nov 28, 2024
1 parent 2ffd45f commit 561adbb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
51 changes: 35 additions & 16 deletions src/ott/neural/methods/flows/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, sigma: float):

@abc.abstractmethod
def compute_mu_t(
self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray
self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray
) -> jnp.ndarray:
"""Compute the mean of the probability path.
Expand All @@ -45,8 +45,8 @@ def compute_mu_t(
Args:
t: Time :math:`t` of shape ``[batch, 1]``.
src: Sample from the source distribution of shape ``[batch, ...]``.
tgt: Sample from the target distribution of shape ``[batch, ...]``.
x0: Sample from the source distribution of shape ``[batch, ...]``.
x1: Sample from the target distribution of shape ``[batch, ...]``.
"""

@abc.abstractmethod
Expand All @@ -62,7 +62,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray:

@abc.abstractmethod
def compute_ut(
self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray
self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray
) -> jnp.ndarray:
"""Evaluate the conditional vector field.
Expand All @@ -71,15 +71,16 @@ def compute_ut(
Args:
t: Time :math:`t` of shape ``[batch, 1]``.
src: Sample from the source distribution of shape ``[batch, ...]``.
tgt: Sample from the target distribution of shape ``[batch, ...]``.
x: Current position :math:`x` of shape ``[batch, ...]``.
x0: Source position :math:`x_0` of shape ``[batch, ...]``.
x1: Target position :math:`x_1` of shape ``[batch, ...]``.
Returns:
Conditional vector field evaluated at time :math:`t`.
"""

def compute_xt(
self, rng: jax.Array, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray
self, rng: jax.Array, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray
) -> jnp.ndarray:
"""Sample from the probability path.
Expand All @@ -89,15 +90,15 @@ def compute_xt(
Args:
rng: Random number generator.
t: Time :math:`t` of shape ``[batch, 1]``.
src: Sample from the source distribution of shape ``[batch, ...]``.
tgt: Sample from the target distribution of shape ``[batch, ...]``.
x0: Sample from the source distribution of shape ``[batch, ...]``.
x1: Sample from the target distribution of shape ``[batch, ...]``.
Returns:
Samples from the probability path between :math:`x_0` and :math:`x_1`
at time :math:`t`.
"""
noise = jax.random.normal(rng, shape=src.shape)
mu_t = self.compute_mu_t(t, src, tgt)
noise = jax.random.normal(rng, shape=x0.shape)
mu_t = self.compute_mu_t(t, x0, x1)
sigma_t = self.compute_sigma_t(t)
return mu_t + sigma_t * noise

Expand All @@ -110,15 +111,15 @@ class StraightFlow(BaseFlow, abc.ABC):
"""

def compute_mu_t( # noqa: D102
self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray
self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray
) -> jnp.ndarray:
return (1.0 - t) * src + t * tgt
return (1.0 - t) * x0 + t * x1

def compute_ut( # noqa: D102
self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray
self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray
) -> jnp.ndarray:
del t
return tgt - src
del t, x
return x1 - x0


class ConstantNoiseFlow(StraightFlow):
Expand Down Expand Up @@ -162,3 +163,21 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray:
at time :math:`t`.
"""
return self.sigma * jnp.sqrt(t * (1.0 - t))

def compute_ut(
self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray
) -> jnp.ndarray:
r"""Compute the conditional vector field :math:`u_t(x|z)`.
Args:
t: Time :math:`t` of shape ``[batch, 1]``.
x: Current position :math:`x` of shape ``[batch, ...]``.
x0: Source position :math:`x_0` of shape ``[batch, ...]``.
x1: Target position :math:`x_1` of shape ``[batch, ...]``.
Returns:
The vector field :math:`u_t(x|z)` at time :math:`t`.
"""
drift_term = (1 - 2 * t) / (2 * t * (1 - t)) * (x - (t * x1 + (1 - t) * x0))
control_term = x1 - x0
return drift_term + control_term
2 changes: 1 addition & 1 deletion src/ott/neural/methods/flows/genot.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def loss_fn(
x_t,
cond,
rngs={"dropout": rng_dropout})
u_t = self.flow.compute_ut(time, latent, target)
u_t = self.flow.compute_ut(time, x_t, latent, target)

return jnp.mean((v_t - u_t) ** 2)

Expand Down
2 changes: 1 addition & 1 deletion src/ott/neural/methods/flows/otfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def loss_fn(
x_t,
source_conditions,
rngs={"dropout": rng_dropout})
u_t = self.flow.compute_ut(t, source, target)
u_t = self.flow.compute_ut(t, x_t, source, target)

return jnp.mean((v_t - u_t) ** 2)

Expand Down

0 comments on commit 561adbb

Please sign in to comment.