Skip to content

Commit

Permalink
🩹Fix for crash when running optimization with more than 30 datasets (g…
Browse files Browse the repository at this point in the history
…lotaran#1184)

* 🩹Fix for ufunc with more than 32 operands
* 'Refactored by Sourcery'
* 📚 Add change to changelog
* [pre-commit.ci] auto fixes from pre-commit.com hooks

Co-authored-by: Sourcery AI <>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jsnel and pre-commit-ci[bot] committed Nov 23, 2022
1 parent 7fecbc1 commit a993738
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

- 🩹 Fix result data overwritten when using multiple dataset_groups (#1147)
- 🩹 Fix for normalization issue described in #1157 (multi-gaussian irfs and multiple time ranges (streak))
- 🩹 Fix for crash described in #1183 when doing an optimization using more than 30 datasets (#1184)


### 📚 Documentation

Expand Down
25 changes: 21 additions & 4 deletions glotaran/optimization/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def get_axis_slice_from_interval(
return slice(minimum, maximum)

def add_model_weight(
self, model: Model, dataset_label: str, model_dimension: str, global_dimension: str
self,
model: Model,
dataset_label: str,
model_dimension: str,
global_dimension: str,
):
"""Add model weight to data.
Expand Down Expand Up @@ -532,7 +536,9 @@ def align_data(
aligned_data = xr.concat(
[
xr.DataArray(
self.get_data(label), dims=["model", "global"], coords={"global": axis}
self.get_data(label),
dims=["model", "global"],
coords={"global": axis},
)
for label, axis in aligned_global_axes.items()
],
Expand Down Expand Up @@ -602,7 +608,16 @@ def align_groups(
dim="dataset",
fill_value="",
)
aligned_group_labels = aligned_groups.str.join(dim="dataset").data
# for every element along the global axis, concatenate all dataset labels
# into an ndarray of shape (len(global,)
# as an alternative to the more elegant xarray built-in which is limited to 32 datasets
# aligned_group_labels = aligned_groups.str.join(dim="dataset").data
aligned_group_labels = [
"".join(sub_arr.values) for _, sub_arr in aligned_groups.groupby("global")
]

aligned_group_labels = np.asarray(aligned_group_labels)

group_definitions: dict[str, list[str]] = {}
for i, group_label in enumerate(aligned_group_labels):
if group_label not in group_definitions:
Expand All @@ -628,7 +643,9 @@ def align_weights(
"""
all_weights = {
label: xr.DataArray(
weight, dims=["model", "global"], coords={"global": aligned_global_axes[label]}
weight,
dims=["model", "global"],
coords={"global": aligned_global_axes[label]},
)
for label, weight in self._weight.items()
if weight is not None
Expand Down

0 comments on commit a993738

Please sign in to comment.