Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade reverse and forward_selection_next_param_removal_codes() to make them faster #20

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 78 additions & 142 deletions src/data_driven_rate_equation_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,15 @@ function data_driven_rate_equation_selection(


#generate param_removal_code_names by converting each mirror parameter for a and i into one name
#(e.g. K_a_Metabolite1 and K_i_Metabolite1 into K_Metabolite1)
#(e.g. K_a_Metabolite1 and K_i_Metabolite1 into K_allo_Metabolite1)
param_removal_code_names = (
[
Symbol(replace(string(param_name), "_a_" => "_allo_")) for
Symbol(replace(string(param_name), "_a_" => "_allo_", "Vmax_a" => "Vmax_allo")) for
param_name in param_names if
!contains(string(param_name), "_i") && param_name != :Vmax
]...,
)

#generate all possible combination of parameter removal codes
all_param_removal_codes = calculate_all_parameter_removal_codes(param_names)
num_alpha_params = count(occursin.("alpha", string.([param_names...])))
#check that range_number_params within bounds of minimal and maximal number of parameters
@assert range_number_params[1] >= length(param_names) - length(param_removal_code_names) "starting range_number_params cannot be below $(length(param_names) - length(param_removal_code_names))"
Expand All @@ -55,37 +53,37 @@ function data_driven_rate_equation_selection(
elseif !forward_model_selection
num_param_range = (range_number_params[1]):1:range_number_params[2]
end

#calculate starting_param_removal_codes num_param_range[1] parameters
all_param_removal_codes = calculate_all_parameter_removal_codes(param_names)
starting_param_removal_codes = calculate_all_parameter_removal_codes_w_num_params(
num_param_range[1],
all_param_removal_codes,
param_names,
num_alpha_params,
)

previous_param_removal_codes = starting_param_removal_codes
nt_param_removal_codes = starting_param_removal_codes
nt_previous_param_removal_codes = similar(nt_param_removal_codes)
println("About to start loop with num_params: $num_param_range")
df_train_results = DataFrame()
df_test_results = DataFrame()
for num_params in num_param_range
println("Running loop with num_params: $num_params")

#calculate param_removal_codes for `num_params` given `all_param_removal_codes` and fixed params from previous `num_params`
if forward_model_selection
nt_param_removal_codes = forward_selection_next_param_removal_codes(
all_param_removal_codes,
previous_param_removal_codes,
num_params,
param_names,
param_removal_code_names,
)
elseif !forward_model_selection
nt_param_removal_codes = reverse_selection_next_param_removal_codes(
all_param_removal_codes,
previous_param_removal_codes,
num_params,
param_names,
param_removal_code_names,
)
if num_params != num_param_range[1]
if forward_model_selection
nt_param_removal_codes = forward_selection_next_param_removal_codes(
nt_previous_param_removal_codes,
num_alpha_params,
)
elseif !forward_model_selection
nt_param_removal_codes = reverse_selection_next_param_removal_codes(
nt_previous_param_removal_codes,
num_alpha_params,
)
end
end
#pmap over nt_param_removal_codes for a given `num_params` return rescaled and nt_param_subset added
results_array = pmap(
Expand Down Expand Up @@ -114,14 +112,19 @@ function data_driven_rate_equation_selection(

#if all train_loss are Inf, then skip to next loop
if all(df_results.train_loss .== Inf)
previous_param_removal_codes = values.(df_results.nt_param_removal_codes)
nt_previous_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in values.(df_results.nt_param_removal_codes)
]
continue
end

#store top 10% for next loop as `previous_param_removal_codes`
filter!(row -> row.train_loss < 1.1 * minimum(df_results.train_loss), df_results)
previous_param_removal_codes = values.(df_results.nt_param_removal_codes)

nt_previous_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in values.(df_results.nt_param_removal_codes)
]
#calculate loocv test loss for top subset for each `num_params`
best_nt_param_removal_code =
df_results.nt_param_removal_codes[argmin(df_results.train_loss)]
Expand Down Expand Up @@ -289,10 +292,10 @@ function param_subset_select(params, param_names, nt_param_removal_code)
for param_choice in keys(nt_param_removal_code)
if startswith(string(param_choice), "L") && nt_param_removal_code[param_choice] == 1
params_dict[:L] = 0.0
elseif startswith(string(param_choice), "Vmax") &&
elseif startswith(string(param_choice), "Vmax_allo") &&
nt_param_removal_code[param_choice] == 1
params_dict[:Vmax_i] = params_dict[:Vmax_a]
elseif startswith(string(param_choice), "Vmax") &&
elseif startswith(string(param_choice), "Vmax_allo") &&
nt_param_removal_code[param_choice] == 2
global params_dict[:Vmax_i] = 0.0
elseif startswith(string(param_choice), "K_allo") &&
Expand Down Expand Up @@ -335,133 +338,66 @@ function param_subset_select(params, param_names, nt_param_removal_code)
end

"""
Calculate `nt_param_removal_codes` with `num_params` including non-zero term combinations for codes (excluding alpha terms) in each `previous_param_removal_codes` that has `num_params-1`
Calculate `nt_param_removal_codes` with `num_params` including non-zero term combinations for codes (excluding alpha terms) in each `nt_previous_param_removal_codes` that has `num_params-1`
"""
function forward_selection_next_param_removal_codes(
all_param_removal_codes,
previous_param_removal_codes,
num_params,
param_names,
param_removal_code_names,
nt_previous_param_removal_codes::Vector{T} where {T<:NamedTuple},
num_alpha_params,
)

num_alpha_params = count(occursin.("alpha", string.([param_names...])))
@assert all([
(
length(param_names) - num_alpha_params -
sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params + 1
) || (
length(param_names) - num_alpha_params -
sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params
) for param_removal_code in previous_param_removal_codes
])
previous_param_subset_masks = unique([
(
mask = (
(previous_param_removal_code[1:(end-num_alpha_params)] .== 0)...,
zeros(Int64, num_alpha_params)...,
),
non_zero_params = previous_param_removal_code .*
(previous_param_removal_code .!= 0),
) for previous_param_removal_code in previous_param_removal_codes
])

#select all param_removal_codes that yield equations with `num_params` number of parameters
all_param_codes_w_num_params = calculate_all_parameter_removal_codes_w_num_params(
num_params,
all_param_removal_codes,
param_names,
num_alpha_params,
)

#choose param_removal_codes with n_removed_params number of parameters removed that also contain non-zero elements from previous_param_removal_codes
param_removal_codes = []
for previous_param_subset_mask in previous_param_subset_masks
push!(
param_removal_codes,
unique([
param_code_w_num_params .* previous_param_subset_mask.mask .+
previous_param_subset_mask.non_zero_params for
param_code_w_num_params in all_param_codes_w_num_params #if (
# length(param_names) - num_alpha_params - sum(
# (param_code_w_num_params.*previous_param_subset_mask.mask.+previous_param_subset_mask.non_zero_params)[1:(end-num_alpha_params)] .>
# 0,
# )
# ) == num_params
])...,
)
param_removal_code_names = keys(nt_previous_param_removal_codes[1])
next_param_removal_codes = Vector{Vector{Int}}()
for previous_param_removal_code in nt_previous_param_removal_codes
i_cut_off = length(previous_param_removal_code) - num_alpha_params
for (i, code_element) in enumerate(previous_param_removal_code)
if i <= i_cut_off && code_element == 0
if param_removal_code_names[i] == :L
feasible_param_subset_codes = [1]
elseif startswith(string(param_removal_code_names[i]), "Vmax_allo")
feasible_param_subset_codes = [1, 2]
elseif startswith(string(param_removal_code_names[i]), "K_allo")
feasible_param_subset_codes = [1, 2, 3]
elseif startswith(string(param_removal_code_names[i]), "K_") &&
!startswith(string(param_removal_code_names[i]), "K_allo") &&
length(split(string(param_removal_code_names[i]), "_")) == 2
feasible_param_subset_codes = [1]
elseif startswith(string(param_removal_code_names[i]), "K_") &&
!startswith(string(param_removal_code_names[i]), "K_allo") &&
length(split(string(param_removal_code_names[i]), "_")) > 2
feasible_param_subset_codes = [1, 2]
end
for code_element in feasible_param_subset_codes
next_param_removal_code = collect(Int, previous_param_removal_code)
next_param_removal_code[i] = code_element
push!(next_param_removal_codes, next_param_removal_code)
end
end
end
end
nt_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in unique(param_removal_codes) if (
length(param_names) - num_alpha_params - sum(x[1:(end-num_alpha_params)] .> 0)
) == num_params
]
nt_param_removal_codes =
[NamedTuple{param_removal_code_names}(x) for x in unique(next_param_removal_codes)]
return nt_param_removal_codes
end

"""
Calculate `param_removal_codes` with `num_params` including zero term combinations for codes (excluding alpha terms) in each `previous_param_removal_codes` that has `num_params+1`
Use `nt_previous_param_removal_codes` to calculate `nt_next_param_removal_codes` that have one additional zero elements except for for elements <= `num_alpha_params` from the end
"""
function reverse_selection_next_param_removal_codes(
all_param_removal_codes,
previous_param_removal_codes,
num_params,
param_names,
param_removal_code_names,
nt_previous_param_removal_codes::Vector{T} where {T<:NamedTuple},
num_alpha_params::Int,
)

num_alpha_params = count(occursin.("alpha", string.([param_names...])))
@assert all([
(
length(param_names) - num_alpha_params -
sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params - 1
) || (
length(param_names) - num_alpha_params -
sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params
) for param_removal_code in previous_param_removal_codes
])
previous_param_subset_masks = unique([
(
mask = [
(previous_param_removal_code[1:(end-num_alpha_params)] .== 0)...,
zeros(Int64, num_alpha_params)...,
],
non_zero_params = previous_param_removal_code .*
(previous_param_removal_code .!= 0),
) for previous_param_removal_code in previous_param_removal_codes
])

#select all codes that yield equations with `num_params` number of parameters
all_param_codes_w_num_params = calculate_all_parameter_removal_codes_w_num_params(
num_params,
all_param_removal_codes,
param_names,
num_alpha_params,
)

#choose param_removal_codes with n_removed_params number of parameters removed that also contain non-zero elements from previous_param_removal_codes
param_removal_codes = []
for previous_param_subset_mask in previous_param_subset_masks
push!(
param_removal_codes,
unique([
previous_param_subset_mask.non_zero_params .*
(param_code_w_num_params .!= 0) for
param_code_w_num_params in all_param_codes_w_num_params #if (
# length(param_names) - num_alpha_params - sum(
# (previous_param_subset_mask.non_zero_params.*(param_code_w_num_params.!=0))[1:(end-num_alpha_params)] .>
# 0,
# )
# ) == num_params
])...,
)
param_removal_code_names = keys(nt_previous_param_removal_codes[1])
next_param_removal_codes = Vector{Vector{Int}}()
for previous_param_removal_code in nt_previous_param_removal_codes
i_cut_off = length(previous_param_removal_code) - num_alpha_params
for (i, code_element) in enumerate(previous_param_removal_code)
if i <= i_cut_off && code_element != 0
next_param_removal_code = collect(Int, previous_param_removal_code)
next_param_removal_code[i] = 0
push!(next_param_removal_codes, next_param_removal_code)
end
end
end
nt_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in unique(param_removal_codes) if (
length(param_names) - num_alpha_params - sum(x[1:(end-num_alpha_params)] .> 0)
) == num_params
]
nt_param_removal_codes =
[NamedTuple{param_removal_code_names}(x) for x in unique(next_param_removal_codes)]
return nt_param_removal_codes
end
Loading