Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix using x_pos in all model functions #1650

Merged
merged 6 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,21 @@ class Model : public AbstractModel, public ModelDimensions {
* stateIsNonNegative
*/
const_N_Vector computeX_pos(const_N_Vector x);

/**
* @brief Compute non-negative state vector.
*
* Compute non-negative state vector according to stateIsNonNegative.
* If anyStateNonNegative is set to `false`, i.e., all entries in
* stateIsNonNegative are `false`, this function directly returns `x`,
* otherwise all entries of x are copied in to `amici::Model::x_pos_tmp_`
* and negative values are replaced by `0` where applicable.
*
* @param x State vector possibly containing negative values
* @return State vector with negative values replaced by `0` according to
* stateIsNonNegative
*/
const realtype *computeX_pos(AmiVector const& x);

/** All variables necessary for function evaluation */
ModelState state_;
Expand Down
111 changes: 69 additions & 42 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,10 @@ void Model::requireSensitivitiesForAllParameters() {
initializeVectors();
}

void Model::getExpression(gsl::span<realtype> w, const realtype t, const AmiVector &x)
void Model::getExpression(gsl::span<realtype> w, const realtype t,
const AmiVector &x)
{
fw(t, x.data());
fw(t, computeX_pos(x));
writeSlice(derived_state_.w_, w);
}

Expand Down Expand Up @@ -922,9 +923,9 @@ void Model::getEventSensitivity(gsl::span<realtype> sz, const int ie,
const realtype t, const AmiVector &x,
const AmiVectorArray &sx) {
for (int ip = 0; ip < nplist(); ip++) {
fsz(&sz[ip * nz], ie, t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), sx.data(ip),
plist(ip));
fsz(&sz[ip * nz], ie, t, computeX_pos(x),
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), sx.data(ip), plist(ip));
}
}

Expand All @@ -949,8 +950,9 @@ void Model::getEventRegularizationSensitivity(gsl::span<realtype> srz,
const AmiVector &x,
const AmiVectorArray &sx) {
for (int ip = 0; ip < nplist(); ip++) {
fsrz(&srz[ip * nz], ie, t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), sx.data(ip),
fsrz(&srz[ip * nz], ie, t, computeX_pos(x),
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), sx.data(ip),
plist(ip));
}
}
Expand Down Expand Up @@ -1072,29 +1074,31 @@ void Model::getEventTimeSensitivity(std::vector<realtype> &stau,
std::fill(stau.begin(), stau.end(), 0.0);

for (int ip = 0; ip < nplist(); ip++) {
fstau(&stau.at(ip), t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), sx.data(ip),
fstau(&stau.at(ip), t, computeX_pos(x),
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), sx.data(ip),
plist(ip), ie);
}
}

void Model::addStateEventUpdate(AmiVector &x, const int ie, const realtype t,
const AmiVector &xdot,
const AmiVector &xdot_old) {

derived_state_.deltax_.assign(nx_solver, 0.0);

// compute update
fdeltax(derived_state_.deltax_.data(), t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), ie, xdot.data(),
xdot_old.data());
fdeltax(derived_state_.deltax_.data(), t, computeX_pos(x),
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), ie, xdot.data(), xdot_old.data());

if (always_check_finite_) {
app->checkFinite(derived_state_.deltax_, "deltax");
}

// update
amici_daxpy(nx_solver, 1.0, derived_state_.deltax_.data(), 1, x.data(), 1);
amici_daxpy(nx_solver, 1.0, derived_state_.deltax_.data(), 1, x.data(),
dweindl marked this conversation as resolved.
Show resolved Hide resolved
1);
}

void Model::addStateSensitivityEventUpdate(AmiVectorArray &sx, const int ie,
Expand All @@ -1111,7 +1115,8 @@ void Model::addStateSensitivityEventUpdate(AmiVectorArray &sx, const int ie,

// compute update
fdeltasx(derived_state_.deltasx_.data(), t, x_old.data(),
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.unscaledParameters.data(),
state_.fixedParameters.data(),
state_.h.data(), derived_state_.w_.data(), plist(ip), ie,
xdot.data(), xdot_old.data(), sx.data(ip), &stau.at(ip));

Expand All @@ -1132,7 +1137,7 @@ void Model::addAdjointStateEventUpdate(AmiVector &xB, const int ie,
derived_state_.deltaxB_.assign(nx_solver, 0.0);

// compute update
fdeltaxB(derived_state_.deltaxB_.data(), t, x.data(),
fdeltaxB(derived_state_.deltaxB_.data(), t, computeX_pos(x),
state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), ie, xdot.data(),
xdot_old.data(), xB.data());
Expand All @@ -1154,7 +1159,8 @@ void Model::addAdjointQuadratureEventUpdate(
for (int ip = 0; ip < nplist(); ip++) {
derived_state_.deltaqB_.assign(nJ, 0.0);

fdeltaqB(derived_state_.deltaqB_.data(), t, x.data(), state_.unscaledParameters.data(),
fdeltaqB(derived_state_.deltaqB_.data(), t, computeX_pos(x),
state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), plist(ip), ie,
xdot.data(), xdot_old.data(), xB.data());

Expand Down Expand Up @@ -1217,11 +1223,13 @@ void Model::fx0(AmiVector &x) {
void Model::fx0_fixedParameters(AmiVector &x) {
if (!getReinitializeFixedParameterInitialStates())
return;

/* we transform to the unreduced states x_rdata and then apply
x0_fixedparameters to (i) enable updates to states that were removed from
conservation laws and (ii) be able to correctly compute total abundances
after updating the state variables */
fx_rdata(derived_state_.x_rdata_.data(), x.data(), state_.total_cl.data());
fx_rdata(derived_state_.x_rdata_.data(), computeX_pos(x),
state_.total_cl.data());
fx0_fixedParameters(derived_state_.x_rdata_.data(),
simulation_parameters_.tstart_,
state_.unscaledParameters.data(),
Expand All @@ -1242,7 +1250,7 @@ void Model::fsx0(AmiVectorArray &sx, const AmiVector &x) {
std::fill(derived_state_.sx_rdata_.begin(),
derived_state_.sx_rdata_.end(), 0.0);
fsx0(derived_state_.sx_rdata_.data(), simulation_parameters_.tstart_,
x.data(), state_.unscaledParameters.data(),
computeX_pos(x), state_.unscaledParameters.data(),
state_.fixedParameters.data(), plist(ip));
fsx_solver(sx.data(ip), derived_state_.sx_rdata_.data());
fstotal_cl(stcl, derived_state_.sx_rdata_.data(), plist(ip));
Expand All @@ -1258,7 +1266,8 @@ void Model::fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x) {
stcl = &state_.stotal_cl.at(plist(ip) * ncl());
fsx_rdata(derived_state_.sx_rdata_.data(), sx.data(ip), stcl, plist(ip));
fsx0_fixedParameters(derived_state_.sx_rdata_.data(),
simulation_parameters_.tstart_, x.data(),
simulation_parameters_.tstart_,
computeX_pos(x),
state_.unscaledParameters.data(),
state_.fixedParameters.data(),
plist(ip),
Expand All @@ -1271,7 +1280,7 @@ void Model::fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x) {
void Model::fsdx0() {}

void Model::fx_rdata(AmiVector &x_rdata, const AmiVector &x) {
fx_rdata(x_rdata.data(), x.data(), state_.total_cl.data());
fx_rdata(x_rdata.data(), computeX_pos(x), state_.total_cl.data());
if (always_check_finite_)
checkFinite(x_rdata.getVector(), "x_rdata");
}
Expand Down Expand Up @@ -1336,12 +1345,14 @@ void Model::initializeVectors() {
void Model::fy(const realtype t, const AmiVector &x) {
if (!ny)
return;

auto x_pos = computeX_pos(x);

derived_state_.y_.assign(ny, 0.0);

fw(t, x.data());
fy(derived_state_.y_.data(), t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(),
fw(t, x_pos);
fy(derived_state_.y_.data(), t, x_pos,
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), derived_state_.w_.data());

if (always_check_finite_) {
Expand All @@ -1352,20 +1363,22 @@ void Model::fy(const realtype t, const AmiVector &x) {
void Model::fdydp(const realtype t, const AmiVector &x) {
if (!ny)
return;

auto x_pos = computeX_pos(x);

derived_state_.dydp_.assign(ny * nplist(), 0.0);
fw(t, x.data());
fdwdp(t, x.data());
fw(t, x_pos);
fdwdp(t, x_pos);

/* get dydp slice (ny) for current time and parameter */
for (int ip = 0; ip < nplist(); ip++)
if (pythonGenerated) {
fdydp(&derived_state_.dydp_.at(ip * ny), t, x.data(),
fdydp(&derived_state_.dydp_.at(ip * ny), t, x_pos,
state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), plist(ip),
derived_state_.w_.data(), state_.stotal_cl.data());
} else {
fdydp(&derived_state_.dydp_.at(ip * ny), t, x.data(),
fdydp(&derived_state_.dydp_.at(ip * ny), t, x_pos,
state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), plist(ip),
derived_state_.w_.data(), derived_state_.dwdp_.data());
Expand All @@ -1379,14 +1392,17 @@ void Model::fdydp(const realtype t, const AmiVector &x) {
void Model::fdydx(const realtype t, const AmiVector &x) {
if (!ny)
return;

auto x_pos = computeX_pos(x);

derived_state_.dydx_.assign(ny * nx_solver, 0.0);

fw(t, x.data());
fdwdx(t, x.data());
fdydx(derived_state_.dydx_.data(), t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(),
derived_state_.w_.data(), derived_state_.dwdx_.data());
fw(t, x_pos);
fdwdx(t, x_pos);
fdydx(derived_state_.dydx_.data(), t, x_pos,
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), derived_state_.w_.data(),
derived_state_.dwdx_.data());

if (always_check_finite_) {
app->checkFinite(derived_state_.dydx_, "dydx");
Expand Down Expand Up @@ -1612,8 +1628,9 @@ void Model::fz(const int ie, const realtype t, const AmiVector &x) {

derived_state_.z_.assign(nz, 0.0);

fz(derived_state_.z_.data(), ie, t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data());
fz(derived_state_.z_.data(), ie, t, computeX_pos(x),
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data());
}

void Model::fdzdp(const int ie, const realtype t, const AmiVector &x) {
Expand All @@ -1623,7 +1640,7 @@ void Model::fdzdp(const int ie, const realtype t, const AmiVector &x) {
derived_state_.dzdp_.assign(nz * nplist(), 0.0);

for (int ip = 0; ip < nplist(); ip++) {
fdzdp(derived_state_.dzdp_.data(), ie, t, x.data(),
fdzdp(derived_state_.dzdp_.data(), ie, t, computeX_pos(x),
state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), plist(ip));
}
Expand All @@ -1639,8 +1656,9 @@ void Model::fdzdx(const int ie, const realtype t, const AmiVector &x) {

derived_state_.dzdx_.assign(nz * nx_solver, 0.0);

fdzdx(derived_state_.dzdx_.data(), ie, t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data());
fdzdx(derived_state_.dzdx_.data(), ie, t, computeX_pos(x),
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data());

if (always_check_finite_) {
app->checkFinite(derived_state_.dzdx_, "dzdx");
Expand All @@ -1651,7 +1669,7 @@ void Model::frz(const int ie, const realtype t, const AmiVector &x) {

derived_state_.rz_.assign(nz, 0.0);

frz(derived_state_.rz_.data(), ie, t, x.data(),
frz(derived_state_.rz_.data(), ie, t, computeX_pos(x),
state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data());
}
Expand All @@ -1663,7 +1681,7 @@ void Model::fdrzdp(const int ie, const realtype t, const AmiVector &x) {
derived_state_.drzdp_.assign(nz * nplist(), 0.0);

for (int ip = 0; ip < nplist(); ip++) {
fdrzdp(derived_state_.drzdp_.data(), ie, t, x.data(),
fdrzdp(derived_state_.drzdp_.data(), ie, t, computeX_pos(x),
state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), plist(ip));
}
Expand All @@ -1679,8 +1697,9 @@ void Model::fdrzdx(const int ie, const realtype t, const AmiVector &x) {

derived_state_.drzdx_.assign(nz * nx_solver, 0.0);

fdrzdx(derived_state_.drzdx_.data(), ie, t, x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data());
fdrzdx(derived_state_.drzdx_.data(), ie, t, computeX_pos(x),
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data());

if (always_check_finite_) {
app->checkFinite(derived_state_.drzdx_, "drzdx");
Expand Down Expand Up @@ -2103,6 +2122,14 @@ const_N_Vector Model::computeX_pos(const_N_Vector x) {
return x;
}

const realtype *Model::computeX_pos(AmiVector const& x) {
if (any_state_non_negative_) {
computeX_pos(x.getNVector());
return derived_state_.x_pos_tmp_.data();
}
return x.data();
}

void Model::setReinitializationStateIdxs(std::vector<int> const& idxs)
{
for(auto idx: idxs) {
Expand Down