diff --git a/src/data_driven_rate_equation_selection.jl b/src/data_driven_rate_equation_selection.jl index cd0dcfc..1e086a2 100644 --- a/src/data_driven_rate_equation_selection.jl +++ b/src/data_driven_rate_equation_selection.jl @@ -34,17 +34,15 @@ function data_driven_rate_equation_selection( #generate param_removal_code_names by converting each mirror parameter for a and i into one name - #(e.g. K_a_Metabolite1 and K_i_Metabolite1 into K_Metabolite1) + #(e.g. K_a_Metabolite1 and K_i_Metabolite1 into K_allo_Metabolite1) param_removal_code_names = ( [ - Symbol(replace(string(param_name), "_a_" => "_allo_")) for + Symbol(replace(string(param_name), "_a_" => "_allo_", "Vmax_a" => "Vmax_allo")) for param_name in param_names if !contains(string(param_name), "_i") && param_name != :Vmax ]..., ) - #generate all possible combination of parameter removal codes - all_param_removal_codes = calculate_all_parameter_removal_codes(param_names) num_alpha_params = count(occursin.("alpha", string.([param_names...]))) #check that range_number_params within bounds of minimal and maximal number of parameters @assert range_number_params[1] >= length(param_names) - length(param_removal_code_names) "starting range_number_params cannot be below $(length(param_names) - length(param_removal_code_names))" @@ -55,6 +53,9 @@ function data_driven_rate_equation_selection( elseif !forward_model_selection num_param_range = (range_number_params[1]):1:range_number_params[2] end + + #calculate starting_param_removal_codes num_param_range[1] parameters + all_param_removal_codes = calculate_all_parameter_removal_codes(param_names) starting_param_removal_codes = calculate_all_parameter_removal_codes_w_num_params( num_param_range[1], all_param_removal_codes, @@ -62,7 +63,8 @@ function data_driven_rate_equation_selection( num_alpha_params, ) - previous_param_removal_codes = starting_param_removal_codes + nt_param_removal_codes = starting_param_removal_codes + nt_previous_param_removal_codes = similar(nt_param_removal_codes) println("About to start loop with num_params: $num_param_range") df_train_results = DataFrame() df_test_results = DataFrame() @@ -70,22 +72,18 @@ function data_driven_rate_equation_selection( println("Running loop with num_params: $num_params") #calculate param_removal_codes for `num_params` given `all_param_removal_codes` and fixed params from previous `num_params` - if forward_model_selection - nt_param_removal_codes = forward_selection_next_param_removal_codes( - all_param_removal_codes, - previous_param_removal_codes, - num_params, - param_names, - param_removal_code_names, - ) - elseif !forward_model_selection - nt_param_removal_codes = reverse_selection_next_param_removal_codes( - all_param_removal_codes, - previous_param_removal_codes, - num_params, - param_names, - param_removal_code_names, - ) + if num_params != num_param_range[1] + if forward_model_selection + nt_param_removal_codes = forward_selection_next_param_removal_codes( + nt_previous_param_removal_codes, + num_alpha_params, + ) + elseif !forward_model_selection + nt_param_removal_codes = reverse_selection_next_param_removal_codes( + nt_previous_param_removal_codes, + num_alpha_params, + ) + end end #pmap over nt_param_removal_codes for a given `num_params` return rescaled and nt_param_subset added results_array = pmap( @@ -114,14 +112,19 @@ function data_driven_rate_equation_selection( #if all train_loss are Inf, then skip to next loop if all(df_results.train_loss .== Inf) - previous_param_removal_codes = values.(df_results.nt_param_removal_codes) + nt_previous_param_removal_codes = [ + NamedTuple{param_removal_code_names}(x) for + x in values.(df_results.nt_param_removal_codes) + ] continue end #store top 10% for next loop as `previous_param_removal_codes` filter!(row -> row.train_loss < 1.1 * minimum(df_results.train_loss), df_results) - previous_param_removal_codes = values.(df_results.nt_param_removal_codes) - + nt_previous_param_removal_codes = [ + NamedTuple{param_removal_code_names}(x) for + x in values.(df_results.nt_param_removal_codes) + ] #calculate loocv test loss for top subset for each `num_params` best_nt_param_removal_code = df_results.nt_param_removal_codes[argmin(df_results.train_loss)] @@ -289,10 +292,10 @@ function param_subset_select(params, param_names, nt_param_removal_code) for param_choice in keys(nt_param_removal_code) if startswith(string(param_choice), "L") && nt_param_removal_code[param_choice] == 1 params_dict[:L] = 0.0 - elseif startswith(string(param_choice), "Vmax") && + elseif startswith(string(param_choice), "Vmax_allo") && nt_param_removal_code[param_choice] == 1 params_dict[:Vmax_i] = params_dict[:Vmax_a] - elseif startswith(string(param_choice), "Vmax") && + elseif startswith(string(param_choice), "Vmax_allo") && nt_param_removal_code[param_choice] == 2 global params_dict[:Vmax_i] = 0.0 elseif startswith(string(param_choice), "K_allo") && @@ -335,133 +338,66 @@ function param_subset_select(params, param_names, nt_param_removal_code) end """ -Calculate `nt_param_removal_codes` with `num_params` including non-zero term combinations for codes (excluding alpha terms) in each `previous_param_removal_codes` that has `num_params-1` +Calculate `nt_param_removal_codes` with `num_params` including non-zero term combinations for codes (excluding alpha terms) in each `nt_previous_param_removal_codes` that has `num_params-1` """ function forward_selection_next_param_removal_codes( - all_param_removal_codes, - previous_param_removal_codes, - num_params, - param_names, - param_removal_code_names, + nt_previous_param_removal_codes::Vector{T} where {T<:NamedTuple}, + num_alpha_params, ) - - num_alpha_params = count(occursin.("alpha", string.([param_names...]))) - @assert all([ - ( - length(param_names) - num_alpha_params - - sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params + 1 - ) || ( - length(param_names) - num_alpha_params - - sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params - ) for param_removal_code in previous_param_removal_codes - ]) - previous_param_subset_masks = unique([ - ( - mask = ( - (previous_param_removal_code[1:(end-num_alpha_params)] .== 0)..., - zeros(Int64, num_alpha_params)..., - ), - non_zero_params = previous_param_removal_code .* - (previous_param_removal_code .!= 0), - ) for previous_param_removal_code in previous_param_removal_codes - ]) - - #select all param_removal_codes that yield equations with `num_params` number of parameters - all_param_codes_w_num_params = calculate_all_parameter_removal_codes_w_num_params( - num_params, - all_param_removal_codes, - param_names, - num_alpha_params, - ) - - #choose param_removal_codes with n_removed_params number of parameters removed that also contain non-zero elements from previous_param_removal_codes - param_removal_codes = [] - for previous_param_subset_mask in previous_param_subset_masks - push!( - param_removal_codes, - unique([ - param_code_w_num_params .* previous_param_subset_mask.mask .+ - previous_param_subset_mask.non_zero_params for - param_code_w_num_params in all_param_codes_w_num_params #if ( - # length(param_names) - num_alpha_params - sum( - # (param_code_w_num_params.*previous_param_subset_mask.mask.+previous_param_subset_mask.non_zero_params)[1:(end-num_alpha_params)] .> - # 0, - # ) - # ) == num_params - ])..., - ) + param_removal_code_names = keys(nt_previous_param_removal_codes[1]) + next_param_removal_codes = Vector{Vector{Int}}() + for previous_param_removal_code in nt_previous_param_removal_codes + i_cut_off = length(previous_param_removal_code) - num_alpha_params + for (i, code_element) in enumerate(previous_param_removal_code) + if i <= i_cut_off && code_element == 0 + if param_removal_code_names[i] == :L + feasible_param_subset_codes = [1] + elseif startswith(string(param_removal_code_names[i]), "Vmax_allo") + feasible_param_subset_codes = [1, 2] + elseif startswith(string(param_removal_code_names[i]), "K_allo") + feasible_param_subset_codes = [1, 2, 3] + elseif startswith(string(param_removal_code_names[i]), "K_") && + !startswith(string(param_removal_code_names[i]), "K_allo") && + length(split(string(param_removal_code_names[i]), "_")) == 2 + feasible_param_subset_codes = [1] + elseif startswith(string(param_removal_code_names[i]), "K_") && + !startswith(string(param_removal_code_names[i]), "K_allo") && + length(split(string(param_removal_code_names[i]), "_")) > 2 + feasible_param_subset_codes = [1, 2] + end + for code_element in feasible_param_subset_codes + next_param_removal_code = collect(Int, previous_param_removal_code) + next_param_removal_code[i] = code_element + push!(next_param_removal_codes, next_param_removal_code) + end + end + end end - nt_param_removal_codes = [ - NamedTuple{param_removal_code_names}(x) for - x in unique(param_removal_codes) if ( - length(param_names) - num_alpha_params - sum(x[1:(end-num_alpha_params)] .> 0) - ) == num_params - ] + nt_param_removal_codes = + [NamedTuple{param_removal_code_names}(x) for x in unique(next_param_removal_codes)] return nt_param_removal_codes end """ -Calculate `param_removal_codes` with `num_params` including zero term combinations for codes (excluding alpha terms) in each `previous_param_removal_codes` that has `num_params+1` +Use `nt_previous_param_removal_codes` to calculate `nt_next_param_removal_codes` that have one additional zero elements except for for elements <= `num_alpha_params` from the end """ function reverse_selection_next_param_removal_codes( - all_param_removal_codes, - previous_param_removal_codes, - num_params, - param_names, - param_removal_code_names, + nt_previous_param_removal_codes::Vector{T} where {T<:NamedTuple}, + num_alpha_params::Int, ) - - num_alpha_params = count(occursin.("alpha", string.([param_names...]))) - @assert all([ - ( - length(param_names) - num_alpha_params - - sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params - 1 - ) || ( - length(param_names) - num_alpha_params - - sum(param_removal_code[1:(end-num_alpha_params)] .> 0) == num_params - ) for param_removal_code in previous_param_removal_codes - ]) - previous_param_subset_masks = unique([ - ( - mask = [ - (previous_param_removal_code[1:(end-num_alpha_params)] .== 0)..., - zeros(Int64, num_alpha_params)..., - ], - non_zero_params = previous_param_removal_code .* - (previous_param_removal_code .!= 0), - ) for previous_param_removal_code in previous_param_removal_codes - ]) - - #select all codes that yield equations with `num_params` number of parameters - all_param_codes_w_num_params = calculate_all_parameter_removal_codes_w_num_params( - num_params, - all_param_removal_codes, - param_names, - num_alpha_params, - ) - - #choose param_removal_codes with n_removed_params number of parameters removed that also contain non-zero elements from previous_param_removal_codes - param_removal_codes = [] - for previous_param_subset_mask in previous_param_subset_masks - push!( - param_removal_codes, - unique([ - previous_param_subset_mask.non_zero_params .* - (param_code_w_num_params .!= 0) for - param_code_w_num_params in all_param_codes_w_num_params #if ( - # length(param_names) - num_alpha_params - sum( - # (previous_param_subset_mask.non_zero_params.*(param_code_w_num_params.!=0))[1:(end-num_alpha_params)] .> - # 0, - # ) - # ) == num_params - ])..., - ) + param_removal_code_names = keys(nt_previous_param_removal_codes[1]) + next_param_removal_codes = Vector{Vector{Int}}() + for previous_param_removal_code in nt_previous_param_removal_codes + i_cut_off = length(previous_param_removal_code) - num_alpha_params + for (i, code_element) in enumerate(previous_param_removal_code) + if i <= i_cut_off && code_element != 0 + next_param_removal_code = collect(Int, previous_param_removal_code) + next_param_removal_code[i] = 0 + push!(next_param_removal_codes, next_param_removal_code) + end + end end - nt_param_removal_codes = [ - NamedTuple{param_removal_code_names}(x) for - x in unique(param_removal_codes) if ( - length(param_names) - num_alpha_params - sum(x[1:(end-num_alpha_params)] .> 0) - ) == num_params - ] + nt_param_removal_codes = + [NamedTuple{param_removal_code_names}(x) for x in unique(next_param_removal_codes)] return nt_param_removal_codes end diff --git a/test/tests_for_optimal_rate_eq_selection.jl b/test/tests_for_optimal_rate_eq_selection.jl index 7b763d7..599eed0 100644 --- a/test/tests_for_optimal_rate_eq_selection.jl +++ b/test/tests_for_optimal_rate_eq_selection.jl @@ -17,37 +17,41 @@ param_names = ( :L, :Vmax_a, :Vmax_i, - [Symbol(:K_a, "_Metabolite$(i)") for i = 1:num_metabolites]..., - [Symbol(:K_i, "_Metabolite$(i)") for i = 1:num_metabolites]..., - [Symbol(:alpha, "_$(i)") for i = 1:n_alphas]..., + [Symbol(:K_a, "_Metabolite$(i)") for i in 1:num_metabolites]..., + [Symbol(:K_i, "_Metabolite$(i)") for i in 1:num_metabolites]..., + [Symbol(:alpha, "_$(i)") for i in 1:n_alphas]... ) param_removal_code_names = ( [ - Symbol(replace(string(param_name), "_a" => "")) for - param_name in param_names if !contains(string(param_name), "_i") + Symbol(replace(string(param_name), "_a_" => "_allo_", "Vmax_a" => "Vmax_allo")) for + param_name in param_names if + !contains(string(param_name), "_i") && param_name != :Vmax ]..., ) all_param_removal_codes = DataDrivenEnzymeRateEqs.calculate_all_parameter_removal_codes(param_names) -param_subset_codes_with_num_params = [ - x for x in all_param_removal_codes if - length(param_names) - sum(values(x[1:(end-n_alphas)]) .> 0) - n_alphas == - num_previous_step_params -] -previous_param_removal_codes = - [rand(param_subset_codes_with_num_params) for i = 1:rand(1:20)] +param_subset_codes_with_num_params = [x + for x in all_param_removal_codes + if + length(param_names) - + sum(values(x[1:(end-n_alphas)]) .> 0) - n_alphas == + num_previous_step_params] +previous_param_removal_codes = [rand(param_subset_codes_with_num_params) + for i in 1:rand(1:20)] +nt_previous_param_removal_codes = [NamedTuple{param_removal_code_names}(x) + for x in previous_param_removal_codes] +param_removal_code_names + nt_funct_output_param_subset_codes = DataDrivenEnzymeRateEqs.forward_selection_next_param_removal_codes( - all_param_removal_codes, - previous_param_removal_codes, - num_params, - param_names, - param_removal_code_names + nt_previous_param_removal_codes, + n_alphas ) funct_output_param_subset_codes = [values(nt) for nt in nt_funct_output_param_subset_codes] #ensure that funct_output_param_subset_codes have one less parameter than previous_param_removal_codes @test all( length(param_names) - n_alphas - sum(funct_output_param_subset_code[1:(end-n_alphas)] .> 0) == - (num_previous_step_params - 1) for + (num_previous_step_params - 1) + for funct_output_param_subset_code in funct_output_param_subset_codes ) #ensure that non-zero elements from previous_param_removal_codes are present in > 1 of the funct_output_param_subset_code but less than the max_matches @@ -68,10 +72,12 @@ for funct_output_param_subset_code in funct_output_param_subset_codes count = 0 max_matches_vect = Int[] for previous_param_removal_code in previous_param_removal_codes - count += - funct_output_param_subset_code[1:end-n_alphas] .* previous_param_removal_code[1:end-n_alphas] == - previous_param_removal_code[1:end-n_alphas] .^ 2 - push!(max_matches_vect, sum((previous_param_removal_code[1:end-n_alphas] .== 0) .* non_zero_code_combos_per_param[1:end-n_alphas])) + count += funct_output_param_subset_code[1:(end-n_alphas)] .* + previous_param_removal_code[1:(end-n_alphas)] == + previous_param_removal_code[1:(end-n_alphas)] .^ 2 + push!(max_matches_vect, + sum((previous_param_removal_code[1:(end-n_alphas)] .== 0) .* + non_zero_code_combos_per_param[1:(end-n_alphas)])) end max_matches = maximum(max_matches_vect) push!(count_matches, max_matches >= count > 0) @@ -93,8 +99,9 @@ param_names = ( ) param_removal_code_names = ( [ - Symbol(replace(string(param_name), "_a" => "")) for - param_name in param_names if !contains(string(param_name), "_i") + Symbol(replace(string(param_name), "_a_" => "_allo_")) for + param_name in param_names if + !contains(string(param_name), "_i") && param_name != :Vmax ]..., ) all_param_removal_codes = DataDrivenEnzymeRateEqs.calculate_all_parameter_removal_codes(param_names) @@ -105,23 +112,20 @@ param_subset_codes_with_num_params = [ ] previous_param_removal_codes = [rand(param_subset_codes_with_num_params) for i = 1:rand(1:20)] - +nt_previous_param_removal_codes = [NamedTuple{param_removal_code_names}(x) for x in previous_param_removal_codes] nt_funct_output_param_subset_codes = DataDrivenEnzymeRateEqs.reverse_selection_next_param_removal_codes( - all_param_removal_codes, - previous_param_removal_codes, - num_params, - param_names, - param_removal_code_names, + nt_previous_param_removal_codes, + n_alphas ) funct_output_param_subset_codes = [values(nt) for nt in nt_funct_output_param_subset_codes] -#ensure that funct_output_param_subset_codes have one more parameter than previous_param_removal_codes +#ensure that funct_output_param_subset_codes have one more parameter than nt_previous_param_removal_codes @test all( length(param_names) - n_alphas - sum(funct_output_param_subset_code[1:(end-n_alphas)] .> 0) == (num_previous_step_params + 1) for funct_output_param_subset_code in funct_output_param_subset_codes ) -#ensure that non-zero elements from each funct_output_param_subset_codes are present in > 1 of the previous_param_removal_codes but less than the max_matches +#ensure that non-zero elements from each funct_output_param_subset_codes are present in > 1 of the nt_previous_param_removal_codes but less than the max_matches count_matches = [] non_zero_code_combos_per_param = () for param_name in param_names @@ -137,7 +141,7 @@ for param_name in param_names end for funct_output_param_subset_code in funct_output_param_subset_codes count = 0 - for previous_param_removal_code in previous_param_removal_codes + for previous_param_removal_code in values.(nt_previous_param_removal_codes) count += funct_output_param_subset_code[1:end-n_alphas] == previous_param_removal_code[1:end-n_alphas] .* (funct_output_param_subset_code[1:end-n_alphas] .!= 0) end