-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerate_dbdt.jl
217 lines (190 loc) · 7.41 KB
/
generate_dbdt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# Generate a minimalistic, absolutely-dedicated dBdt! function
# to hand out to the ODE solver.
# The generated function is only valid for one specific value of `ModelParameters`,
# and must be re-generated any time the parameters change
# and/or the associated network topology.
# The code below works with julia representations of julia code,
# in terms of julia's "symbols" and "expressions".
# The idea is that identifiers within the simple expressions
# explicitly appearing in the program
# will be successively transformed according to replacement rules
# until the whole code in dBdt! is generated.
"""
Construct a copy of the expression with the replacements given in `rep`.
```jldoctest
julia> import EcologicalNetworksDynamics.Internals: replace
julia> replace(:(a + (b + c / a)), Dict(:a => 5, :b => 8))
:(5 + (8 + c / 5))
```
"""
function replace(xp, rep)
# Degenerated single occurence case.
if haskey(rep, xp)
return rep[xp]
end
# Deep-copy with a recursive descent.
res = Expr(xp.head)
for a in xp.args
if haskey(rep, a)
a = rep[a]
elseif isa(a, Expr)
a = replace(a, rep)
end
push!(res.args, a)
end
res
end
"""
Repeat the given expression into terms of a sum,
successively replacing `indexes` in `term` by elements in (zipped) `lists`.
```jldoctest
julia> import EcologicalNetworksDynamics.Internals: xp_sum
julia> xp_sum([:i], [[1, 2, 3]], :(u^i)) # Three terms.
:(u ^ 1 + u ^ 2 + u ^ 3)
julia> xp_sum([:i], [[1]], :(u^i)) # Single term.
:(u ^ 1)
julia> xp_sum([:i], [[]], :(u^i)) # No terms.
0
julia> xp_sum([:i, :j], [[:a, :b, :c], [5, 8, 13]], :(j * i)) # Zipped indices.
:(5a + 8b + 13c)
```
"""
function xp_sum(indexes, lists, term)
n_terms = min((length(l) for l in lists)...)
if n_terms == 0
return 0
end
if n_terms == 1
reps = Dict(index => v for (index, v) in zip(indexes, first(zip(lists...))))
return replace(term, reps)
end
sum = Expr(:call, :+)
for values in zip(lists...)
reps = Dict(index => v for (index, v) in zip(indexes, values))
push!(sum.args, replace(term, reps))
end
sum
end
# Recursively visit the expression to modify it once according to the transformation rules.
# The rules are plain symbols corresponding functions identifiers returning expressions.
# - When appearing as plain identifier like in `1 + symbol + 2`, they are replaced
# with the expression returned by `symbol(data...)`.
# - When appearing as function calls like `1 + symbol(expr, expr) + 2`, they are replaced
# with the expression returned by `symbol(expr, expr)` as one would expect.
# Return false when no modification has been made, so no more expansion step is needed.
function expand!(xp, rules, data)
modified = false
for (n, a) in enumerate(xp.args)
# Simple identifier case.
if a in rules
xp.args[n] = eval(a)(data...)
modified = true
elseif isa(a, Expr)
# Call-like identifier(args).
if a.head == :call && a.args[1] in rules
xp.args[n] = eval(a)
modified = true
else
modified |= expand!(a, rules, data)
end
end
end
modified
end
"""
Wraps a Julia Expression generated by [`generate_dbdt`](@ref).
Mostly useful for pretty-printing,
but you can `eval`uate it like a regular expression.
The actual expression lies inside the `.expr` attribute.
"""
struct GeneratedExpression
expr::Expr
end
function Base.show(io::IO, c::GeneratedExpression)
print(io, nameof(typeof(c)), "{")
print(io, Base.remove_linenums!(deepcopy(c.expr)))
print(io, "}")
end
# https://discourse.julialang.org/t/whats-a-good-way-to-add-method-to-base-maininclude-eval/91697/2?u=iago-lito
Core.eval(m::Module, c::GeneratedExpression) = Core.eval(m, c.expr)
"""
generate_dbdt(parms::ModelParameters, type)
Produce a specialized julia expression and associated data,
supposed to improve efficiency of subsequent simulations.
The returned expression is typically
[`eval`](https://docs.julialang.org/en/v1/devdocs/eval/)uated
then passed along with the data as a `diff_code_data` argument to [`simulate`](@ref).
There are two possible code generation styles:
- With `type = :raw`,
the generated expression is a straightforward translation
of the underlying differential equations,
with no loops, no recursive calls, nor heap-allocations:
only local variables and basic arithmetic is used.
This makes simulation very efficient,
but the length of the generated expression
varies with the number of species interactions.
When the length becomes high,
it takes much longer for julia to compile it.
If it takes forever,
(typically over `SyntaxtTree.callcount(expression) > 20_000`),
wait until julia 1.9 ([maybe](https://discourse.julialang.org/t/profiling-compilation-of-a-large-generated-expression/83179))
or use the alternate style instead.
- With `type = :compact`,
the generated expression is a more sophisticated implementation
of the underlying differential equations,
involving carefully crafted minimal loops
and exactly one fixed-size heap-allocated bunch of data,
reused on every call during simulation.
This makes the simulation slightly less efficient than the above but,
as the expression size no longer depends on the number of species interactions,
there is no limit to using it and speedup simulations.
"""
function generate_dbdt(parms::ModelParameters, type)
style = Symbol(type)
# TEMP: Summary of working and convincingly tested implementations.
(resp, _, pg) = boostable_criteria(parms)
function to_test()
@warn "Automatic generated :$style specialized code for $resp ($net, $pg) \
has not been rigorously tested yet.\n\
If you are using it for non-trivial simulation, \
please make sure that the resulting trajectories \
do match the ones generated with traditional generic code \
with the same parameters:\n$parms\n\
If they are, then consider adding that simulation to the packages tests set \
so this warning can be removed in future upgrades."
end
unimplemented() = throw("Automatic generated :$style specialized code
for $resp ($net, $pg) is not implemented yet.")
ok() = nothing
if !is_boostable(parms, type)
unimplemented()
else
to_test() # Switch to ok() once it has been tested on longer simulations.
end
if style == :raw
xp, data = generate_dbdt_raw(parms)
elseif style == :compact
xp, data = generate_dbdt_compact(parms)
else
throw("Unknown code generation style: '$style'.")
end
GeneratedExpression(xp), data
end
"""
is_boostable(parms::ModelParameters, type)
Return true if boost has been implemented for this parametrization
with the given boosting style.
"""
function is_boostable(parms::ModelParameters, type)
(resp, _, pg) = boostable_criteria(parms)
resp != MultiplexNetwork && pg != NutrientIntake
end
# Extract necessary information to decide whether boosting is supported or not.
function boostable_criteria(parms)
resp = typeof(parms.functional_response)
net = typeof(parms.network)
pg = typeof(parms.producer_growth)
(resp, net, pg)
end
include("./generate_dbdt_compact.jl")
include("./generate_dbdt_raw.jl")