-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save predictions to sacc #349
Changes from 23 commits
d7f51ab
acee539
08597ff
aa379cc
52c8222
4f77feb
8acdb04
dd85f19
043bcfb
d93848d
5973790
fba2463
1207889
92b3916
5d9b462
16904b8
4bd3e4b
af11932
6962a58
0798278
123724e
2c6ca9e
86be99d
04825dd
55df315
a155d99
f38f762
3a0c3af
bdc3575
447c24d
af14c76
a0f32d6
2a6134e
ef3b08e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,9 @@ | |
""" | ||
|
||
from __future__ import annotations | ||
|
||
from enum import Enum | ||
from typing import List, Optional, Tuple, Sequence | ||
from typing import List, Optional, Tuple, Sequence, Dict | ||
from typing import final | ||
import warnings | ||
|
||
|
@@ -61,12 +62,24 @@ def __init__( | |
self.state: State = State.INITIALIZED | ||
if len(statistics) == 0: | ||
raise ValueError("GaussFamily requires at least one statistic") | ||
self.statistics: UpdatableCollection = UpdatableCollection( | ||
|
||
for i, s in enumerate(statistics): | ||
if not isinstance(s, Statistic): | ||
raise ValueError( | ||
f"statistics[{i}] is not an instance of Statistic: {s}" | ||
f"it is a {type(s)} instead." | ||
) | ||
|
||
self.statistics: UpdatableCollection[GuardedStatistic] = UpdatableCollection( | ||
GuardedStatistic(s) for s in statistics | ||
) | ||
self.cov: Optional[npt.NDArray[np.float64]] = None | ||
self.cholesky: Optional[npt.NDArray[np.float64]] = None | ||
self.inv_cov: Optional[npt.NDArray[np.float64]] = None | ||
self.cov_index_map: Optional[Dict[int, int]] = None | ||
self.computed_theory_vector = False | ||
self.theory_vector: Optional[npt.NDArray[np.double]] = None | ||
self.data_vector: Optional[npt.NDArray[np.double]] = None | ||
|
||
def _update(self, _: ParamsMap) -> None: | ||
"""Handle the state resetting required by :class:`GaussFamily` | ||
|
@@ -84,6 +97,10 @@ def _reset(self) -> None: | |
at the start of the method, and change the state at the end of the | ||
method.""" | ||
assert self.state == State.UPDATED, "update() must be called before reset()" | ||
|
||
self.computed_theory_vector = False | ||
self.theory_vector = None | ||
|
||
self.state = State.READY | ||
|
||
def read(self, sacc_data: sacc.Sacc) -> None: | ||
|
@@ -98,28 +115,50 @@ def read(self, sacc_data: sacc.Sacc) -> None: | |
raise RuntimeError(msg) | ||
|
||
covariance = sacc_data.covariance.dense | ||
|
||
indices_list = [] | ||
data_vector_list = [] | ||
for stat in self.statistics: | ||
stat.read(sacc_data) | ||
if stat.statistic.sacc_indices is None: | ||
raise RuntimeError( | ||
f"The statistic {stat.statistic} has no sacc_indices." | ||
) | ||
indices_list.append(stat.statistic.sacc_indices.copy()) | ||
data_vector_list.append(stat.statistic.get_data_vector()) | ||
|
||
indices_list = [s.statistic.sacc_indices.copy() for s in self.statistics] | ||
indices = np.concatenate(indices_list) | ||
data_vector = np.concatenate(data_vector_list) | ||
cov = np.zeros((len(indices), len(indices))) | ||
|
||
for new_i, old_i in enumerate(indices): | ||
for new_j, old_j in enumerate(indices): | ||
cov[new_i, new_j] = covariance[old_i, old_j] | ||
|
||
self.data_vector = data_vector | ||
self.cov_index_map = {old_i: new_i for new_i, old_i in enumerate(indices)} | ||
self.cov = cov | ||
self.cholesky = scipy.linalg.cholesky(self.cov, lower=True) | ||
self.inv_cov = np.linalg.inv(cov) | ||
|
||
self.state = State.READY | ||
|
||
@final | ||
def get_cov(self) -> npt.NDArray[np.float64]: | ||
"""Gets the current covariance matrix.""" | ||
def get_cov(self, statistic: Optional[Statistic] = None) -> npt.NDArray[np.float64]: | ||
"""Gets the current covariance matrix. | ||
|
||
:param statistic: The statistic for which the sub-covariance matrix | ||
should be return. If not specified, return the covariance of all | ||
statistics. | ||
""" | ||
assert self._is_ready(), "read() must be called before get_cov()" | ||
assert self.cov is not None | ||
if statistic is not None: | ||
assert statistic.sacc_indices is not None | ||
assert self.cov_index_map is not None | ||
idx = [self.cov_index_map[idx] for idx in statistic.sacc_indices] | ||
# We do not change the state. | ||
return self.cov[np.ix_(idx, idx)] | ||
# We do not change the state. | ||
return self.cov | ||
|
||
|
@@ -129,11 +168,8 @@ def get_data_vector(self) -> npt.NDArray[np.float64]: | |
order.""" | ||
assert self._is_ready(), "read() must be called before get_data_vector()" | ||
|
||
data_vector_list: List[npt.NDArray[np.float64]] = [ | ||
stat.get_data_vector() for stat in self.statistics | ||
] | ||
# We do not change the state. | ||
return np.concatenate(data_vector_list) | ||
assert self.data_vector is not None | ||
return self.data_vector | ||
|
||
@final | ||
def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]: | ||
|
@@ -148,8 +184,30 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64] | |
theory_vector_list: List[npt.NDArray[np.float64]] = [ | ||
stat.compute_theory_vector(tools) for stat in self.statistics | ||
] | ||
# We do not change the state | ||
return np.concatenate(theory_vector_list) | ||
self.computed_theory_vector = True | ||
self.theory_vector = np.concatenate(theory_vector_list) | ||
|
||
return self.theory_vector | ||
|
||
@final | ||
def get_theory_vector(self) -> npt.NDArray[np.float64]: | ||
"""Get the theory vector from all statistics and concatenate in the right | ||
order.""" | ||
|
||
assert ( | ||
self.state == State.UPDATED | ||
), "update() must be called before get_theory_vector()" | ||
|
||
if not self.computed_theory_vector: | ||
raise RuntimeError( | ||
"The theory vector has not been computed yet. " | ||
"Call compute_theory_vector first." | ||
) | ||
assert self.theory_vector is not None, ( | ||
"Implementation error, " | ||
"computed_theory_vector is True but theory_vector is None" | ||
) | ||
return self.theory_vector | ||
|
||
@final | ||
def compute( | ||
|
@@ -186,9 +244,6 @@ def compute_chisq(self, tools: ModelingTools) -> float: | |
assert len(data_vector) == len(theory_vector) | ||
residuals = data_vector - theory_vector | ||
|
||
self.predicted_data_vector: npt.NDArray[np.float64] = theory_vector | ||
self.measured_data_vector: npt.NDArray[np.float64] = data_vector | ||
|
||
x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True) | ||
chisq = np.dot(x, x) | ||
|
||
|
@@ -198,3 +253,35 @@ def compute_chisq(self, tools: ModelingTools) -> float: | |
def _is_ready(self) -> bool: | ||
"""Return True if the state is either READY or UPDATED.""" | ||
return self.state in (State.READY, State.UPDATED) | ||
|
||
def make_realization( | ||
self, sacc_data: sacc.Sacc, add_noise: bool = True, strict: bool = True | ||
) -> sacc.Sacc: | ||
new_sacc = sacc_data.copy() | ||
|
||
sacc_indices_list = [] | ||
for stat in self.statistics: | ||
assert stat.statistic.sacc_indices is not None | ||
sacc_indices_list.append(stat.statistic.sacc_indices.copy()) | ||
|
||
sacc_indices = np.concatenate(sacc_indices_list) | ||
|
||
if add_noise: | ||
new_data_vector = self.make_realization_vector() | ||
else: | ||
new_data_vector = self.get_theory_vector() | ||
|
||
assert len(sacc_indices) == len(new_data_vector) | ||
|
||
if strict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider collapsing the nested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll deal with this in a latter issue. |
||
if set(sacc_indices.tolist()) != set(sacc_data.indices()): | ||
raise RuntimeError( | ||
"The predicted data does not cover all the data in the " | ||
"sacc object. To write only the calculated predictions, " | ||
"set strict=False." | ||
) | ||
|
||
for prediction_idx, sacc_idx in enumerate(sacc_indices): | ||
new_sacc.data[sacc_idx].value = new_data_vector[prediction_idx] | ||
|
||
return new_sacc |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
""" | ||
|
||
from __future__ import annotations | ||
import numpy as np | ||
|
||
from .gauss_family import GaussFamily | ||
from ...modeling_tools import ModelingTools | ||
|
@@ -15,3 +16,12 @@ def compute_loglike(self, tools: ModelingTools): | |
"""Compute the log-likelihood.""" | ||
|
||
return -0.5 * self.compute_chisq(tools) | ||
|
||
def make_realization_vector(self) -> np.ndarray: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be checking pre- and post-conditions on self.state in every method. |
||
theory_vector = self.get_theory_vector() | ||
assert self.cholesky is not None | ||
new_data_vector = theory_vector + np.dot( | ||
self.cholesky, np.random.randn(len(theory_vector)) | ||
) | ||
|
||
return new_data_vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would prefer to pass the numpy array of indices that corresponds to the sub-matrix desired.
This would allow the caller to obtain the sub-matrix for two or more statistics, when that is desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The application I wrote this for was to get error bars when plotting the data vector. The idea was specifically to abstract away the indices and instead use the statistics, since that's what the user interacts with. I see the use of passing a list of statistics though, to get their corresponding sub-matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will expand the interface to have a single method
get_cov
that can accept:Statistic
Statistic
np.ndarray
(the indices)np.ndarray
(a list of indices)In the case when a stat or list of stats is passed in, we need also to make sure that the stat (or all the stats) that have been passed are in the likelihood object on which we've called the
get_cov
method. We will make the code verify this.We also have to specify the order in which entries appear in the returned matrix. We propose to respect the order of the entries of the list of statistics (or of numpy arrays), so that the user-specified order of the list controls the ordering of the elements in the returned matrix result, rather than the order of the entires in the SACC data object controlling the ordering of the entries in the returned matrix.
For example, if we pass a list of stats of length 3 (note the order of the entries in this passed to
get_cov
:stats1 -> 0:9, stats2 -> 10:19, stats3 -> 20:29 =>
get_cov([stats1,stats3,stats2])
-> 0:9 + 20:29 + 10:19