Skip to content

Commit

Permalink
✨ Damped Oscillation Megacomplex (#764)
Browse files Browse the repository at this point in the history
* Added DampedOscillationMegacomplex
* Added damped-oscillation to setup.cfg plugins

Co-authored-by: Joris Snellenburg <[email protected]>
  • Loading branch information
joernweissenborn and jsnel committed Sep 16, 2021
1 parent 4327184 commit 9336486
Show file tree
Hide file tree
Showing 13 changed files with 725 additions and 45 deletions.
42 changes: 31 additions & 11 deletions glotaran/analysis/test/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,23 @@ def index_dependent(self, dataset_model):

class SimpleTestModel(Model):
@classmethod
def from_dict(cls, model_dict):
def from_dict(
cls,
model_dict,
*,
megacomplex_types: dict[str, type[Megacomplex]] | None = None,
default_megacomplex_type: str | None = None,
):
defaults: dict[str, type[Megacomplex]] = {
"model_complex": SimpleTestMegacomplex,
"global_complex": SimpleTestMegacomplexGlobal,
}
if megacomplex_types is not None:
defaults.update(megacomplex_types)
return super().from_dict(
model_dict,
megacomplex_types={
"model_complex": SimpleTestMegacomplex,
"global_complex": SimpleTestMegacomplexGlobal,
},
megacomplex_types=defaults,
default_megacomplex_type=default_megacomplex_type,
)


Expand Down Expand Up @@ -157,14 +167,24 @@ def finalize_data(

class DecayModel(Model):
@classmethod
def from_dict(cls, model_dict):
def from_dict(
cls,
model_dict,
*,
megacomplex_types: dict[str, type[Megacomplex]] | None = None,
default_megacomplex_type: str | None = None,
):
defaults: dict[str, type[Megacomplex]] = {
"model_complex": SimpleKineticMegacomplex,
"global_complex": SimpleSpectralMegacomplex,
"global_complex_shaped": ShapedSpectralMegacomplex,
}
if megacomplex_types is not None:
defaults.update(megacomplex_types)
return super().from_dict(
model_dict,
megacomplex_types={
"model_complex": SimpleKineticMegacomplex,
"global_complex": SimpleSpectralMegacomplex,
"global_complex_shaped": ShapedSpectralMegacomplex,
},
megacomplex_types=defaults,
default_megacomplex_type=default_megacomplex_type,
)


Expand Down
21 changes: 12 additions & 9 deletions glotaran/analysis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,25 @@ def calculate_matrix(
clp_labels = this_clp_labels
matrix = this_matrix
else:
tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels]
tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64)
for idx, label in enumerate(tmp_clp_labels):
if label in clp_labels:
tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)]
if label in this_clp_labels:
tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)]
clp_labels = tmp_clp_labels
matrix = tmp_matrix
clp_labels, matrix = combine_matrix(matrix, this_matrix, clp_labels, this_clp_labels)

if as_global_model:
dataset_model.swap_dimensions()

return CalculatedMatrix(clp_labels, matrix)


def combine_matrix(matrix, this_matrix, clp_labels, this_clp_labels):
tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels]
tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64)
for idx, label in enumerate(tmp_clp_labels):
if label in clp_labels:
tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)]
if label in this_clp_labels:
tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)]
return tmp_clp_labels, tmp_matrix


@nb.jit(nopython=True, parallel=True)
def apply_weight(matrix, weight):
for i in nb.prange(matrix.shape[1]):
Expand Down
3 changes: 3 additions & 0 deletions glotaran/builtin/megacomplexes/damped_oscillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from glotaran.builtin.megacomplexes.damped_oscillation.damped_oscillation_megacomplex import (
DampedOscillationMegacomplex,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from __future__ import annotations

from typing import List

import numba as nb
import numpy as np
import xarray as xr
from scipy.special import erf

from glotaran.builtin.megacomplexes.decay.irf import Irf
from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian
from glotaran.model import DatasetModel
from glotaran.model import Megacomplex
from glotaran.model import Model
from glotaran.model import megacomplex
from glotaran.model.item import model_item_validator
from glotaran.parameter import Parameter


@megacomplex(
dimension="time",
dataset_model_items={
"irf": {"type": Irf, "allow_none": True},
},
properties={
"labels": List[str],
"frequencies": List[Parameter],
"rates": List[Parameter],
},
register_as="damped-oscillation",
)
class DampedOscillationMegacomplex(Megacomplex):
@model_item_validator(False)
def ensure_oscillation_paramater(self, model: Model) -> list[str]:

problems = []

if len(self.labels) != len(self.frequencies) or len(self.labels) != len(self.rates):
problems.append(
f"Size of labels ({len(self.labels)}), frequencies ({len(self.frequencies)}) "
f"and rates ({len(self.rates)}) does not match for damped oscillation "
f"megacomplex '{self.label}'."
)

return problems

def calculate_matrix(
self,
dataset_model: DatasetModel,
indices: dict[str, int],
**kwargs,
):

clp_label = [f"{label}_cos" for label in self.labels] + [
f"{label}_sin" for label in self.labels
]

model_axis = dataset_model.get_model_axis()
delta = np.abs(model_axis[1:] - model_axis[:-1])
delta_min = delta[np.argmin(delta)]
frequency_max = 1 / (2 * 0.03 * delta_min)
frequencies = np.array(self.frequencies) * 0.03 * 2 * np.pi
frequencies[frequencies >= frequency_max] = np.mod(
frequencies[frequencies >= frequency_max], frequency_max
)
rates = np.array(self.rates)

matrix = np.ones((model_axis.size, len(clp_label)), dtype=np.float64)

if dataset_model.irf is None:
calculate_damped_oscillation_matrix_no_irf(matrix, frequencies, rates, model_axis)
elif isinstance(dataset_model.irf, IrfMultiGaussian):
global_dimension = dataset_model.get_global_dimension()
global_axis = dataset_model.get_global_axis()
global_index = indices.get(global_dimension)
centers, widths, scales, shift, _, _ = dataset_model.irf.parameter(
global_index, global_axis
)
for center, width, scale in zip(centers, widths, scales):
matrix += calculate_damped_oscillation_matrix_gaussian_irf(
frequencies,
rates,
model_axis,
center,
width,
shift,
scale,
)
matrix /= np.sum(scales)

return clp_label, matrix

def index_dependent(self, dataset_model: DatasetModel) -> bool:
return (
isinstance(dataset_model.irf, IrfMultiGaussian)
and dataset_model.irf.is_index_dependent()
)

def finalize_data(
self,
dataset_model: DatasetModel,
dataset: xr.Dataset,
is_full_model: bool = False,
as_global: bool = False,
):
if is_full_model:
return

megacomplexes = (
dataset_model.global_megacomplex if is_full_model else dataset_model.megacomplex
)
unique = len([m for m in megacomplexes if isinstance(m, DampedOscillationMegacomplex)]) < 2

prefix = "damped_oscillation" if unique else f"{self.label}_damped_oscillation"

dataset.coords[f"{prefix}"] = self.labels
dataset.coords[f"{prefix}_frequency"] = (prefix, self.frequencies)
dataset.coords[f"{prefix}_rate"] = (prefix, self.rates)

dim1 = dataset_model.get_global_axis().size
dim2 = len(self.labels)
doas = np.zeros((dim1, dim2), dtype=np.float64)
phase = np.zeros((dim1, dim2), dtype=np.float64)
for i, label in enumerate(self.labels):
sin = dataset.clp.sel(clp_label=f"{label}_sin")
cos = dataset.clp.sel(clp_label=f"{label}_cos")
doas[:, i] = np.sqrt(sin * sin + cos * cos)
phase[:, i] = np.unwrap(np.arctan2(sin, cos))

dataset[f"{prefix}_associated_spectra"] = (
(dataset_model.get_global_dimension(), prefix),
doas,
)

dataset[f"{prefix}_phase"] = (
(dataset_model.get_global_dimension(), prefix),
phase,
)

if not is_full_model:
if self.index_dependent(dataset_model):
dataset[f"{prefix}_sin"] = (
(
dataset_model.get_global_dimension(),
dataset_model.get_model_dimension(),
prefix,
),
dataset.matrix.sel(clp_label=[f"{label}_sin" for label in self.labels]).values,
)

dataset[f"{prefix}_cos"] = (
(
dataset_model.get_global_dimension(),
dataset_model.get_model_dimension(),
prefix,
),
dataset.matrix.sel(clp_label=[f"{label}_cos" for label in self.labels]).values,
)
else:
dataset[f"{prefix}_sin"] = (
(dataset_model.get_model_dimension(), prefix),
dataset.matrix.sel(clp_label=[f"{label}_sin" for label in self.labels]).values,
)

dataset[f"{prefix}_cos"] = (
(dataset_model.get_model_dimension(), prefix),
dataset.matrix.sel(clp_label=[f"{label}_cos" for label in self.labels]).values,
)


@nb.jit(nopython=True, parallel=True)
def calculate_damped_oscillation_matrix_no_irf(matrix, frequencies, rates, axis):

idx = 0
for frequency, rate in zip(frequencies, rates):
osc = np.exp(-rate * axis - 1j * frequency * axis)
matrix[:, idx] = osc.real
matrix[:, idx + 1] = osc.imag
idx += 2


def calculate_damped_oscillation_matrix_gaussian_irf(
frequencies: np.ndarray,
rates: np.ndarray,
model_axis: np.ndarray,
center: float,
width: float,
shift: float,
scale: float,
):
shifted_axis = model_axis - center - shift
d = width ** 2
k = rates + 1j * frequencies
dk = k * d
sqwidth = np.sqrt(2) * width
a = (-1 * shifted_axis[:, None] + 0.5 * dk) * k
a = np.minimum(a, 709)
a = np.exp(a)
b = 1 + erf((shifted_axis[:, None] - dk) / sqwidth)
osc = a * b * scale
return np.concatenate((osc.real, osc.imag), axis=1)
Loading

0 comments on commit 9336486

Please sign in to comment.