-
-
Notifications
You must be signed in to change notification settings - Fork 211
/
Copy pathMTKChainRulesCoreExt.jl
100 lines (92 loc) · 3.47 KB
/
MTKChainRulesCoreExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
module MTKChainRulesCoreExt
import ModelingToolkit as MTK
import ChainRulesCore
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
function mtp_pullback(dt)
dt = unthunk(dt)
(NoTangent(), dt.tunable[1:length(tunables)],
ntuple(_ -> NoTangent(), length(args))...)
end
MTK.MTKParameters(tunables, args...), mtp_pullback
end
function subset_idxs(idxs, portion, template)
ntuple(Val(length(template))) do subi
[Base.tail(idx.idx) for idx in idxs if idx.portion == portion && idx.idx[1] == subi]
end
end
selected_tangents(::NoTangent, _) = ()
selected_tangents(::ZeroTangent, _) = ZeroTangent()
function selected_tangents(
tangents::AbstractArray{T}, idxs::Vector{Tuple{Int}}) where {T <: Number}
selected_tangents(tangents, map(only, idxs))
end
function selected_tangents(tangents::AbstractArray{T}, idxs...) where {T <: Number}
newtangents = copy(tangents)
view(newtangents, idxs...) .= zero(T)
newtangents
end
function selected_tangents(
tangents::AbstractVector{T}, idxs) where {S <: Number, T <: AbstractArray{S}}
newtangents = copy(tangents)
for i in idxs
j, k... = i
if k == ()
newtangents[j] = zero(newtangents[j])
else
newtangents[j] = selected_tangents(newtangents[j], k...)
end
end
newtangents
end
function selected_tangents(tangents::AbstractVector{T}, idxs) where {T <: AbstractArray}
newtangents = similar(tangents, Union{T, NoTangent})
copyto!(newtangents, tangents)
for i in idxs
j, k... = i
if k == ()
newtangents[j] = NoTangent()
else
newtangents[j] = selected_tangents(newtangents[j], k...)
end
end
newtangents
end
function selected_tangents(
tangents::Union{Tangent{<:Tuple}, Tangent{T, <:Tuple}}, idxs) where {T}
ntuple(Val(length(tangents))) do i
selected_tangents(tangents[i], idxs[i])
end
end
function ChainRulesCore.rrule(
::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals)
if idxs isa AbstractSet
idxs = collect(idxs)
end
idxs = map(idxs) do i
i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i)
end
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
tunable_idxs = reduce(
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable))
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
pullback = let idxs = idxs
function remake_buffer_pullback(buf′)
buf′ = unthunk(buf′)
f′ = NoTangent()
indp′ = NoTangent()
tunable = selected_tangents(buf′.tunable, tunable_idxs)
discrete = selected_tangents(buf′.discrete, disc_idxs)
constant = selected_tangents(buf′.constant, const_idxs)
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
idxs′ = NoTangent()
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
return f′, indp′, oldbuf′, idxs′, vals′
end
end
newbuf, pullback
end
end