Skip to content

Commit

Permalink
* fix annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Aug 13, 2024
1 parent 95e4199 commit 1a2bb85
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions essm_jax/essm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Extended Gaussian State Space Model."""

import dataclasses
from typing import Callable, NamedTuple, Tuple, Union
from typing import Callable, NamedTuple, Tuple, Union, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -52,7 +52,7 @@ class InitialPrior(NamedTuple):
covariance: jax.Array # [latent_size, latent_size] The covariance of the initial state


def _efficient_add_scalar_diag(A: jax.Array, c: jax.Array | float) -> jax.Array:
def _efficient_add_scalar_diag(A: jax.Array, c: Union[jax.Array, float]) -> jax.Array:
"""
Efficiently add a scalar to the diagonal of a matrix.
Expand Down Expand Up @@ -116,7 +116,7 @@ def _transition_fn(z):

return JVPLinearOp(_transition_fn, more_outputs_than_inputs=False)

def get_observation_jacobian(self, t: jax.Array, observation_size: int | None = None) -> JVPLinearOp:
def get_observation_jacobian(self, t: jax.Array, observation_size: Optional[int] = None) -> JVPLinearOp:
def _observation_fn(z):
return self.observation_fn(z, t).mean()

Expand All @@ -133,7 +133,7 @@ def observation_matrix(self, z, t):
Hop = self.get_observation_jacobian(t)
return Hop(z).to_dense()

def sample(self, key, num_time: int, t0: jax.Array | int = 0) -> SampleResult:
def sample(self, key, num_time: int, t0: Union[jax.Array, int] = 0) -> SampleResult:
"""
Sample from the model.
Expand Down Expand Up @@ -175,7 +175,7 @@ def _sample_latents_op(latent, y):

return samples

def _check_shapes(self, observations: jax.Array, mask: jax.Array | None = None):
def _check_shapes(self, observations: jax.Array, mask: Optional[jax.Array] = None):
"""
Check the shapes of the observations and mask.
Expand All @@ -197,7 +197,7 @@ def _check_shapes(self, observations: jax.Array, mask: jax.Array | None = None):
raise ValueError('mask and observations must have the same length')

def forward_simulate(self, key: jax.Array, num_time: int,
observations: jax.Array, mask: jax.Array | None = None) -> SampleResult:
observations: jax.Array, mask: Optional[jax.Array] = None) -> SampleResult:
"""
Simulate from the model, from the end of the forward filtering pass.
Expand All @@ -224,9 +224,9 @@ def forward_simulate(self, key: jax.Array, num_time: int,
)
return new_essm.sample(key=key, num_time=num_time, t0=filter_result.t[-1])

def forward_filter(self, observations: jax.Array, mask: jax.Array | None = None,
def forward_filter(self, observations: jax.Array, mask: Optional[jax.Array] = None,
marginal_likelihood_only: bool = False,
t0: jax.Array | int = 0) -> FilterResult | jax.Array:
t0: Union[jax.Array, int] = 0) -> FilterResult | jax.Array:
"""
Run the forward filtering pass, computing the total marginal likelihood
Expand Down Expand Up @@ -379,7 +379,7 @@ def _filter_op(carry: Carry, y: YType) -> Tuple[Carry, FilterResult]:
return final_accumulate.log_cumulative_marginal_likelihood
return filter_results

def log_prob(self, observations: jax.Array, mask: jax.Array | None = None) -> jax.Array:
def log_prob(self, observations: jax.Array, mask: Optional[jax.Array] = None) -> jax.Array:
"""
Compute the log probability of the observations under the model.
Expand All @@ -392,8 +392,8 @@ def log_prob(self, observations: jax.Array, mask: jax.Array | None = None) -> ja
"""
return self.forward_filter(observations, mask, marginal_likelihood_only=True)

def posterior_marginals(self, observations: jax.Array, mask: jax.Array | None = None,
t0: jax.Array | int = 0) -> Union[
def posterior_marginals(self, observations: jax.Array, mask: Optional[jax.Array] = None,
t0: Union[jax.Array, int] = 0) -> Union[
SmoothingResult, Tuple[SmoothingResult, InitialPrior]]:
"""
Compute the posterior marginal distributions of the latents, p(z[t] | x[:T]).
Expand Down

0 comments on commit 1a2bb85

Please sign in to comment.