Skip to content

Commit

Permalink
Adds standard deviation calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
TinyMarsh committed Dec 9, 2024
1 parent 667200c commit b8e6bf4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
93 changes: 93 additions & 0 deletions src/HealthGPS/analysis_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,76 @@ void AnalysisModule::calculate_population_statistics(RuntimeContext &context) co
}
}

void AnalysisModule::calculate_standard_deviation(RuntimeContext &context) const {

// Accumulate squared deviations from mean.
auto accumulate_squared_diffs = [&](const std::string &chan, const auto &person, double value) {
size_t index = calculate_index(person);
const double mean = calculated_stats_[index + channel_index_.at("mean_" + chan)];
const double diff = value - mean;
calculated_stats_[index + channel_index_.at("std_" + chan)] += diff * diff;
};

auto current_time = static_cast<unsigned int>(context.time_now());
for (const auto &person : context.population()) {
unsigned int age = person.age;
core::Gender sex = person.gender;

if (!person.is_active()) {
if (!person.is_alive() && person.time_of_death() == current_time) {
float expcted_life = definition_.life_expectancy().at(context.time_now(), sex);
double yll = std::max(expcted_life - age, 0.0f) * DALY_UNITS;
accumulate_squared_diffs("yll", person, yll);
accumulate_squared_diffs("daly", person, yll);
}

continue;
}

double dw = calculate_disability_weight(person);
double yld = dw * DALY_UNITS;
accumulate_squared_diffs("yld", person, yld);
accumulate_squared_diffs("daly", person, yld);

for (const auto &factor : context.mapping().entries()) {
const double value = person.get_risk_factor_value(factor.key());
accumulate_squared_diffs(factor.key().to_string(), person, value);
}
}

// Calculate in-place standard deviation.
auto divide_by_count_sqrt = [&](const std::string &chan, core::Gender sex, int age,

Check failure on line 463 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-24.04, gcc-latest, false)

unused parameter ‘sex’ [-Werror=unused-parameter]

Check failure on line 463 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-24.04, gcc-latest, false)

unused parameter ‘age’ [-Werror=unused-parameter]

Check failure on line 463 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-24.04, gcc-latest, false)

unused parameter ‘sex’ [-Werror=unused-parameter]

Check failure on line 463 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-24.04, gcc-latest, false)

unused parameter ‘age’ [-Werror=unused-parameter]
double count, std::vector<double> &factor_values) {
const double sum =
calculated_stats_[calculate_index(factor_values) + channel_index_.at("std_" + chan)];
const double std = std::sqrt(sum / count);
calculated_stats_[calculate_index(factor_values) + channel_index_.at("std_" + chan)] = std;
};

// For each age group in the analysis...
const auto age_range = context.age_range();
for (int age = age_range.lower(); age <= age_range.upper(); age++) {
std::vector<double> factor_values_male = {1.0, static_cast<double>(age)};
std::vector<double> factor_values_female = {0.0, static_cast<double>(age)};
double count_F =
calculated_stats_[calculate_index(factor_values_female) + channel_index_.at("count")];
double count_M =
calculated_stats_[calculate_index(factor_values_male) + channel_index_.at("count")];
double deaths_F =
calculated_stats_[calculate_index(factor_values_female) + channel_index_.at("deaths")];
double deaths_M =
calculated_stats_[calculate_index(factor_values_male) + channel_index_.at("deaths")];

// Calculate in-place factor standard deviation.
for (const auto &factor : context.mapping().entries()) {
divide_by_count_sqrt(factor.key().to_string(), core::Gender::female, age,
(count_F + deaths_F), factor_values_female);
divide_by_count_sqrt(factor.key().to_string(), core::Gender::male, age,
(count_M + deaths_M), factor_values_male);
}
}
}

// NOLINTBEGIN(readability-function-cognitive-complexity)
void AnalysisModule::calculate_population_statistics(RuntimeContext &context,
DataSeries &series) const {
Expand Down Expand Up @@ -680,6 +750,29 @@ size_t AnalysisModule::calculate_index(const Person &person) const {
return index;
}

size_t AnalysisModule::calculate_index(const std::vector<double> &factor_values) const {
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = factor_values[i];
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(), size_t{1},
std::multiplies<>());
index += bin_indices[i] * accumulated_bins * num_stats_to_calc_;
}
index += bin_indices.back() * num_stats_to_calc_;

return index;
}

std::unique_ptr<AnalysisModule> build_analysis_module(Repository &repository,
const ModelInput &config) {
auto analysis_entity = repository.manager().get_disease_analysis(config.settings().country());
Expand Down
9 changes: 9 additions & 0 deletions src/HealthGPS/analysis_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ class AnalysisModule final : public UpdatableModule {
/// @return The index in `calculated_stats_`
size_t calculate_index(const Person &person) const;

/// @brief Calculates the bin index in `calculated_stats_` for a given set of factor values
/// @param factor_values The factor values to calculate the index for
/// @return The index in `calculated_stats_`
size_t calculate_index(const std::vector<double> &factor_values) const;

/// @brief Calculates the standard deviation of factors given data series containing means
/// @param context The runtime context
void calculate_standard_deviation(RuntimeContext &context) const;

/// @brief Calculates the standard deviation of factors given data series containing means
/// @param context The runtime context
/// @param series The data series containing factor means
Expand Down

0 comments on commit b8e6bf4

Please sign in to comment.