Skip to content

Commit

Permalink
Merge pull request #24 from DenisTitovLab/decrease-allocation-of-loss…
Browse files Browse the repository at this point in the history
…_rate_equation()

decrease allocation in loss_rate_equation() from ~50KB to ~20KB
  • Loading branch information
Denis-Titov authored Jun 13, 2024
2 parents c03ac00 + aaf9619 commit b0881b8
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions src/rate_equation_fitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ fit_rate_equation(rate_equation, data, (:A,), (:Vmax, :K_S))
function fit_rate_equation(
rate_equation::Function,
data::DataFrame,
metab_names::Tuple{Symbol, Vararg{Symbol}},
param_names::Tuple{Symbol, Vararg{Symbol}};
metab_names::Tuple{Symbol,Vararg{Symbol}},
param_names::Tuple{Symbol,Vararg{Symbol}};
n_iter = 20,
)
train_results = train_rate_equation(
rate_equation::Function,
data::DataFrame,
metab_names::Tuple{Symbol, Vararg{Symbol}},
param_names::Tuple{Symbol, Vararg{Symbol}};
metab_names::Tuple{Symbol,Vararg{Symbol}},
param_names::Tuple{Symbol,Vararg{Symbol}};
n_iter = n_iter,
nt_param_removal_code = nothing,
)
Expand All @@ -63,8 +63,8 @@ end
function train_rate_equation(
rate_equation::Function,
data::DataFrame,
metab_names::Tuple{Symbol, Vararg{Symbol}},
param_names::Tuple{Symbol, Vararg{Symbol}};
metab_names::Tuple{Symbol,Vararg{Symbol}},
param_names::Tuple{Symbol,Vararg{Symbol}};
n_iter = 20,
nt_param_removal_code = nothing,
)
Expand All @@ -74,8 +74,10 @@ function train_rate_equation(
# Add a new column to data to assign an integer to each source/figure from publication
filtered_data.fig_num = vcat(
[
i * ones(Int64, count(==(unique(filtered_data.source)[i]), filtered_data.source)) for
i = 1:length(unique(filtered_data.source))
i * ones(
Int64,
count(==(unique(filtered_data.source)[i]), filtered_data.source),
) for i = 1:length(unique(filtered_data.source))
]...,
)
# Add a column containing indexes of points corresponding to each figure
Expand Down Expand Up @@ -187,7 +189,7 @@ function loss_rate_equation(
params,
rate_equation::Function,
rate_data_nt::NamedTuple,
param_names::Tuple{Symbol, Vararg{Symbol}},
param_names::Tuple{Symbol,Vararg{Symbol}},
fig_point_indexes::Vector{Vector{Int}};
rescale_params_from_0_10_scale = true,
nt_param_removal_code = nothing,
Expand All @@ -204,16 +206,24 @@ function loss_rate_equation(
end

#precalculate log_pred_vs_data_ratios for all points as it is expensive and reuse it for weights and loss
#convert kinetic_params to NamedTuple with field names from param_names for better type stability
kinetic_params_nt = NamedTuple{param_names}(kinetic_params)
kinetic_params_nt =
NamedTuple{param_names}(ntuple(i -> kinetic_params[i], Val(length(kinetic_params))))
log_pred_vs_data_ratios =
log_ratio_predict_vs_data(rate_equation, rate_data_nt, kinetic_params_nt)

#calculate figures weights and loss on per figure basis
loss = zero(eltype(kinetic_params))
loss = zero(eltype(log_pred_vs_data_ratios))
for i = 1:maximum(rate_data_nt.fig_num)
# calculate Vmax weights for each figure which have analytical solution as ratio of gemetric means of data vs prediction
log_weight = mean(-log_pred_vs_data_ratios[fig_point_indexes[i]])
loss += sum(abs2.(log_weight .+ log_pred_vs_data_ratios[fig_point_indexes[i]]))
log_fig_weight = zero(eltype(log_pred_vs_data_ratios))
counter = 0
for j in fig_point_indexes[i]
log_fig_weight += log_pred_vs_data_ratios[j]
counter += 1
end
log_fig_weight /= counter
for j in fig_point_indexes[i]
loss += abs2(log_fig_weight - log_pred_vs_data_ratios[j])
end
end
return loss / length(rate_data_nt.Rate)
end
Expand All @@ -223,9 +233,7 @@ function log_ratio_predict_vs_data(
rate_data_nt::NamedTuple,
kinetic_params_nt::NamedTuple;
)
log_pred_vs_data_ratios = 10 .* ones(Float64, length(rate_data_nt.Rate))
#TODO: maybe convert this to broacasting calculation of rate using rate_equation.(row, Ref(kinetic_params_nt))
# and then log.(pred ./ data) instead of a loop.
log_pred_vs_data_ratios = ones(Float64, length(rate_data_nt.Rate))
for (i, row) in enumerate(Tables.namedtupleiterator(rate_data_nt))
log_pred_vs_data_ratios[i] = log(rate_equation(row, kinetic_params_nt) / row.Rate)
end
Expand Down

0 comments on commit b0881b8

Please sign in to comment.