diff --git a/examples/scripts/exp_UKF.py b/examples/scripts/exp_UKF.py
new file mode 100644
index 000000000..e4cf24f3e
--- /dev/null
+++ b/examples/scripts/exp_UKF.py
@@ -0,0 +1,116 @@
+import pybop
+import pybamm
+import numpy as np
+from examples.standalone.model import ExponentialDecay
+
+# Parameter set and model definition
+parameter_set = pybamm.ParameterValues({"k": "[input]", "y0": "[input]"})
+model = ExponentialDecay(parameter_set=parameter_set, n_states=1)
+x0 = np.array([0.1, 1.0])
+
+# Fitting parameters
+parameters = [
+    pybop.Parameter(
+        "k",
+        prior=pybop.Gaussian(0.1, 0.05),
+        bounds=[0, 1],
+    ),
+    pybop.Parameter(
+        "y0",
+        prior=pybop.Gaussian(1, 0.05),
+        bounds=[0, 3],
+    ),
+]
+
+# Verification: save fixed inputs for testing
+inputs = dict()
+for i, param in enumerate(parameters):
+    inputs[param.name] = x0[i]
+
+# Make a prediction with measurement noise
+sigma = 1e-2
+t_eval = np.linspace(0, 20, 10)
+values = model.predict(t_eval=t_eval, inputs=inputs)
+values = values["2y"].data
+corrupt_values = values + np.random.normal(0, sigma, len(t_eval))
+
+# Verification step: compute the analytical solution for 2y
+expected_values = 2 * inputs["y0"] * np.exp(-inputs["k"] * t_eval)
+
+# Verification step: make another prediction using the Observer class
+model.build(parameters=parameters)
+simulator = pybop.Observer(parameters, model, signal=["2y"], x0=x0)
+simulator._time_data = t_eval
+measurements = simulator.evaluate(x0)
+measurements = measurements[:, 0]
+
+# Verification step: Compare by plotting
+go = pybop.PlotlyManager().go
+line1 = go.Scatter(x=t_eval, y=corrupt_values, name="Corrupt values", mode="markers")
+line2 = go.Scatter(
+    x=t_eval, y=expected_values, name="Expected trajectory", mode="lines"
+)
+line3 = go.Scatter(x=t_eval, y=measurements, name="Observed values", mode="markers")
+fig = go.Figure(data=[line1, line2, line3])
+
+# Form dataset
+dataset = pybop.Dataset(
+    {
+        "Time [s]": t_eval,
+        "Current function [A]": 0 * t_eval,  # placeholder
+        "2y": corrupt_values,
+    }
+)
+
+# Build the model to get the number of states
+model.build(dataset=dataset.data, parameters=parameters)
+
+# Define the UKF observer
+signal = ["2y"]
+n_states = model.n_states
+n_signals = len(signal)
+covariance = np.diag([sigma**2] * n_states)
+process_noise = np.diag([1e-6] * n_states)
+measurement_noise = np.diag([sigma**2] * n_signals)
+observer = pybop.UnscentedKalmanFilterObserver(
+    parameters,
+    model,
+    covariance,
+    process_noise,
+    measurement_noise,
+    dataset,
+    signal=signal,
+    x0=x0,
+)
+
+# Verification step: Find the maximum likelihood estimate given the true parameters
+estimation = observer.evaluate(x0)
+estimation = estimation[:, 0]
+
+# Verification step: Add the estimate to the plot
+line4 = go.Scatter(x=t_eval, y=estimation, name="Estimated trajectory", mode="lines")
+fig.add_trace(line4)
+fig.show()
+
+# Generate problem, cost function, and optimisation class
+cost = pybop.ObserverCost(observer)
+optim = pybop.Optimisation(cost, optimiser=pybop.CMAES, verbose=True)
+
+# Run optimisation
+x, final_cost = optim.run()
+print("Estimated parameters:", x)
+
+# Plot the timeseries output (requires model that returns Voltage)
+pybop.quick_plot(x, cost, title="Optimised Comparison")
+
+# Plot convergence
+pybop.plot_convergence(optim)
+
+# Plot the parameter traces
+pybop.plot_parameters(optim)
+
+# Plot the cost landscape
+pybop.plot_cost2d(cost, steps=15)
+
+# Plot the cost landscape with optimisation path
+pybop.plot_cost2d(cost, optim=optim, steps=15)
diff --git a/examples/scripts/spm_UKF.py b/examples/scripts/spm_UKF.py
new file mode 100644
index 000000000..7128293cd
--- /dev/null
+++ b/examples/scripts/spm_UKF.py
@@ -0,0 +1,82 @@
+import pybop
+import numpy as np
+
+# Parameter set and model definition
+parameter_set = pybop.ParameterSet.pybamm("Chen2020")
+model = pybop.lithium_ion.SPM(parameter_set=parameter_set)
+
+# Fitting parameters
+parameters = [
+    pybop.Parameter(
+        "Negative electrode active material volume fraction",
+        prior=pybop.Gaussian(0.6, 0.05),
+        bounds=[0.5, 0.8],
+    ),
+    pybop.Parameter(
+        "Positive electrode active material volume fraction",
+        prior=pybop.Gaussian(0.48, 0.05),
+        bounds=[0.4, 0.7],
+    ),
+]
+
+# Make a prediction with measurement noise
+sigma = 0.001
+t_eval = np.arange(0, 300, 2)
+values = model.predict(t_eval=t_eval)
+corrupt_values = values["Voltage [V]"].data + np.random.normal(0, sigma, len(t_eval))
+
+# Form dataset
+dataset = pybop.Dataset(
+    {
+        "Time [s]": t_eval,
+        "Current function [A]": values["Current [A]"].data,
+        "Voltage [V]": corrupt_values,
+    }
+)
+
+# Build the model to get the number of states
+model.build(dataset=dataset.data, parameters=parameters)
+
+# Define the UKF observer, setting the particle boundaries as uncertain states
+signal = ["Voltage [V]"]
+n_states = model.n_states
+n_signals = len(signal)
+covariance = np.diag([0] * 19 + [sigma**2] + [0] * 19 + [sigma**2])
+process_noise = np.diag([0] * 19 + [1e-6] + [0] * 19 + [1e-6])
+measurement_noise = np.diag([sigma**2])
+observer = pybop.UnscentedKalmanFilterObserver(
+    parameters,
+    model,
+    covariance,
+    process_noise,
+    measurement_noise,
+    dataset,
+    signal=signal,
+)
+
+# Generate problem, cost function, and optimisation class
+cost = pybop.ObserverCost(observer)
+optim = pybop.Optimisation(cost, optimiser=pybop.PSO, verbose=True)
+
+# Parameter identification using the current observer implementation is very slow
+# so let's restrict the number of iterations and reduce the number of plots
+optim.set_max_iterations(5)
+
+# Run optimisation
+x, final_cost = optim.run()
+print("Estimated parameters:", x)
+
+# Plot the timeseries output (requires model that returns Voltage)
+pybop.quick_plot(x, cost, title="Optimised Comparison")
+
+# # Plot convergence
+# pybop.plot_convergence(optim)
+
+# # Plot the parameter traces
+# pybop.plot_parameters(optim)
+
+# # Plot the cost landscape
+# pybop.plot_cost2d(cost, steps=5)
+
+# # Plot the cost landscape with optimisation path
+# pybop.plot_cost2d(cost, optim=optim, steps=5)
diff --git a/examples/standalone/exponential_decay.py b/examples/standalone/model.py
similarity index 80%
rename from examples/standalone/exponential_decay.py
rename to examples/standalone/model.py
index 5bc5c9efb..9943ff3d8 100644
--- a/examples/standalone/exponential_decay.py
+++ b/examples/standalone/model.py
@@ -15,15 +15,15 @@ class ExponentialDecay(BaseModel):
     def __init__(
         self,
         name: str = "Constant Model",
-        parameters: pybamm.ParameterValues = None,
-        nstate: int = 1,
+        parameter_set: pybamm.ParameterValues = None,
+        n_states: int = 1,
     ):
         super().__init__()
-        self.nstate = nstate
-        if nstate < 1:
-            raise ValueError("nstate must be greater than 0")
+        self.n_states = n_states
+        if n_states < 1:
+            raise ValueError("The number of states (n_states) must be greater than 0")
         self.pybamm_model = pybamm.BaseModel()
-        ys = [pybamm.Variable(f"y_{i}") for i in range(nstate)]
+        ys = [pybamm.Variable(f"y_{i}") for i in range(n_states)]
         k = pybamm.Parameter("k")
         y0 = pybamm.Parameter("y0")
         self.pybamm_model.rhs = {y: -k * y for y in ys}
@@ -41,7 +41,7 @@ def __init__(
         self.name = name
 
         self.default_parameter_values = (
-            default_parameter_values if parameters is None else parameters
+            default_parameter_values if parameter_set is None else parameter_set
         )
         self._parameter_set = self.default_parameter_values
         self._unprocessed_parameter_set = self._parameter_set
diff --git a/pybop/_costs.py b/pybop/_costs.py
index 9a6c6d347..4ab8a35c0 100644
--- a/pybop/_costs.py
+++ b/pybop/_costs.py
@@ -1,7 +1,6 @@
 import numpy as np
 
 from pybop.observers.observer import Observer
-from pybop._problem import FittingProblem
 
 
 class BaseCost:
@@ -26,6 +25,8 @@ class BaseCost:
         The bounds for the model parameters.
     n_parameters : int
         The number of parameters in the model.
+    n_outputs : int
+        The number of outputs in the model.
     """
 
     def __init__(self, problem):
@@ -35,6 +36,7 @@ def __init__(self, problem):
             self.x0 = problem.x0
             self.bounds = problem.bounds
             self.n_parameters = problem.n_parameters
+            self.n_outputs = problem.n_outputs
 
     def __call__(self, x, grad=None):
         """
@@ -255,13 +257,13 @@ def _evaluateS1(self, x):
         y, dy = self.problem.evaluateS1(x)
         if len(y) < len(self._target):
             e = np.float64(np.inf)
-            de = self._de * np.ones(self.problem.n_parameters)
+            de = self._de * np.ones(self.n_parameters)
         else:
             dy = dy.reshape(
                 (
                     self.problem.n_time_data,
-                    self.problem.n_outputs,
-                    self.problem.n_parameters,
+                    self.n_outputs,
+                    self.n_parameters,
                 )
             )
             r = y - self._target
@@ -297,11 +299,11 @@ class ObserverCost(BaseCost):
 
     """
 
-    def __init__(self, problem: FittingProblem, observer: Observer):
-        super(ObserverCost, self).__init__(problem)
+    def __init__(self, observer: Observer):
+        super().__init__(problem=observer)
         self._observer = observer
 
-    def __call__(self, x, grad=None):
+    def _evaluate(self, x, grad=None):
         """
         Calculate the observer cost for a given set of parameters.
 
@@ -317,16 +319,12 @@ def __call__(self, x, grad=None):
         -------
         float
             The observer cost (negative of the log likelihood).
-
         """
-        try:
-            inputs = {key: x[i] for i, key in enumerate(self._observer._model.fit_keys)}
-            log_likelihood = self._observer.log_likelihood(
-                self.problem.target(), self.problem.time_data(), inputs
-            )
-            return -log_likelihood
-        except Exception as e:
-            raise ValueError(f"Error in cost calculation: {e}")
+        inputs = {key: x[i] for i, key in enumerate(self._observer._model.fit_keys)}
+        log_likelihood = self._observer.log_likelihood(
+            self._target, self._observer.time_data(), inputs
+        )
+        return -log_likelihood
 
     def evaluateS1(self, x):
         """
diff --git a/pybop/_problem.py b/pybop/_problem.py
index 1c9e33ab7..87b933b04 100644
--- a/pybop/_problem.py
+++ b/pybop/_problem.py
@@ -1,8 +1,5 @@
-from typing import Optional
 import numpy as np
 
-from pybop.observers.observer import Observer
-
 
 class BaseProblem:
     """
@@ -16,6 +13,8 @@ class BaseProblem:
         The model to be used for the problem (default: None).
     check_model : bool, optional
         Flag to indicate if the model should be checked (default: True).
+    signal: List[str]
+      The signal to observe.
     init_soc : float, optional
         Initial state of charge (default: None).
     x0 : np.ndarray, optional
@@ -132,8 +131,8 @@ class FittingProblem(BaseProblem):
         The model to fit.
     parameters : list
         List of parameters for the problem.
-    dataset : list
-        List of data objects to fit the model to.
+    dataset : Dataset
+        Dataset object containing the data to fit the model to.
     signal : str, optional
         The signal to fit (default: "Voltage [V]").
     """
@@ -147,7 +146,6 @@ def __init__(
         signal=["Voltage [V]"],
         init_soc=None,
         x0=None,
-        observer: Optional[Observer] = None,
     ):
         super().__init__(parameters, model, check_model, signal, init_soc, x0)
         self._dataset = dataset.data
@@ -155,7 +153,7 @@ def __init__(
         # Check that the dataset contains time and current
         for name in ["Time [s]", "Current function [A]"] + self.signal:
             if name not in self._dataset:
-                raise ValueError(f"expected {name} in list of dataset")
+                raise ValueError(f"Expected {name} in list of dataset")
 
         self._time_data = self._dataset["Time [s]"]
         self.n_time_data = len(self._time_data)
diff --git a/pybop/models/base_model.py b/pybop/models/base_model.py
index 9887a6b21..c6192a946 100644
--- a/pybop/models/base_model.py
+++ b/pybop/models/base_model.py
@@ -113,6 +113,8 @@ def build(
             # Clear solver and setup model
             self._solver._model_set_up = {}
 
+        self.n_states = self._built_model.len_rhs_and_alg  # len_rhs + len_alg
+
     def set_init_soc(self, init_soc):
         """
         Set the initial state of charge for the battery model.
@@ -214,7 +216,7 @@ def step(self, state: TimeSeriesState, time: np.ndarray) -> TimeSeriesState:
             The time to predict the system to (in whatever time units the model is in)
         """
         dt = time - state.t
-        new_sol = self.solver.step(
+        new_sol = self._solver.step(
             state.sol, self.built_model, dt, npts=2, inputs=state.inputs, save=False
         )
         return TimeSeriesState(sol=new_sol, inputs=state.inputs, t=time)
@@ -252,7 +254,7 @@ def simulate(self, inputs, t_eval) -> np.ndarray[np.float64]:
                 inputs=inputs,
                 allow_infeasible_solutions=self.allow_infeasible_solutions,
             ):
-                sol = self.solver.solve(self.built_model, inputs=inputs, t_eval=t_eval)
+                sol = self._solver.solve(self.built_model, inputs=inputs, t_eval=t_eval)
 
                 predictions = [sol[signal].data for signal in self.signal]
 
@@ -295,7 +297,7 @@ def simulateS1(self, inputs, t_eval):
                 inputs=inputs,
                 allow_infeasible_solutions=self.allow_infeasible_solutions,
             ):
-                sol = self.solver.solve(
+                sol = self._solver.solve(
                     self.built_model,
                     inputs=inputs,
                     t_eval=t_eval,
diff --git a/pybop/observers/observer.py b/pybop/observers/observer.py
index 8f6592aad..0a93fe8a4 100644
--- a/pybop/observers/observer.py
+++ b/pybop/observers/observer.py
@@ -1,9 +1,11 @@
 from typing import List, Optional
 import numpy as np
+from pybop._problem import BaseProblem
 from pybop.models.base_model import BaseModel, Inputs, TimeSeriesState
+from pybop.parameters.parameter import Parameter
 
 
-class Observer(object):
+class Observer(BaseProblem):
     """
     An observer of a time series state. Observers:
      1. keep track of the distribution of a current time series model state
@@ -12,30 +14,45 @@ class Observer(object):
 
     Parameters
     ----------
+    parameters : list
+        List of parameters for the problem.
     model : BaseModel
       The model to observe.
-    inputs: Dict[str, float]
-      The inputs to the model.
+    check_model : bool, optional
+        Flag to indicate if the model should be checked (default: True).
     signal: List[str]
       The signal to observe.
+    init_soc : float, optional
+        Initial state of charge (default: None).
+    x0 : np.ndarray, optional
+        Initial parameter values (default: None).
     """
 
     # define a subtype for covariance matrices for use by derived classes
     Covariance = np.ndarray
 
-    def __init__(self, model: BaseModel, inputs: Inputs, signal: List[str]) -> None:
+    def __init__(
+        self,
+        parameters: List[Parameter],
+        model: BaseModel,
+        check_model=True,
+        signal=["Voltage [V]"],
+        init_soc=None,
+        x0=None,
+    ) -> None:
+        super().__init__(parameters, model, check_model, signal, init_soc, x0)
         if model._built_model is None:
             raise ValueError("Only built models can be used in Observers")
-        if not isinstance(inputs, dict):
-            raise ValueError("Inputs must be of type Dict[str, float]")
-        if not isinstance(signal, list):
-            raise ValueError("Signal must be of type List[str]")
-
         if model.signal is None:
-            model.signal = signal
+            model.signal = self.signal
+
+        inputs = dict()
+        for param in self.parameters:
+            inputs[param.name] = param.value
+
         self._state = model.reinit(inputs)
         self._model = model
-        self._signal = signal
+        self._signal = self.signal
 
     def reset(self, inputs: Inputs) -> None:
         self._state = self._model.reinit(inputs)
@@ -51,9 +68,9 @@ def observe(self, time: float, value: Optional[np.ndarray] = None) -> float:
         Parameters
         ----------
         time : float
-          The time of the new observation.
+            The time of the new observation.
         value : np.ndarray (optional)
-          The new observation.
+            The new observation.
         """
         if time < self._state.t:
             raise ValueError("Time must be increasing.")
@@ -70,9 +87,11 @@ def log_likelihood(
         Parameters
         ----------
         values : np.ndarray
-          The values of the model.
+            The values of the model.
+        times : np.ndarray
+            The times at which to observe the model.
         inputs : Inputs
-          The inputs to the model.
+            The inputs to the model.
         """
         if len(values) != len(times):
             raise ValueError("values and times must have the same length.")
@@ -81,7 +100,7 @@ def log_likelihood(
         for t, v in zip(times, values):
             try:
                 log_likelihood += self.observe(t, v)
-            except ValueError:
+            except Exception:
                 return np.float64(-np.inf)
         return log_likelihood
 
@@ -113,3 +132,41 @@ def get_current_time(self) -> float:
         Returns the current time.
         """
         return self._state.t
+
+    def evaluate(self, x):
+        """
+        Evaluate the model with the given parameters and return the signal.
+
+        Parameters
+        ----------
+        x : np.ndarray
+            Parameter values to evaluate the model at.
+
+        Returns
+        -------
+        y : np.ndarray
+            The model output y(t) simulated with inputs x.
+        """
+        inputs = dict()
+        if isinstance(x[0], Parameter):
+            for param in x:
+                inputs[param.name] = param.value
+        else:  # x is an array of parameter values
+            for i, param in enumerate(self.parameters):
+                inputs[param.name] = x[i]
+        self.reset(inputs)
+
+        output = []
+        if hasattr(self, "_dataset"):
+            ym = self._target
+            for i, t in enumerate(self._time_data):
+                self.observe(t, ym[i])
+                ys = self.get_current_measure()
+                output.append(ys)
+        else:
+            for t in self._time_data:
+                self.observe(t)
+                ys = self.get_current_measure()
+                output.append(ys)
+
+        return np.vstack(output)
diff --git a/pybop/observers/unscented_kalman.py b/pybop/observers/unscented_kalman.py
index 06b1ec612..ae13acc8e 100644
--- a/pybop/observers/unscented_kalman.py
+++ b/pybop/observers/unscented_kalman.py
@@ -5,20 +5,19 @@
 
 from pybop.models.base_model import BaseModel, Inputs
 from pybop.observers.observer import Observer
+from pybop.parameters.parameter import Parameter
 
 
 class UnscentedKalmanFilterObserver(Observer):
     """
-    An observer using the unscented kalman filter. This is a wrapper class for PyBOP, see class UkfFilter for more details on the method.
+    An observer using the unscented Kalman filter. This is a wrapper class for PyBOP, see class SquareRootUKF for more details on the method.
 
     Parameters
     ----------
+    parameters: List[Parameters]
+        The inputs to the model.
     model : BaseModel
         The model to observe.
-    inputs: Dict[str, float]
-        The inputs to the model.
-    signal: str
-        The signal to observe.
     sigma0 : np.ndarray | float
         The covariance matrix of the initial state. If a float is provided, the covariance matrix is set to sigma0 * np.eye(n), where n is the number of states.
         To remove a state from the filter, set the corresponding row and col to zero in both sigma0 and process.
@@ -27,20 +26,65 @@ class UnscentedKalmanFilterObserver(Observer):
         To remove a state from the filter, set the corresponding row and col to zero in both sigma0 and process.
     measure : np.ndarray | float
         The covariance matrix of the measurement noise. If a float is provided, the covariance matrix is set to measure * np.eye(m), where m is the number of measurements.
+    dataset : Dataset
+        Dataset object containing the data to fit the model to.
+    check_model : bool, optional
+        Flag to indicate if the model should be checked (default: True).
+    signal: str
+        The signal to observe.
+    init_soc : float, optional
+        Initial state of charge (default: None).
+    x0 : np.ndarray, optional
+        Initial parameter values (default: None).
     """
 
     Covariance = np.ndarray
 
     def __init__(
         self,
+        parameters: List[Parameter],
         model: BaseModel,
-        inputs: Inputs,
-        signal: List[str],
         sigma0: Union[Covariance, float],
         process: Union[Covariance, float],
         measure: Union[Covariance, float],
+        dataset=None,
+        check_model=True,
+        signal=["Voltage [V]"],
+        init_soc=None,
+        x0=None,
     ) -> None:
-        super().__init__(model, inputs, signal)
+        super().__init__(parameters, model, check_model, signal, init_soc, x0)
+        if dataset is not None:
+            self._dataset = dataset.data
+
+            # Check that the dataset contains time and current
+            for name in ["Time [s]", "Current function [A]"] + self.signal:
+                if name not in self._dataset:
+                    raise ValueError(f"expected {name} in list of dataset")
+
+            self._time_data = self._dataset["Time [s]"]
+            self.n_time_data = len(self._time_data)
+            if np.any(self._time_data < 0):
+                raise ValueError("Times can not be negative.")
+            if np.any(self._time_data[:-1] >= self._time_data[1:]):
+                raise ValueError("Times must be increasing.")
+
+            for signal in self.signal:
+                if len(self._dataset[signal]) != self.n_time_data:
+                    raise ValueError(
+                        f"Time data and {signal} data must be the same length."
+                    )
+            target = [self._dataset[signal] for signal in self.signal]
+            self._target = np.vstack(target).T
+
+        # Add useful parameters to model
+        if model is not None:
+            self._model.signal = self.signal
+            self._model.n_outputs = self.n_outputs
+            if dataset is not None:
+                self._model.n_time_data = self.n_time_data
+
+        # Observer initiation
         self._process = process
 
         x0 = self.get_current_state().as_ndarray()
@@ -69,7 +113,7 @@ def measure_f(x: np.ndarray) -> np.ndarray:
             sol = self._model.get_state(inputs=self._state.inputs, t=self._state.t, x=x)
             return self.get_measure(sol).reshape(-1)
 
-        self._ukf = UkfFilter(
+        self._ukf = SquareRootUKF(
             x0=x0,
             P0=sigma0,
             Rp=process,
@@ -85,6 +129,8 @@ def reset(self, inputs: Inputs) -> None:
     def observe(self, time: float, value: np.ndarray) -> float:
         if value is None:
             raise ValueError("Measurement must be provided.")
+        elif isinstance(value, np.floating):
+            value = np.array([value])
 
         dt = time - self.get_current_time()
         if dt < 0:
@@ -113,6 +159,7 @@ def f(x: np.ndarray) -> np.ndarray:
         return log_likelihood
 
     def get_current_covariance(self) -> Covariance:
+        # Get the covariance from the square-root covariance
         return self._ukf.S @ self._ukf.S.T
 
 
@@ -127,13 +174,14 @@ class SigmaPoint(object):
     w_c: float
 
 
-class UkfFilter(object):
+class SquareRootUKF(object):
     """
-    van der Menve, R., & Wan, E. A. (n.d.). THE SQUARE-ROOT UNSCENTED KALMAN FILTER FOR STATE AND PARAMETER-ESTIMATION. http://wol.ra.phy.cam.ac.uk/mackay
+    van der Menve, R., & Wan, E. A. (2001). THE SQUARE-ROOT UNSCENTED KALMAN FILTER FOR STATE AND PARAMETER-ESTIMATION.
+    https://doi.org/10.1109/ICASSP.2001.940586
 
-    we implement a UKF filter with additive process and measurement noise
+    We implement a square root unscented Kalman filter (UKF) with additive process and measurement noise.
 
-    the square root unscented kalman filter is a variant of the unscented kalman filter that is more numerically stable and has better performance.
+    The square root UKF is a variant of the UKF that is more numerically stable and has better performance.
 
     Parameters
     ----------
@@ -160,12 +208,13 @@ def __init__(
         f: callable,
         h: callable,
     ) -> None:
-        # find states that are zero in both sigma0 and process
+        # Find states that are zero in both sigma0 and process
         zero_rows = np.logical_and(np.all(P0 == 0, axis=0), np.all(Rp == 0, axis=0))
         zero_cols = np.logical_and(np.all(P0 == 0, axis=1), np.all(Rp == 0, axis=1))
         zeros = np.logical_and(zero_rows, zero_cols)
         ones = np.logical_not(zeros)
         states = np.array(range(len(x0)))[ones]
+        bool_mask = np.ix_(ones, ones)
 
         S_filtered = linalg.cholesky(P0[ones, :][:, ones])
         sqrtRp_filtered = linalg.cholesky(Rp[ones, :][:, ones])
@@ -173,8 +222,8 @@ def __init__(
         n = len(x0)
         S = np.zeros((n, n))
         sqrtRp = np.zeros((n, n))
-        S[ones, :][:, ones] = S_filtered
-        sqrtRp[ones, :][:, ones] = sqrtRp_filtered
+        S[bool_mask] = S_filtered
+        sqrtRp[bool_mask] = sqrtRp_filtered
 
         self.x = x0
         self.S = S
@@ -185,13 +234,14 @@ def __init__(
         self.f = f
         self.h = h
         self.states = states
+        self.bool_mask = bool_mask
 
     def reset(self, x: np.ndarray, S: np.ndarray) -> None:
-        self.x = x[self.states]
+        self.x = x
         S_filtered = S[self.states, :][:, self.states]
         S_filtered = linalg.cholesky(S_filtered)
         S_full = S.copy()
-        S_full[self.states, :][:, self.states] = S_filtered
+        S_full[self.bool_mask] = S_filtered
         self.S = S_full
 
     @staticmethod
@@ -199,7 +249,7 @@ def gen_sigma_points(
         x: np.ndarray, S: np.ndarray, alpha: float, beta: float, states: np.ndarray
     ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
         """
-        Generates the sigma points for the unscented transform
+        Generates 2L+1 sigma points for the unscented transform, where L is the number of states.
 
         Parameters
         ----------
@@ -221,21 +271,29 @@ def gen_sigma_points(
         List[float]
             The weights of the sigma points
         List[float]
-            The weights of the covariance points
+            The weights of the covariance of the sigma points
         """
-        kappa = 1.0
+        # Set the scaling parameters: sigma and eta
+        kappa = 0.0
         L = len(states)
         sigma = alpha**2 * (L + kappa) - L
         eta = np.sqrt(L + sigma)
-        wm_0 = sigma / (L + sigma)
-        wc_0 = wm_0 + (1 - alpha**2 + beta)
+
+        # Define the sigma points
         points = np.hstack(
             [x]
             + [x + eta * S[:, i].reshape(-1, 1) for i in states]
             + [x - eta * S[:, i].reshape(-1, 1) for i in states]
         )
-        w_m = np.array([wm_0] + [(1 - wm_0) / (2 * L)] * (2 * L))
-        w_c = np.array([wc_0] + [(1 - wc_0) / (2 * L)] * (2 * L))
+
+        # Define the weights of the sigma points
+        w_m0 = sigma / (L + sigma)
+        w_m = np.array([w_m0] + [1 / (2 * (L + sigma))] * (2 * L))
+
+        # Define the weights of the covariance of the sigma points
+        w_c0 = w_m0 + (1 - alpha**2 + beta)
+        w_c = np.array([w_c0] + [1 / (2 * (L + sigma))] * (2 * L))
+
         return (points, w_m, w_c)
 
     @staticmethod
@@ -261,21 +319,37 @@ def unscented_transform(
         Returns
         -------
         Tuple[np.ndarray, np.ndarray]
-            The mean and covariance of the sigma points
+            The mean and square-root covariance of the sigma points
         """
+        # Update the predicted mean of the sigma points
         x = np.sum(w_m * sigma_points, axis=1).reshape(-1, 1)
-        sigma_points_diff = sigma_points - x
-        A = np.hstack([np.sqrt(w_c[1:]) * (sigma_points_diff[:, 1:]), sqrtR])
-        (
-            _,
-            S,
-        ) = linalg.qr(A.T, mode="economic")
+
+        # Update the predicted square-root covariance
         if states is None:
-            S = UkfFilter.cholupdate(S, sigma_points_diff[:, 0:1], w_c[0])
+            sigma_points_diff = sigma_points - x
+            A = np.hstack([np.sqrt(w_c[1:]) * (sigma_points_diff[:, 1:]), sqrtR])
+            (_, S) = linalg.qr(A.T, mode="economic")
+            S = SquareRootUKF.cholupdate(S, sigma_points_diff[:, 0:1], w_c[0])
         else:
-            S = UkfFilter.filtered_cholupdate(
-                S, sigma_points_diff[:, 0:1], w_c[0], states
+            # First overwrite states without noise to remove numerial error
+            clean = np.full(len(x), True)
+            clean[states] = False
+            x[clean] = sigma_points[clean, 0].reshape(-1, 1)
+
+            sigma_points_diff = sigma_points[states, :] - x[states]
+            A = np.hstack(
+                [
+                    np.sqrt(w_c[1:]) * (sigma_points_diff[:, 1:]),
+                    sqrtR[states, :][:, states],
+                ]
+            )
+            (_, S_filtered) = linalg.qr(A.T, mode="economic")
+            S_filtered = SquareRootUKF.cholupdate(
+                S_filtered, sigma_points_diff[:, 0:1], w_c[0]
             )
+            ones = np.logical_not(clean)
+            S = np.zeros_like(sqrtR)
+            S[np.ix_(ones, ones)] = S_filtered
 
         return x, S
 
@@ -286,14 +360,16 @@ def filtered_cholupdate(
         R_full = R.copy()
         R_filtered = R[states, :][:, states]
         x_filtered = x[states]
-        R_filtered = UkfFilter.cholupdate(R_filtered, x_filtered, w)
-        R_full[states, :][:, states] = R_filtered
+        R_filtered = SquareRootUKF.cholupdate(R_filtered, x_filtered, w)
+        ones = np.full(len(x), False)
+        ones[states] = True
+        R_full[np.ix_(ones, ones)] = R_filtered
         return R_full
 
     @staticmethod
     def cholupdate(R: np.ndarray, x: np.ndarray, w: float) -> np.ndarray:
         """
-        Updates the cholesky decomposition of a matrix (see https://github.com/modusdatascience/choldate/blob/master/choldate/_choldate.pyx)
+        Updates the Cholesky decomposition of a matrix (see https://github.com/modusdatascience/choldate/blob/master/choldate/_choldate.pyx)
 
         Note: will be in scipy soon so replace with this: https://github.com/scipy/scipy/pull/16499
 
@@ -302,7 +378,7 @@ def cholupdate(R: np.ndarray, x: np.ndarray, w: float) -> np.ndarray:
         Parameters
         ----------
         R : np.ndarray
-            The cholesky decomposition of the matrix
+            The Cholesky decomposition of the matrix
         x : np.ndarray
             The vector to add to the matrix
         w : float
@@ -311,21 +387,31 @@ def cholupdate(R: np.ndarray, x: np.ndarray, w: float) -> np.ndarray:
         Returns
         -------
         np.ndarray
-            The updated cholesky decomposition
+            The updated Cholesky decomposition
         """
-        x = x.flatten()
+        sign = np.sign(w)
+        x = np.sqrt(abs(w)) * x.flatten()
         p = x.shape[0]
         for k in range(p):
-            r = np.sqrt(R[k, k] ** 2 + w * x[k] ** 2)
-            # r = UkfFilter.hypot(R[k, k], x[k])
+            Rkk = abs(R[k, k])
+            xk = abs(x[k])
+            r = SquareRootUKF.hypot(Rkk, xk, sign)
             c = r / R[k, k]
             s = x[k] / R[k, k]
             R[k, k] = r
             if k < p - 1:
-                R[k, k + 1 :] = (R[k, k + 1 :] + w * s * x[k + 1 :]) / c
+                R[k, k + 1 :] = (R[k, k + 1 :] + sign * s * x[k + 1 :]) / c
                 x[k + 1 :] = c * x[k + 1 :] - s * R[k, k + 1 :]
         return R
 
+    def hypot(R: float, x: float, sign: float) -> float:
+        if R < x:
+            return R * np.sqrt(1 + sign * R**2 / x**2)
+        elif x < R:
+            return np.sqrt(R**2 + sign * x**2)
+        else:
+            return 0.0
+
     def step(self, y: np.ndarray) -> float:
         """
         Steps the filter forward one step using a measurement. Returns the log likelihood of the measurement.
@@ -340,25 +426,40 @@ def step(self, y: np.ndarray) -> float:
         float
             The log likelihood of the measurement
         """
+        # Sigma point calculation
         sigma_points, w_m, w_c = self.gen_sigma_points(
             self.x, self.S, self.alpha, self.beta, self.states
         )
+
+        # Update sigma points in time
         sigma_points = np.apply_along_axis(self.f, 0, sigma_points)
 
+        # Compute the mean and square-root covariance
         x_minus, S_minus = self.unscented_transform(
             sigma_points, w_m, w_c, self.sqrtRp, self.states
         )
+
+        # Compute the output corresponding to the updated sigma points
         sigma_points_y = np.apply_along_axis(self.h, 0, sigma_points)
+
+        # Compute the mean and square-root covariance
         y_minus, S_y = self.unscented_transform(sigma_points_y, w_m, w_c, self.sqrtRm)
+
+        # Compute the gain from the covariance
         P = np.einsum(
             "k,jk,lk -> jl ", w_c, sigma_points - x_minus, sigma_points_y - y_minus
         )
         gain = linalg.lstsq(linalg.lstsq(P.T, S_y.transpose())[0].T, S_y)[0]
+
+        # Update the states and square-root covariance based on the gain
         residual = y - y_minus
         self.x = x_minus + gain @ residual
         U = gain @ S_y
         self.S = self.filtered_cholupdate(S_minus, U, -1, self.states)
-        log_det = 2 * np.sum(np.log(np.diag(self.S)))
+
+        # Compute the log-likelihood of the covariance
+        S = self.S[self.states, :][:, self.states]
+        log_det = 2 * np.sum(np.log(np.diag(S)))
         n = len(y)
         log_likelihood = -0.5 * (
             n * log_det + residual.T @ linalg.cho_solve((S_y, True), residual)
diff --git a/tests/unit/test_cost.py b/tests/unit/test_cost.py
index f2e439adc..94d5e890c 100644
--- a/tests/unit/test_cost.py
+++ b/tests/unit/test_cost.py
@@ -45,10 +45,10 @@ def problem(self, request):
     )
     def cost(self, problem, request):
         cls = request.param
-        inputs = {p.name: problem.x0[i] for i, p in enumerate(problem.parameters)}
         if cls == pybop.RootMeanSquaredError or cls == pybop.SumSquaredError:
             return cls(problem)
         elif cls == pybop.ObserverCost:
+            inputs = {p.name: problem.x0[i] for i, p in enumerate(problem.parameters)}
             state = problem._model.reinit(inputs)
             n = len(state)
             sigma_diag = [0.0] * n
@@ -59,15 +59,16 @@ def cost(self, problem, request):
             process_diag[1] = 1e-4
             sigma0 = np.diag(sigma_diag)
             process = np.diag(process_diag)
+            dataset = type("dataset", (object,), {"data": problem._dataset})()
             return cls(
-                problem,
-                observer=pybop.UnscentedKalmanFilterObserver(
+                pybop.UnscentedKalmanFilterObserver(
+                    problem.parameters,
                     problem._model,
-                    inputs,
-                    problem.signal,
                     sigma0=sigma0,
                     process=process,
                     measure=1e-4,
+                    dataset=dataset,
+                    signal=problem.signal,
                 ),
             )
 
diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py
index 81524e0d0..39d3f371f 100644
--- a/tests/unit/test_models.py
+++ b/tests/unit/test_models.py
@@ -3,7 +3,7 @@
 import numpy as np
 import pybamm
 
-from examples.standalone.exponential_decay import ExponentialDecay
+from examples.standalone.model import ExponentialDecay
 
 
 class TestModels:
diff --git a/tests/unit/test_observer_unscented_kalman.py b/tests/unit/test_observer_unscented_kalman.py
index 0821e1b29..e89db15b6 100644
--- a/tests/unit/test_observer_unscented_kalman.py
+++ b/tests/unit/test_observer_unscented_kalman.py
@@ -2,10 +2,8 @@
 import numpy as np
 import pybamm
 import pytest
-from pybop.observers.unscented_kalman import UkfFilter
-import matplotlib.pyplot as plt
-
-from examples.standalone.exponential_decay import ExponentialDecay
+from pybop.observers.unscented_kalman import SquareRootUKF
+from examples.standalone.model import ExponentialDecay
 
 
 class TestUKF:
@@ -18,16 +16,34 @@ class TestUKF:
     @pytest.fixture(params=[1, 2, 3])
     def model(self, request):
         model = ExponentialDecay(
-            parameters=pybamm.ParameterValues({"k": "[input]", "y0": "[input]"}),
-            nstate=request.param,
+            parameter_set=pybamm.ParameterValues({"k": "[input]", "y0": "[input]"}),
+            n_states=request.param,
         )
         model.build()
         return model
 
     @pytest.fixture
-    def dataset(self, model: pybop.BaseModel):
-        inputs = {"k": 0.1, "y0": 1.0}
-        observer = pybop.Observer(model, inputs, ["2y"])
+    def parameters(self):
+        return [
+            pybop.Parameter(
+                "k",
+                prior=pybop.Gaussian(0.1, 0.05),
+                bounds=[0, 1],
+            ),
+            pybop.Parameter(
+                "y0",
+                prior=pybop.Gaussian(1, 0.05),
+                bounds=[0, 3],
+            ),
+        ]
+
+    @pytest.fixture
+    def x0(self):
+        return np.array([0.1, 1.0])
+
+    @pytest.fixture
+    def dataset(self, model: pybop.BaseModel, parameters, x0):
+        observer = pybop.Observer(parameters, model, signal=["2y"], x0=x0)
         measurements = []
         t_eval = np.linspace(0, 20, 10)
         for t in t_eval:
@@ -40,10 +56,8 @@ def dataset(self, model: pybop.BaseModel):
         return {"Time [s]": t_eval, "y": measurements}
 
     @pytest.fixture
-    def observer(self, model: pybop.BaseModel):
-        inputs = {"k": 0.1, "y0": 1.0}
-        signal = ["2y"]
-        n = model.nstate
+    def observer(self, model: pybop.BaseModel, parameters, x0):
+        n = model.n_states
         sigma0 = np.diag([self.measure_noise] * n)
         process = np.diag([1e-6] * n)
         # for 3rd  model, set sigma0 and process to zero for the 1st and 2nd state
@@ -54,7 +68,7 @@ def observer(self, model: pybop.BaseModel):
             process[1, 1] = 0
         measure = np.diag([1e-4])
         observer = pybop.UnscentedKalmanFilterObserver(
-            model, inputs, signal, sigma0, process, measure
+            parameters, model, sigma0, process, measure, signal=["2y"], x0=x0
         )
         return observer
 
@@ -77,7 +91,7 @@ def test_cholupdate(self):
 
         # The following is equivalent to the above
         R1_ = R.copy()
-        UkfFilter.cholupdate(R1_, u.copy(), 1.0)
+        SquareRootUKF.cholupdate(R1_, u.copy(), 1.0)
         np.testing.assert_array_almost_equal(R1, R1_)
 
     @pytest.mark.unit
@@ -85,15 +99,11 @@ def test_unscented_kalman_filter(self, dataset, observer):
         t_eval = dataset["Time [s]"]
         measurements = dataset["y"]
         inputs = observer._state.inputs
-        n = observer._model.nstate
+        n = observer._model.n_states
         expected = inputs["y0"] * np.exp(-inputs["k"] * t_eval)
-        plt_x = []
-        plt_y = []
-        plt_stddev = []
-        plt_m = []
-        for i in range(len(t_eval)):
+
+        for i, t in enumerate(t_eval):
             y = np.array([[expected[i]]] * n)
-            t = t_eval[i]
             ym = measurements[:, i]
             observer.observe(t, ym)
             np.testing.assert_array_almost_equal(
@@ -106,23 +116,6 @@ def test_unscented_kalman_filter(self, dataset, observer):
                 np.array([2 * y[0]]),
                 decimal=4,
             )
-            plt_x.append(t)
-            plt_y.append(observer.get_current_state().as_ndarray()[0][0])
-            plt_stddev.append(np.sqrt(observer.get_current_covariance()[0, 0]))
-            plt_m.append(ym[0])
-        plt_y = np.array(plt_y)
-        plt_stddev = np.array(plt_stddev)
-        plt_m = np.array(plt_m)
-        plt.clf()
-        plt.plot(plt_x, plt_y, label="UKF")
-        plt.plot(plt_x, expected, "--", label="Expected")
-        plt.plot(plt_x, plt_y + plt_stddev, label="UKF + stddev")
-        plt.plot(
-            plt_x, expected + np.sqrt(self.measure_noise), label="Expected + stddev"
-        )
-        plt.plot(plt_x, plt_m / 2, ".", label="Measurement")
-        plt.legend()
-        plt.savefig(f"test{n}.png")
 
     @pytest.mark.unit
     def test_observe_no_measurement(self, observer):
@@ -137,17 +130,16 @@ def test_observe_decreasing_time(self, observer):
             observer.observe(0, np.array([2]))
 
     @pytest.mark.unit
-    def test_wrong_input_shapes(self, model):
-        inputs = {"k": 0.1, "y0": 1.0}
+    def test_wrong_input_shapes(self, model, parameters):
         signal = "2y"
-        n = model.nstate
+        n = model.n_states
 
         sigma0 = np.diag([1e-4] * (n + 1))
         process = np.diag([1e-4] * n)
         measure = np.diag([1e-4])
         with pytest.raises(ValueError):
             pybop.UnscentedKalmanFilterObserver(
-                model, inputs, signal, sigma0, process, measure
+                parameters, model, sigma0, process, measure, signal=signal
             )
 
         sigma0 = np.diag([1e-4] * n)
@@ -155,7 +147,7 @@ def test_wrong_input_shapes(self, model):
         measure = np.diag([1e-4])
         with pytest.raises(ValueError):
             pybop.UnscentedKalmanFilterObserver(
-                model, inputs, signal, sigma0, process, measure
+                parameters, model, sigma0, process, measure, signal=signal
             )
 
         sigma0 = np.diag([1e-4] * n)
@@ -163,5 +155,5 @@ def test_wrong_input_shapes(self, model):
         measure = np.diag([1e-4] * 2)
         with pytest.raises(ValueError):
             pybop.UnscentedKalmanFilterObserver(
-                model, inputs, signal, sigma0, process, measure
+                parameters, model, sigma0, process, measure, signal=signal
             )
diff --git a/tests/unit/test_observers.py b/tests/unit/test_observers.py
index a7a4f9bb8..020f8f91a 100644
--- a/tests/unit/test_observers.py
+++ b/tests/unit/test_observers.py
@@ -2,8 +2,7 @@
 import numpy as np
 import pybamm
 import pytest
-
-from examples.standalone.exponential_decay import ExponentialDecay
+from examples.standalone.model import ExponentialDecay
 
 
 class TestObserver:
@@ -14,20 +13,37 @@ class TestObserver:
     @pytest.fixture(params=[1, 2])
     def model(self, request):
         model = ExponentialDecay(
-            parameters=pybamm.ParameterValues({"k": "[input]", "y0": "[input]"}),
-            nstate=request.param,
+            parameter_set=pybamm.ParameterValues({"k": "[input]", "y0": "[input]"}),
+            n_states=request.param,
         )
         model.build()
         return model
 
+    @pytest.fixture
+    def parameters(self):
+        return [
+            pybop.Parameter(
+                "k",
+                prior=pybop.Gaussian(0.1, 0.05),
+                bounds=[0, 1],
+            ),
+            pybop.Parameter(
+                "y0",
+                prior=pybop.Gaussian(1, 0.05),
+                bounds=[0, 3],
+            ),
+        ]
+
+    @pytest.fixture
+    def x0(self):
+        return np.array([0.1, 1.0])
+
     @pytest.mark.unit
-    def test_observer(self, model):
-        inputs = {"k": 0.1, "y0": 1.0}
-        signal = ["2y"]
-        n = model.nstate
-        observer = pybop.Observer(model, inputs, signal)
+    def test_observer(self, model, parameters, x0):
+        n = model.n_states
+        observer = pybop.Observer(parameters, model, signal=["2y"], x0=x0)
         t_eval = np.linspace(0, 1, 100)
-        expected = inputs["y0"] * np.exp(-inputs["k"] * t_eval)
+        expected = x0[1] * np.exp(-x0[0] * t_eval)
         for y, t in zip(expected, t_eval):
             observer.observe(t)
             np.testing.assert_array_almost_equal(