-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_driven_rate_equation_selection.jl
501 lines (471 loc) · 22.7 KB
/
data_driven_rate_equation_selection.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
using Dates, CSV, DataFrames, Distributed
"""
data_driven_rate_equation_selection(
general_rate_equation::Function,
data::DataFrame,
metab_names::Tuple{Symbol,Vararg{Symbol}},
param_names::Tuple{Symbol,Vararg{Symbol}},
range_number_params::Tuple{Int,Int},
forward_model_selection::Bool;
save_train_results::Bool = false,
enzyme_name::String = "Enzyme",
)
This function is used to perform data-driven rate equation selection using a general rate equation and data. The function will select the best rate equation by iteratively removing parameters from the general rate equation and finding an equation that yield best test scores on data not used for fitting.
# Arguments
- `general_rate_equation::Function`: Function that takes a NamedTuple of metabolite concentrations (with `metab_names` keys) and parameters (with `param_names` keys) and returns an enzyme rate.
- `data::DataFrame`: DataFrame containing the data with column `Rate` and columns for each `metab_names` where each row is one measurement. It also needs to have a column `source` that contains a string that identifies the source of the data. This is used to calculate the weights for each figure in the publication.
- `metab_names::Tuple`: Tuple of metabolite names that correspond to the metabolites of `rate_equation` and column names in `data`.
- `param_names::Tuple`: Tuple of parameter names that correspond to the parameters of `rate_equation`.
- `range_number_params::Tuple{Int,Int}`: A tuple of integers representing the range of the number of parameters of general_rate_equation to search over.
- `forward_model_selection::Bool`: A boolean indicating whether to use forward model selection (true) or reverse model selection (false).
# Keyword Arguments
- `save_train_results::Bool`: A boolean indicating whether to save the results of the training for each number of parameters as a csv file.
- `enzyme_name::String`: A string for enzyme name that is used to name the csv files that are saved.
# Returns train_results, test_results and list of practically_unidentifiable_params and optionally saves a csv file for each `num_params` with the results of the training for each combination of parameters tested and a csv file with test results for top 10% of the best results with each number of parameters tested.
"""
function data_driven_rate_equation_selection(
general_rate_equation::Function,
data::DataFrame,
metab_names::Tuple{Symbol,Vararg{Symbol}},
param_names::Tuple{Symbol,Vararg{Symbol}},
range_number_params::Tuple{Int,Int},
forward_model_selection::Bool;
save_train_results::Bool = false,
enzyme_name::String = "Enzyme",
)
#generate param_removal_code_names by converting each mirror parameter for a and i into one name
#(e.g. K_a_Metabolite1 and K_i_Metabolite1 into K_allo_Metabolite1)
param_removal_code_names = (
[
Symbol(replace(string(param_name), "_a_" => "_allo_", "Vmax_a" => "Vmax_allo")) for param_name in param_names if
!contains(string(param_name), "_i") && param_name != :Vmax
]...,
)
num_alpha_params = count(occursin.("alpha", string.([param_names...])))
#check that range_number_params within bounds of minimal and maximal number of parameters
@assert range_number_params[1] >= length(param_names) - length(param_removal_code_names) "starting range_number_params cannot be below $(length(param_names) - length(param_removal_code_names))"
@assert range_number_params[2] <= length(param_names) "ending range_number_params cannot be above $(length(param_names))"
if forward_model_selection
num_param_range = (range_number_params[2]):-1:range_number_params[1]
elseif !forward_model_selection
num_param_range = (range_number_params[1]):1:range_number_params[2]
end
#calculate starting_param_removal_codes parameters
practically_unidentifiable_params =
find_practically_unidentifiable_params(data, param_names)
all_param_removal_codes = calculate_all_parameter_removal_codes(
param_names,
practically_unidentifiable_params,
)
starting_param_removal_codes = calculate_all_parameter_removal_codes_w_num_params(
num_param_range[1],
all_param_removal_codes,
param_names,
param_removal_code_names,
metab_names,
num_alpha_params,
)
if isempty(starting_param_removal_codes)
@error "Equations for this enzymes with $(num_param_range[1]) parameters do not exist. One reason could be that some parameters are unidentifiable based on the data so upper bound of range_number_params should be reduced. Check number of practically_unidentifiable_params with find_practically_unidentifiable_params(data, param_names)."
return
end
nt_param_removal_codes = starting_param_removal_codes
nt_previous_param_removal_codes = similar(nt_param_removal_codes)
println("About to start loop with num_params: $num_param_range")
df_train_results = DataFrame()
df_test_results = DataFrame()
for num_params in num_param_range
println("Running loop with num_params: $num_params")
#calculate param_removal_codes for `num_params` given `all_param_removal_codes` and fixed params from previous `num_params`
if num_params != num_param_range[1]
if forward_model_selection
nt_param_removal_codes = forward_selection_next_param_removal_codes(
nt_previous_param_removal_codes,
num_alpha_params,
)
elseif !forward_model_selection
nt_param_removal_codes = reverse_selection_next_param_removal_codes(
nt_previous_param_removal_codes,
num_alpha_params,
)
end
end
#pmap over nt_param_removal_codes for a given `num_params` return rescaled and nt_param_subset added
results_array = pmap(
nt_param_removal_code -> train_rate_equation(
general_rate_equation,
data,
metab_names,
param_names;
n_iter = 20,
nt_param_removal_code = nt_param_removal_code,
),
nt_param_removal_codes,
)
#convert results_array to DataFrame
df_results = DataFrame(results_array)
df_results.num_params = fill(num_params, nrow(df_results))
df_results.nt_param_removal_codes = nt_param_removal_codes
df_train_results = vcat(df_train_results, df_results)
# Optinally consider saving results to csv file for long running calculation of cluster
if save_train_results
CSV.write(
"$(Dates.format(now(),"mmddyy"))_$(enzyme_name)_$(forward_model_selection ? "forward" : "reverse")_model_select_results_$(num_params)_num_params.csv",
df_results,
)
end
#if all train_loss are Inf, then skip to next loop
if all(df_results.train_loss .== Inf)
nt_previous_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in values.(df_results.nt_param_removal_codes)
]
continue
end
#store top 10% for next loop as `previous_param_removal_codes`
filter!(row -> row.train_loss < 1.1 * minimum(df_results.train_loss), df_results)
nt_previous_param_removal_codes = [
NamedTuple{param_removal_code_names}(x) for
x in values.(df_results.nt_param_removal_codes)
]
#calculate loocv test loss for top subset for each `num_params`
best_nt_param_removal_code =
df_results.nt_param_removal_codes[argmin(df_results.train_loss)]
test_results = pmap(
removed_fig -> loocv_rate_equation(
removed_fig,
general_rate_equation,
data,
metab_names,
param_names;
n_iter = 20,
nt_param_removal_code = best_nt_param_removal_code,
),
unique(data.source),
)
df_results = DataFrame(test_results)
df_results.num_params = fill(num_params, nrow(df_results))
df_results.nt_param_removal_codes =
fill(best_nt_param_removal_code, nrow(df_results))
df_test_results = vcat(df_test_results, df_results)
end
return (
train_results = df_train_results,
test_results = df_test_results,
practically_unidentifiable_params = practically_unidentifiable_params,
)
end
"function to calculate train loss without a figure and test loss on removed figure"
function loocv_rate_equation(
fig,
rate_equation::Function,
data::DataFrame,
metab_names::Tuple{Symbol,Vararg{Symbol}},
param_names::Tuple{Symbol,Vararg{Symbol}};
n_iter = 20,
nt_param_removal_code = nothing,
)
# Drop selected figure from data
train_data = data[data.source.!=fig, :]
test_data = data[data.source.==fig, :]
# Calculate fit
train_res = train_rate_equation(
rate_equation,
train_data,
metab_names,
param_names;
n_iter = n_iter,
nt_param_removal_code = nt_param_removal_code,
)
test_loss = test_rate_equation(
rate_equation,
test_data,
train_res.params,
metab_names,
param_names,
)
return (
dropped_fig = fig,
train_loss_wo_fig = train_res.train_loss,
test_loss_leftout_fig = test_loss,
params = train_res.params,
)
end
"""Function to calculate loss for a given `rate_equation` and `nt_fitted_params` on `data` that was not used for training"""
function test_rate_equation(
rate_equation::Function,
data::DataFrame,
nt_fitted_params::NamedTuple,
metab_names::Tuple{Symbol,Vararg{Symbol}},
param_names::Tuple{Symbol,Vararg{Symbol}},
)
filtered_data = data[.!isnan.(data.Rate), [:Rate, metab_names..., :source]]
#Only include Rate > 0 because otherwise log_ratio_predict_vs_data() will have to divide by 0
filter!(row -> row.Rate != 0, filtered_data)
# Add a new column to data to assign an integer to each source/figure from publication
filtered_data.fig_num = vcat(
[
i * ones(
Int64,
count(==(unique(filtered_data.source)[i]), filtered_data.source),
) for i = 1:length(unique(filtered_data.source))
]...,
)
# Add a column containing indexes of points corresponding to each figure
fig_point_indexes =
[findall(filtered_data.fig_num .== i) for i in unique(filtered_data.fig_num)]
# Convert DF to NamedTuple for better type stability / speed
rate_data_nt = Tables.columntable(filtered_data)
fitted_params = values(nt_fitted_params)
test_loss = loss_rate_equation(
fitted_params,
rate_equation::Function,
rate_data_nt::NamedTuple,
param_names::Tuple{Symbol,Vararg{Symbol}},
fig_point_indexes::Vector{Vector{Int64}};
rescale_params_from_0_10_scale = false,
nt_param_removal_code = nothing,
)
return test_loss
end
"""Generate all possibles codes for ways that params can be removed from the rate equation"""
function calculate_all_parameter_removal_codes(
param_names::Tuple{Symbol,Vararg{Symbol}},
practically_unidentifiable_params::Tuple{Vararg{Symbol}},
)
feasible_param_subset_codes = ()
for param_name in param_names
if param_name == :L
feasible_param_subset_codes = (feasible_param_subset_codes..., [0, 1])
elseif startswith(string(param_name), "Vmax_a")
feasible_param_subset_codes = (feasible_param_subset_codes..., [0, 1, 2])
elseif startswith(string(param_name), "K_a")
feasible_param_subset_codes = (feasible_param_subset_codes..., [0, 1, 2, 3])
elseif startswith(string(param_name), "K_") &&
!startswith(string(param_name), "K_i") &&
!startswith(string(param_name), "K_a") &&
length(split(string(param_name), "_")) == 2
feasible_param_subset_codes = (feasible_param_subset_codes..., [0, 1])
elseif startswith(string(param_name), "K_") &&
!startswith(string(param_name), "K_i") &&
!startswith(string(param_name), "K_a") &&
length(split(string(param_name), "_")) > 2
if param_name in practically_unidentifiable_params
feasible_param_subset_codes = (feasible_param_subset_codes..., [1])
else
feasible_param_subset_codes = (feasible_param_subset_codes..., [0, 1, 2])
end
elseif startswith(string(param_name), "alpha")
if param_name in practically_unidentifiable_params
feasible_param_subset_codes = (feasible_param_subset_codes..., [0])
else
feasible_param_subset_codes = (feasible_param_subset_codes..., [0, 1])
end
end
end
# return collect(Iterators.product(feasible_param_subset_codes...))
return Iterators.product(feasible_param_subset_codes...)
end
"""Find parameters that cannot be identified based on data and they are in front of products of metabolites concentrations that are always zero as these combinations of metabolites are absent in the data."""
function find_practically_unidentifiable_params(
data::DataFrame,
param_names::Tuple{Symbol,Vararg{Symbol}},
)
practically_unidentifiable_params = []
for param_name in param_names
if startswith(string(param_name), "K_") &&
!startswith(string(param_name), "K_i") &&
!startswith(string(param_name), "K_a") &&
length(split(string(param_name), "_")) > 3
if all([
prod(row) == 0 for
row in eachrow(data[:, Symbol.(split(string(param_name), "_")[2:end])])
])
push!(practically_unidentifiable_params, param_name)
end
elseif startswith(string(param_name), "alpha_")
metabs_in_param_name = Symbol.(split(string(param_name), "_")[2:3])
if all([prod(row) == 0 for row in eachrow(data[:, metabs_in_param_name])])
push!(practically_unidentifiable_params, param_name)
end
end
end
return Tuple(practically_unidentifiable_params)
end
"""Generate NamedTuple of codes for ways that params can be removed from the rate equation but still leave `num_params`"""
function calculate_all_parameter_removal_codes_w_num_params(
num_params::Int,
all_param_removal_codes,
param_names::Tuple{Symbol,Vararg{Symbol}},
param_removal_code_names::Tuple{Symbol,Vararg{Symbol}},
metab_names::Tuple{Symbol,Vararg{Symbol}},
num_alpha_params::Int,
)
codes_with_num_params = Tuple[]
num_non_zero_in_each_code = Int[]
for code in all_param_removal_codes
sum_non_zero = 0
for i = 1:(length(code)-num_alpha_params)
if code[i] > 0
sum_non_zero += 1
end
end
push!(num_non_zero_in_each_code, sum_non_zero)
end
num_params_in_each_code =
length(param_names) .- num_alpha_params .- num_non_zero_in_each_code
for (i, code) in enumerate(all_param_removal_codes)
if num_params_in_each_code[i] == num_params
push!(codes_with_num_params, code)
end
end
nt_param_removal_codes =
[NamedTuple{param_removal_code_names}(x) for x in unique(codes_with_num_params)]
# ensure that if K_S1 = Inf then all K_S1_S2 and all other K containing S1 in qssa cannot be 2, which stands for (K_S1_S2)^2 = K_S1 * K_S2
if any([occursin("allo", string(key)) for key in keys(nt_param_removal_codes[1])])
filtered_nt_param_removal_codes = nt_param_removal_codes
else
filtered_nt_param_removal_codes = NamedTuple[]
for nt_param_removal_code in nt_param_removal_codes
if all(
nt_param_removal_code[Symbol("K_" * string(metab))] != 1 for
metab in metab_names
)
push!(filtered_nt_param_removal_codes, nt_param_removal_code)
else
one_metab_codes = metab_names[findall(
nt_param_removal_code[Symbol("K_" * string(metab))] == 1 for
metab in metab_names
)]
if all(
nt_param_removal_code[param_name] != 2 for
param_name in keys(nt_param_removal_code) if
any(occursin.(string.(one_metab_codes), string(param_name)))
)
push!(filtered_nt_param_removal_codes, nt_param_removal_code)
end
end
end
end
return filtered_nt_param_removal_codes
end
"""
Function to convert parameter vector to vector where some params are equal to 0, Inf or each other based on nt_param_removal_code
"""
function param_subset_select(
params,
param_names::Tuple{Symbol,Vararg{Symbol}},
nt_param_removal_code::T where {T<:NamedTuple},
)
@assert length(params) == length(param_names)
params_dict =
Dict(param_name => params[i] for (i, param_name) in enumerate(param_names))
for param_choice in keys(nt_param_removal_code)
if startswith(string(param_choice), "L") && nt_param_removal_code[param_choice] == 1
params_dict[:L] = 0.0
elseif startswith(string(param_choice), "Vmax_allo") &&
nt_param_removal_code[param_choice] == 1
params_dict[:Vmax_i] = params_dict[:Vmax_a]
elseif startswith(string(param_choice), "Vmax_allo") &&
nt_param_removal_code[param_choice] == 2
global params_dict[:Vmax_i] = 0.0
elseif startswith(string(param_choice), "K_allo") &&
nt_param_removal_code[param_choice] == 1
K_i = Symbol("K_i_" * string(param_choice)[8:end])
K_a = Symbol("K_a_" * string(param_choice)[8:end])
params_dict[K_i] = params_dict[K_a]
elseif startswith(string(param_choice), "K_allo") &&
nt_param_removal_code[param_choice] == 2
K_a = Symbol("K_a_" * string(param_choice)[8:end])
params_dict[K_a] = Inf
elseif startswith(string(param_choice), "K_allo") &&
nt_param_removal_code[param_choice] == 3
K_i = Symbol("K_i_" * string(param_choice)[8:end])
params_dict[K_i] = Inf
elseif startswith(string(param_choice), "K_") &&
!startswith(string(param_choice), "K_allo") &&
nt_param_removal_code[param_choice] == 1
params_dict[param_choice] = Inf
elseif startswith(string(param_choice), "K_") &&
!startswith(string(param_choice), "K_allo") &&
length(split(string(param_choice), "_")) > 2 &&
nt_param_removal_code[param_choice] == 2
params_dict[param_choice] =
prod([
params_dict[Symbol("K_" * string(metab))] for
metab in split(string(param_choice), "_")[2:end]
])^(1 / (length(split(string(param_choice), "_")[2:end])))
elseif startswith(string(param_choice), "alpha") &&
nt_param_removal_code[param_choice] == 0
params_dict[param_choice] = 0.0
elseif startswith(string(param_choice), "alpha") &&
nt_param_removal_code[param_choice] == 1
params_dict[param_choice] = 1.0
end
end
new_params_sorted = [params_dict[param_name] for param_name in param_names]
return new_params_sorted
end
"""
Calculate `nt_param_removal_codes` with `num_params` including non-zero term combinations for codes (excluding alpha terms) in each `nt_previous_param_removal_codes` that has `num_params-1`
"""
function forward_selection_next_param_removal_codes(
nt_previous_param_removal_codes::Vector{T} where {T<:NamedTuple},
num_alpha_params::Int,
)
param_removal_code_names = keys(nt_previous_param_removal_codes[1])
next_param_removal_codes = Vector{Vector{Int}}()
for previous_param_removal_code in nt_previous_param_removal_codes
i_cut_off = length(previous_param_removal_code) - num_alpha_params
for (i, code_element) in enumerate(previous_param_removal_code)
if i <= i_cut_off && code_element == 0
if param_removal_code_names[i] == :L
feasible_param_subset_codes = [1]
elseif startswith(string(param_removal_code_names[i]), "Vmax_allo")
feasible_param_subset_codes = [1, 2]
elseif startswith(string(param_removal_code_names[i]), "K_allo")
feasible_param_subset_codes = [1, 2, 3]
elseif startswith(string(param_removal_code_names[i]), "K_") &&
!startswith(string(param_removal_code_names[i]), "K_allo") &&
length(split(string(param_removal_code_names[i]), "_")) == 2
feasible_param_subset_codes = [1]
elseif startswith(string(param_removal_code_names[i]), "K_") &&
!startswith(string(param_removal_code_names[i]), "K_allo") &&
length(split(string(param_removal_code_names[i]), "_")) > 2
feasible_param_subset_codes = [1, 2]
end
for code_element in feasible_param_subset_codes
next_param_removal_code = collect(Int, previous_param_removal_code)
next_param_removal_code[i] = code_element
push!(next_param_removal_codes, next_param_removal_code)
end
end
end
end
nt_param_removal_codes =
[NamedTuple{param_removal_code_names}(x) for x in unique(next_param_removal_codes)]
return nt_param_removal_codes
end
"""
Use `nt_previous_param_removal_codes` to calculate `nt_next_param_removal_codes` that have one additional zero elements except for for elements <= `num_alpha_params` from the end
"""
function reverse_selection_next_param_removal_codes(
nt_previous_param_removal_codes::Vector{T} where {T<:NamedTuple},
num_alpha_params::Int,
)
param_removal_code_names = keys(nt_previous_param_removal_codes[1])
next_param_removal_codes = Vector{Vector{Int}}()
for previous_param_removal_code in nt_previous_param_removal_codes
i_cut_off = length(previous_param_removal_code) - num_alpha_params
for (i, code_element) in enumerate(previous_param_removal_code)
if i <= i_cut_off && code_element != 0
next_param_removal_code = collect(Int, previous_param_removal_code)
next_param_removal_code[i] = 0
push!(next_param_removal_codes, next_param_removal_code)
end
end
end
nt_param_removal_codes =
[NamedTuple{param_removal_code_names}(x) for x in unique(next_param_removal_codes)]
return nt_param_removal_codes
end