Skip to content

Commit

Permalink
Add operator== for ExpData (#1881)
Browse files Browse the repository at this point in the history
Adds `operator==` for `amici::ExpData` and fixes issues for other  `operator==` in case of NaNs in arrays.


Closes #1880
  • Loading branch information
dweindl authored Nov 21, 2022
1 parent 662e76c commit b9c0650
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 16 deletions.
21 changes: 21 additions & 0 deletions include/amici/edata.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class ExpData : public SimulationParameters {

~ExpData() = default;

friend inline bool operator==(const ExpData& lhs, const ExpData& rhs);

/**
* @brief number of observables of the non-augmented model
*
Expand Down Expand Up @@ -455,6 +457,25 @@ class ExpData : public SimulationParameters {
std::vector<realtype> observed_events_std_dev_;
};


inline bool operator==(const ExpData& lhs, const ExpData& rhs) {
return *dynamic_cast< const SimulationParameters* >(&lhs)
== *dynamic_cast< const SimulationParameters* >(&rhs)
&& lhs.id == rhs.id
&& lhs.nytrue_ == rhs.nytrue_
&& lhs.nztrue_ == rhs.nztrue_
&& lhs.nmaxevent_ == rhs.nmaxevent_
&& is_equal(lhs.observed_data_,
rhs.observed_data_)
&& is_equal(lhs.observed_data_std_dev_,
rhs.observed_data_std_dev_)
&& is_equal(lhs.observed_events_,
rhs.observed_events_)
&& is_equal(lhs.observed_events_std_dev_,
rhs.observed_events_std_dev_);
};


/**
* @brief checks input vector of sigmas for not strictly positive values
*
Expand Down
22 changes: 22 additions & 0 deletions include/amici/misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,28 @@ class ContextManager{
auto unravel_index(size_t flat_idx, size_t num_cols)
-> std::pair<size_t, size_t>;

/**
* @brief Check if two spans are equal, treating NaNs in the same position as
* equal.
* @param a
* @param b
* @return Whether the contents of the two spans are equal.
*/
template <class T>
bool is_equal(T const& a, T const& b) {
if(a.size() != b.size())
return false;

auto a_data = a.data();
auto b_data = b.data();
for(typename T::size_type i = 0; i < a.size(); ++i) {
if(a_data[i] != b_data[i]
&& !(std::isnan(a_data[i]) && std::isnan(b_data[i])))
return false;
}
return true;
}

} // namespace amici

#endif // AMICI_MISC_H
10 changes: 10 additions & 0 deletions include/amici/model_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "amici/defines.h"
#include "amici/sundials_matrix_wrapper.h"
#include "amici/model_dimensions.h"
#include "amici/misc.h"

#include <vector>

Expand Down Expand Up @@ -45,6 +46,15 @@ struct ModelState {
std::vector<int> plist;
};

inline bool operator==(const ModelState &a, const ModelState &b) {
return is_equal(a.h, b.h)
&& is_equal(a.total_cl, b.total_cl)
&& is_equal(a.stotal_cl, b.stotal_cl)
&& is_equal(a.unscaledParameters, b.unscaledParameters)
&& is_equal(a.fixedParameters, b.fixedParameters)
&& a.plist == b.plist;
}


/**
* @brief Storage for `amici::Model` quantities computed based on
Expand Down
8 changes: 3 additions & 5 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,14 @@ bool operator==(const Model &a, const Model &b) {
== static_cast<ModelDimensions const&>(b))
&& (a.o2mode == b.o2mode) &&
(a.z2event_ == b.z2event_) && (a.idlist == b.idlist) &&
(a.state_.h == b.state_.h) &&
(a.state_.unscaledParameters == b.state_.unscaledParameters) &&
(a.simulation_parameters_ == b.simulation_parameters_) &&
(a.state_.fixedParameters == b.state_.fixedParameters) &&
(a.state_.plist == b.state_.plist) && (a.x0data_ == b.x0data_) &&
(a.x0data_ == b.x0data_) &&
(a.sx0data_ == b.sx0data_) &&
(a.nmaxevent_ == b.nmaxevent_) &&
(a.state_is_non_negative_ == b.state_is_non_negative_) &&
(a.sigma_res_ == b.sigma_res_) &&
(a.min_sigma_ == b.min_sigma_);
(a.min_sigma_ == b.min_sigma_)
&& a.state_ == b.state_;
}

bool operator==(const ModelDimensions &a, const ModelDimensions &b) {
Expand Down
26 changes: 15 additions & 11 deletions src/simulation_parameters.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
#include "amici/simulation_parameters.h"
#include "amici/misc.h"

#include <numeric>

namespace amici {

bool operator==(const SimulationParameters &a, const SimulationParameters &b) {
return (a.fixedParameters == b.fixedParameters) &&
(a.fixedParametersPreequilibration == b.fixedParametersPreequilibration) &&
(a.fixedParametersPresimulation == b.fixedParametersPresimulation) &&
(a.parameters == b.parameters) &&
(a.plist == b.plist) &&
(a.pscale == b.pscale) &&
(a.reinitializeFixedParameterInitialStates == b.reinitializeFixedParameterInitialStates) &&
(a.sx0 == b.sx0) &&
(a.t_presim == b.t_presim) &&
(a.tstart_ == b.tstart_) &&
(a.ts_ == b.ts_);
return is_equal(a.fixedParameters, b.fixedParameters) &&
is_equal(a.fixedParametersPreequilibration,
b.fixedParametersPreequilibration) &&
is_equal(a.fixedParametersPresimulation,
b.fixedParametersPresimulation) &&
is_equal(a.parameters, b.parameters) &&
(a.plist == b.plist) &&
(a.pscale == b.pscale) &&
(a.reinitializeFixedParameterInitialStates
== b.reinitializeFixedParameterInitialStates) &&
is_equal(a.sx0, b.sx0) &&
(a.t_presim == b.t_presim) &&
(a.tstart_ == b.tstart_) &&
(a.ts_ == b.ts_);
}

void SimulationParameters::reinitializeAllFixedParameterDependentInitialStatesForPresimulation(int nx_rdata)
Expand Down
11 changes: 11 additions & 0 deletions tests/cpp/unittests/testExpData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,17 @@ TEST_F(ExpDataTest, CopyConstructable)
"ts");
}


TEST_F(ExpDataTest, Equality)
{
auto edata = ExpData(testModel);
auto edata2(edata);
ASSERT_TRUE(edata == edata2);

edata2.id = "different";
ASSERT_FALSE(edata == edata2);
}

TEST_F(ExpDataTest, DimensionChecks)
{
std::vector<realtype> bad_std(ny, -0.1);
Expand Down
10 changes: 10 additions & 0 deletions tests/cpp/unittests/testMisc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,4 +711,14 @@ TEST(ReturnCodeToStr, ReturnCodeToStr)
simulation_status_to_str(AMICI_UNRECOVERABLE_ERROR));
}

TEST(SpanEqual, SpanEqual)
{
std::vector<realtype> a {1, 2, 3};
std::vector<realtype> b {1, 2, NAN};

EXPECT_TRUE(is_equal(a, a));
EXPECT_TRUE(is_equal(b, b));
EXPECT_FALSE(is_equal(a, b));
}

} // namespace

0 comments on commit b9c0650

Please sign in to comment.