Skip to content

Commit

Permalink
Merge pull request #230 from imperialCHEPI/sector_and_income
Browse files Browse the repository at this point in the history
Sector and income
  • Loading branch information
jamesturner246 authored Oct 27, 2023
2 parents a59f4cd + 1ebaa1a commit 1fda35b
Show file tree
Hide file tree
Showing 16 changed files with 375 additions and 138 deletions.
37 changes: 37 additions & 0 deletions example_new/KevinHall.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,43 @@
"Description": "string"
}
},
"RuralPrevalence": [
{
"Name": "Under18",
"Female": 0.65,
"Male": 0.65
},
{
"Name": "Over18",
"Female": 0.6,
"Male": 0.6
}
],
"IncomeModels": [
{
"Name": "Category_1",
"Intercept": 1.0,
"Coefficients": {}
},
{
"Name": "Category_2",
"Intercept": 0.75498278113169,
"Coefficients": {
"Gender": 0.0204338261883629,
"Over18": 0.0486699373325097,
"Sector": -0.706477682886734
}
},
{
"Name": "Category_3",
"Intercept": 1.48856946873517,
"Coefficients": {
"Gender": 0.000641255498596025,
"Over18": 0.206622749570132,
"Sector": -1.71313287940798
}
}
],
"AgeMeanHeight": {
"Male": [
0.498, 0.757, 0.868, 0.952, 1.023, 1.092, 1.155, 1.219, 1.28, 1.333,
Expand Down
94 changes: 65 additions & 29 deletions src/HealthGPS.Console/model_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,31 +129,50 @@ std::unique_ptr<hgps::StaticLinearModelDefinition>
load_staticlinear_risk_model_definition(const poco::json &opt, const host::Configuration &config) {
MEASURE_FUNCTION();

// Risk factor linear models.
hgps::LinearModelParams models;
for (const auto &factor : opt["RiskFactorModels"]) {
auto factor_key = factor["Name"].get<hgps::core::Identifier>();
models.intercepts[factor_key] = factor["Intercept"].get<double>();
models.coefficients[factor_key] =
factor["Coefficients"].get<std::unordered_map<hgps::core::Identifier, double>>();
}

// Risk factor names and correlation matrix.
std::vector<hgps::core::Identifier> names;
// Risk factor linear models and correlation matrix.
std::vector<hgps::LinearModelParams> risk_factor_models;
const auto correlations_file_info =
host::get_file_info(opt["RiskFactorCorrelationFile"], config.root_path);
const auto correlations_table = load_datatable_from_csv(correlations_file_info);
Eigen::MatrixXd correlations{correlations_table.num_rows(), correlations_table.num_columns()};
for (size_t col = 0; col < correlations_table.num_columns(); col++) {
names.emplace_back(correlations_table.column(col).name());
for (size_t row = 0; row < correlations_table.num_rows(); row++) {
correlations(row, col) =
std::any_cast<double>(correlations_table.column(col).value(row));

for (size_t i = 0; i < opt["RiskFactorModels"].size(); i++) {
// Risk factor model.
const auto &factor = opt["RiskFactorModels"][i];
hgps::LinearModelParams model;
model.name = factor["Name"].get<hgps::core::Identifier>();
model.intercept = factor["Intercept"].get<double>();
model.coefficients =
factor["Coefficients"].get<std::unordered_map<hgps::core::Identifier, double>>();

// Check correlation matrix column name matches risk factor name.
auto column_name = hgps::core::Identifier{correlations_table.column(i).name()};
if (model.name != column_name) {
throw hgps::core::HgpsException{
fmt::format("Risk factor {} name ({}) does not match correlation matrix "
"column {} name ({})",
i, model.name.to_string(), i, column_name.to_string())};
}

// Write data structures.
for (size_t j = 0; j < correlations_table.num_rows(); j++) {
correlations(i, j) = std::any_cast<double>(correlations_table.column(i).value(j));
}
risk_factor_models.emplace_back(std::move(model));
}

// Check correlation matrix column count matches risk factor count.
if (opt["RiskFactorModels"].size() != correlations_table.num_columns()) {
throw hgps::core::HgpsException{
fmt::format("Risk factor count ({}) does not match correlation "
"matrix column count ({})",
opt["RiskFactorModels"].size(), correlations_table.num_columns())};
}

// Compute Cholesky decomposition of correlation matrix.
auto cholesky = Eigen::MatrixXd{Eigen::LLT<Eigen::MatrixXd>{correlations}.matrixL()};

return std::make_unique<hgps::StaticLinearModelDefinition>(std::move(names), std::move(models),
return std::make_unique<hgps::StaticLinearModelDefinition>(std::move(risk_factor_models),
std::move(cholesky));
}

Expand Down Expand Up @@ -253,25 +272,20 @@ load_ebhlm_risk_model_definition(const poco::json &opt) {
std::unique_ptr<hgps::KevinHallModelDefinition>
load_kevinhall_risk_model_definition(const poco::json &opt, const host::Configuration &config) {
MEASURE_FUNCTION();
std::unordered_map<hgps::core::Identifier, double> energy_equation;
std::unordered_map<hgps::core::Identifier, std::pair<double, double>> nutrient_ranges;
std::unordered_map<hgps::core::Identifier, std::map<hgps::core::Identifier, double>>
nutrient_equations;
std::unordered_map<hgps::core::Identifier, std::optional<double>> food_prices;
std::unordered_map<hgps::core::Gender, std::vector<double>> age_mean_height;

// Nutrient groups.
std::unordered_map<hgps::core::Identifier, double> energy_equation;
std::unordered_map<hgps::core::Identifier, hgps::core::DoubleInterval> nutrient_ranges;
for (const auto &nutrient : opt["Nutrients"]) {
auto nutrient_key = nutrient["Name"].get<hgps::core::Identifier>();
nutrient_ranges[nutrient_key] = nutrient["Range"].get<std::pair<double, double>>();
if (nutrient_ranges[nutrient_key].first > nutrient_ranges[nutrient_key].second) {
throw hgps::core::HgpsException{
fmt::format("Nutrient range is invalid: {}", nutrient_key.to_string())};
}
nutrient_ranges[nutrient_key] = nutrient["Range"].get<hgps::core::DoubleInterval>();
energy_equation[nutrient_key] = nutrient["Energy"].get<double>();
}

// Food groups.
std::unordered_map<hgps::core::Identifier, std::map<hgps::core::Identifier, double>>
nutrient_equations;
std::unordered_map<hgps::core::Identifier, std::optional<double>> food_prices;
for (const auto &food : opt["Foods"]) {
auto food_key = food["Name"].get<hgps::core::Identifier>();
food_prices[food_key] = food["Price"].get<std::optional<double>>();
Expand All @@ -291,7 +305,28 @@ load_kevinhall_risk_model_definition(const poco::json &opt, const host::Configur
const auto food_data_file_info = host::get_file_info(opt["FoodsDataFile"], config.root_path);
const auto food_data_table = load_datatable_from_csv(food_data_file_info);

// Rural sector prevalence for age groups and sex.
std::unordered_map<hgps::core::Identifier, std::unordered_map<hgps::core::Gender, double>>
rural_prevalence;
for (const auto &age_group : opt["RuralPrevalence"]) {
auto age_group_name = age_group["Name"].get<hgps::core::Identifier>();
rural_prevalence[age_group_name] = {{hgps::core::Gender::female, age_group["Female"]},
{hgps::core::Gender::male, age_group["Male"]}};
}

// Income models for different income classifications.
std::vector<hgps::LinearModelParams> income_models;
for (const auto &factor : opt["IncomeModels"]) {
hgps::LinearModelParams model;
model.name = factor["Name"].get<hgps::core::Identifier>();
model.intercept = factor["Intercept"].get<double>();
model.coefficients =
factor["Coefficients"].get<std::unordered_map<hgps::core::Identifier, double>>();
income_models.emplace_back(std::move(model));
}

// Load M/F average heights for age.
std::unordered_map<hgps::core::Gender, std::vector<double>> age_mean_height;
const auto max_age = static_cast<size_t>(config.settings.age_range.upper());
auto male_height = opt["AgeMeanHeight"]["Male"].get<std::vector<double>>();
auto female_height = opt["AgeMeanHeight"]["Female"].get<std::vector<double>>();
Expand All @@ -306,7 +341,8 @@ load_kevinhall_risk_model_definition(const poco::json &opt, const host::Configur

return std::make_unique<hgps::KevinHallModelDefinition>(
std::move(energy_equation), std::move(nutrient_ranges), std::move(nutrient_equations),
std::move(food_prices), std::move(age_mean_height));
std::move(food_prices), std::move(rural_prevalence), std::move(income_models),
std::move(age_mean_height));
}

std::pair<hgps::RiskFactorModelType, std::unique_ptr<hgps::RiskFactorModelDefinition>>
Expand Down
13 changes: 13 additions & 0 deletions src/HealthGPS.Core/forward_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ enum class DiseaseGroup : uint8_t {
cancer
};

/// @brief Enumerates sector types
enum class Sector : uint8_t {
/// @brief Unknown sector
unknown,

/// @brief Urban sector
urban,

/// @brief Rural sector
rural
};

/// @brief C++20 concept for numeric columns types
template <typename T>
concept Numerical = std::is_arithmetic_v<T>;
Expand All @@ -60,4 +72,5 @@ class DoubleDataTableColumn;
class IntegerDataTableColumn;

class DataTableColumnVisitor;

} // namespace hgps::core
23 changes: 15 additions & 8 deletions src/HealthGPS.Core/interval.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#pragma once

#include "HealthGPS.Core/exception.h"
#include "forward_type.h"
#include "string_util.h"

#include <algorithm>
#include <fmt/format.h>

namespace hgps::core {
Expand All @@ -16,7 +20,11 @@ template <Numerical TYPE> class Interval {
/// @param lower_value Lower bound value
/// @param upper_value Upper bound value
explicit Interval(TYPE lower_value, TYPE upper_value)
: lower_{lower_value}, upper_{upper_value} {}
: lower_{lower_value}, upper_{upper_value} {
if (lower_ > upper_) {
throw HgpsException(fmt::format("Invalid interval: {}-{}", lower_, upper_));
}
}

/// @brief Gets the interval lower bound
/// @return The lower bound value
Expand All @@ -33,13 +41,7 @@ template <Numerical TYPE> class Interval {
/// @brief Determines whether a value is in the Interval.
/// @param value The value to check
/// @return true if the value is in the interval; otherwise, false.
bool contains(TYPE value) const noexcept {
if (lower_ < upper_) {
return lower_ <= value && value <= upper_;
}

return lower_ >= value && value >= upper_;
}
bool contains(TYPE value) const noexcept { return lower_ <= value && value <= upper_; }

/// @brief Determines whether an Interval is inside this instance interval.
/// @param other The other Interval to check
Expand All @@ -48,6 +50,11 @@ template <Numerical TYPE> class Interval {
return contains(other.lower_) && contains(other.upper_);
}

/// @brief Clamp a given value to the interval boundaries
/// @param value The value to clamp
/// @return The clamped value
TYPE clamp(TYPE value) const noexcept { return std::clamp(value, lower_, upper_); }

/// @brief Convert this instance to a string representation
/// @return The equivalent string representation
std::string to_string() const noexcept { return fmt::format("{}-{}", lower_, upper_); }
Expand Down
2 changes: 1 addition & 1 deletion src/HealthGPS.Tests/AgeGenderTable.Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ TEST(TestHealthGPS_AgeGenderTable, CreateWithWrongRangerThrows) {
using namespace hgps;

auto negative_range = core::IntegerInterval(-1, 10);
auto inverted_range = core::IntegerInterval(10, 1);
auto inverted_range = core::IntegerInterval(1, 1);

ASSERT_THROW(create_age_gender_table<double>(negative_range), std::invalid_argument);
ASSERT_THROW(create_age_gender_table<double>(inverted_range), std::invalid_argument);
Expand Down
18 changes: 0 additions & 18 deletions src/HealthGPS.Tests/Interval.Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,6 @@ TEST(TestCore_Interval, CreatePositive) {
ASSERT_TRUE(animal.contains(dog));
}

TEST(TestCore_Interval, CreateNegative) {
using namespace hgps::core;
auto lower = 0;
auto upper = -10;
auto len = upper - lower;
auto mid = len / 2;
auto animal = IntegerInterval{lower, upper};
auto cat = IntegerInterval{mid, upper};
auto dog = IntegerInterval{lower, mid};

ASSERT_EQ(lower, animal.lower());
ASSERT_EQ(upper, animal.upper());
ASSERT_EQ(len, animal.length());
ASSERT_TRUE(animal.contains(mid));
ASSERT_TRUE(animal.contains(cat));
ASSERT_TRUE(animal.contains(dog));
}

TEST(TestCore_Interval, Comparable) {
using namespace hgps::core;

Expand Down
5 changes: 1 addition & 4 deletions src/HealthGPS/dynamic_hierarchical_linear_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ RiskFactorModelType DynamicHierarchicalLinearModel::type() const noexcept {
std::string DynamicHierarchicalLinearModel::name() const noexcept { return "Dynamic"; }

void DynamicHierarchicalLinearModel::generate_risk_factors(
[[maybe_unused]] RuntimeContext &context) {
throw core::HgpsException(
"DynamicHierarchicalLinearModel::generate_risk_factors not yet implemented.");
}
[[maybe_unused]] RuntimeContext &context) {}

void DynamicHierarchicalLinearModel::update_risk_factors(RuntimeContext &context) {
auto age_key = core::Identifier{"age"};
Expand Down
1 change: 0 additions & 1 deletion src/HealthGPS/dynamic_hierarchical_linear_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class DynamicHierarchicalLinearModel final : public RiskFactorModel {

std::string name() const noexcept override;

/// @throws std::logic_error the dynamic model does not generate risk factors.
void generate_risk_factors(RuntimeContext &context) override;

void update_risk_factors(RuntimeContext &context) override;
Expand Down
4 changes: 2 additions & 2 deletions src/HealthGPS/gender_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ GenderTable<int, TYPE> create_integer_gender_table(const core::IntegerInterval &
/// @tparam TYPE The values data type
/// @param age_range The age breakpoints range
/// @return A new instance of the AgeGenderTable class
/// @throws std::out_of_range for age range 'lower' of negative value or less than the 'upper' value
/// @throws std::out_of_range for age range 'lower' of negative value or equal to the 'upper' value
template <core::Numerical TYPE>
AgeGenderTable<TYPE> create_age_gender_table(const core::IntegerInterval &age_range) {
if (age_range.lower() < 0 || age_range.lower() >= age_range.upper()) {
if (age_range.lower() < 0 || age_range.lower() == age_range.upper()) {
throw std::invalid_argument(
"The 'age lower' value must be greater than zero and less than the 'age upper' value.");
}
Expand Down
Loading

0 comments on commit 1fda35b

Please sign in to comment.