diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f486dd6..d60a6fe5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed --- +- Use `enum_dispatch` 0.3.9 (updated from 0.3.7) crate to implement `LnPriorTrait` for `LnPrior` https://github.com/light-curve/light-curve-feature/pull/6 ### Deprecated diff --git a/Cargo.toml b/Cargo.toml index a8d317b2..89f5ac22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ anyhow = "<1.0.49" conv = "^0.3.3" emcee = "^0.3.0" emcee_rand = { version = "^0.3.15", package = "rand" } -enum_dispatch = "^0.3.7" +enum_dispatch = "^0.3.9" fftw = { version = "^0.7", default-features = false } GSL = { version = "^6", default-features = false, optional = true } itertools = "^0.10" diff --git a/src/nl_fit/prior/ln_prior.rs b/src/nl_fit/prior/ln_prior.rs index d2452656..7efe7301 100644 --- a/src/nl_fit/prior/ln_prior.rs +++ b/src/nl_fit/prior/ln_prior.rs @@ -1,15 +1,18 @@ use crate::nl_fit::prior::ln_prior_1d::{LnPrior1D, LnPrior1DTrait}; +use enum_dispatch::enum_dispatch; use schemars::JsonSchema; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::fmt::Debug; +#[enum_dispatch] pub trait LnPriorTrait: Clone + Debug + Serialize + DeserializeOwned { fn ln_prior(&self, params: &[f64; NPARAMS]) -> f64; } /// Natural logarithm of prior for non-linear curve-fit problem +#[enum_dispatch(LnPriorTrait)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[non_exhaustive] pub enum LnPrior { @@ -55,29 +58,6 @@ impl LnPrior { } } -// Looks like enum_dispatch doesn't work with const generics yet -// https://gitlab.com/antonok/enum_dispatch/-/issues/51 -impl LnPriorTrait for LnPrior { - fn ln_prior(&self, params: &[f64; NPARAMS]) -> f64 { - match self { - Self::None(x) => x.ln_prior(params), - Self::IndComponents(x) => x.ln_prior(params), - } - } -} - -impl From for LnPrior { - fn from(item: NoneLnPrior) -> Self { - Self::None(item) - } -} - -impl From> for LnPrior { - fn from(item: IndComponentsLnPrior) -> Self { - Self::IndComponents(item) - } -} - #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct NoneLnPrior {}