Line | Exclusive | Inclusive | Code |
---|---|---|---|
1 | ######## | ||
2 | # Dual # | ||
3 | ######## | ||
4 | |||
5 | """ | ||
6 | ForwardDiff.can_dual(V::Type) | ||
7 | |||
8 | Determines whether the type V is allowed as the scalar type in a | ||
9 | Dual. By default, only `<:Real` types are allowed. | ||
10 | """ | ||
11 | can_dual(::Type{<:Real}) = true | ||
12 | can_dual(::Type) = false | ||
13 | |||
14 | struct Dual{T,V,N} <: Real | ||
15 | value::V | ||
16 | partials::Partials{N,V} | ||
17 | function Dual{T, V, N}(value::V, partials::Partials{N, V}) where {T, V, N} | ||
18 | can_dual(V) || throw_cannot_dual(V) | ||
19 | new{T, V, N}(value, partials) | ||
20 | end | ||
21 | end | ||
22 | |||
23 | ########## | ||
24 | # Traits # | ||
25 | ########## | ||
26 | Base.ArithmeticStyle(::Type{<:Dual{T,V}}) where {T,V} = Base.ArithmeticStyle(V) | ||
27 | |||
28 | ############## | ||
29 | # Exceptions # | ||
30 | ############## | ||
31 | |||
32 | struct DualMismatchError{A,B} <: Exception | ||
33 | a::A | ||
34 | b::B | ||
35 | end | ||
36 | |||
37 | Base.showerror(io::IO, e::DualMismatchError{A,B}) where {A,B} = | ||
38 | print(io, "Cannot determine ordering of Dual tags $(e.a) and $(e.b)") | ||
39 | |||
40 | @noinline function throw_cannot_dual(V::Type) | ||
41 | throw(ArgumentError("Cannot create a dual over scalar type $V." * | ||
42 | " If the type behaves as a scalar, define ForwardDiff.can_dual(::Type{$V}) = true.")) | ||
43 | end | ||
44 | |||
45 | """ | ||
46 | ForwardDiff.≺(a, b)::Bool | ||
47 | |||
48 | Determines the order in which tagged `Dual` objects are composed. If true, then `Dual{b}` | ||
49 | objects will appear outside `Dual{a}` objects. | ||
50 | |||
51 | This is important when working with nested differentiation: currently, only the outermost | ||
52 | tag can be extracted, so it should be used in the _innermost_ function. | ||
53 | """ | ||
54 | ≺(a,b) = throw(DualMismatchError(a,b)) | ||
55 | |||
56 | ################ | ||
57 | # Constructors # | ||
58 | ################ | ||
59 | |||
60 | @inline Dual{T}(value::V, partials::Partials{N,V}) where {T,N,V} = Dual{T,V,N}(value, partials) | ||
61 | |||
62 | @inline function Dual{T}(value::A, partials::Partials{N,B}) where {T,N,A,B} | ||
63 | C = promote_type(A, B) | ||
64 | return Dual{T}(convert(C, value), convert(Partials{N,C}, partials)) | ||
65 | end | ||
66 | |||
67 | @inline Dual{T}(value, partials::Tuple) where {T} = Dual{T}(value, Partials(partials)) | ||
68 | @inline Dual{T}(value, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials)) | ||
69 | @inline Dual{T}(value) where {T} = Dual{T}(value, ()) | ||
70 | @inline Dual{T}(x::Dual{T}) where {T} = Dual{T}(x, ()) | ||
71 | @inline Dual{T}(value, partial1, partials...) where {T} = Dual{T}(value, tuple(partial1, partials...)) | ||
72 | @inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p)) | ||
73 | @inline Dual(args...) = Dual{Nothing}(args...) | ||
74 | |||
75 | # we define these special cases so that the "constructor <--> convert" pun holds for `Dual` | ||
76 | @inline Dual{T,V,N}(x::Dual{T,V,N}) where {T,V,N} = x | ||
77 | @inline Dual{T,V,N}(x) where {T,V,N} = convert(Dual{T,V,N}, x) | ||
78 | @inline Dual{T,V,N}(x::Number) where {T,V,N} = convert(Dual{T,V,N}, x) | ||
79 | @inline Dual{T,V}(x) where {T,V} = convert(Dual{T,V}, x) | ||
80 | |||
81 | ############################## | ||
82 | # Utility/Accessor Functions # | ||
83 | ############################## | ||
84 | |||
85 | @inline value(x) = x | ||
86 | @inline value(d::Dual) = d.value | ||
87 | |||
88 | @inline value(::Type{T}, x) where T = x | ||
89 | @inline value(::Type{T}, d::Dual{T}) where T = value(d) | ||
90 | @inline function value(::Type{T}, d::Dual{S}) where {T,S} | ||
91 | if S ≺ T | ||
92 | d | ||
93 | else | ||
94 | throw(DualMismatchError(T,S)) | ||
95 | end | ||
96 | end | ||
97 | |||
98 | @inline partials(x) = Partials{0,typeof(x)}(tuple()) | ||
99 | @inline partials(d::Dual) = d.partials | ||
100 | @inline partials(x, i...) = zero(x) | ||
101 | @inline Base.@propagate_inbounds partials(d::Dual, i) = d.partials[i] | ||
102 | @inline Base.@propagate_inbounds partials(d::Dual, i, j) = partials(d, i).partials[j] | ||
103 | @inline Base.@propagate_inbounds partials(d::Dual, i, j, k...) = partials(partials(d, i, j), k...) | ||
104 | |||
105 | @inline Base.@propagate_inbounds partials(::Type{T}, x, i...) where T = partials(x, i...) | ||
106 | @inline Base.@propagate_inbounds partials(::Type{T}, d::Dual{T}, i...) where T = partials(d, i...) | ||
107 | @inline function partials(::Type{T}, d::Dual{S}, i...) where {T,S} | ||
108 | if S ≺ T | ||
109 | zero(d) | ||
110 | else | ||
111 | throw(DualMismatchError(T,S)) | ||
112 | end | ||
113 | end | ||
114 | |||
115 | |||
116 | @inline npartials(::Dual{T,V,N}) where {T,V,N} = N | ||
117 | @inline npartials(::Type{Dual{T,V,N}}) where {T,V,N} = N | ||
118 | |||
119 | @inline order(::Type{V}) where {V} = 0 | ||
120 | @inline order(::Type{Dual{T,V,N}}) where {T,V,N} = 1 + order(V) | ||
121 | |||
122 | @inline valtype(::V) where {V} = V | ||
123 | @inline valtype(::Type{V}) where {V} = V | ||
124 | @inline valtype(::Dual{T,V,N}) where {T,V,N} = V | ||
125 | @inline valtype(::Type{Dual{T,V,N}}) where {T,V,N} = V | ||
126 | |||
127 | @inline tagtype(::V) where {V} = Nothing | ||
128 | @inline tagtype(::Type{V}) where {V} = Nothing | ||
129 | @inline tagtype(::Dual{T,V,N}) where {T,V,N} = T | ||
130 | @inline tagtype(::Type{Dual{T,V,N}}) where {T,V,N} = T | ||
131 | |||
132 | #################################### | ||
133 | # N-ary Operation Definition Tools # | ||
134 | #################################### | ||
135 | |||
136 | macro define_binary_dual_op(f, xy_body, x_body, y_body) | ||
137 | FD = ForwardDiff | ||
138 | defs = quote | ||
139 | @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}) where {Txy} = $xy_body | ||
140 | @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}) where {Tx,Ty} = Ty ≺ Tx ? $x_body : $y_body | ||
141 | end | ||
142 | for R in AMBIGUOUS_TYPES | ||
143 | expr = quote | ||
144 | @inline $(f)(x::$FD.Dual{Tx}, y::$R) where {Tx} = $x_body | ||
145 | @inline $(f)(x::$R, y::$FD.Dual{Ty}) where {Ty} = $y_body | ||
146 | end | ||
147 | append!(defs.args, expr.args) | ||
148 | end | ||
149 | return esc(defs) | ||
150 | end | ||
151 | |||
152 | macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body) | ||
153 | FD = ForwardDiff | ||
154 | defs = quote | ||
155 | @inline $(f)(x::$FD.Dual{Txyz}, y::$FD.Dual{Txyz}, z::$FD.Dual{Txyz}) where {Txyz} = $xyz_body | ||
156 | @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$FD.Dual{Tz}) where {Txy,Tz} = Tz ≺ Txy ? $xy_body : $z_body | ||
157 | @inline $(f)(x::$FD.Dual{Txz}, y::$FD.Dual{Ty}, z::$FD.Dual{Txz}) where {Txz,Ty} = Ty ≺ Txz ? $xz_body : $y_body | ||
158 | @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tx,Tyz} = Tyz ≺ Tx ? $x_body : $yz_body | ||
159 | @inline function $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Tx,Ty,Tz} | ||
160 | if Tz ≺ Tx && Ty ≺ Tx | ||
161 | $x_body | ||
162 | elseif Tz ≺ Ty | ||
163 | $y_body | ||
164 | else | ||
165 | $z_body | ||
166 | end | ||
167 | end | ||
168 | end | ||
169 | for R in AMBIGUOUS_TYPES | ||
170 | expr = quote | ||
171 | @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body | ||
172 | @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body | ||
173 | @inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body | ||
174 | @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz ≺ Tx ? $x_body : $z_body | ||
175 | @inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body | ||
176 | @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body | ||
177 | end | ||
178 | append!(defs.args, expr.args) | ||
179 | for Q in AMBIGUOUS_TYPES | ||
180 | Q === R && continue | ||
181 | expr = quote | ||
182 | @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body | ||
183 | @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body | ||
184 | @inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body | ||
185 | end | ||
186 | append!(defs.args, expr.args) | ||
187 | end | ||
188 | expr = quote | ||
189 | @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body | ||
190 | @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body | ||
191 | @inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body | ||
192 | end | ||
193 | append!(defs.args, expr.args) | ||
194 | end | ||
195 | return esc(defs) | ||
196 | end | ||
197 | |||
198 | # Support complex-valued functions such as `hankelh1` | ||
199 | function dual_definition_retval(::Val{T}, val::Real, deriv::Real, partial::Partials) where {T} | ||
200 | 1 (0 %) |
1 (100 %)
samples spent calling
*
return Dual{T}(val, deriv * partial)
|
|
201 | end | ||
202 |
17 (6 %)
samples spent in dual_definition_retval
function dual_definition_retval(::Val{T}, val::Real, deriv1::Real, partial1::Partials, deriv2::Real, partial2::Partials) where {T}
1 (6 %) (ex.), 16 (94 %) (incl.) when called from * line 271 1 (6 %) (incl.) when called from * line 281 |
||
203 | 1 (0 %) | 16 (6 %) |
15 (100 %)
samples spent calling
_mul_partials
return Dual{T}(val, _mul_partials(partial1, partial2, deriv1, deriv2))
|
204 | end | ||
205 | function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Complex}, partial::Partials) where {T} | ||
206 | reval, imval = reim(val) | ||
207 | if deriv isa Real | ||
208 | p = deriv * partial | ||
209 | return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p))) | ||
210 | else | ||
211 | rederiv, imderiv = reim(deriv) | ||
212 | return Complex(Dual{T}(reval, rederiv * partial), Dual{T}(imval, imderiv * partial)) | ||
213 | end | ||
214 | end | ||
215 | function dual_definition_retval(::Val{T}, val::Complex, deriv1::Union{Real,Complex}, partial1::Partials, deriv2::Union{Real,Complex}, partial2::Partials) where {T} | ||
216 | reval, imval = reim(val) | ||
217 | if deriv1 isa Real && deriv2 isa Real | ||
218 | p = _mul_partials(partial1, partial2, deriv1, deriv2) | ||
219 | return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p))) | ||
220 | else | ||
221 | rederiv1, imderiv1 = reim(deriv1) | ||
222 | rederiv2, imderiv2 = reim(deriv2) | ||
223 | return Complex( | ||
224 | Dual{T}(reval, _mul_partials(partial1, partial2, rederiv1, rederiv2)), | ||
225 | Dual{T}(imval, _mul_partials(partial1, partial2, imderiv1, imderiv2)), | ||
226 | ) | ||
227 | end | ||
228 | end | ||
229 | |||
230 | function unary_dual_definition(M, f) | ||
231 | FD = ForwardDiff | ||
232 | Mf = M == :Base ? f : :($M.$f) | ||
233 | work = qualified_cse!(quote | ||
234 | val = $Mf(x) | ||
235 | deriv = $(DiffRules.diffrule(M, f, :x)) | ||
236 | end) | ||
237 | return quote | ||
238 | @inline function $M.$f(d::$FD.Dual{T}) where T | ||
239 | x = $FD.value(d) | ||
240 | $work | ||
241 | return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d)) | ||
242 | end | ||
243 | end | ||
244 | end | ||
245 | |||
246 | function binary_dual_definition(M, f) | ||
247 | FD = ForwardDiff | ||
248 | dvx, dvy = DiffRules.diffrule(M, f, :vx, :vy) | ||
249 | Mf = M == :Base ? f : :($M.$f) | ||
250 | xy_work = qualified_cse!(quote | ||
251 | val = $Mf(vx, vy) | ||
252 | dvx = $dvx | ||
253 | dvy = $dvy | ||
254 | end) | ||
255 | dvx, _ = DiffRules.diffrule(M, f, :vx, :y) | ||
256 | x_work = qualified_cse!(quote | ||
257 | val = $Mf(vx, y) | ||
258 | dvx = $dvx | ||
259 | end) | ||
260 | _, dvy = DiffRules.diffrule(M, f, :x, :vy) | ||
261 | y_work = qualified_cse!(quote | ||
262 | val = $Mf(x, vy) | ||
263 | dvy = $dvy | ||
264 | end) | ||
265 | expr = quote | ||
266 | $FD.@define_binary_dual_op( | ||
267 | $M.$f, | ||
268 | begin | ||
269 | vx, vy = $FD.value(x), $FD.value(y) | ||
270 | $xy_work | ||
271 | 1 (0 %) | 17 (6 %) |
17 (6 %)
samples spent in *
1 (6 %) (ex.), 9 (53 %) (incl.) when called from brusselator_2d_loop line 49 8 (47 %) (incl.) when called from brusselator_2d_loop line 54
16 (100 %)
samples spent calling
dual_definition_retval
return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y))
|
272 | end, | ||
273 | begin | ||
274 | vx = $FD.value(x) | ||
275 | $x_work | ||
276 | return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x)) | ||
277 | end, | ||
278 | begin | ||
279 | vy = $FD.value(y) | ||
280 | $y_work | ||
281 | 1 (0 %) |
1 (100 %)
samples spent calling
dual_definition_retval
return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y))
|
|
282 | end | ||
283 | ) | ||
284 | end | ||
285 | return expr | ||
286 | end | ||
287 | |||
288 | ##################### | ||
289 | # Generic Functions # | ||
290 | ##################### | ||
291 | |||
292 | Base.copy(d::Dual) = d | ||
293 | |||
294 | Base.eps(d::Dual) = eps(value(d)) | ||
295 | Base.eps(::Type{D}) where {D<:Dual} = eps(valtype(D)) | ||
296 | |||
297 | # The `base` keyword was added in Julia 1.8: | ||
298 | # https://github.com/JuliaLang/julia/pull/42428 | ||
299 | if VERSION < v"1.8.0-DEV.725" | ||
300 | Base.precision(d::Dual) = precision(value(d)) | ||
301 | Base.precision(::Type{D}) where {D<:Dual} = precision(valtype(D)) | ||
302 | else | ||
303 | Base.precision(d::Dual; base::Integer=2) = precision(value(d); base=base) | ||
304 | function Base.precision(::Type{D}; base::Integer=2) where {D<:Dual} | ||
305 | precision(valtype(D); base=base) | ||
306 | end | ||
307 | end | ||
308 | |||
309 | function Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} | ||
310 | ForwardDiff.Dual{T}(nextfloat(d.value), d.partials) | ||
311 | end | ||
312 | |||
313 | function Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} | ||
314 | ForwardDiff.Dual{T}(prevfloat(d.value), d.partials) | ||
315 | end | ||
316 | |||
317 | Base.rtoldefault(::Type{D}) where {D<:Dual} = Base.rtoldefault(valtype(D)) | ||
318 | |||
319 | Base.floor(::Type{R}, d::Dual) where {R<:Real} = floor(R, value(d)) | ||
320 | Base.floor(d::Dual) = floor(value(d)) | ||
321 | |||
322 | Base.ceil(::Type{R}, d::Dual) where {R<:Real} = ceil(R, value(d)) | ||
323 | Base.ceil(d::Dual) = ceil(value(d)) | ||
324 | |||
325 | Base.trunc(::Type{R}, d::Dual) where {R<:Real} = trunc(R, value(d)) | ||
326 | Base.trunc(d::Dual) = trunc(value(d)) | ||
327 | |||
328 | Base.round(::Type{R}, d::Dual) where {R<:Real} = round(R, value(d)) | ||
329 | Base.round(d::Dual) = round(value(d)) | ||
330 | |||
331 | Base.fld(x::Dual, y::Dual) = fld(value(x), value(y)) | ||
332 | |||
333 | Base.cld(x::Dual, y::Dual) = cld(value(x), value(y)) | ||
334 | |||
335 | Base.exponent(x::Dual) = exponent(value(x)) | ||
336 | |||
337 | if VERSION ≥ v"1.4" | ||
338 | Base.div(x::Dual, y::Dual, r::RoundingMode) = div(value(x), value(y), r) | ||
339 | else | ||
340 | Base.div(x::Dual, y::Dual) = div(value(x), value(y)) | ||
341 | end | ||
342 | |||
343 | Base.hash(d::Dual) = hash(value(d)) | ||
344 | Base.hash(d::Dual, hsh::UInt) = hash(value(d), hsh) | ||
345 | |||
346 | function Base.read(io::IO, ::Type{Dual{T,V,N}}) where {T,V,N} | ||
347 | value = read(io, V) | ||
348 | partials = read(io, Partials{N,V}) | ||
349 | return Dual{T,V,N}(value, partials) | ||
350 | end | ||
351 | |||
352 | function Base.write(io::IO, d::Dual) | ||
353 | write(io, value(d)) | ||
354 | write(io, partials(d)) | ||
355 | end | ||
356 | |||
357 | @inline Base.zero(d::Dual) = zero(typeof(d)) | ||
358 | @inline Base.zero(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(zero(V), zero(Partials{N,V})) | ||
359 | |||
360 | @inline Base.one(d::Dual) = one(typeof(d)) | ||
361 | @inline Base.one(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(one(V), zero(Partials{N,V})) | ||
362 | |||
363 | @inline function Base.Int(d::Dual) | ||
364 | all(iszero, partials(d)) || throw(InexactError(:Int, Int, d)) | ||
365 | Int(value(d)) | ||
366 | end | ||
367 | @inline function Base.Integer(d::Dual) | ||
368 | all(iszero, partials(d)) || throw(InexactError(:Integer, Integer, d)) | ||
369 | Integer(value(d)) | ||
370 | end | ||
371 | |||
372 | @inline Random.rand(rng::AbstractRNG, d::Dual) = rand(rng, value(d)) | ||
373 | @inline Random.rand(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(rand(V), zero(Partials{N,V})) | ||
374 | @inline Random.rand(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(rand(rng, V), zero(Partials{N,V})) | ||
375 | @inline Random.randn(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randn(V), zero(Partials{N,V})) | ||
376 | @inline Random.randn(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randn(rng, V), zero(Partials{N,V})) | ||
377 | @inline Random.randexp(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randexp(V), zero(Partials{N,V})) | ||
378 | @inline Random.randexp(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randexp(rng, V), zero(Partials{N,V})) | ||
379 | |||
380 | # Predicates # | ||
381 | #------------# | ||
382 | |||
383 | isconstant(d::Dual) = iszero(partials(d)) | ||
384 | |||
385 | for pred in UNARY_PREDICATES | ||
386 | @eval Base.$(pred)(d::Dual) = $(pred)(value(d)) | ||
387 | end | ||
388 | |||
389 | for pred in BINARY_PREDICATES | ||
390 | @eval begin | ||
391 | @define_binary_dual_op( | ||
392 | Base.$(pred), | ||
393 | $(pred)(value(x), value(y)), | ||
394 | $(pred)(value(x), y), | ||
395 | $(pred)(x, value(y)) | ||
396 | ) | ||
397 | end | ||
398 | end | ||
399 | |||
400 | ######################## | ||
401 | # Promotion/Conversion # | ||
402 | ######################## | ||
403 | |||
404 | function Base.promote_rule(::Type{Dual{T1,V1,N1}}, | ||
405 | ::Type{Dual{T2,V2,N2}}) where {T1,V1,N1,T2,V2,N2} | ||
406 | # V1 and V2 might themselves be Dual types | ||
407 | if T2 ≺ T1 | ||
408 | Dual{T1,promote_type(V1,Dual{T2,V2,N2}),N1} | ||
409 | else | ||
410 | Dual{T2,promote_type(V2,Dual{T1,V1,N1}),N2} | ||
411 | end | ||
412 | end | ||
413 | |||
414 | function Base.promote_rule(::Type{Dual{T,A,N}}, | ||
415 | ::Type{Dual{T,B,N}}) where {T,A,B,N} | ||
416 | return Dual{T,promote_type(A, B),N} | ||
417 | end | ||
418 | |||
419 | for R in (Irrational, Real, BigFloat, Bool) | ||
420 | if isconcretetype(R) # issue #322 | ||
421 | @eval begin | ||
422 | Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N} | ||
423 | Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V,N} = Dual{T,promote_type(V, $R),N} | ||
424 | end | ||
425 | else | ||
426 | @eval begin | ||
427 | Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N} | ||
428 | Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V,N,R<:$R} = Dual{T,promote_type(V, R),N} | ||
429 | end | ||
430 | end | ||
431 | end | ||
432 | |||
433 | @inline Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V,N} = Dual{T}(V(value(d)), convert(Partials{N,V}, partials(d))) | ||
434 | @inline Base.convert(::Type{Dual{T,Dual{T,V,M},N}}, d::Dual{T,V,M}) where {T,V,N,M} = Dual{T}(d, Partials{N,Dual{T,V,M}}(zero_tuple(NTuple{N,Dual{T,V,M}}))) | ||
435 | @inline Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(V(x), zero(Partials{N,V})) | ||
436 | @inline Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(V(x), zero(Partials{N,V})) | ||
437 | Base.convert(::Type{D}, d::D) where {D<:Dual} = d | ||
438 | |||
439 | Base.float(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,float(V),N} | ||
440 | Base.float(d::Dual) = convert(float(typeof(d)), d) | ||
441 | |||
442 | ################################### | ||
443 | # General Mathematical Operations # | ||
444 | ################################### | ||
445 | |||
446 | for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) | ||
447 | if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-), (:Base, :sin), (:Base, :cos)) | ||
448 | continue # Skip methods which we define elsewhere. | ||
449 | elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) | ||
450 | continue # Skip rules for methods not defined in the current scope | ||
451 | end | ||
452 | if arity == 1 | ||
453 | eval(unary_dual_definition(M, f)) | ||
454 | elseif arity == 2 | ||
455 | eval(binary_dual_definition(M, f)) | ||
456 | else | ||
457 | # error("ForwardDiff currently only knows how to autogenerate Dual definitions for unary and binary functions.") | ||
458 | # However, the presence of N-ary rules need not cause any problems here, they can simply be ignored. | ||
459 | end | ||
460 | end | ||
461 | |||
462 | ################# | ||
463 | # Special Cases # | ||
464 | ################# | ||
465 | |||
466 | # +/- # | ||
467 | #-----# | ||
468 | |||
469 | @define_binary_dual_op( | ||
470 | Base.:+, | ||
471 | begin | ||
472 | vx, vy = value(x), value(y) | ||
473 | 3 (1 %) |
3 (1 %)
samples spent in +
Dual{Txy}(vx + vy, partials(x) + partials(y))
2 (67 %) (incl.) when called from + line 587 1 (33 %) (incl.) when called from brusselator_2d_loop line 54 |
|
474 | end, | ||
475 | Dual{Tx}(value(x) + y, partials(x)), | ||
476 | Dual{Ty}(x + value(y), partials(y)) | ||
477 | ) | ||
478 | |||
479 | @define_binary_dual_op( | ||
480 | Base.:-, | ||
481 | begin | ||
482 | vx, vy = value(x), value(y) | ||
483 | 4 (1 %) |
4 (100 %)
samples spent calling
-
Dual{Txy}(vx - vy, partials(x) - partials(y))
|
|
484 | end, | ||
485 | Dual{Tx}(value(x) - y, partials(x)), | ||
486 | Dual{Ty}(x - value(y), -partials(y)) | ||
487 | ) | ||
488 | |||
489 | @inline Base.:-(d::Dual{T}) where {T} = Dual{T}(-value(d), -partials(d)) | ||
490 | |||
491 | # * # | ||
492 | #---# | ||
493 | |||
494 | @inline Base.:*(d::Dual, x::Bool) = x ? d : (signbit(value(d))==0 ? zero(d) : -zero(d)) | ||
495 | @inline Base.:*(x::Bool, d::Dual) = d * x | ||
496 | |||
497 | # / # | ||
498 | #---# | ||
499 | |||
500 | # We can't use the normal diffrule autogeneration for this because (x/y) === (x * (1/y)) | ||
501 | # doesn't generally hold true for floating point; see issue #264 | ||
502 | @define_binary_dual_op( | ||
503 | Base.:/, | ||
504 | begin | ||
505 | vx, vy = value(x), value(y) | ||
506 | Dual{Txy}(vx / vy, _div_partials(partials(x), partials(y), vx, vy)) | ||
507 | end, | ||
508 | Dual{Tx}(value(x) / y, partials(x) / y), | ||
509 | begin | ||
510 | v = value(y) | ||
511 | divv = x / v | ||
512 | Dual{Ty}(divv, -(divv / v) * partials(y)) | ||
513 | end | ||
514 | ) | ||
515 | |||
516 | # exponentiation # | ||
517 | #----------------# | ||
518 | |||
519 | for f in (:(Base.:^), :(NaNMath.pow)) | ||
520 | @eval begin | ||
521 | @define_binary_dual_op( | ||
522 | $f, | ||
523 | begin | ||
524 | vx, vy = value(x), value(y) | ||
525 | expv = ($f)(vx, vy) | ||
526 | powval = vy * ($f)(vx, vy - 1) | ||
527 | if isconstant(y) | ||
528 | logval = one(expv) | ||
529 | elseif iszero(vx) && vy > 0 | ||
530 | logval = zero(vx) | ||
531 | else | ||
532 | logval = expv * log(vx) | ||
533 | end | ||
534 | new_partials = _mul_partials(partials(x), partials(y), powval, logval) | ||
535 | return Dual{Txy}(expv, new_partials) | ||
536 | end, | ||
537 | begin | ||
538 | v = value(x) | ||
539 | expv = ($f)(v, y) | ||
540 | if y == zero(y) || iszero(partials(x)) | ||
541 | new_partials = zero(partials(x)) | ||
542 | else | ||
543 | new_partials = partials(x) * y * ($f)(v, y - 1) | ||
544 | end | ||
545 | return Dual{Tx}(expv, new_partials) | ||
546 | end, | ||
547 | begin | ||
548 | v = value(y) | ||
549 | expv = ($f)(x, v) | ||
550 | deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x) | ||
551 | return Dual{Ty}(expv, deriv * partials(y)) | ||
552 | end | ||
553 | ) | ||
554 | end | ||
555 | end | ||
556 | |||
557 | @inline Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{0}) where {T} = | ||
558 | Dual{T}(one(value(x)), zero(partials(x))) | ||
559 | |||
560 | for y in 1:3 | ||
561 | @eval @inline function Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{$y}) where {T} | ||
562 | v = value(x) | ||
563 | expv = v^$y | ||
564 | deriv = $y * v^$(y - 1) | ||
565 | 1 (0 %) |
1 (100 %)
samples spent calling
*
return Dual{T}(expv, deriv * partials(x))
|
|
566 | end | ||
567 | end | ||
568 | |||
569 | # hypot # | ||
570 | #-------# | ||
571 | |||
572 | @inline function calc_hypot(x, y, z, ::Type{T}) where T | ||
573 | vx = value(x) | ||
574 | vy = value(y) | ||
575 | vz = value(z) | ||
576 | h = hypot(vx, vy, vz) | ||
577 | p = (vx / h) * partials(x) + (vy / h) * partials(y) + (vz / h) * partials(z) | ||
578 | return Dual{T}(h, p) | ||
579 | end | ||
580 | |||
581 | @define_ternary_dual_op( | ||
582 | Base.hypot, | ||
583 | calc_hypot(x, y, z, Txyz), | ||
584 | calc_hypot(x, y, z, Txy), | ||
585 | calc_hypot(x, y, z, Txz), | ||
586 | calc_hypot(x, y, z, Tyz), | ||
587 | calc_hypot(x, y, z, Tx), | ||
588 | calc_hypot(x, y, z, Ty), | ||
589 | calc_hypot(x, y, z, Tz), | ||
590 | ) | ||
591 | |||
592 | # fma # | ||
593 | #-----# | ||
594 | |||
595 | @generated function calc_fma_xyz(x::Dual{T,<:Any,N}, | ||
596 | y::Dual{T,<:Any,N}, | ||
597 | z::Dual{T,<:Any,N}) where {T,N} | ||
598 | ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...) | ||
599 | return quote | ||
600 | $(Expr(:meta, :inline)) | ||
601 | v = fma(value(x), value(y), value(z)) | ||
602 | return Dual{T}(v, $ex) | ||
603 | end | ||
604 | end | ||
605 | |||
606 | @inline function calc_fma_xy(x::Dual{T}, y::Dual{T}, z::Real) where T | ||
607 | vx, vy = value(x), value(y) | ||
608 | result = fma(vx, vy, z) | ||
609 | return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx)) | ||
610 | end | ||
611 | |||
612 | @generated function calc_fma_xz(x::Dual{T,<:Any,N}, | ||
613 | y::Real, | ||
614 | z::Dual{T,<:Any,N}) where {T,N} | ||
615 | ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...) | ||
616 | return quote | ||
617 | $(Expr(:meta, :inline)) | ||
618 | v = fma(value(x), y, value(z)) | ||
619 | Dual{T}(v, $ex) | ||
620 | end | ||
621 | end | ||
622 | |||
623 | @define_ternary_dual_op( | ||
624 | Base.fma, | ||
625 | calc_fma_xyz(x, y, z), # xyz_body | ||
626 | calc_fma_xy(x, y, z), # xy_body | ||
627 | calc_fma_xz(x, y, z), # xz_body | ||
628 | Base.fma(y, x, z), # yz_body | ||
629 | Dual{Tx}(fma(value(x), y, z), partials(x) * y), # x_body | ||
630 | Base.fma(y, x, z), # y_body | ||
631 | Dual{Tz}(fma(x, y, value(z)), partials(z)) # z_body | ||
632 | ) | ||
633 | |||
634 | # muladd # | ||
635 | #--------# | ||
636 | |||
637 | @generated function calc_muladd_xyz(x::Dual{T,<:Any,N}, | ||
638 | y::Dual{T,<:Any,N}, | ||
639 | z::Dual{T,<:Any,N}) where {T,N} | ||
640 | ex = Expr(:tuple, [:(muladd(value(x), partials(y)[$i], muladd(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...) | ||
641 | return quote | ||
642 | $(Expr(:meta, :inline)) | ||
643 | v = muladd(value(x), value(y), value(z)) | ||
644 | return Dual{T}(v, $ex) | ||
645 | end | ||
646 | end | ||
647 | |||
648 | @inline function calc_muladd_xy(x::Dual{T}, y::Dual{T}, z::Real) where T | ||
649 | vx, vy = value(x), value(y) | ||
650 | result = muladd(vx, vy, z) | ||
651 | return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx)) | ||
652 | end | ||
653 | |||
654 | @generated function calc_muladd_xz(x::Dual{T,<:Any,N}, | ||
655 | y::Real, | ||
656 | z::Dual{T,<:Any,N}) where {T,N} | ||
657 | ex = Expr(:tuple, [:(muladd(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...) | ||
658 | return quote | ||
659 | $(Expr(:meta, :inline)) | ||
660 | v = muladd(value(x), y, value(z)) | ||
661 | Dual{T}(v, $ex) | ||
662 | end | ||
663 | end | ||
664 | |||
665 | @define_ternary_dual_op( | ||
666 | Base.muladd, | ||
667 | calc_muladd_xyz(x, y, z), # xyz_body | ||
668 | calc_muladd_xy(x, y, z), # xy_body | ||
669 | calc_muladd_xz(x, y, z), # xz_body | ||
670 | Base.muladd(y, x, z), # yz_body | ||
671 | Dual{Tx}(muladd(value(x), y, z), partials(x) * y), # x_body | ||
672 | Base.muladd(y, x, z), # y_body | ||
673 | Dual{Tz}(muladd(x, y, value(z)), partials(z)) # z_body | ||
674 | ) | ||
675 | |||
676 | # sin/cos # | ||
677 | #--------# | ||
678 | function Base.sin(d::Dual{T}) where T | ||
679 | s, c = sincos(value(d)) | ||
680 | return Dual{T}(s, c * partials(d)) | ||
681 | end | ||
682 | |||
683 | function Base.cos(d::Dual{T}) where T | ||
684 | s, c = sincos(value(d)) | ||
685 | return Dual{T}(c, -s * partials(d)) | ||
686 | end | ||
687 | |||
688 | @inline function Base.sincos(d::Dual{T}) where T | ||
689 | sd, cd = sincos(value(d)) | ||
690 | return (Dual{T}(sd, cd * partials(d)), Dual{T}(cd, -sd * partials(d))) | ||
691 | end | ||
692 | |||
693 | # sincospi # | ||
694 | #----------# | ||
695 | |||
696 | if VERSION >= v"1.6.0-DEV.292" | ||
697 | @inline function Base.sincospi(d::Dual{T}) where T | ||
698 | sd, cd = sincospi(value(d)) | ||
699 | return (Dual{T}(sd, cd * π * partials(d)), Dual{T}(cd, -sd * π * partials(d))) | ||
700 | end | ||
701 | end | ||
702 | |||
703 | # Symmetric eigvals # | ||
704 | #-------------------# | ||
705 | |||
706 | function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} | ||
707 | λ,Q = eigen(Symmetric(value.(parent(A)))) | ||
708 | parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N) | ||
709 | Dual{Tg}.(λ, tuple.(parts...)) | ||
710 | end | ||
711 | |||
712 | function LinearAlgebra.eigvals(A::Hermitian{<:Complex{<:Dual{Tg,T,N}}}) where {Tg,T<:Real,N} | ||
713 | λ,Q = eigen(Hermitian(value.(real.(parent(A))) .+ im .* value.(imag.(parent(A))))) | ||
714 | parts = ntuple(j -> diag(real.(Q' * (getindex.(partials.(real.(A)) .+ im .* partials.(imag.(A)), j)) * Q)), N) | ||
715 | Dual{Tg}.(λ, tuple.(parts...)) | ||
716 | end | ||
717 | |||
718 | function LinearAlgebra.eigvals(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} | ||
719 | λ,Q = eigen(SymTridiagonal(value.(parent(A).dv),value.(parent(A).ev))) | ||
720 | parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N) | ||
721 | Dual{Tg}.(λ, tuple.(parts...)) | ||
722 | end | ||
723 | |||
724 | # A ./ (λ - λ') but with diag special cased | ||
725 | function _lyap_div!(A, λ) | ||
726 | for (j,μ) in enumerate(λ), (k,λ) in enumerate(λ) | ||
727 | if k ≠ j | ||
728 | A[k,j] /= μ - λ | ||
729 | end | ||
730 | end | ||
731 | A | ||
732 | end | ||
733 | |||
734 | function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} | ||
735 | λ = eigvals(A) | ||
736 | _,Q = eigen(Symmetric(value.(parent(A)))) | ||
737 | parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N) | ||
738 | Eigen(λ,Dual{Tg}.(Q, tuple.(parts...))) | ||
739 | end | ||
740 | |||
741 | function LinearAlgebra.eigen(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} | ||
742 | λ = eigvals(A) | ||
743 | _,Q = eigen(SymTridiagonal(value.(parent(A)))) | ||
744 | parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N) | ||
745 | Eigen(λ,Dual{Tg}.(Q, tuple.(parts...))) | ||
746 | end | ||
747 | |||
748 | # Functions in SpecialFunctions which return tuples # | ||
749 | # Their derivatives are not defined in DiffRules # | ||
750 | #---------------------------------------------------# | ||
751 | |||
752 | function SpecialFunctions.logabsgamma(d::Dual{T,<:Real}) where {T} | ||
753 | x = value(d) | ||
754 | y, s = SpecialFunctions.logabsgamma(x) | ||
755 | return (Dual{T}(y, SpecialFunctions.digamma(x) * partials(d)), s) | ||
756 | end | ||
757 | |||
758 | # Derivatives wrt to first parameter and precision setting are not supported | ||
759 | function SpecialFunctions.gamma_inc(a::Real, d::Dual{T,<:Real}, ind::Integer) where {T} | ||
760 | x = value(d) | ||
761 | p, q = SpecialFunctions.gamma_inc(a, x, ind) | ||
762 | ∂p = exp(-x) * x^(a - 1) / SpecialFunctions.gamma(a) * partials(d) | ||
763 | return (Dual{T}(p, ∂p), Dual{T}(q, -∂p)) | ||
764 | end | ||
765 | |||
766 | ################### | ||
767 | # Pretty Printing # | ||
768 | ################### | ||
769 | |||
770 | function Base.show(io::IO, d::Dual{T,V,N}) where {T,V,N} | ||
771 | print(io, "Dual{$(repr(T))}(", value(d)) | ||
772 | for i in 1:N | ||
773 | print(io, ",", partials(d, i)) | ||
774 | end | ||
775 | print(io, ")") | ||
776 | end | ||
777 | |||
778 | for op in (:(Base.typemin), :(Base.typemax), :(Base.floatmin), :(Base.floatmax)) | ||
779 | @eval function $op(::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} | ||
780 | ForwardDiff.Dual{T,V,N}($op(V)) | ||
781 | end | ||
782 | end | ||
783 | |||
784 | if VERSION >= v"1.6.0-rc1" | ||
785 | Printf.tofloat(d::Dual) = Printf.tofloat(value(d)) | ||
786 | end |