-
Notifications
You must be signed in to change notification settings - Fork 0
/
rate_equation_fitting.jl
242 lines (227 loc) · 9.49 KB
/
rate_equation_fitting.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
#=
CODE FOR RATE EQUATION FITTING
=#
using CMAEvolutionStrategy, DataFrames, Statistics
#TODO; add optimization_kwargs and use Optimization.jl
#TODO: add an option to set different ranges for L, Vmax, K and alpha
#TODO: add an option to fit real Vmax values instead of fixing Vmax=1.0
"""
fit_rate_equation(
rate_equation::Function,
data::DataFrame,
metab_names::Tuple,
param_names::Tuple;
n_iter = 20
)
Fit `rate_equation` to `data` and return loss and best fit parameters.
# Arguments
- `rate_equation::Function`: Function that takes a NamedTuple of metabolite concentrations (with `metab_names` keys) and parameters (with `param_names` keys) and returns an enzyme rate.
- `data::DataFrame`: DataFrame containing the data with column `Rate` and columns for each `metab_names` where each row is one measurement. It also needs to have a column `source` that contains a string that identifies the source of the data. This is used to calculate the weights for each figure in the publication.
- `metab_names::Tuple`: Tuple of metabolite names that correspond to the metabolites of `rate_equation` and column names in `data`.
- `param_names::Tuple`: Tuple of parameter names that correspond to the parameters of `rate_equation`.
- `n_iter::Int`: Number of iterations to run the fitting process.
# Returns
- `loss::Float64`: Loss of the best fit.
- `params::NamedTuple`: Best fit parameters with `param_names` keys
# Example
```julia
using DataFrames
data = DataFrame(
Rate = [1.0, 2.0, 3.0],
A = [1.0, 2.0, 3.0],
source = ["Figure 1", "Figure 1", "Figure 2"]
)
rate_equation(metabs, params) = params.Vmax * metabs.S / (1 + metabs.S / params.K_S)
fit_rate_equation(rate_equation, data, (:A,), (:Vmax, :K_S))
```
"""
function fit_rate_equation(
rate_equation::Function,
data::DataFrame,
metab_names::Tuple,
param_names::Tuple;
n_iter = 20
)
train_results = train_rate_equation(
rate_equation::Function,
data::DataFrame,
metab_names::Tuple,
param_names::Tuple;
n_iter = n_iter,
nt_param_removal_code = nothing
)
# rescaled_params = param_rescaling(train_results[2], param_names)
# return (loss = train_results[1], params = NamedTuple{param_names}(rescaled_params))
return (loss = train_results.loss, params = train_results.params)
end
function train_rate_equation(
rate_equation::Function,
data::DataFrame,
metab_names::Tuple,
param_names::Tuple;
n_iter = 20,
nt_param_removal_code = nothing
)
# Add a new column to data to assign an integer to each source/figure from publication
data.fig_num = vcat(
[i * ones(Int64, count(==(unique(data.source)[i]), data.source))
for
i in 1:length(unique(data.source))]...,
)
# Convert DF to NamedTuple for better type stability / speed
#TODO: add fig_point_indexes to rate_data_nt to avoid passing it as an argument to loss_rate_equation
rate_data_nt = Tables.columntable(data[
.!isnan.(data.Rate), [:Rate, metab_names..., :fig_num]])
# Make a vector containing indexes of points corresponding to each figure
fig_point_indexes = [findall(data.fig_num .== i) for i in unique(data.fig_num)]
# Check if nt_param_removal_code makes loss returns NaN and abort early if it does. The latter
# could happens due to nt_param_removal_code making params=Inf
if isnan(
loss_rate_equation(
5 .* ones(length(param_names)),
rate_equation,
rate_data_nt,
param_names,
fig_point_indexes;
rescale_params_from_0_10_scale = true,
nt_param_removal_code = nt_param_removal_code
),
)
println("Loss returns NaN for this param combo")
return Inf, fill(NaN, length(param_names))
end
solns = []
for i in 1:n_iter
x0 = 10 .* rand(length(param_names))
sol = try
minimize(
x -> loss_rate_equation(
x,
rate_equation,
rate_data_nt,
param_names,
fig_point_indexes;
rescale_params_from_0_10_scale = true,
nt_param_removal_code = nt_param_removal_code
),
x0,
0.01,
lower = repeat([0.0], length(x0)),
upper = repeat([10.0], length(x0)),
popsize = 4 * (4 + floor(Int, 3 * log(length(x0)))),
maxiter = 50_000,
verbosity = 0,
ftol = 1e-10
)
catch error
# bypass rare errors (~1 in 10,000 runs) where the minimize() fails to converge with "ArgumentError: matrix contains Infs or NaNs"
if isa(error, ArgumentError)
println(error)
sol = Inf
end
end
push!(solns, sol)
end
filter!(sol -> sol != Inf ? !isinf(fbest(sol)) : !isinf(fbest(sol)), solns)
filter!(sol -> sol != NaN ? !isnan(fbest(sol)) : !isnan(fbest(sol)), solns)
if isempty(solns)
println("All of the iterations of fits for this param combo return NaN or Inf")
return Inf, fill(NaN, length(param_names))
end
index_best_sol = argmin([fbest(sol) for sol in solns])
best_sol = try
minimize(
x -> loss_rate_equation(
x,
rate_equation::Function,
rate_data_nt,
param_names,
fig_point_indexes;
rescale_params_from_0_10_scale = true,
nt_param_removal_code = nt_param_removal_code
),
xbest(solns[index_best_sol]),
0.001,
lower = repeat([0.0], length(xbest(solns[index_best_sol]))),
upper = repeat([10.0], length(xbest(solns[index_best_sol]))),
popsize = 4 * (4 + floor(Int, 3 * log(length(xbest(solns[index_best_sol]))))),
maxiter = 50_000,
verbosity = 0,
ftol = 1e-14
)
catch error
# bypass rare errors where the minimize() fails to converge with "ArgumentError: matrix contains Infs or NaNs"
if isa(error, ArgumentError)
println(error)
best_sol = solns[index_best_sol]
end
end
rescaled_params = param_rescaling(xbest(best_sol), param_names)
if !isnothing(nt_param_removal_code)
rescaled_params = param_subset_select(rescaled_params, param_names, nt_param_removal_code)
end
return (loss=fbest(best_sol), params=NamedTuple{param_names}(rescaled_params))
end
"Loss function used for fitting that calculate log of ratio of rate equation predicting of rate and rate data"
function loss_rate_equation(
params,
rate_equation::Function,
rate_data_nt::NamedTuple,
param_names,
fig_point_indexes::Vector{Vector{Int}};
rescale_params_from_0_10_scale = true,
nt_param_removal_code = nothing
)
if rescale_params_from_0_10_scale
kinetic_params = param_rescaling(params, param_names)
end
if !isnothing(nt_param_removal_code)
kinetic_params .= param_subset_select(kinetic_params, param_names, nt_param_removal_code)
end
#precalculate log_pred_vs_data_ratios for all points as it is expensive and reuse it for weights and loss
#convert kinetic_params to NamedTuple with field names from param_names for better type stability
kinetic_params_nt = NamedTuple{param_names}(kinetic_params)
log_pred_vs_data_ratios = log_ratio_predict_vs_data(
rate_equation, rate_data_nt, kinetic_params_nt)
#calculate figures weights and loss on per figure basis
loss = zero(eltype(kinetic_params))
for i in 1:maximum(rate_data_nt.fig_num)
# calculate Vmax weights for each figure which have analytical solution as ratio of gemetric means of data vs prediction
log_weight = mean(-log_pred_vs_data_ratios[fig_point_indexes[i]])
loss += sum(abs2.(log_weight .+ log_pred_vs_data_ratios[fig_point_indexes[i]]))
end
return loss / length(rate_data_nt.Rate)
end
function log_ratio_predict_vs_data(
rate_equation::Function,
rate_data_nt::NamedTuple,
kinetic_params_nt::NamedTuple;
)
log_pred_vs_data_ratios = 10 .* ones(Float64, length(rate_data_nt.Rate))
#TODO: maybe convert this to broacasting calculation of rate using rate_equation.(row, Ref(kinetic_params_nt))
# and then log.(pred ./ data) instead of a loop.
for (i, row) in enumerate(Tables.namedtupleiterator(rate_data_nt))
log_pred_vs_data_ratios[i] = log(rate_equation(row, kinetic_params_nt) / row.Rate)
end
return log_pred_vs_data_ratios
end
# TODO: need to add an option to set different ranges for L, Vmax, K and alpha
"Rescaling of fitting parameters from [0, 10] scale that optimizer uses to actual values"
function param_rescaling(p, param_names)
@assert length(p) == length(param_names)
new_p = similar(p)
for i in eachindex(p)
if param_names[i] == :L
new_p[i] = 10^(-5) * 10^(10 * p[i] / 10)
elseif startswith(string(param_names[i]), "Vmax")
new_p[i] = 10^(-3) * 10^(3 * p[i] / 10)
elseif startswith(string(param_names[i]), "K_")
new_p[i] = 10^(-10) * 10^(13 * p[i] / 10)
elseif startswith(string(param_names[i]), "alpha_")
p[i] >= 5.0 ? new_p[i] = 1.0 : new_p[i] = 0.0
else
error("Cannot rescale unknown parameter name $(string(param_names[i])) using `param_rescaling()`")
end
end
return new_p
end