Skip to content

Commit

Permalink
Update bool_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolaCourtier committed Feb 7, 2024
1 parent 0f56c3a commit f303b10
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions pybop/observers/unscented_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,16 @@ def __init__(
zeros = np.logical_and(zero_rows, zero_cols)
ones = np.logical_not(zeros)
states = np.array(range(len(x0)))[ones]
bool_mask = np.vstack(ones) & np.hstack(ones)
bool_mask = np.ix_(ones, ones)

S_filtered = linalg.cholesky(P0[ones, :][:, ones])
sqrtRp_filtered = linalg.cholesky(Rp[ones, :][:, ones])

n = len(x0)
S = np.zeros((n, n))
sqrtRp = np.zeros((n, n))
S[bool_mask] = S_filtered.flatten()
sqrtRp[bool_mask] = sqrtRp_filtered.flatten()
S[bool_mask] = S_filtered
sqrtRp[bool_mask] = sqrtRp_filtered

self.x = x0
self.S = S
Expand All @@ -241,7 +241,7 @@ def reset(self, x: np.ndarray, S: np.ndarray) -> None:
S_filtered = S[self.states, :][:, self.states]
S_filtered = linalg.cholesky(S_filtered)
S_full = S.copy()
S_full[self.bool_mask] = S_filtered.flatten()
S_full[self.bool_mask] = S_filtered
self.S = S_full

@staticmethod
Expand Down Expand Up @@ -348,9 +348,8 @@ def unscented_transform(
S_filtered, sigma_points_diff[:, 0:1], w_c[0]
)
ones = np.logical_not(clean)
bool_mask = np.vstack(ones) & np.hstack(ones)
S = np.zeros_like(sqrtR)
S[bool_mask] = S_filtered.flatten()
S[np.ix_(ones, ones)] = S_filtered

return x, S

Expand All @@ -364,8 +363,7 @@ def filtered_cholupdate(
R_filtered = SquareRootUKF.cholupdate(R_filtered, x_filtered, w)
ones = np.full(len(x), False)
ones[states] = True
bool_mask = np.vstack(ones) & np.hstack(ones)
R_full[bool_mask] = R_filtered.flatten()
R_full[np.ix_(ones, ones)] = R_filtered
return R_full

@staticmethod
Expand Down

0 comments on commit f303b10

Please sign in to comment.