From d030846b99c99a19796cf666bd0b07b96609d5a6 Mon Sep 17 00:00:00 2001 From: Denis-Titov Date: Thu, 7 Mar 2024 20:26:20 -0800 Subject: [PATCH] more edits to make data_driven_rate_equation_selection work properly --- src/data_driven_rate_equation_selection.jl | 94 +++++++++++++++----- src/rate_equation_fitting.jl | 44 ++++----- test/tests_for_general_rate_eq_derivation.jl | 1 - test/tests_for_optimal_rate_eq_selection.jl | 44 +++++++-- test/tests_for_rate_eq_fitting.jl | 36 ++++++-- 5 files changed, 159 insertions(+), 60 deletions(-) diff --git a/src/data_driven_rate_equation_selection.jl b/src/data_driven_rate_equation_selection.jl index 7e58d23..efa63fa 100644 --- a/src/data_driven_rate_equation_selection.jl +++ b/src/data_driven_rate_equation_selection.jl @@ -36,6 +36,7 @@ function data_driven_rate_equation_selection( (1 + sum([occursin("K_a_", string(param_name)) for param_name in param_names])) @assert range_number_params[2] <= length(param_names) println("Past assertions") + #generate param_removal_code_names by converting each mirror parameter for a and i into one name #(e.g. K_a_Metabolite1 and K_i_Metabolite1 into K_Metabolite1) param_removal_code_names = ( @@ -67,6 +68,8 @@ function data_driven_rate_equation_selection( previous_param_removal_codes = starting_param_removal_codes println("About to start loop with num_params: $num_param_range") + df_train_results = DataFrame() + df_test_results = DataFrame() for num_params in num_param_range println("Running loop with num_params: $num_params") @@ -102,10 +105,13 @@ function data_driven_rate_equation_selection( nt_param_removal_codes, ) - #convert results_array to DataFrame and save in csv file + #convert results_array to DataFrame df_results = DataFrame(results_array) + df_results.num_params = fill(num_params, nrow(df_results)) df_results.nt_param_removal_codes = nt_param_removal_codes - df_results + df_train_results = vcat(df_train_results, df_results) + + # Optinally consider saving results to csv file for long running calculation of cluster # CSV.write( # "$(Dates.format(now(),"mmddyy"))_$(forward_model_selection ? "forward" : "reverse")_model_select_results_$(num_params)_num_params.csv", # df_results, @@ -114,26 +120,68 @@ function data_driven_rate_equation_selection( filter!(row -> row.train_loss < 1.1 * minimum(df_results.train_loss), df_results) previous_param_removal_codes = values.(df_results.nt_param_removal_codes) - #calculate test loss for top 10% subsets for each `num_params` - #TODO: loop over all figures and calculate test loss for each figure - #TODO: consider looping over all figures and calculating test loss separately from train loss calculations + #calculate loocv test loss for top subset for each `num_params` #TODO: change to pmap - test_loss = map( - nt_fitted_params -> test_rate_equation( + best_nt_param_removal_code = + df_results.nt_param_removal_codes[argmin(df_results.train_loss)] + test_results = map( + removed_fig -> loocv_rate_equation( + removed_fig, general_rate_equation, data, - nt_fitted_params, metab_names, - param_names, + param_names; + n_iter = 20, + nt_param_removal_code = best_nt_param_removal_code, ), - df_results.params, + unique(data.source), ) - #store rescaled results + df_results = DataFrame(test_results) + df_results.num_params = fill(num_params, nrow(df_results)) + df_results.nt_param_removal_codes = + fill(best_nt_param_removal_code, nrow(df_results)) + df_test_results = vcat(df_test_results, df_results) end - println("Finished loop") - #return train loss and params for all tested subsets, test loss for all tested subsets + return (train_results = df_train_results, test_results = df_test_results) +end + +"function to calculate train loss without a figure and test loss on removed figure" +function loocv_rate_equation( + fig, + rate_equation::Function, + data::DataFrame, + metab_names::Tuple{Symbol,Vararg{Symbol}}, + param_names::Tuple{Symbol,Vararg{Symbol}}; + n_iter = 20, + nt_param_removal_code = nothing, +) + # Drop selected figure from data + train_data = data[data.source.!=fig, :] + test_data = data[data.source.==fig, :] + # Calculate fit + train_res = train_rate_equation( + rate_equation, + train_data, + metab_names, + param_names; + n_iter = n_iter, + nt_param_removal_code = nt_param_removal_code, + ) + test_loss = test_rate_equation( + rate_equation, + test_data, + train_res.params, + metab_names, + param_names, + ) + return ( + fig = fig, + train_loss = train_res.train_loss, + test_loss = test_loss, + params = train_res.params, + ) end """Function to calculate loss for a given `rate_equation` and `nt_fitted_params` on `data` that was not used for training""" @@ -144,20 +192,24 @@ function test_rate_equation( metab_names::Tuple{Symbol,Vararg{Symbol}}, param_names::Tuple{Symbol,Vararg{Symbol}}, ) + filtered_data = data[.!isnan.(data.Rate), [:Rate, metab_names..., :source]] + #Only include Rate > 0 because otherwise log_ratio_predict_vs_data() will have to divide by 0 + filter!(row -> row.Rate != 0, filtered_data) # Add a new column to data to assign an integer to each source/figure from publication - data.fig_num = vcat( + filtered_data.fig_num = vcat( [ - i * ones(Int64, count(==(unique(data.source)[i]), data.source)) for - i = 1:length(unique(data.source)) + i * ones( + Int64, + count(==(unique(filtered_data.source)[i]), filtered_data.source), + ) for i = 1:length(unique(filtered_data.source)) ]..., ) - + # Add a column containing indexes of points corresponding to each figure + fig_point_indexes = + [findall(filtered_data.fig_num .== i) for i in unique(filtered_data.fig_num)] # Convert DF to NamedTuple for better type stability / speed - rate_data_nt = - Tables.columntable(data[.!isnan.(data.Rate), [:Rate, metab_names..., :fig_num]]) + rate_data_nt = Tables.columntable(filtered_data) - # Make a vector containing indexes of points corresponding to each figure - fig_point_indexes = [findall(data.fig_num .== i) for i in unique(data.fig_num)] fitted_params = values(nt_fitted_params) test_loss = loss_rate_equation( fitted_params, diff --git a/src/rate_equation_fitting.jl b/src/rate_equation_fitting.jl index af43c24..cffbd9f 100644 --- a/src/rate_equation_fitting.jl +++ b/src/rate_equation_fitting.jl @@ -10,8 +10,8 @@ using CMAEvolutionStrategy, DataFrames, Statistics fit_rate_equation( rate_equation::Function, data::DataFrame, - metab_names::Tuple, - param_names::Tuple; + metab_names::Tuple{Symbol, Vararg{Symbol}}, + param_names::Tuple{Symbol, Vararg{Symbol}}; n_iter = 20 ) @@ -20,8 +20,8 @@ Fit `rate_equation` to `data` and return loss and best fit parameters. # Arguments - `rate_equation::Function`: Function that takes a NamedTuple of metabolite concentrations (with `metab_names` keys) and parameters (with `param_names` keys) and returns an enzyme rate. - `data::DataFrame`: DataFrame containing the data with column `Rate` and columns for each `metab_names` where each row is one measurement. It also needs to have a column `source` that contains a string that identifies the source of the data. This is used to calculate the weights for each figure in the publication. -- `metab_names::Tuple`: Tuple of metabolite names that correspond to the metabolites of `rate_equation` and column names in `data`. -- `param_names::Tuple`: Tuple of parameter names that correspond to the parameters of `rate_equation`. +- `metab_names::Tuple{Symbol, Vararg{Symbol}}`: Tuple of metabolite names that correspond to the metabolites of `rate_equation` and column names in `data`. +- `param_names::Tuple{Symbol, Vararg{Symbol}}`: Tuple of parameter names that correspond to the parameters of `rate_equation`. - `n_iter::Int`: Number of iterations to run the fitting process. # Returns @@ -43,15 +43,15 @@ fit_rate_equation(rate_equation, data, (:A,), (:Vmax, :K_S)) function fit_rate_equation( rate_equation::Function, data::DataFrame, - metab_names::Tuple, - param_names::Tuple; + metab_names::Tuple{Symbol, Vararg{Symbol}}, + param_names::Tuple{Symbol, Vararg{Symbol}}; n_iter = 20, ) train_results = train_rate_equation( rate_equation::Function, data::DataFrame, - metab_names::Tuple, - param_names::Tuple; + metab_names::Tuple{Symbol, Vararg{Symbol}}, + param_names::Tuple{Symbol, Vararg{Symbol}}; n_iter = n_iter, nt_param_removal_code = nothing, ) @@ -63,26 +63,26 @@ end function train_rate_equation( rate_equation::Function, data::DataFrame, - metab_names::Tuple, - param_names::Tuple; + metab_names::Tuple{Symbol, Vararg{Symbol}}, + param_names::Tuple{Symbol, Vararg{Symbol}}; n_iter = 20, nt_param_removal_code = nothing, ) + filtered_data = data[.!isnan.(data.Rate), [:Rate, metab_names..., :source]] + #Only include Rate > 0 because otherwise log_ratio_predict_vs_data() will have to divide by 0 + filter!(row -> row.Rate != 0, filtered_data) # Add a new column to data to assign an integer to each source/figure from publication - data.fig_num = vcat( + filtered_data.fig_num = vcat( [ - i * ones(Int64, count(==(unique(data.source)[i]), data.source)) for - i = 1:length(unique(data.source)) + i * ones(Int64, count(==(unique(filtered_data.source)[i]), filtered_data.source)) for + i = 1:length(unique(filtered_data.source)) ]..., ) - + # Add a column containing indexes of points corresponding to each figure + fig_point_indexes = + [findall(filtered_data.fig_num .== i) for i in unique(filtered_data.fig_num)] # Convert DF to NamedTuple for better type stability / speed - #TODO: add fig_point_indexes to rate_data_nt to avoid passing it as an argument to loss_rate_equation - rate_data_nt = - Tables.columntable(data[.!isnan.(data.Rate), [:Rate, metab_names..., :fig_num]]) - - # Make a vector containing indexes of points corresponding to each figure - fig_point_indexes = [findall(data.fig_num .== i) for i in unique(data.fig_num)] + rate_data_nt = Tables.columntable(filtered_data) # Check if nt_param_removal_code makes loss returns NaN and abort early if it does. The latter # could happens due to nt_param_removal_code making params=Inf @@ -97,7 +97,7 @@ function train_rate_equation( nt_param_removal_code = nt_param_removal_code, ), ) - @warn "Loss returns NaN for this param combo in train_rate_equation() before minimization" + # @warn "Loss returns NaN for this param combo in train_rate_equation() before minimization" return ( train_loss = Inf, params = NamedTuple{param_names}(Tuple(fill(NaN, length(param_names)))), @@ -187,7 +187,7 @@ function loss_rate_equation( params, rate_equation::Function, rate_data_nt::NamedTuple, - param_names, + param_names::Tuple{Symbol, Vararg{Symbol}}, fig_point_indexes::Vector{Vector{Int}}; rescale_params_from_0_10_scale = true, nt_param_removal_code = nothing, diff --git a/test/tests_for_general_rate_eq_derivation.jl b/test/tests_for_general_rate_eq_derivation.jl index 93c6026..9f0fb52 100644 --- a/test/tests_for_general_rate_eq_derivation.jl +++ b/test/tests_for_general_rate_eq_derivation.jl @@ -5,7 +5,6 @@ using DataDrivenEnzymeRateEqs, Test, BenchmarkTools ## -#TODO: test random inputs into @derive_general_mwc_rate_eq and make sure the resulting rate equation, param_names and metab_names always work together substrates = [Symbol(:S, i) for i in 1:rand(1:4)] products = [Symbol(:P, i) for i in 1:rand(1:4)] regulators = [Symbol(:R, i) for i in 1:rand(1:9)] diff --git a/test/tests_for_optimal_rate_eq_selection.jl b/test/tests_for_optimal_rate_eq_selection.jl index 5805b81..8e8f252 100644 --- a/test/tests_for_optimal_rate_eq_selection.jl +++ b/test/tests_for_optimal_rate_eq_selection.jl @@ -145,19 +145,41 @@ end @test all(count_matches) -# # -# #test the ability of `data_driven_rate_equation_selection` to recover the rate_equation and params used to generated data for an arbitrary enzyme -data_gen_rate_equation(metabs, params) = params.Vmax * (metabs.S / params.K_S - metabs.P / params.K_P) / (1 + metabs.S / params.K_S + metabs.P / params.K_P) +## +#test the ability of `data_driven_rate_equation_selection` to recover the rate_equation and params used to generated data for an arbitrary enzyme +data_gen_rate_equation_Keq = 1.0 +data_gen_rate_equation(metabs, params) = params.Vmax * (metabs.S / params.K_S - (1 / data_gen_rate_equation_Keq) * metabs.P / params.K_P) / (1 + metabs.S / params.K_S + metabs.P / params.K_P) param_names = (:Vmax, :K_S, :K_P) metab_names = (:S, :P) -params = (Vmax=10.0, K_S=1.0, K_P=5.0) -data = DataFrame(S=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]) -data.P = [0.0 for row in eachrow(data)] +params = (Vmax=10.0, K_S=1e-3, K_P=5e-3) +#create DataFrame of simulated data +num_datapoints = 10 +num_figures = 4 +S_concs = Float64[] +P_concs = Float64[] +sources = String[] + +for i in 1:num_figures + if i < num_figures ÷ 2 + for S in range(0, rand(1:10) * params.K_S, rand(num_datapoints ÷ 2 : num_datapoints * 2)) + push!(S_concs, S) + push!(P_concs, 0.0) + push!(sources, "Figure$i") + end + else + for P in range(0, rand(1:10) * params.K_P, rand(num_datapoints ÷ 2 : num_datapoints * 2)) + push!(S_concs, 0.0) + push!(P_concs, P) + push!(sources, "Figure$i") + end + end +end +data = DataFrame(S=S_concs, P=P_concs, source=sources) noise_sd = 0.2 data.Rate = [data_gen_rate_equation(row, params) * (1 + noise_sd * randn()) for row in eachrow(data)] -data.source = ["Figure1" for i in 1:nrow(data)] +data + fit_result = fit_rate_equation(data_gen_rate_equation, data, metab_names, param_names; n_iter=20) -@test isapprox(fit_result.params.K_S, params.K_S, rtol=3 * noise_sd) enzyme_parameters = (; substrates=[:S,], products=[:P], cat1=[:S, :P], reg1=[], reg2=[], Keq=1.0, oligomeric_state=1, rate_equation_name=:derived_rate_equation) metab_names, param_names = @derive_general_mwc_rate_eq(enzyme_parameters) @@ -165,4 +187,8 @@ nt_params = NamedTuple{param_names}(rand(length(param_names))) nt_metabs = NamedTuple{metab_names}(rand(length(metab_names))) derived_rate_equation(nt_metabs, nt_params) = derived_rate_equation(nt_metabs, nt_params, enzyme_parameters.Keq) fit_result = fit_rate_equation(derived_rate_equation, data, metab_names, param_names; n_iter=20) -selection_result = data_driven_rate_equation_selection(derived_rate_equation, data, metab_names, param_names, (3, 7), true) +selection_result = @time data_driven_rate_equation_selection(derived_rate_equation, data, metab_names, param_names, (3, 7), true) + +for n in unique(selection_result.test_results.num_params) + println("for $n param, mean(test_losses) = $(mean(selection_result.test_results[selection_result.test_results.num_params .== n, :test_loss]))") +end diff --git a/test/tests_for_rate_eq_fitting.jl b/test/tests_for_rate_eq_fitting.jl index 572701f..8a6d20c 100644 --- a/test/tests_for_rate_eq_fitting.jl +++ b/test/tests_for_rate_eq_fitting.jl @@ -77,14 +77,36 @@ TODO: delete PKM2 example above after making below to use more complex test_rate randomly generated parameters and data around K values. Add an option to fit real Vmax values instead of fixing Vmax=1.0. =# -test_rate_equation(metabs, params) = params.Vmax * (metabs.S / params.K_S) / (1 + metabs.S / params.K_S) - -param_names = (:Vmax, :K_S) -metab_names = (:S,) -params = (Vmax=10.0, K_S=1.0) -data = DataFrame(S=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]) +test_rate_equation_Keq = 1.0 +test_rate_equation(metabs, params) = params.Vmax * (metabs.S / params.K_S - (1 / test_rate_equation_Keq) * metabs.P / params.K_P) / (1 + metabs.S / params.K_S + metabs.P / params.K_P) +param_names = (:Vmax, :K_S, :K_P) +metab_names = (:S, :P) +params = (Vmax=10.0, K_S=1.0, K_P=5.0) +#create DataFrame of simulated data +num_datapoints = 10 +num_figures = 8 +S_concs = Float64[] +P_concs = Float64[] +sources = String[] +for i in 1:num_figures + if i < num_figures ÷ 2 + for S in range(0, 10 * params.K_S, num_datapoints) + push!(S_concs, S) + push!(P_concs, 0.0) + push!(sources, "Figure$i") + end + else + for P in range(0, 10 * params.K_P, num_datapoints) + push!(S_concs, 0.0) + push!(P_concs, P) + push!(sources, "Figure$i") + end + end +end +data = DataFrame(S=S_concs, P=P_concs, source=sources) noise_sd = 0.2 data.Rate = [test_rate_equation(row, params) * (1 + noise_sd * randn()) for row in eachrow(data)] -data.source = ["Figure1" for i in 1:nrow(data)] fit_result = fit_rate_equation(test_rate_equation, data, metab_names, param_names; n_iter=20) + @test isapprox(fit_result.params.K_S, params.K_S, rtol=3 * noise_sd) +@test isapprox(fit_result.params.K_P, params.K_P, rtol=3 * noise_sd)