Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add clp guidance megacomplex #1029

Merged
merged 18 commits into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ repos:
- id: interrogate
args: [-vv, --config=pyproject.toml, glotaran]
pass_filenames: false
additional_dependencies: [click<8]

- repo: https://github.com/myint/rstcheck
rev: "3f92957478422df87bd730abde66f089cc1ee19b"
Expand Down
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- ✨ Python 3.10 support (#977)
- ✨ Add simple decay megacomplexes (#860)
- ✨ Feature: Generators (#866)
- ✨ Add clp guidance megacomplex (#1029)

### 👌 Minor Improvements:

Expand Down
2 changes: 2 additions & 0 deletions glotaran/analysis/optimization_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def create_result_dataset(
dataset_model = self.dataset_models[label]
global_dimension = dataset_model.get_global_dimension()
model_dimension = dataset_model.get_model_dimension()
dataset.attrs["global_dimension"] = global_dimension
dataset.attrs["model_dimension"] = model_dimension
if copy:
dataset = dataset.copy()
if dataset_model.is_index_dependent():
Expand Down
38 changes: 22 additions & 16 deletions glotaran/builtin/io/ascii/wavelength_time_explicit_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import os.path
import re
import warnings
from enum import Enum
from warnings import warn

import numpy as np
import pandas as pd
Expand All @@ -13,8 +13,6 @@
from glotaran.io import register_data_io
from glotaran.io.prepare_dataset import prepare_time_trace_dataset

# from glotaran.io.reader import file_reader


class DataFileType(Enum):
time_explicit = "Time explicit"
Expand All @@ -27,7 +25,7 @@ class ExplicitFile:
"""

# TODO: implement time_intervals
def __init__(self, filepath: str = None, dataset: xr.DataArray = None):
def __init__(self, filepath: str | None = None, dataset: xr.DataArray | None = None):
self._file_data_format = None
self._observations = [] # TODO: choose name: data_points, observations, data
self._times = []
Expand Down Expand Up @@ -76,22 +74,19 @@ def write(

if os.path.isfile(self._file) and not overwrite:
raise FileExistsError(f"File already exist:\n{self._file}")
comment = self._comment + " " + comment
comment = f"{self._comment} {comment}"

comments = "# Filename: " + str(self._file) + "\n" + " ".join(comment.splitlines()) + "\n"
comments = f"# Filename: {str(self._file)}\n{' '.join(comment.splitlines())}\n"

if file_format == DataFileType.wavelength_explicit:
wav = "\t".join(repr(num) for num in self._spectral_indices)
header = (
comments + "Wavelength explicit\nIntervalnr {}"
"".format(len(self._spectral_indices)) + "\n" + wav
f"{comments}Wavelength explicit\nIntervalnr {len(self._spectral_indices)}\n{wav}"
)
raw_data = np.vstack((self._times.T, self._observations)).T
elif file_format == DataFileType.time_explicit:
tim = "\t".join(repr(num) for num in self._times)
header = (
comments + "Time explicit\nIntervalnr {}" "".format(len(self._times)) + "\n" + tim
)
header = f"{comments}Time explicit\nIntervalnr {len(self._times)}\n{tim}"
raw_data = np.vstack((self._spectral_indices.T, self._observations.T)).T
else:
raise NotImplementedError
Expand All @@ -109,7 +104,7 @@ def write(

def read(self, prepare: bool = True):
if not os.path.isfile(self._file):
raise Exception("File does not exist.")
raise FileNotFoundError("File does not exist.")
with open(self._file) as f:
f.readline() # The first two lines are comments
f.readline()
Expand Down Expand Up @@ -221,7 +216,7 @@ def get_interval_number(line):
try:
interval_number = int(interval_number)
except ValueError:
warnings.warn(f"No interval number found in line:\n{line}")
warn(f"No interval number found in line:\n{line}")
interval_number = None
return interval_number

Expand All @@ -242,7 +237,7 @@ def get_data_file_format(line):
# @file_reader(extension="ascii", name="Wavelength-/Time-Explicit ASCII")
@register_data_io("ascii")
class AsciiDataIo(DataIoInterface):
def load_dataset(self, file_name: str) -> xr.Dataset | xr.DataArray:
def load_dataset(self, file_name: str, *, prepare: bool = True) -> xr.Dataset | xr.DataArray:
"""Reads an ascii file in wavelength- or time-explicit format.

See [1]_ for documentation of this format.
Expand Down Expand Up @@ -272,17 +267,28 @@ def load_dataset(self, file_name: str) -> xr.Dataset | xr.DataArray:
else TimeExplicitFile(file_name)
)

return data_file.read(prepare=True)
return data_file.read(prepare=prepare)

def save_dataset(
self,
dataset: xr.DataArray,
dataset: xr.DataArray | xr.Dataset,
file_name: str,
*,
comment: str = "",
file_format: DataFileType = DataFileType.time_explicit,
number_format: str = "%.10e",
):
if isinstance(dataset, xr.Dataset) and "data" in dataset:
dataset = dataset.data
warn(
UserWarning(
"Saving the 'data' attribute of 'dataset' as a fallback."
"Result saving for ascii format only supports xarray.DataArray format, "
"please pass a xarray.DataArray instead of a xarray.Dataset "
"(e.g. dataset.data)."
),
stacklevel=4,
)
data_file = (
TimeExplicitFile(filepath=file_name, dataset=dataset)
if file_format is DataFileType.time_explicit
Expand Down
1 change: 1 addition & 0 deletions glotaran/builtin/megacomplexes/clp_guide/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from glotaran.builtin.megacomplexes.clp_guide.clp_guide_megacomplex import ClpGuideMegacomplex
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import numpy as np
import xarray as xr

from glotaran.model import DatasetModel
from glotaran.model import Megacomplex
from glotaran.model import megacomplex


@megacomplex(exclusive=True, register_as="clp-guide", properties={"target": str})
class ClpGuideMegacomplex(Megacomplex):
def calculate_matrix(
self,
dataset_model: DatasetModel,
indices: dict[str, int],
**kwargs,
):
clp_label = [self.target]
matrix = np.ones((1, 1), dtype=np.float64)
return clp_label, matrix

def index_dependent(self, dataset_model: DatasetModel) -> bool:
return False

def finalize_data(
self,
dataset_model: DatasetModel,
dataset: xr.Dataset,
is_full_model: bool = False,
as_global: bool = False,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np

from glotaran.analysis.optimize import optimize
from glotaran.analysis.simulation import simulate
from glotaran.builtin.megacomplexes.clp_guide import ClpGuideMegacomplex
from glotaran.builtin.megacomplexes.decay import DecaySequentialMegacomplex
from glotaran.builtin.megacomplexes.decay.test.test_decay_megacomplex import create_gaussian_clp
from glotaran.model import Model
from glotaran.parameter import ParameterGroup
from glotaran.project import Scheme


def test_clp_guide():

model = Model.from_dict(
{
"dataset_groups": {"default": {"link_clp": True}},
"megacomplex": {
"mc1": {
"type": "decay-sequential",
"compartments": ["s1", "s2"],
"rates": ["1", "2"],
},
"mc2": {"type": "clp-guide", "dimension": "time", "target": "s1"},
},
"dataset": {
"dataset1": {"megacomplex": ["mc1"]},
"dataset2": {"megacomplex": ["mc2"]},
},
},
megacomplex_types={
"decay-sequential": DecaySequentialMegacomplex,
"clp-guide": ClpGuideMegacomplex,
},
)

initial_parameters = ParameterGroup.from_list(
[101e-5, 501e-4, [1, {"vary": False, "non-negative": False}]]
)
wanted_parameters = ParameterGroup.from_list(
[101e-4, 501e-3, [1, {"vary": False, "non-negative": False}]]
)

time = np.arange(0, 50, 1.5)
pixel = np.arange(600, 750, 5)
axis = {"time": time, "pixel": pixel}

clp = create_gaussian_clp(["s1", "s2"], [7, 30], [620, 720], [10, 50], pixel)

dataset1 = simulate(model, "dataset1", wanted_parameters, axis, clp)
dataset2 = clp.sel(clp_label=["s1"]).rename(clp_label="time")
data = {"dataset1": dataset1, "dataset2": dataset2}

scheme = Scheme(
model=model,
parameters=initial_parameters,
data=data,
maximum_number_function_evaluations=20,
)
result = optimize(scheme)
print(result.optimized_parameters)

for label, param in result.optimized_parameters.all():
assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from glotaran.project import Scheme


def _create_gaussian_clp(labels, amplitudes, centers, widths, axis):
def create_gaussian_clp(labels, amplitudes, centers, widths, axis):
return xr.DataArray(
[
amplitudes[i] * np.exp(-np.log(2) * np.square(2 * (axis - centers[i]) / widths[i]))
Expand Down Expand Up @@ -179,9 +179,7 @@ class ThreeComponentParallel:

axis = {"time": time, "pixel": pixel}

clp = _create_gaussian_clp(
["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel
)
clp = create_gaussian_clp(["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel)


class ThreeComponentSequential:
Expand Down Expand Up @@ -240,9 +238,7 @@ class ThreeComponentSequential:
pixel = np.arange(600, 750, 10)
axis = {"time": time, "pixel": pixel}

clp = _create_gaussian_clp(
["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel
)
clp = create_gaussian_clp(["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel)


@pytest.mark.parametrize(
Expand Down
12 changes: 12 additions & 0 deletions glotaran/builtin/megacomplexes/decay/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@


def index_dependent(dataset_model: DatasetModel) -> bool:
"""Determine if a dataset_model is index dependent.

Parameters
----------
dataset_model : DatasetModel
A dataset model instance.

Returns
-------
bool
Returns True if the dataset_model has an IRF that is index dependent (e.g. has dispersion).
"""
return (
isinstance(dataset_model.irf, IrfMultiGaussian) and dataset_model.irf.is_index_dependent()
)
Expand Down
Loading