Skip to content

Commit

Permalink
Merge pull request #20 from DenisTitovLab/fix-more-performanse-issues…
Browse files Browse the repository at this point in the history
…-related-to-alpha-slices

Upgrade reverse and forward_selection_next_param_removal_codes() to make them faster
  • Loading branch information
Denis-Titov authored Jun 8, 2024
2 parents 4562399 + 2bbfbfc commit fdcdab5
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 175 deletions.
220 changes: 78 additions & 142 deletions src/data_driven_rate_equation_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,15 @@ function data_driven_rate_equation_selection(


#generate param_removal_code_names by converting each mirror parameter for a and i into one name
#(e.g. K_a_Metabolite1 and K_i_Metabolite1 into K_Metabolite1)
#(e.g. K_a_Metabolite1 and K_i_Metabolite1 into K_allo_Metabolite1)
param_removal_code_names = (
[
Symbol(replace(string(param_name), "_a_" => "_allo_")) for
Symbol(replace(string(param_name), "_a_" => "_allo_", "Vmax_a" => "Vmax_allo")) for
param_name in param_names if
!contains(string(param_name), "_i") && param_name != :Vmax
]...,
)

#generate all possible combination of parameter removal codes
all_param_removal_codes = calculate_all_parameter_removal_codes(param_names)
num_alpha_params = count(occursin.("alpha", string.([param_names...])))
#check that range_number_params within bounds of minimal and maximal number of parameters
@assert range_number_params[1] >= length(param_names) - length(param_removal_code_names) "starting range_number_params cannot be below $(length(param_names) - length(param_removal_code_names))"
Expand All @@ -55,37 +53,37 @@ function data_driven_rate_equation_selection(
elseif !forward_model_selection
num_param_range = (range_number_params[1]):1:range_number_params[2]
end

#calculate starting_param_removal_codes num_param_range[1] parameters
all_param_removal_codes = calculate_all_parameter_removal_codes(param_names)
starting_param_removal_codes = calculate_all_parameter_removal_codes_w_num_params(
num_param_range[1],
all_param_removal_codes,
param_names,
num_alpha_params,
)

previous_param_removal_codes = starting_param_removal_codes
nt_param_removal_codes = starting_param_removal_codes
nt_previous_param_removal_codes = similar(nt_param_removal_codes)
println("About to start loop with num_params: $num_param_range")
df_train_results = DataFrame()
df_test_results = DataFrame()
for num_params in num_param_range
println("Running loop with num_params: $num_params")

#calculate param_removal_codes for `num_params` given `all_param_removal_codes` and fixed params from previous `num_params`
if forward_model_selection
nt_param_removal_codes = forward_selection_next_param_removal_codes(
all_param_removal_codes,
previous_param_removal_codes,
num_params,
param_names,
param_removal_code_names,
)
elseif !forward_model_selection
nt_param_removal_codes = reverse_selection_next_param_removal_codes(
all_param_removal_codes,
previous_param_removal_codes,
num_params,
param_names,
param_removal_code_names,
)
if num_params != num_param_range[1]
if forward_model_selection
nt_param_removal_codes = forward_selection_next_param_removal_codes(
nt_previous_param_removal_codes,
num_alpha_params,
)
elseif !forward_model_selection
nt_param_removal_codes = reverse_selection_next_param_removal_codes(
nt_previous_param_removal_codes,
num_alpha_params,
)
end
end
#pmap over nt_param_removal_codes for a given `num_params` return rescaled and nt_param_subset added
results_array = pmap(
Expand Down Expand Up @@ -114,14 +112,19 @@ function data_driven_rate_equation_selection(

#if all train_loss are Inf, then skip to next loop
if all(df_results.train_loss .== Inf)
previous_param_removal_codes = values.(df_results.nt_param_removal_codes)
nt_previous_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in values.(df_results.nt_param_removal_codes)
]
continue
end

#store top 10% for next loop as `previous_param_removal_codes`
filter!(row -> row.train_loss < 1.1 * minimum(df_results.train_loss), df_results)
previous_param_removal_codes = values.(df_results.nt_param_removal_codes)

nt_previous_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in values.(df_results.nt_param_removal_codes)
]
#calculate loocv test loss for top subset for each `num_params`
best_nt_param_removal_code =
df_results.nt_param_removal_codes[argmin(df_results.train_loss)]
Expand Down Expand Up @@ -289,10 +292,10 @@ function param_subset_select(params, param_names, nt_param_removal_code)
for param_choice in keys(nt_param_removal_code)
if startswith(string(param_choice), "L") && nt_param_removal_code[param_choice] == 1
params_dict[:L] = 0.0
elseif startswith(string(param_choice), "Vmax") &&
elseif startswith(string(param_choice), "Vmax_allo") &&
nt_param_removal_code[param_choice] == 1
params_dict[:Vmax_i] = params_dict[:Vmax_a]
elseif startswith(string(param_choice), "Vmax") &&
elseif startswith(string(param_choice), "Vmax_allo") &&
nt_param_removal_code[param_choice] == 2
global params_dict[:Vmax_i] = 0.0
elseif startswith(string(param_choice), "K_allo") &&
Expand Down Expand Up @@ -335,133 +338,66 @@ function param_subset_select(params, param_names, nt_param_removal_code)
end

"""
Calculate `nt_param_removal_codes` with `num_params` including non-zero term combinations for codes (excluding alpha terms) in each `previous_param_removal_codes` that has `num_params-1`
Calculate `nt_param_removal_codes` with `num_params` including non-zero term combinations for codes (excluding alpha terms) in each `nt_previous_param_removal_codes` that has `num_params-1`
"""
function forward_selection_next_param_removal_codes(
all_param_removal_codes,
previous_param_removal_codes,
num_params,
param_names,
param_removal_code_names,
nt_previous_param_removal_codes::Vector{T} where {T<:NamedTuple},
num_alpha_params,
)

num_alpha_params = count(occursin.("alpha", string.([param_names...])))
@assert all([
(
length(param_names) - num_alpha_params -
sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params + 1
) || (
length(param_names) - num_alpha_params -
sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params
) for param_removal_code in previous_param_removal_codes
])
previous_param_subset_masks = unique([
(
mask = (
(previous_param_removal_code[1:(end-num_alpha_params)] .== 0)...,
zeros(Int64, num_alpha_params)...,
),
non_zero_params = previous_param_removal_code .*
(previous_param_removal_code .!= 0),
) for previous_param_removal_code in previous_param_removal_codes
])

#select all param_removal_codes that yield equations with `num_params` number of parameters
all_param_codes_w_num_params = calculate_all_parameter_removal_codes_w_num_params(
num_params,
all_param_removal_codes,
param_names,
num_alpha_params,
)

#choose param_removal_codes with n_removed_params number of parameters removed that also contain non-zero elements from previous_param_removal_codes
param_removal_codes = []
for previous_param_subset_mask in previous_param_subset_masks
push!(
param_removal_codes,
unique([
param_code_w_num_params .* previous_param_subset_mask.mask .+
previous_param_subset_mask.non_zero_params for
param_code_w_num_params in all_param_codes_w_num_params #if (
# length(param_names) - num_alpha_params - sum(
# (param_code_w_num_params.*previous_param_subset_mask.mask.+previous_param_subset_mask.non_zero_params)[1:(end-num_alpha_params)] .>
# 0,
# )
# ) == num_params
])...,
)
param_removal_code_names = keys(nt_previous_param_removal_codes[1])
next_param_removal_codes = Vector{Vector{Int}}()
for previous_param_removal_code in nt_previous_param_removal_codes
i_cut_off = length(previous_param_removal_code) - num_alpha_params
for (i, code_element) in enumerate(previous_param_removal_code)
if i <= i_cut_off && code_element == 0
if param_removal_code_names[i] == :L
feasible_param_subset_codes = [1]
elseif startswith(string(param_removal_code_names[i]), "Vmax_allo")
feasible_param_subset_codes = [1, 2]
elseif startswith(string(param_removal_code_names[i]), "K_allo")
feasible_param_subset_codes = [1, 2, 3]
elseif startswith(string(param_removal_code_names[i]), "K_") &&
!startswith(string(param_removal_code_names[i]), "K_allo") &&
length(split(string(param_removal_code_names[i]), "_")) == 2
feasible_param_subset_codes = [1]
elseif startswith(string(param_removal_code_names[i]), "K_") &&
!startswith(string(param_removal_code_names[i]), "K_allo") &&
length(split(string(param_removal_code_names[i]), "_")) > 2
feasible_param_subset_codes = [1, 2]
end
for code_element in feasible_param_subset_codes
next_param_removal_code = collect(Int, previous_param_removal_code)
next_param_removal_code[i] = code_element
push!(next_param_removal_codes, next_param_removal_code)
end
end
end
end
nt_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in unique(param_removal_codes) if (
length(param_names) - num_alpha_params - sum(x[1:(end-num_alpha_params)] .> 0)
) == num_params
]
nt_param_removal_codes =
[NamedTuple{param_removal_code_names}(x) for x in unique(next_param_removal_codes)]
return nt_param_removal_codes
end

"""
Calculate `param_removal_codes` with `num_params` including zero term combinations for codes (excluding alpha terms) in each `previous_param_removal_codes` that has `num_params+1`
Use `nt_previous_param_removal_codes` to calculate `nt_next_param_removal_codes` that have one additional zero elements except for for elements <= `num_alpha_params` from the end
"""
function reverse_selection_next_param_removal_codes(
all_param_removal_codes,
previous_param_removal_codes,
num_params,
param_names,
param_removal_code_names,
nt_previous_param_removal_codes::Vector{T} where {T<:NamedTuple},
num_alpha_params::Int,
)

num_alpha_params = count(occursin.("alpha", string.([param_names...])))
@assert all([
(
length(param_names) - num_alpha_params -
sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params - 1
) || (
length(param_names) - num_alpha_params -
sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params
) for param_removal_code in previous_param_removal_codes
])
previous_param_subset_masks = unique([
(
mask = [
(previous_param_removal_code[1:(end-num_alpha_params)] .== 0)...,
zeros(Int64, num_alpha_params)...,
],
non_zero_params = previous_param_removal_code .*
(previous_param_removal_code .!= 0),
) for previous_param_removal_code in previous_param_removal_codes
])

#select all codes that yield equations with `num_params` number of parameters
all_param_codes_w_num_params = calculate_all_parameter_removal_codes_w_num_params(
num_params,
all_param_removal_codes,
param_names,
num_alpha_params,
)

#choose param_removal_codes with n_removed_params number of parameters removed that also contain non-zero elements from previous_param_removal_codes
param_removal_codes = []
for previous_param_subset_mask in previous_param_subset_masks
push!(
param_removal_codes,
unique([
previous_param_subset_mask.non_zero_params .*
(param_code_w_num_params .!= 0) for
param_code_w_num_params in all_param_codes_w_num_params #if (
# length(param_names) - num_alpha_params - sum(
# (previous_param_subset_mask.non_zero_params.*(param_code_w_num_params.!=0))[1:(end-num_alpha_params)] .>
# 0,
# )
# ) == num_params
])...,
)
param_removal_code_names = keys(nt_previous_param_removal_codes[1])
next_param_removal_codes = Vector{Vector{Int}}()
for previous_param_removal_code in nt_previous_param_removal_codes
i_cut_off = length(previous_param_removal_code) - num_alpha_params
for (i, code_element) in enumerate(previous_param_removal_code)
if i <= i_cut_off && code_element != 0
next_param_removal_code = collect(Int, previous_param_removal_code)
next_param_removal_code[i] = 0
push!(next_param_removal_codes, next_param_removal_code)
end
end
end
nt_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in unique(param_removal_codes) if (
length(param_names) - num_alpha_params - sum(x[1:(end-num_alpha_params)] .> 0)
) == num_params
]
nt_param_removal_codes =
[NamedTuple{param_removal_code_names}(x) for x in unique(next_param_removal_codes)]
return nt_param_removal_codes
end
Loading

0 comments on commit fdcdab5

Please sign in to comment.