From b9c065010f8b5ba5ff6c9296c99ec6debcccc02a Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 21 Nov 2022 15:09:05 +0100 Subject: [PATCH] Add operator== for ExpData (#1881) Adds `operator==` for `amici::ExpData` and fixes issues for other `operator==` in case of NaNs in arrays. Closes #1880 --- include/amici/edata.h | 21 +++++++++++++++++++++ include/amici/misc.h | 22 ++++++++++++++++++++++ include/amici/model_state.h | 10 ++++++++++ src/model.cpp | 8 +++----- src/simulation_parameters.cpp | 26 +++++++++++++++----------- tests/cpp/unittests/testExpData.cpp | 11 +++++++++++ tests/cpp/unittests/testMisc.cpp | 10 ++++++++++ 7 files changed, 92 insertions(+), 16 deletions(-) diff --git a/include/amici/edata.h b/include/amici/edata.h index 406e5fdf66..b15e34fc10 100644 --- a/include/amici/edata.h +++ b/include/amici/edata.h @@ -115,6 +115,8 @@ class ExpData : public SimulationParameters { ~ExpData() = default; + friend inline bool operator==(const ExpData& lhs, const ExpData& rhs); + /** * @brief number of observables of the non-augmented model * @@ -455,6 +457,25 @@ class ExpData : public SimulationParameters { std::vector observed_events_std_dev_; }; + +inline bool operator==(const ExpData& lhs, const ExpData& rhs) { + return *dynamic_cast< const SimulationParameters* >(&lhs) + == *dynamic_cast< const SimulationParameters* >(&rhs) + && lhs.id == rhs.id + && lhs.nytrue_ == rhs.nytrue_ + && lhs.nztrue_ == rhs.nztrue_ + && lhs.nmaxevent_ == rhs.nmaxevent_ + && is_equal(lhs.observed_data_, + rhs.observed_data_) + && is_equal(lhs.observed_data_std_dev_, + rhs.observed_data_std_dev_) + && is_equal(lhs.observed_events_, + rhs.observed_events_) + && is_equal(lhs.observed_events_std_dev_, + rhs.observed_events_std_dev_); +}; + + /** * @brief checks input vector of sigmas for not strictly positive values * diff --git a/include/amici/misc.h b/include/amici/misc.h index 1144ea85a0..fe9fb5b8e2 100644 --- a/include/amici/misc.h +++ b/include/amici/misc.h @@ -221,6 +221,28 @@ class ContextManager{ auto unravel_index(size_t flat_idx, size_t num_cols) -> std::pair; +/** + * @brief Check if two spans are equal, treating NaNs in the same position as + * equal. + * @param a + * @param b + * @return Whether the contents of the two spans are equal. + */ +template +bool is_equal(T const& a, T const& b) { + if(a.size() != b.size()) + return false; + + auto a_data = a.data(); + auto b_data = b.data(); + for(typename T::size_type i = 0; i < a.size(); ++i) { + if(a_data[i] != b_data[i] + && !(std::isnan(a_data[i]) && std::isnan(b_data[i]))) + return false; + } + return true; +} + } // namespace amici #endif // AMICI_MISC_H diff --git a/include/amici/model_state.h b/include/amici/model_state.h index 492f01f1a0..e7695a70be 100644 --- a/include/amici/model_state.h +++ b/include/amici/model_state.h @@ -4,6 +4,7 @@ #include "amici/defines.h" #include "amici/sundials_matrix_wrapper.h" #include "amici/model_dimensions.h" +#include "amici/misc.h" #include @@ -45,6 +46,15 @@ struct ModelState { std::vector plist; }; +inline bool operator==(const ModelState &a, const ModelState &b) { + return is_equal(a.h, b.h) + && is_equal(a.total_cl, b.total_cl) + && is_equal(a.stotal_cl, b.stotal_cl) + && is_equal(a.unscaledParameters, b.unscaledParameters) + && is_equal(a.fixedParameters, b.fixedParameters) + && a.plist == b.plist; +} + /** * @brief Storage for `amici::Model` quantities computed based on diff --git a/src/model.cpp b/src/model.cpp index a696437e4b..9fea9942c5 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -244,16 +244,14 @@ bool operator==(const Model &a, const Model &b) { == static_cast(b)) && (a.o2mode == b.o2mode) && (a.z2event_ == b.z2event_) && (a.idlist == b.idlist) && - (a.state_.h == b.state_.h) && - (a.state_.unscaledParameters == b.state_.unscaledParameters) && (a.simulation_parameters_ == b.simulation_parameters_) && - (a.state_.fixedParameters == b.state_.fixedParameters) && - (a.state_.plist == b.state_.plist) && (a.x0data_ == b.x0data_) && + (a.x0data_ == b.x0data_) && (a.sx0data_ == b.sx0data_) && (a.nmaxevent_ == b.nmaxevent_) && (a.state_is_non_negative_ == b.state_is_non_negative_) && (a.sigma_res_ == b.sigma_res_) && - (a.min_sigma_ == b.min_sigma_); + (a.min_sigma_ == b.min_sigma_) + && a.state_ == b.state_; } bool operator==(const ModelDimensions &a, const ModelDimensions &b) { diff --git a/src/simulation_parameters.cpp b/src/simulation_parameters.cpp index 824f5e2541..990e872dfd 100644 --- a/src/simulation_parameters.cpp +++ b/src/simulation_parameters.cpp @@ -1,21 +1,25 @@ #include "amici/simulation_parameters.h" +#include "amici/misc.h" #include namespace amici { bool operator==(const SimulationParameters &a, const SimulationParameters &b) { - return (a.fixedParameters == b.fixedParameters) && - (a.fixedParametersPreequilibration == b.fixedParametersPreequilibration) && - (a.fixedParametersPresimulation == b.fixedParametersPresimulation) && - (a.parameters == b.parameters) && - (a.plist == b.plist) && - (a.pscale == b.pscale) && - (a.reinitializeFixedParameterInitialStates == b.reinitializeFixedParameterInitialStates) && - (a.sx0 == b.sx0) && - (a.t_presim == b.t_presim) && - (a.tstart_ == b.tstart_) && - (a.ts_ == b.ts_); + return is_equal(a.fixedParameters, b.fixedParameters) && + is_equal(a.fixedParametersPreequilibration, + b.fixedParametersPreequilibration) && + is_equal(a.fixedParametersPresimulation, + b.fixedParametersPresimulation) && + is_equal(a.parameters, b.parameters) && + (a.plist == b.plist) && + (a.pscale == b.pscale) && + (a.reinitializeFixedParameterInitialStates + == b.reinitializeFixedParameterInitialStates) && + is_equal(a.sx0, b.sx0) && + (a.t_presim == b.t_presim) && + (a.tstart_ == b.tstart_) && + (a.ts_ == b.ts_); } void SimulationParameters::reinitializeAllFixedParameterDependentInitialStatesForPresimulation(int nx_rdata) diff --git a/tests/cpp/unittests/testExpData.cpp b/tests/cpp/unittests/testExpData.cpp index 5a77c8ad00..83c6dda740 100644 --- a/tests/cpp/unittests/testExpData.cpp +++ b/tests/cpp/unittests/testExpData.cpp @@ -182,6 +182,17 @@ TEST_F(ExpDataTest, CopyConstructable) "ts"); } + +TEST_F(ExpDataTest, Equality) +{ + auto edata = ExpData(testModel); + auto edata2(edata); + ASSERT_TRUE(edata == edata2); + + edata2.id = "different"; + ASSERT_FALSE(edata == edata2); +} + TEST_F(ExpDataTest, DimensionChecks) { std::vector bad_std(ny, -0.1); diff --git a/tests/cpp/unittests/testMisc.cpp b/tests/cpp/unittests/testMisc.cpp index d77aa54f93..a1763c399e 100644 --- a/tests/cpp/unittests/testMisc.cpp +++ b/tests/cpp/unittests/testMisc.cpp @@ -711,4 +711,14 @@ TEST(ReturnCodeToStr, ReturnCodeToStr) simulation_status_to_str(AMICI_UNRECOVERABLE_ERROR)); } +TEST(SpanEqual, SpanEqual) +{ + std::vector a {1, 2, 3}; + std::vector b {1, 2, NAN}; + + EXPECT_TRUE(is_equal(a, a)); + EXPECT_TRUE(is_equal(b, b)); + EXPECT_FALSE(is_equal(a, b)); +} + } // namespace