diff --git a/example_new/KevinHall.json b/example_new/KevinHall.json index 0ae394a1f..c49ecc9d5 100644 --- a/example_new/KevinHall.json +++ b/example_new/KevinHall.json @@ -67,6 +67,43 @@ "Description": "string" } }, + "RuralPrevalence": [ + { + "Name": "Under18", + "Female": 0.65, + "Male": 0.65 + }, + { + "Name": "Over18", + "Female": 0.6, + "Male": 0.6 + } + ], + "IncomeModels": [ + { + "Name": "Category_1", + "Intercept": 1.0, + "Coefficients": {} + }, + { + "Name": "Category_2", + "Intercept": 0.75498278113169, + "Coefficients": { + "Gender": 0.0204338261883629, + "Over18": 0.0486699373325097, + "Sector": -0.706477682886734 + } + }, + { + "Name": "Category_3", + "Intercept": 1.48856946873517, + "Coefficients": { + "Gender": 0.000641255498596025, + "Over18": 0.206622749570132, + "Sector": -1.71313287940798 + } + } + ], "AgeMeanHeight": { "Male": [ 0.498, 0.757, 0.868, 0.952, 1.023, 1.092, 1.155, 1.219, 1.28, 1.333, diff --git a/src/HealthGPS.Console/model_parser.cpp b/src/HealthGPS.Console/model_parser.cpp index ca65fad00..43bc12272 100644 --- a/src/HealthGPS.Console/model_parser.cpp +++ b/src/HealthGPS.Console/model_parser.cpp @@ -129,31 +129,50 @@ std::unique_ptr load_staticlinear_risk_model_definition(const poco::json &opt, const host::Configuration &config) { MEASURE_FUNCTION(); - // Risk factor linear models. - hgps::LinearModelParams models; - for (const auto &factor : opt["RiskFactorModels"]) { - auto factor_key = factor["Name"].get(); - models.intercepts[factor_key] = factor["Intercept"].get(); - models.coefficients[factor_key] = - factor["Coefficients"].get>(); - } - - // Risk factor names and correlation matrix. - std::vector names; + // Risk factor linear models and correlation matrix. + std::vector risk_factor_models; const auto correlations_file_info = host::get_file_info(opt["RiskFactorCorrelationFile"], config.root_path); const auto correlations_table = load_datatable_from_csv(correlations_file_info); Eigen::MatrixXd correlations{correlations_table.num_rows(), correlations_table.num_columns()}; - for (size_t col = 0; col < correlations_table.num_columns(); col++) { - names.emplace_back(correlations_table.column(col).name()); - for (size_t row = 0; row < correlations_table.num_rows(); row++) { - correlations(row, col) = - std::any_cast(correlations_table.column(col).value(row)); + + for (size_t i = 0; i < opt["RiskFactorModels"].size(); i++) { + // Risk factor model. + const auto &factor = opt["RiskFactorModels"][i]; + hgps::LinearModelParams model; + model.name = factor["Name"].get(); + model.intercept = factor["Intercept"].get(); + model.coefficients = + factor["Coefficients"].get>(); + + // Check correlation matrix column name matches risk factor name. + auto column_name = hgps::core::Identifier{correlations_table.column(i).name()}; + if (model.name != column_name) { + throw hgps::core::HgpsException{ + fmt::format("Risk factor {} name ({}) does not match correlation matrix " + "column {} name ({})", + i, model.name.to_string(), i, column_name.to_string())}; } + + // Write data structures. + for (size_t j = 0; j < correlations_table.num_rows(); j++) { + correlations(i, j) = std::any_cast(correlations_table.column(i).value(j)); + } + risk_factor_models.emplace_back(std::move(model)); + } + + // Check correlation matrix column count matches risk factor count. + if (opt["RiskFactorModels"].size() != correlations_table.num_columns()) { + throw hgps::core::HgpsException{ + fmt::format("Risk factor count ({}) does not match correlation " + "matrix column count ({})", + opt["RiskFactorModels"].size(), correlations_table.num_columns())}; } + + // Compute Cholesky decomposition of correlation matrix. auto cholesky = Eigen::MatrixXd{Eigen::LLT{correlations}.matrixL()}; - return std::make_unique(std::move(names), std::move(models), + return std::make_unique(std::move(risk_factor_models), std::move(cholesky)); } @@ -253,25 +272,20 @@ load_ebhlm_risk_model_definition(const poco::json &opt) { std::unique_ptr load_kevinhall_risk_model_definition(const poco::json &opt, const host::Configuration &config) { MEASURE_FUNCTION(); - std::unordered_map energy_equation; - std::unordered_map> nutrient_ranges; - std::unordered_map> - nutrient_equations; - std::unordered_map> food_prices; - std::unordered_map> age_mean_height; // Nutrient groups. + std::unordered_map energy_equation; + std::unordered_map nutrient_ranges; for (const auto &nutrient : opt["Nutrients"]) { auto nutrient_key = nutrient["Name"].get(); - nutrient_ranges[nutrient_key] = nutrient["Range"].get>(); - if (nutrient_ranges[nutrient_key].first > nutrient_ranges[nutrient_key].second) { - throw hgps::core::HgpsException{ - fmt::format("Nutrient range is invalid: {}", nutrient_key.to_string())}; - } + nutrient_ranges[nutrient_key] = nutrient["Range"].get(); energy_equation[nutrient_key] = nutrient["Energy"].get(); } // Food groups. + std::unordered_map> + nutrient_equations; + std::unordered_map> food_prices; for (const auto &food : opt["Foods"]) { auto food_key = food["Name"].get(); food_prices[food_key] = food["Price"].get>(); @@ -291,7 +305,28 @@ load_kevinhall_risk_model_definition(const poco::json &opt, const host::Configur const auto food_data_file_info = host::get_file_info(opt["FoodsDataFile"], config.root_path); const auto food_data_table = load_datatable_from_csv(food_data_file_info); + // Rural sector prevalence for age groups and sex. + std::unordered_map> + rural_prevalence; + for (const auto &age_group : opt["RuralPrevalence"]) { + auto age_group_name = age_group["Name"].get(); + rural_prevalence[age_group_name] = {{hgps::core::Gender::female, age_group["Female"]}, + {hgps::core::Gender::male, age_group["Male"]}}; + } + + // Income models for different income classifications. + std::vector income_models; + for (const auto &factor : opt["IncomeModels"]) { + hgps::LinearModelParams model; + model.name = factor["Name"].get(); + model.intercept = factor["Intercept"].get(); + model.coefficients = + factor["Coefficients"].get>(); + income_models.emplace_back(std::move(model)); + } + // Load M/F average heights for age. + std::unordered_map> age_mean_height; const auto max_age = static_cast(config.settings.age_range.upper()); auto male_height = opt["AgeMeanHeight"]["Male"].get>(); auto female_height = opt["AgeMeanHeight"]["Female"].get>(); @@ -306,7 +341,8 @@ load_kevinhall_risk_model_definition(const poco::json &opt, const host::Configur return std::make_unique( std::move(energy_equation), std::move(nutrient_ranges), std::move(nutrient_equations), - std::move(food_prices), std::move(age_mean_height)); + std::move(food_prices), std::move(rural_prevalence), std::move(income_models), + std::move(age_mean_height)); } std::pair> diff --git a/src/HealthGPS.Core/forward_type.h b/src/HealthGPS.Core/forward_type.h index 4f7852919..8f175fb10 100644 --- a/src/HealthGPS.Core/forward_type.h +++ b/src/HealthGPS.Core/forward_type.h @@ -35,6 +35,18 @@ enum class DiseaseGroup : uint8_t { cancer }; +/// @brief Enumerates sector types +enum class Sector : uint8_t { + /// @brief Unknown sector + unknown, + + /// @brief Urban sector + urban, + + /// @brief Rural sector + rural +}; + /// @brief C++20 concept for numeric columns types template concept Numerical = std::is_arithmetic_v; @@ -60,4 +72,5 @@ class DoubleDataTableColumn; class IntegerDataTableColumn; class DataTableColumnVisitor; + } // namespace hgps::core diff --git a/src/HealthGPS.Core/interval.h b/src/HealthGPS.Core/interval.h index ef749b806..680ea315e 100644 --- a/src/HealthGPS.Core/interval.h +++ b/src/HealthGPS.Core/interval.h @@ -1,6 +1,10 @@ #pragma once + +#include "HealthGPS.Core/exception.h" #include "forward_type.h" #include "string_util.h" + +#include #include namespace hgps::core { @@ -16,7 +20,11 @@ template class Interval { /// @param lower_value Lower bound value /// @param upper_value Upper bound value explicit Interval(TYPE lower_value, TYPE upper_value) - : lower_{lower_value}, upper_{upper_value} {} + : lower_{lower_value}, upper_{upper_value} { + if (lower_ > upper_) { + throw HgpsException(fmt::format("Invalid interval: {}-{}", lower_, upper_)); + } + } /// @brief Gets the interval lower bound /// @return The lower bound value @@ -33,13 +41,7 @@ template class Interval { /// @brief Determines whether a value is in the Interval. /// @param value The value to check /// @return true if the value is in the interval; otherwise, false. - bool contains(TYPE value) const noexcept { - if (lower_ < upper_) { - return lower_ <= value && value <= upper_; - } - - return lower_ >= value && value >= upper_; - } + bool contains(TYPE value) const noexcept { return lower_ <= value && value <= upper_; } /// @brief Determines whether an Interval is inside this instance interval. /// @param other The other Interval to check @@ -48,6 +50,11 @@ template class Interval { return contains(other.lower_) && contains(other.upper_); } + /// @brief Clamp a given value to the interval boundaries + /// @param value The value to clamp + /// @return The clamped value + TYPE clamp(TYPE value) const noexcept { return std::clamp(value, lower_, upper_); } + /// @brief Convert this instance to a string representation /// @return The equivalent string representation std::string to_string() const noexcept { return fmt::format("{}-{}", lower_, upper_); } diff --git a/src/HealthGPS.Tests/AgeGenderTable.Test.cpp b/src/HealthGPS.Tests/AgeGenderTable.Test.cpp index a5114b9e7..da739652d 100644 --- a/src/HealthGPS.Tests/AgeGenderTable.Test.cpp +++ b/src/HealthGPS.Tests/AgeGenderTable.Test.cpp @@ -159,7 +159,7 @@ TEST(TestHealthGPS_AgeGenderTable, CreateWithWrongRangerThrows) { using namespace hgps; auto negative_range = core::IntegerInterval(-1, 10); - auto inverted_range = core::IntegerInterval(10, 1); + auto inverted_range = core::IntegerInterval(1, 1); ASSERT_THROW(create_age_gender_table(negative_range), std::invalid_argument); ASSERT_THROW(create_age_gender_table(inverted_range), std::invalid_argument); diff --git a/src/HealthGPS.Tests/Interval.Test.cpp b/src/HealthGPS.Tests/Interval.Test.cpp index 7f25d4488..72e7318c1 100644 --- a/src/HealthGPS.Tests/Interval.Test.cpp +++ b/src/HealthGPS.Tests/Interval.Test.cpp @@ -30,24 +30,6 @@ TEST(TestCore_Interval, CreatePositive) { ASSERT_TRUE(animal.contains(dog)); } -TEST(TestCore_Interval, CreateNegative) { - using namespace hgps::core; - auto lower = 0; - auto upper = -10; - auto len = upper - lower; - auto mid = len / 2; - auto animal = IntegerInterval{lower, upper}; - auto cat = IntegerInterval{mid, upper}; - auto dog = IntegerInterval{lower, mid}; - - ASSERT_EQ(lower, animal.lower()); - ASSERT_EQ(upper, animal.upper()); - ASSERT_EQ(len, animal.length()); - ASSERT_TRUE(animal.contains(mid)); - ASSERT_TRUE(animal.contains(cat)); - ASSERT_TRUE(animal.contains(dog)); -} - TEST(TestCore_Interval, Comparable) { using namespace hgps::core; diff --git a/src/HealthGPS/dynamic_hierarchical_linear_model.cpp b/src/HealthGPS/dynamic_hierarchical_linear_model.cpp index f30d0f097..b49ac10d5 100644 --- a/src/HealthGPS/dynamic_hierarchical_linear_model.cpp +++ b/src/HealthGPS/dynamic_hierarchical_linear_model.cpp @@ -26,10 +26,7 @@ RiskFactorModelType DynamicHierarchicalLinearModel::type() const noexcept { std::string DynamicHierarchicalLinearModel::name() const noexcept { return "Dynamic"; } void DynamicHierarchicalLinearModel::generate_risk_factors( - [[maybe_unused]] RuntimeContext &context) { - throw core::HgpsException( - "DynamicHierarchicalLinearModel::generate_risk_factors not yet implemented."); -} + [[maybe_unused]] RuntimeContext &context) {} void DynamicHierarchicalLinearModel::update_risk_factors(RuntimeContext &context) { auto age_key = core::Identifier{"age"}; diff --git a/src/HealthGPS/dynamic_hierarchical_linear_model.h b/src/HealthGPS/dynamic_hierarchical_linear_model.h index c8d7faa94..7871504c0 100644 --- a/src/HealthGPS/dynamic_hierarchical_linear_model.h +++ b/src/HealthGPS/dynamic_hierarchical_linear_model.h @@ -48,7 +48,6 @@ class DynamicHierarchicalLinearModel final : public RiskFactorModel { std::string name() const noexcept override; - /// @throws std::logic_error the dynamic model does not generate risk factors. void generate_risk_factors(RuntimeContext &context) override; void update_risk_factors(RuntimeContext &context) override; diff --git a/src/HealthGPS/gender_table.h b/src/HealthGPS/gender_table.h index f8d9b4a57..f2046ed87 100644 --- a/src/HealthGPS/gender_table.h +++ b/src/HealthGPS/gender_table.h @@ -157,10 +157,10 @@ GenderTable create_integer_gender_table(const core::IntegerInterval & /// @tparam TYPE The values data type /// @param age_range The age breakpoints range /// @return A new instance of the AgeGenderTable class -/// @throws std::out_of_range for age range 'lower' of negative value or less than the 'upper' value +/// @throws std::out_of_range for age range 'lower' of negative value or equal to the 'upper' value template AgeGenderTable create_age_gender_table(const core::IntegerInterval &age_range) { - if (age_range.lower() < 0 || age_range.lower() >= age_range.upper()) { + if (age_range.lower() < 0 || age_range.lower() == age_range.upper()) { throw std::invalid_argument( "The 'age lower' value must be greater than zero and less than the 'age upper' value."); } diff --git a/src/HealthGPS/kevin_hall_model.cpp b/src/HealthGPS/kevin_hall_model.cpp index b36bea1b9..8635ce215 100644 --- a/src/HealthGPS/kevin_hall_model.cpp +++ b/src/HealthGPS/kevin_hall_model.cpp @@ -3,7 +3,6 @@ #include "HealthGPS.Core/exception.h" -#include #include /* @@ -29,13 +28,17 @@ const core::Identifier CI_key{"Carbohydrate"}; KevinHallModel::KevinHallModel( const std::unordered_map &energy_equation, - const std::unordered_map> &nutrient_ranges, + const std::unordered_map &nutrient_ranges, const std::unordered_map> &nutrient_equations, const std::unordered_map> &food_prices, + const std::unordered_map> + &rural_prevalence, + const std::vector &income_models, const std::unordered_map> &age_mean_height) : energy_equation_{energy_equation}, nutrient_ranges_{nutrient_ranges}, nutrient_equations_{nutrient_equations}, food_prices_{food_prices}, + rural_prevalence_{rural_prevalence}, income_models_{income_models}, age_mean_height_{age_mean_height} { if (energy_equation_.empty()) { @@ -50,6 +53,12 @@ KevinHallModel::KevinHallModel( if (food_prices_.empty()) { throw core::HgpsException("Food price mapping is empty"); } + if (rural_prevalence_.empty()) { + throw core::HgpsException("Rural prevalence mapping is empty"); + } + if (income_models_.empty()) { + throw core::HgpsException("Income models list is empty"); + } if (age_mean_height_.empty()) { throw core::HgpsException("Age mean height mapping is empty"); } @@ -59,20 +68,35 @@ RiskFactorModelType KevinHallModel::type() const noexcept { return RiskFactorMod std::string KevinHallModel::name() const noexcept { return "Dynamic"; } -void KevinHallModel::generate_risk_factors([[maybe_unused]] RuntimeContext &context) { - throw core::HgpsException("KevinHallModel::generate_risk_factors not yet implemented."); +void KevinHallModel::generate_risk_factors(RuntimeContext &context) { + + // Initialise everyone. + for (auto &person : context.population()) { + initialise_sector(context, person); + initialise_income(context, person); + } } void KevinHallModel::update_risk_factors(RuntimeContext &context) { - hgps::Population &population = context.population(); - double mean_sim_body_weight = 0.0; - double mean_adjustment_coefficient = 0.0; + + // Initialise newborns and update others. + for (auto &person : context.population()) { + if (person.age == 0) { + initialise_sector(context, person); + initialise_income(context, person); + } else { + update_sector(context, person); + update_income(context, person); + } + } // TODO: Compute target body weight. const float target_BW = 100.0; // Trial run. - for (auto &person : population) { + double mean_sim_body_weight = 0.0; + double mean_adjustment_coefficient = 0.0; + for (auto &person : context.population()) { // Ignore if inactive. if (!person.is_active()) { continue; @@ -85,13 +109,13 @@ void KevinHallModel::update_risk_factors(RuntimeContext &context) { } // Compute model adjustment term. - const size_t population_size = population.current_active_size(); + const size_t population_size = context.population().current_active_size(); mean_sim_body_weight /= population_size; mean_adjustment_coefficient /= population_size; double shift = (target_BW - mean_sim_body_weight) / mean_adjustment_coefficient; // Final run. - for (auto &person : population) { + for (auto &person : context.population()) { // Ignore if inactive. if (!person.is_active()) { continue; @@ -114,6 +138,90 @@ void KevinHallModel::update_risk_factors(RuntimeContext &context) { } } +void KevinHallModel::initialise_sector(RuntimeContext &context, Person &person) const { + + // Get rural prevalence for age group and sex. + double prevalence; + if (person.age < 18) { + prevalence = rural_prevalence_.at("Under18"_id).at(person.gender); + } else { + prevalence = rural_prevalence_.at("Over18"_id).at(person.gender); + } + + // Sample the person's sector. + double rand = context.random().next_double(); + auto sector = rand < prevalence ? core::Sector::rural : core::Sector::urban; + person.sector = sector; +} + +void KevinHallModel::update_sector(RuntimeContext &context, Person &person) const { + + // Only update rural sector 18 year olds. + if ((person.age != 18) || (person.sector != core::Sector::rural)) { + return; + } + + // Get rural prevalence for age group and sex. + double prevalence_under18 = rural_prevalence_.at("Under18"_id).at(person.gender); + double prevalence_over18 = rural_prevalence_.at("Over18"_id).at(person.gender); + + // Compute random rural to urban transition. + double rand = context.random().next_double(); + double p_rural_to_urban = 1.0 - prevalence_over18 / prevalence_under18; + if (rand < p_rural_to_urban) { + person.sector = core::Sector::urban; + } +} + +void KevinHallModel::initialise_income(RuntimeContext &context, Person &person) const { + + // Compute logits for each income category. + auto logits = std::vector{}; + logits.reserve(income_models_.size()); + for (const auto &income_model : income_models_) { + logits.push_back(income_model.intercept); + for (const auto &[factor_name, coefficient] : income_model.coefficients) { + logits.back() += coefficient * person.get_risk_factor_value(factor_name); + } + } + + // Compute softmax probabilities for each income category. + auto e_logits = std::vector{}; + e_logits.reserve(income_models_.size()); + double e_logits_sum = 0.0; + for (const auto &logit : logits) { + e_logits.push_back(exp(logit)); + e_logits_sum += e_logits.back(); + } + + // Compute income category probabilities. + auto probabilities = std::vector{}; + probabilities.reserve(income_models_.size()); + for (const auto &e_logit : e_logits) { + probabilities.push_back(e_logit / e_logits_sum); + } + + // Compute income category. + double rand = context.random().next_double(); + for (size_t i = 0; i < income_models_.size(); i++) { + if (rand < probabilities[i]) { + person.income = income_models_[i].name; + return; + } + rand -= probabilities[i]; + } + + throw core::HgpsException("Logic Error: failed to initialise income category"); +} + +void KevinHallModel::update_income(RuntimeContext &context, Person &person) const { + + // Only update 18 year olds. + if (person.age == 18) { + initialise_income(context, person); + } +} + SimulatePersonState KevinHallModel::simulate_person(Person &person, double shift) const { // Initial simulated person state. const double H_0 = person.get_risk_factor_value(H_key); @@ -269,20 +377,18 @@ double KevinHallModel::compute_AT(double EI, double EI_0) const { return beta_AT * delta_EI; } -double KevinHallModel::bounded_nutrient_value(const core::Identifier &nutrient, - double value) const { - const auto &range = nutrient_ranges_.at(nutrient); - return std::clamp(range.first, range.second, value); -} - KevinHallModelDefinition::KevinHallModelDefinition( std::unordered_map energy_equation, - std::unordered_map> nutrient_ranges, + std::unordered_map nutrient_ranges, std::unordered_map> nutrient_equations, std::unordered_map> food_prices, + std::unordered_map> + rural_prevalence, + std::vector income_models, std::unordered_map> age_mean_height) : energy_equation_{std::move(energy_equation)}, nutrient_ranges_{std::move(nutrient_ranges)}, nutrient_equations_{std::move(nutrient_equations)}, food_prices_{std::move(food_prices)}, + rural_prevalence_{std::move(rural_prevalence)}, income_models_{std::move(income_models)}, age_mean_height_{std::move(age_mean_height)} { if (energy_equation_.empty()) { @@ -297,6 +403,12 @@ KevinHallModelDefinition::KevinHallModelDefinition( if (food_prices_.empty()) { throw core::HgpsException("Food prices mapping is empty"); } + if (rural_prevalence_.empty()) { + throw core::HgpsException("Rural prevalence mapping is empty"); + } + if (income_models_.empty()) { + throw core::HgpsException("Income models list is empty"); + } if (age_mean_height_.empty()) { throw core::HgpsException("Age mean height mapping is empty"); } @@ -304,7 +416,8 @@ KevinHallModelDefinition::KevinHallModelDefinition( std::unique_ptr KevinHallModelDefinition::create_model() const { return std::make_unique(energy_equation_, nutrient_ranges_, nutrient_equations_, - food_prices_, age_mean_height_); + food_prices_, rural_prevalence_, income_models_, + age_mean_height_); } } // namespace hgps diff --git a/src/HealthGPS/kevin_hall_model.h b/src/HealthGPS/kevin_hall_model.h index 816f12e53..182ea9ad4 100644 --- a/src/HealthGPS/kevin_hall_model.h +++ b/src/HealthGPS/kevin_hall_model.h @@ -1,5 +1,8 @@ #pragma once +// TODO: LinearModelParams (in static_linear_model.h) should be moved somewhere better. +#include "static_linear_model.h" + #include "interfaces.h" #include "mapping.h" @@ -50,33 +53,40 @@ class KevinHallModel final : public RiskFactorModel { public: /// @brief Initialises a new instance of the KevinHallModel class /// @param energy_equation The energy coefficients for each nutrient - /// @param nutrient_ranges The minimum and maximum nutrient values + /// @param nutrient_ranges The interval boundaries for nutrient values /// @param nutrient_equations The nutrient coefficients for each food group /// @param food_prices The unit price for each food group + /// @param rural_prevalence Rural sector prevalence for age groups and sex + /// @param income_models The income models for each income category /// @param age_mean_height The mean height at all ages (male and female) KevinHallModel( const std::unordered_map &energy_equation, - const std::unordered_map> &nutrient_ranges, + const std::unordered_map &nutrient_ranges, const std::unordered_map> &nutrient_equations, const std::unordered_map> &food_prices, + const std::unordered_map> &rural_prevalence, + const std::vector &income_models, const std::unordered_map> &age_mean_height); RiskFactorModelType type() const noexcept override; std::string name() const noexcept override; - /// @throws std::logic_error the dynamic model does not generate risk factors. void generate_risk_factors(RuntimeContext &context) override; void update_risk_factors(RuntimeContext &context) override; private: const std::unordered_map &energy_equation_; - const std::unordered_map> &nutrient_ranges_; + const std::unordered_map &nutrient_ranges_; const std::unordered_map> &nutrient_equations_; const std::unordered_map> &food_prices_; + const std::unordered_map> + &rural_prevalence_; + const std::vector &income_models_; const std::unordered_map> &age_mean_height_; // Model parameters. @@ -91,6 +101,26 @@ class KevinHallModel final : public RiskFactorModel { static constexpr double xi_Na = 3000.0; // Na from ECF changes (mg/L/day). static constexpr double xi_CI = 4000.0; // Na from carbohydrate changes (mg/day). + /// @brief Initialise the sector of a person + /// @param context The runtime context + /// @param person The person to initialise sector for + void initialise_sector(RuntimeContext &context, Person &person) const; + + /// @brief Update the sector of a person + /// @param context The runtime context + /// @param person The person to update sector for + void update_sector(RuntimeContext &context, Person &person) const; + + /// @brief Initialise the income category of a person + /// @param context The runtime context + /// @param person The person to initialise sector for + void initialise_income(RuntimeContext &context, Person &person) const; + + /// @brief Update the income category of a person + /// @param context The runtime context + /// @param person The person to update sector for + void update_income(RuntimeContext &context, Person &person) const; + /// @brief Simulates the energy balance model for a given person /// @param person The person to simulate /// @param shift Model adjustment term @@ -148,12 +178,6 @@ class KevinHallModel final : public RiskFactorModel { /// @param EI_0 The initial energy intake /// @return The computed adaptive thermogenesis double compute_AT(double EI, double EI_0) const; - - /// @brief Return the nutrient value bounded within its range - /// @param nutrient The nutrient Identifier - /// @param value The nutrient value to bound - /// @return The bounded nutrient value - double bounded_nutrient_value(const core::Identifier &nutrient, double value) const; }; /// @brief Defines the energy balance model data type @@ -161,16 +185,21 @@ class KevinHallModelDefinition final : public RiskFactorModelDefinition { public: /// @brief Initialises a new instance of the KevinHallModelDefinition class /// @param energy_equation The energy coefficients for each nutrient - /// @param nutrient_ranges The minimum and maximum nutrient values + /// @param nutrient_ranges The interval boundaries for nutrient values /// @param nutrient_equations The nutrient coefficients for each food group /// @param food_prices The unit price for each food group + /// @param rural_prevalence Rural sector prevalence for age groups and sex + /// @param income_models The income models for each income category /// @param age_mean_height The mean height at all ages (male and female) /// @throws std::invalid_argument for empty arguments KevinHallModelDefinition( std::unordered_map energy_equation, - std::unordered_map> nutrient_ranges, + std::unordered_map nutrient_ranges, std::unordered_map> nutrient_equations, std::unordered_map> food_prices, + std::unordered_map> + rural_prevalence, + std::vector income_models, std::unordered_map> age_mean_height); /// @brief Construct a new KevinHallModel from this definition @@ -179,9 +208,12 @@ class KevinHallModelDefinition final : public RiskFactorModelDefinition { private: std::unordered_map energy_equation_; - std::unordered_map> nutrient_ranges_; + std::unordered_map nutrient_ranges_; std::unordered_map> nutrient_equations_; std::unordered_map> food_prices_; + std::unordered_map> + rural_prevalence_; + std::vector income_models_; std::unordered_map> age_mean_height_; }; diff --git a/src/HealthGPS/person.cpp b/src/HealthGPS/person.cpp index 4a0690195..ee511357c 100644 --- a/src/HealthGPS/person.cpp +++ b/src/HealthGPS/person.cpp @@ -1,5 +1,7 @@ #include "person.h" +#include "HealthGPS.Core/exception.h" + namespace hgps { std::atomic Person::newUID{0}; @@ -10,6 +12,8 @@ std::map> Person::curren {"Age"_id, [](const Person &p) { return static_cast(p.age); }}, {"Age2"_id, [](const Person &p) { return pow(p.age, 2); }}, {"Age3"_id, [](const Person &p) { return pow(p.age, 3); }}, + {"Over18"_id, [](const Person &p) { return static_cast(p.over_18()); }}, + {"Sector"_id, [](const Person &p) { return p.sector_to_value(); }}, {"SES"_id, [](const Person &p) { return p.ses; }}, // HACK: ew, gross... allows us to mock risk factors we don't have data for yet @@ -56,14 +60,29 @@ double Person::get_risk_factor_value(const core::Identifier &key) const { throw std::out_of_range("Risk factor not found: " + key.to_string()); } -float Person::gender_to_value() const noexcept { +float Person::gender_to_value() const { + if (gender == core::Gender::unknown) { + throw core::HgpsException("Gender is unknown."); + } return gender == core::Gender::male ? 1.0f : 0.0f; } -std::string Person::gender_to_string() const noexcept { +std::string Person::gender_to_string() const { + if (gender == core::Gender::unknown) { + throw core::HgpsException("Gender is unknown."); + } return gender == core::Gender::male ? "male" : "female"; } +float Person::sector_to_value() const { + if (sector == core::Sector::unknown) { + throw core::HgpsException("Sector is unknown."); + } + return sector == core::Sector::urban ? 0.0f : 1.0f; +} + +bool Person::over_18() const noexcept { return age >= 18; } + void Person::emigrate(const unsigned int time) { if (!is_active()) { throw std::logic_error("Entity must be active prior to emigrate."); diff --git a/src/HealthGPS/person.h b/src/HealthGPS/person.h index 6c176cc8d..d7d89db96 100644 --- a/src/HealthGPS/person.h +++ b/src/HealthGPS/person.h @@ -56,9 +56,15 @@ struct Person { /// @brief Current age in years unsigned int age{}; + /// @brief Sector (region) assigned value + core::Sector sector{core::Sector::unknown}; + /// @brief Social-economic status (SES) assigned value double ses{}; + /// @brief Income category + core::Identifier income{}; + /// @brief Current risk factors values std::map risk_factors; @@ -94,11 +100,22 @@ struct Person { /// @brief Gets the gender enumeration as a number for analysis /// @return The gender associated value - float gender_to_value() const noexcept; + /// @throws HgpsException if gender is unknown + float gender_to_value() const; /// @brief Gets the gender enumeration name string /// @return The gender name - std::string gender_to_string() const noexcept; + /// @throws HgpsException if gender is unknown + std::string gender_to_string() const; + + /// @brief Gets the sector enumeration as a number + /// @return The sector value (0 for urban, 1 for rural) + /// @throws HgpsException if sector is unknown + float sector_to_value() const; + + /// @brief Check if person is an adult (18 or over) + /// @return true if person is 18 or over; else false + bool over_18() const noexcept; /// @brief Emigrate this instance from the virtual population /// @param time Migration time diff --git a/src/HealthGPS/riskfactor.cpp b/src/HealthGPS/riskfactor.cpp index 84b020f3d..3871067af 100644 --- a/src/HealthGPS/riskfactor.cpp +++ b/src/HealthGPS/riskfactor.cpp @@ -45,6 +45,9 @@ RiskFactorModel &RiskFactorModule::at(const RiskFactorModelType &model_type) con void RiskFactorModule::initialise_population(RuntimeContext &context) { auto &static_model = models_.at(RiskFactorModelType::Static); static_model->generate_risk_factors(context); + + auto &dynamic_model = models_.at(RiskFactorModelType::Dynamic); + dynamic_model->generate_risk_factors(context); } void RiskFactorModule::update_population(RuntimeContext &context) { diff --git a/src/HealthGPS/static_linear_model.cpp b/src/HealthGPS/static_linear_model.cpp index 51563481a..e6b7ff031 100644 --- a/src/HealthGPS/static_linear_model.cpp +++ b/src/HealthGPS/static_linear_model.cpp @@ -6,18 +6,13 @@ namespace hgps { -StaticLinearModel::StaticLinearModel(std::vector risk_factor_names, - LinearModelParams risk_factor_models, +StaticLinearModel::StaticLinearModel(std::vector risk_factor_models, Eigen::MatrixXd risk_factor_cholesky) - : risk_factor_names_{std::move(risk_factor_names)}, - risk_factor_models_{std::move(risk_factor_models)}, + : risk_factor_models_{std::move(risk_factor_models)}, risk_factor_cholesky_{std::move(risk_factor_cholesky)} { - if (risk_factor_names_.empty()) { - throw core::HgpsException("Risk factor names list is empty"); - } - if (risk_factor_models_.intercepts.empty() || risk_factor_models_.coefficients.empty()) { - throw core::HgpsException("Risk factor models mapping is incomplete"); + if (risk_factor_models_.empty()) { + throw core::HgpsException("Risk factor model list is empty"); } if (!risk_factor_cholesky_.allFinite()) { throw core::HgpsException("Risk factor Cholesky matrix contains non-finite values"); @@ -64,7 +59,7 @@ void StaticLinearModel::update_risk_factors(RuntimeContext &context) { Eigen::VectorXd StaticLinearModel::correlated_samples(RuntimeContext &context) { // Correlated samples using Cholesky decomposition. - Eigen::VectorXd samples{risk_factor_names_.size()}; + Eigen::VectorXd samples{risk_factor_models_.size()}; std::ranges::generate(samples, [&context] { return context.random().next_normal(0.0, 1.0); }); samples = risk_factor_cholesky_ * samples; @@ -80,30 +75,22 @@ Eigen::VectorXd StaticLinearModel::correlated_samples(RuntimeContext &context) { void StaticLinearModel::linear_approximation(Person &person) { // Approximate risk factor values for person with linear models. - for (const auto &factor_name : risk_factor_names_) { - double factor = risk_factor_models_.intercepts.at(factor_name); - const auto &coefficients = risk_factor_models_.coefficients.at(factor_name); - - for (const auto &[coefficient_name, coefficient_value] : coefficients) { + for (const auto &model : risk_factor_models_) { + double factor = model.intercept; + for (const auto &[coefficient_name, coefficient_value] : model.coefficients) { factor += coefficient_value * person.get_risk_factor_value(coefficient_name); } - - person.risk_factors[factor_name] = factor; + person.risk_factors[model.name] = factor; } } StaticLinearModelDefinition::StaticLinearModelDefinition( - std::vector risk_factor_names, LinearModelParams risk_factor_models, - Eigen::MatrixXd risk_factor_cholesky) - : risk_factor_names_{std::move(risk_factor_names)}, - risk_factor_models_{std::move(risk_factor_models)}, + std::vector risk_factor_models, Eigen::MatrixXd risk_factor_cholesky) + : risk_factor_models_{std::move(risk_factor_models)}, risk_factor_cholesky_{std::move(risk_factor_cholesky)} { - if (risk_factor_names_.empty()) { - throw core::HgpsException("Risk factor names list is empty"); - } - if (risk_factor_models_.intercepts.empty() || risk_factor_models_.coefficients.empty()) { - throw core::HgpsException("Risk factor models mapping is incomplete"); + if (risk_factor_models_.empty()) { + throw core::HgpsException("Risk factor model list is empty"); } if (!risk_factor_cholesky_.allFinite()) { throw core::HgpsException("Risk factor Cholesky matrix contains non-finite values"); @@ -111,8 +98,7 @@ StaticLinearModelDefinition::StaticLinearModelDefinition( } std::unique_ptr StaticLinearModelDefinition::create_model() const { - return std::make_unique(risk_factor_names_, risk_factor_models_, - risk_factor_cholesky_); + return std::make_unique(risk_factor_models_, risk_factor_cholesky_); } } // namespace hgps diff --git a/src/HealthGPS/static_linear_model.h b/src/HealthGPS/static_linear_model.h index 63dc4667f..3928dba86 100644 --- a/src/HealthGPS/static_linear_model.h +++ b/src/HealthGPS/static_linear_model.h @@ -9,8 +9,9 @@ namespace hgps { /// @brief Defines the linear model parameters used to initialise risk factors struct LinearModelParams { - std::unordered_map intercepts; - std::unordered_map> coefficients; + core::Identifier name; + double intercept; + std::unordered_map coefficients; }; /// @brief Implements the static linear model type @@ -19,12 +20,11 @@ struct LinearModelParams { class StaticLinearModel final : public RiskFactorModel { public: /// @brief Initialises a new instance of the StaticLinearModel class - /// @param risk_factor_names An ordered list of risk factor names /// @param risk_factor_models The linear models used to initialise a person's risk factor values /// @param risk_factor_cholesky The Cholesky decomposition of the risk factor correlation matrix /// @throws HgpsException for invalid arguments - StaticLinearModel(std::vector risk_factor_names, - LinearModelParams risk_factor_models, Eigen::MatrixXd risk_factor_cholesky); + StaticLinearModel(std::vector risk_factor_models, + Eigen::MatrixXd risk_factor_cholesky); RiskFactorModelType type() const noexcept override; @@ -39,8 +39,7 @@ class StaticLinearModel final : public RiskFactorModel { void linear_approximation(Person &person); private: - const std::vector risk_factor_names_; - const LinearModelParams risk_factor_models_; + const std::vector risk_factor_models_; const Eigen::MatrixXd risk_factor_cholesky_; }; @@ -48,12 +47,10 @@ class StaticLinearModel final : public RiskFactorModel { class StaticLinearModelDefinition final : public RiskFactorModelDefinition { public: /// @brief Initialises a new instance of the StaticLinearModelDefinition class - /// @param risk_factor_names An ordered list of risk factor names /// @param risk_factor_models The linear models used to initialise a person's risk factor values /// @param risk_factor_cholesky The Cholesky decomposition of the risk factor correlation matrix /// @throws HgpsException for invalid arguments - StaticLinearModelDefinition(std::vector risk_factor_names, - LinearModelParams risk_factor_models, + StaticLinearModelDefinition(std::vector risk_factor_models, Eigen::MatrixXd risk_factor_cholesky); /// @brief Construct a new StaticLinearModel from this definition @@ -61,8 +58,7 @@ class StaticLinearModelDefinition final : public RiskFactorModelDefinition { std::unique_ptr create_model() const override; private: - std::vector risk_factor_names_; - LinearModelParams risk_factor_models_; + std::vector risk_factor_models_; Eigen::MatrixXd risk_factor_cholesky_; };