Skip to content

Commit

Permalink
More generic API for Rainbow variants (#324)
Browse files Browse the repository at this point in the history
* Initial import from Etienne PR, with some adjustments to make it more generic

* Make parameter and error unscaling more generic

* Implement peak time computation

* Use new generic implementation instead of old RainbowFit

* (Slightly better) decorrelate temperature and amplitude by using peak-normalized Planck function

* Generalize baseline handling w.r.t. bolometric term initial parameters

* Cont. of previous commit

* Do not crash in baseline fitting when some bands do not have data

* Revert the normalization of Planck term back to 'bolometric', so that it works meaningfully with temperature evolution

* Decorrelate Bazin amplitude from rise/fall times by normalizing it to 1

* Pass measurement errors to initial parameter estimators, and use them to better estimate rise/fall times

* Add tests for generic RainbowFit implementation

* Remove now obsolete versions of RainbowFit

* Improve docstrings

* Do not raise an exception if the fitting fails but `get_initial=True`

* Delayed sigmoid temperature term added

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix the tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Do not use 3.9+ features

* Improve docstrings and stricter check for parameter unscaling

* Minor fixes requested in the review

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Additional fixes requested in the review

* Small docstring for the sigmoid mentioning that it has no peak time properly defined

* Fix linting errors

---------

Co-authored-by: erusseil <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 29, 2024
1 parent 84f276f commit 4128c04
Show file tree
Hide file tree
Showing 12 changed files with 685 additions and 795 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from .bazin import *
from .rising import *
from .symmetric import *
from .generic import *
107 changes: 73 additions & 34 deletions light-curve/light_curve/light_curve_py/features/rainbow/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple

Expand Down Expand Up @@ -143,16 +144,6 @@ def temp_func(self, t, params):
"""Temperature evolution function."""
return NotImplementedError

@abstractmethod
def _normalize_bolometric_flux(self, params) -> None:
"""Normalize bolometric flux parameters to internal units in-place."""
raise NotImplementedError

@abstractmethod
def _denormalize_bolometric_flux(self, params) -> None:
"""Denormalize boloemtric flux parameters from internal units in-place."""
raise NotImplementedError

@abstractmethod
def _unscale_parameters(self, params, t_scaler: Scaler, m_scaler: MultiBandScaler) -> None:
"""Unscale parameters from internal units, in-place.
Expand All @@ -161,6 +152,21 @@ def _unscale_parameters(self, params, t_scaler: Scaler, m_scaler: MultiBandScale
"""
return NotImplementedError

def _unscale_errors(self, errors, t_scaler: Scaler, m_scaler: MultiBandScaler) -> None:
"""Unscale parameter errors from internal units, in-place.
No baseline parameters are needed to be unscaled.
"""

# We need to modify original scalers to only apply the scale, not shifts, to the errors
# It should be re-implemented in subclasses for a cleaner way to unscale the errors
t_scaler = deepcopy(t_scaler)
m_scaler = deepcopy(m_scaler)
t_scaler.reset_shift()
m_scaler.reset_shift()

return self._unscale_parameters(errors, t_scaler, m_scaler)

def _unscale_baseline_parameters(self, params, m_scaler: MultiBandScaler) -> None:
"""Unscale baseline parameters from internal units, in-place.
Expand All @@ -171,6 +177,16 @@ def _unscale_baseline_parameters(self, params, m_scaler: MultiBandScaler) -> Non
baseline = params[self.p[baseline_name]]
params[self.p[baseline_name]] = m_scaler.undo_shift_scale_band(baseline, band_name)

def _unscale_baseline_errors(self, errors, m_scaler: MultiBandScaler) -> None:
"""Unscale baseline parameters from internal units, in-place.
Must be used only if `with_baseline` is True.
"""
for band_name in self.bands.names:
baseline_name = self.p.baseline_parameter_name(band_name)
baseline = errors[self.p[baseline_name]]
errors[self.p[baseline_name]] = m_scaler.undo_scale_band(baseline, band_name)

@staticmethod
def planck_nu(wave_cm, T):
"""Planck function in frequency units."""
Expand All @@ -184,11 +200,18 @@ def _lsq_model_no_baseline(self, x, *params):
t, _band_idx, wave_cm = x
params = np.array(params)

self._denormalize_bolometric_flux(params)

bol = self.bol_func(t, params)
temp = self.temp_func(t, params)
flux = np.pi * self.planck_nu(wave_cm, temp) / (sigma_sb * temp**4) * bol

# Normalize the Planck function so that the result is of order unity
norm = (sigma_sb * temp**4) / np.pi / self.average_nu # Original "bolometric" normalization
# peak_nu = 2.821 * boltzman_constant * temp / planck_constant # Wien displacement law
# norm = self.planck_nu(speed_of_light / peak_nu, temp) # Peak = 1 normalization

planck = self.planck_nu(wave_cm, temp) / norm

flux = planck * bol

return flux

def _lsq_model_with_baseline(self, x, *params):
Expand All @@ -206,7 +229,6 @@ def model(self, t, band, *params):
band_idx = self.bands.get_index(band)
wave_cm = self.bands.index_to_wave_cm(band_idx)
params = np.array(params)
self._normalize_bolometric_flux(params)
return self._lsq_model((t, band_idx, wave_cm), *params)

@property
Expand All @@ -215,34 +237,41 @@ def names(self):
return list(self.p.__members__)

@abstractmethod
def _initial_guesses(self, t, m, band) -> Dict[str, float]:
def _initial_guesses(self, t, m, sigma, band) -> Dict[str, float]:
"""Initial guesses for the fit parameters.
t and m are *scaled* arrays. No baseline parameters are included.
"""
return NotImplementedError

def _baseline_initial_guesses(self, t, m, band) -> Dict[str, float]:
def _baseline_initial_guesses(self, t, m, sigma, band) -> Dict[str, float]:
"""Initial guesses for the baseline parameters."""
del t
return {self.p.baseline_parameter_name(b): np.min(m[band == b]) for b in self.bands.names}
return {
self.p.baseline_parameter_name(b): (np.median(m[band == b]) if np.sum(band == b) else 0)
for b in self.bands.names
}

@abstractmethod
def _limits(self, t, m, band) -> Dict[str, Tuple[float, float]]:
def _limits(self, t, m, sigma, band) -> Dict[str, Tuple[float, float]]:
"""Limits for the fit parameters.
t and m are *scaled* arrays. No baseline parameters are included.
"""
return NotImplementedError

def _baseline_limits(self, t, m, band) -> Dict[str, Tuple[float, float]]:
def _baseline_limits(self, t, m, sigma, band) -> Dict[str, Tuple[float, float]]:
"""Limits for the baseline parameters."""
del t
limits = {}
for b in self.bands.names:
m_band = m[band == b]
lower = np.min(m_band) - 10 * np.ptp(m_band)
upper = np.max(m_band)
if len(m_band) > 0:
lower = np.min(m_band) - 10 * np.ptp(m_band)
upper = np.max(m_band)
else:
lower = 0
upper = 0
limits[self.p.baseline_parameter_name(b)] = (lower, upper)
return limits

Expand All @@ -254,7 +283,7 @@ def _eval(self, *, t, m, sigma, band):
def _eval_and_fill(self, *, t, m, sigma, band, fill_value):
return super()._eval_and_fill(t=t, m=m, sigma=sigma, band=band, fill_value=fill_value)

def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None):
def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None, get_initial=False):
# Initialize data scalers
t_scaler = Scaler.from_time(t)
m_scaler = MultiBandScaler.from_flux(m, band, with_baseline=self.with_baseline)
Expand All @@ -267,11 +296,20 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None):
band_idx = self.bands.get_index(band)
wave_cm = self.bands.index_to_wave_cm(band_idx)

initial_guesses = self._initial_guesses(t, m, band)
limits = self._limits(t, m, band)
if self.with_baseline:
initial_guesses.update(self._baseline_initial_guesses(t, m, band))
limits.update(self._baseline_limits(t, m, band))
initial_baselines = self._baseline_initial_guesses(t, m, sigma, band)
m_corr = m - np.array([initial_baselines[self.p.baseline_parameter_name(b)] for b in band])

# Compute initial guesses for the parameters on baseline-subtracted data
initial_guesses = self._initial_guesses(t, m_corr, sigma, band)
limits = self._limits(t, m_corr, sigma, band)

initial_guesses.update(initial_baselines)
limits.update(self._baseline_limits(t, m, sigma, band))
else:
# Compute initial guesses for the parameters on original data
initial_guesses = self._initial_guesses(t, m, sigma, band)
limits = self._limits(t, m, sigma, band)

least_squares = LeastSquares(
model=self._lsq_model,
Expand All @@ -280,17 +318,22 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None):
y=m,
yerror=sigma,
)
minuit = self.Minuit(least_squares, **initial_guesses)
minuit = self.Minuit(least_squares, name=self.names, **initial_guesses)
# TODO: expose these parameters through function arguments
if print_level is not None:
minuit.print_level = print_level
minuit.strategy = 2
minuit.migrad(ncall=10000, iterate=10)

if not minuit.valid and self.fail_on_divergence:
if not minuit.valid and self.fail_on_divergence and not get_initial:
raise RuntimeError("Fitting failed")

reduced_chi2 = minuit.fval / (len(t) - self.size)

if get_initial:
# Reset the fitter so that it returns initial values instead of final ones
minuit.reset()

params = np.array(minuit.values)
errors = np.array(minuit.errors)

Expand All @@ -299,13 +342,9 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None):
self._unscale_baseline_parameters(params, m_scaler)

# Unscale errors
# We need to modify original scalers to only apply the scale, not shifts, to the errors
t_scaler.reset_shift()
m_scaler.reset_shift()

self._unscale_parameters(errors, t_scaler, m_scaler)
self._unscale_errors(errors, t_scaler, m_scaler)
if self.with_baseline:
self._unscale_baseline_parameters(errors, m_scaler)
self._unscale_baseline_errors(errors, m_scaler)

return np.r_[params, reduced_chi2], errors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def from_flux(cls, flux, band, *, with_baseline: bool) -> "MultiBandScaler":
return cls(shift=shift_array, scale=scale, per_band_shift=per_band_shift, per_band_scale=per_band_scale)

def undo_shift_scale_band(self, x, band):
return x * self.per_band_scale[band] + self.per_band_shift[band]
return x * self.per_band_scale.get(band, 1) + self.per_band_shift.get(band, 0)

def undo_scale_band(self, x, band):
return x * self.per_band_scale.get(band, 1)

def reset_shift(self):
"""Resets scaler shift to zero, keeping only the scale"""
Expand Down
118 changes: 0 additions & 118 deletions light-curve/light_curve/light_curve_py/features/rainbow/bazin.py

This file was deleted.

Loading

0 comments on commit 4128c04

Please sign in to comment.