From 561adbbda9a88ea11d42d3262dbb5ce81bd482e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+selmanozleyen@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:53:19 +0100 Subject: [PATCH] Fix Wrong implementation of u_t(x|z) in Brownian Bridge (#600) * refactor naming of x0 x1 instead of src tgt * change genot and otfm implementations to fit the new flow changes --- src/ott/neural/methods/flows/dynamics.py | 51 ++++++++++++++++-------- src/ott/neural/methods/flows/genot.py | 2 +- src/ott/neural/methods/flows/otfm.py | 2 +- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/ott/neural/methods/flows/dynamics.py b/src/ott/neural/methods/flows/dynamics.py index 3ca60168c..fc3c54099 100644 --- a/src/ott/neural/methods/flows/dynamics.py +++ b/src/ott/neural/methods/flows/dynamics.py @@ -36,7 +36,7 @@ def __init__(self, sigma: float): @abc.abstractmethod def compute_mu_t( - self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray + self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray ) -> jnp.ndarray: """Compute the mean of the probability path. @@ -45,8 +45,8 @@ def compute_mu_t( Args: t: Time :math:`t` of shape ``[batch, 1]``. - src: Sample from the source distribution of shape ``[batch, ...]``. - tgt: Sample from the target distribution of shape ``[batch, ...]``. + x0: Sample from the source distribution of shape ``[batch, ...]``. + x1: Sample from the target distribution of shape ``[batch, ...]``. """ @abc.abstractmethod @@ -62,7 +62,7 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: @abc.abstractmethod def compute_ut( - self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray + self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray ) -> jnp.ndarray: """Evaluate the conditional vector field. @@ -71,15 +71,16 @@ def compute_ut( Args: t: Time :math:`t` of shape ``[batch, 1]``. - src: Sample from the source distribution of shape ``[batch, ...]``. - tgt: Sample from the target distribution of shape ``[batch, ...]``. + x: Current position :math:`x` of shape ``[batch, ...]``. + x0: Source position :math:`x_0` of shape ``[batch, ...]``. + x1: Target position :math:`x_1` of shape ``[batch, ...]``. Returns: Conditional vector field evaluated at time :math:`t`. """ def compute_xt( - self, rng: jax.Array, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray + self, rng: jax.Array, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray ) -> jnp.ndarray: """Sample from the probability path. @@ -89,15 +90,15 @@ def compute_xt( Args: rng: Random number generator. t: Time :math:`t` of shape ``[batch, 1]``. - src: Sample from the source distribution of shape ``[batch, ...]``. - tgt: Sample from the target distribution of shape ``[batch, ...]``. + x0: Sample from the source distribution of shape ``[batch, ...]``. + x1: Sample from the target distribution of shape ``[batch, ...]``. Returns: Samples from the probability path between :math:`x_0` and :math:`x_1` at time :math:`t`. """ - noise = jax.random.normal(rng, shape=src.shape) - mu_t = self.compute_mu_t(t, src, tgt) + noise = jax.random.normal(rng, shape=x0.shape) + mu_t = self.compute_mu_t(t, x0, x1) sigma_t = self.compute_sigma_t(t) return mu_t + sigma_t * noise @@ -110,15 +111,15 @@ class StraightFlow(BaseFlow, abc.ABC): """ def compute_mu_t( # noqa: D102 - self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray + self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray ) -> jnp.ndarray: - return (1.0 - t) * src + t * tgt + return (1.0 - t) * x0 + t * x1 def compute_ut( # noqa: D102 - self, t: jnp.ndarray, src: jnp.ndarray, tgt: jnp.ndarray + self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray ) -> jnp.ndarray: - del t - return tgt - src + del t, x + return x1 - x0 class ConstantNoiseFlow(StraightFlow): @@ -162,3 +163,21 @@ def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: at time :math:`t`. """ return self.sigma * jnp.sqrt(t * (1.0 - t)) + + def compute_ut( + self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray + ) -> jnp.ndarray: + r"""Compute the conditional vector field :math:`u_t(x|z)`. + + Args: + t: Time :math:`t` of shape ``[batch, 1]``. + x: Current position :math:`x` of shape ``[batch, ...]``. + x0: Source position :math:`x_0` of shape ``[batch, ...]``. + x1: Target position :math:`x_1` of shape ``[batch, ...]``. + + Returns: + The vector field :math:`u_t(x|z)` at time :math:`t`. + """ + drift_term = (1 - 2 * t) / (2 * t * (1 - t)) * (x - (t * x1 + (1 - t) * x0)) + control_term = x1 - x0 + return drift_term + control_term diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index 8b281b590..b3d71f192 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -139,7 +139,7 @@ def loss_fn( x_t, cond, rngs={"dropout": rng_dropout}) - u_t = self.flow.compute_ut(time, latent, target) + u_t = self.flow.compute_ut(time, x_t, latent, target) return jnp.mean((v_t - u_t) ** 2) diff --git a/src/ott/neural/methods/flows/otfm.py b/src/ott/neural/methods/flows/otfm.py index 1bd97cc42..68073ba9e 100644 --- a/src/ott/neural/methods/flows/otfm.py +++ b/src/ott/neural/methods/flows/otfm.py @@ -87,7 +87,7 @@ def loss_fn( x_t, source_conditions, rngs={"dropout": rng_dropout}) - u_t = self.flow.compute_ut(t, source, target) + u_t = self.flow.compute_ut(t, x_t, source, target) return jnp.mean((v_t - u_t) ** 2)