StatProfilerHTML.jl report
Generated on Thu, 21 Dec 2023 13:06:16
File source code
Line Exclusive Inclusive Code
1 ########
2 # Dual #
3 ########
4
5 """
6 ForwardDiff.can_dual(V::Type)
7
8 Determines whether the type V is allowed as the scalar type in a
9 Dual. By default, only `<:Real` types are allowed.
10 """
11 can_dual(::Type{<:Real}) = true
12 can_dual(::Type) = false
13
14 struct Dual{T,V,N} <: Real
15 value::V
16 partials::Partials{N,V}
17 function Dual{T, V, N}(value::V, partials::Partials{N, V}) where {T, V, N}
18 can_dual(V) || throw_cannot_dual(V)
19 new{T, V, N}(value, partials)
20 end
21 end
22
23 ##########
24 # Traits #
25 ##########
26 Base.ArithmeticStyle(::Type{<:Dual{T,V}}) where {T,V} = Base.ArithmeticStyle(V)
27
28 ##############
29 # Exceptions #
30 ##############
31
32 struct DualMismatchError{A,B} <: Exception
33 a::A
34 b::B
35 end
36
37 Base.showerror(io::IO, e::DualMismatchError{A,B}) where {A,B} =
38 print(io, "Cannot determine ordering of Dual tags $(e.a) and $(e.b)")
39
40 @noinline function throw_cannot_dual(V::Type)
41 throw(ArgumentError("Cannot create a dual over scalar type $V." *
42 " If the type behaves as a scalar, define ForwardDiff.can_dual(::Type{$V}) = true."))
43 end
44
45 """
46 ForwardDiff.≺(a, b)::Bool
47
48 Determines the order in which tagged `Dual` objects are composed. If true, then `Dual{b}`
49 objects will appear outside `Dual{a}` objects.
50
51 This is important when working with nested differentiation: currently, only the outermost
52 tag can be extracted, so it should be used in the _innermost_ function.
53 """
54 ≺(a,b) = throw(DualMismatchError(a,b))
55
56 ################
57 # Constructors #
58 ################
59
60 @inline Dual{T}(value::V, partials::Partials{N,V}) where {T,N,V} = Dual{T,V,N}(value, partials)
61
62 @inline function Dual{T}(value::A, partials::Partials{N,B}) where {T,N,A,B}
63 C = promote_type(A, B)
64 return Dual{T}(convert(C, value), convert(Partials{N,C}, partials))
65 end
66
67 @inline Dual{T}(value, partials::Tuple) where {T} = Dual{T}(value, Partials(partials))
68 @inline Dual{T}(value, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials))
69 @inline Dual{T}(value) where {T} = Dual{T}(value, ())
70 @inline Dual{T}(x::Dual{T}) where {T} = Dual{T}(x, ())
71 @inline Dual{T}(value, partial1, partials...) where {T} = Dual{T}(value, tuple(partial1, partials...))
72 @inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p))
73 @inline Dual(args...) = Dual{Nothing}(args...)
74
75 # we define these special cases so that the "constructor <--> convert" pun holds for `Dual`
76 @inline Dual{T,V,N}(x::Dual{T,V,N}) where {T,V,N} = x
77 @inline Dual{T,V,N}(x) where {T,V,N} = convert(Dual{T,V,N}, x)
78 @inline Dual{T,V,N}(x::Number) where {T,V,N} = convert(Dual{T,V,N}, x)
79 @inline Dual{T,V}(x) where {T,V} = convert(Dual{T,V}, x)
80
81 ##############################
82 # Utility/Accessor Functions #
83 ##############################
84
85 @inline value(x) = x
86 @inline value(d::Dual) = d.value
87
88 @inline value(::Type{T}, x) where T = x
89 @inline value(::Type{T}, d::Dual{T}) where T = value(d)
90 @inline function value(::Type{T}, d::Dual{S}) where {T,S}
91 if S ≺ T
92 d
93 else
94 throw(DualMismatchError(T,S))
95 end
96 end
97
98 @inline partials(x) = Partials{0,typeof(x)}(tuple())
99 @inline partials(d::Dual) = d.partials
100 @inline partials(x, i...) = zero(x)
101 @inline Base.@propagate_inbounds partials(d::Dual, i) = d.partials[i]
102 @inline Base.@propagate_inbounds partials(d::Dual, i, j) = partials(d, i).partials[j]
103 @inline Base.@propagate_inbounds partials(d::Dual, i, j, k...) = partials(partials(d, i, j), k...)
104
105 @inline Base.@propagate_inbounds partials(::Type{T}, x, i...) where T = partials(x, i...)
106 @inline Base.@propagate_inbounds partials(::Type{T}, d::Dual{T}, i...) where T = partials(d, i...)
107 @inline function partials(::Type{T}, d::Dual{S}, i...) where {T,S}
108 if S ≺ T
109 zero(d)
110 else
111 throw(DualMismatchError(T,S))
112 end
113 end
114
115
116 @inline npartials(::Dual{T,V,N}) where {T,V,N} = N
117 @inline npartials(::Type{Dual{T,V,N}}) where {T,V,N} = N
118
119 @inline order(::Type{V}) where {V} = 0
120 @inline order(::Type{Dual{T,V,N}}) where {T,V,N} = 1 + order(V)
121
122 @inline valtype(::V) where {V} = V
123 @inline valtype(::Type{V}) where {V} = V
124 @inline valtype(::Dual{T,V,N}) where {T,V,N} = V
125 @inline valtype(::Type{Dual{T,V,N}}) where {T,V,N} = V
126
127 @inline tagtype(::V) where {V} = Nothing
128 @inline tagtype(::Type{V}) where {V} = Nothing
129 @inline tagtype(::Dual{T,V,N}) where {T,V,N} = T
130 @inline tagtype(::Type{Dual{T,V,N}}) where {T,V,N} = T
131
132 ####################################
133 # N-ary Operation Definition Tools #
134 ####################################
135
136 macro define_binary_dual_op(f, xy_body, x_body, y_body)
137 FD = ForwardDiff
138 defs = quote
139 @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}) where {Txy} = $xy_body
140 @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}) where {Tx,Ty} = Ty ≺ Tx ? $x_body : $y_body
141 end
142 for R in AMBIGUOUS_TYPES
143 expr = quote
144 @inline $(f)(x::$FD.Dual{Tx}, y::$R) where {Tx} = $x_body
145 @inline $(f)(x::$R, y::$FD.Dual{Ty}) where {Ty} = $y_body
146 end
147 append!(defs.args, expr.args)
148 end
149 return esc(defs)
150 end
151
152 macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body)
153 FD = ForwardDiff
154 defs = quote
155 @inline $(f)(x::$FD.Dual{Txyz}, y::$FD.Dual{Txyz}, z::$FD.Dual{Txyz}) where {Txyz} = $xyz_body
156 @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$FD.Dual{Tz}) where {Txy,Tz} = Tz ≺ Txy ? $xy_body : $z_body
157 @inline $(f)(x::$FD.Dual{Txz}, y::$FD.Dual{Ty}, z::$FD.Dual{Txz}) where {Txz,Ty} = Ty ≺ Txz ? $xz_body : $y_body
158 @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tx,Tyz} = Tyz ≺ Tx ? $x_body : $yz_body
159 @inline function $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Tx,Ty,Tz}
160 if Tz ≺ Tx && Ty ≺ Tx
161 $x_body
162 elseif Tz ≺ Ty
163 $y_body
164 else
165 $z_body
166 end
167 end
168 end
169 for R in AMBIGUOUS_TYPES
170 expr = quote
171 @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body
172 @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body
173 @inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body
174 @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz ≺ Tx ? $x_body : $z_body
175 @inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body
176 @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body
177 end
178 append!(defs.args, expr.args)
179 for Q in AMBIGUOUS_TYPES
180 Q === R && continue
181 expr = quote
182 @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body
183 @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body
184 @inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body
185 end
186 append!(defs.args, expr.args)
187 end
188 expr = quote
189 @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body
190 @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body
191 @inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body
192 end
193 append!(defs.args, expr.args)
194 end
195 return esc(defs)
196 end
197
198 # Support complex-valued functions such as `hankelh1`
199 function dual_definition_retval(::Val{T}, val::Real, deriv::Real, partial::Partials) where {T}
200 1 (0 %)
1 (100 %) samples spent calling *
return Dual{T}(val, deriv * partial)
201 end
202
17 (6 %) samples spent in dual_definition_retval
1 (6 %) (ex.), 16 (94 %) (incl.) when called from * line 271
1 (6 %) (incl.) when called from * line 281
function dual_definition_retval(::Val{T}, val::Real, deriv1::Real, partial1::Partials, deriv2::Real, partial2::Partials) where {T}
203 1 (0 %) 16 (6 %)
15 (100 %) samples spent calling _mul_partials
return Dual{T}(val, _mul_partials(partial1, partial2, deriv1, deriv2))
204 end
205 function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Complex}, partial::Partials) where {T}
206 reval, imval = reim(val)
207 if deriv isa Real
208 p = deriv * partial
209 return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p)))
210 else
211 rederiv, imderiv = reim(deriv)
212 return Complex(Dual{T}(reval, rederiv * partial), Dual{T}(imval, imderiv * partial))
213 end
214 end
215 function dual_definition_retval(::Val{T}, val::Complex, deriv1::Union{Real,Complex}, partial1::Partials, deriv2::Union{Real,Complex}, partial2::Partials) where {T}
216 reval, imval = reim(val)
217 if deriv1 isa Real && deriv2 isa Real
218 p = _mul_partials(partial1, partial2, deriv1, deriv2)
219 return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p)))
220 else
221 rederiv1, imderiv1 = reim(deriv1)
222 rederiv2, imderiv2 = reim(deriv2)
223 return Complex(
224 Dual{T}(reval, _mul_partials(partial1, partial2, rederiv1, rederiv2)),
225 Dual{T}(imval, _mul_partials(partial1, partial2, imderiv1, imderiv2)),
226 )
227 end
228 end
229
230 function unary_dual_definition(M, f)
231 FD = ForwardDiff
232 Mf = M == :Base ? f : :($M.$f)
233 work = qualified_cse!(quote
234 val = $Mf(x)
235 deriv = $(DiffRules.diffrule(M, f, :x))
236 end)
237 return quote
238 @inline function $M.$f(d::$FD.Dual{T}) where T
239 x = $FD.value(d)
240 $work
241 return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d))
242 end
243 end
244 end
245
246 function binary_dual_definition(M, f)
247 FD = ForwardDiff
248 dvx, dvy = DiffRules.diffrule(M, f, :vx, :vy)
249 Mf = M == :Base ? f : :($M.$f)
250 xy_work = qualified_cse!(quote
251 val = $Mf(vx, vy)
252 dvx = $dvx
253 dvy = $dvy
254 end)
255 dvx, _ = DiffRules.diffrule(M, f, :vx, :y)
256 x_work = qualified_cse!(quote
257 val = $Mf(vx, y)
258 dvx = $dvx
259 end)
260 _, dvy = DiffRules.diffrule(M, f, :x, :vy)
261 y_work = qualified_cse!(quote
262 val = $Mf(x, vy)
263 dvy = $dvy
264 end)
265 expr = quote
266 $FD.@define_binary_dual_op(
267 $M.$f,
268 begin
269 vx, vy = $FD.value(x), $FD.value(y)
270 $xy_work
271 1 (0 %) 17 (6 %)
17 (6 %) samples spent in *
1 (6 %) (ex.), 9 (53 %) (incl.) when called from brusselator_2d_loop line 49
8 (47 %) (incl.) when called from brusselator_2d_loop line 54
16 (100 %) samples spent calling dual_definition_retval
return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y))
272 end,
273 begin
274 vx = $FD.value(x)
275 $x_work
276 return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x))
277 end,
278 begin
279 vy = $FD.value(y)
280 $y_work
281 1 (0 %)
1 (0 %) samples spent in *
1 (100 %) (incl.) when called from brusselator_2d_loop line 49
1 (100 %) samples spent calling dual_definition_retval
return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y))
282 end
283 )
284 end
285 return expr
286 end
287
288 #####################
289 # Generic Functions #
290 #####################
291
292 Base.copy(d::Dual) = d
293
294 Base.eps(d::Dual) = eps(value(d))
295 Base.eps(::Type{D}) where {D<:Dual} = eps(valtype(D))
296
297 # The `base` keyword was added in Julia 1.8:
298 # https://github.com/JuliaLang/julia/pull/42428
299 if VERSION < v"1.8.0-DEV.725"
300 Base.precision(d::Dual) = precision(value(d))
301 Base.precision(::Type{D}) where {D<:Dual} = precision(valtype(D))
302 else
303 Base.precision(d::Dual; base::Integer=2) = precision(value(d); base=base)
304 function Base.precision(::Type{D}; base::Integer=2) where {D<:Dual}
305 precision(valtype(D); base=base)
306 end
307 end
308
309 function Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N}
310 ForwardDiff.Dual{T}(nextfloat(d.value), d.partials)
311 end
312
313 function Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N}
314 ForwardDiff.Dual{T}(prevfloat(d.value), d.partials)
315 end
316
317 Base.rtoldefault(::Type{D}) where {D<:Dual} = Base.rtoldefault(valtype(D))
318
319 Base.floor(::Type{R}, d::Dual) where {R<:Real} = floor(R, value(d))
320 Base.floor(d::Dual) = floor(value(d))
321
322 Base.ceil(::Type{R}, d::Dual) where {R<:Real} = ceil(R, value(d))
323 Base.ceil(d::Dual) = ceil(value(d))
324
325 Base.trunc(::Type{R}, d::Dual) where {R<:Real} = trunc(R, value(d))
326 Base.trunc(d::Dual) = trunc(value(d))
327
328 Base.round(::Type{R}, d::Dual) where {R<:Real} = round(R, value(d))
329 Base.round(d::Dual) = round(value(d))
330
331 Base.fld(x::Dual, y::Dual) = fld(value(x), value(y))
332
333 Base.cld(x::Dual, y::Dual) = cld(value(x), value(y))
334
335 Base.exponent(x::Dual) = exponent(value(x))
336
337 if VERSION ≥ v"1.4"
338 Base.div(x::Dual, y::Dual, r::RoundingMode) = div(value(x), value(y), r)
339 else
340 Base.div(x::Dual, y::Dual) = div(value(x), value(y))
341 end
342
343 Base.hash(d::Dual) = hash(value(d))
344 Base.hash(d::Dual, hsh::UInt) = hash(value(d), hsh)
345
346 function Base.read(io::IO, ::Type{Dual{T,V,N}}) where {T,V,N}
347 value = read(io, V)
348 partials = read(io, Partials{N,V})
349 return Dual{T,V,N}(value, partials)
350 end
351
352 function Base.write(io::IO, d::Dual)
353 write(io, value(d))
354 write(io, partials(d))
355 end
356
357 @inline Base.zero(d::Dual) = zero(typeof(d))
358 @inline Base.zero(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(zero(V), zero(Partials{N,V}))
359
360 @inline Base.one(d::Dual) = one(typeof(d))
361 @inline Base.one(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(one(V), zero(Partials{N,V}))
362
363 @inline function Base.Int(d::Dual)
364 all(iszero, partials(d)) || throw(InexactError(:Int, Int, d))
365 Int(value(d))
366 end
367 @inline function Base.Integer(d::Dual)
368 all(iszero, partials(d)) || throw(InexactError(:Integer, Integer, d))
369 Integer(value(d))
370 end
371
372 @inline Random.rand(rng::AbstractRNG, d::Dual) = rand(rng, value(d))
373 @inline Random.rand(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(rand(V), zero(Partials{N,V}))
374 @inline Random.rand(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(rand(rng, V), zero(Partials{N,V}))
375 @inline Random.randn(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randn(V), zero(Partials{N,V}))
376 @inline Random.randn(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randn(rng, V), zero(Partials{N,V}))
377 @inline Random.randexp(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randexp(V), zero(Partials{N,V}))
378 @inline Random.randexp(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randexp(rng, V), zero(Partials{N,V}))
379
380 # Predicates #
381 #------------#
382
383 isconstant(d::Dual) = iszero(partials(d))
384
385 for pred in UNARY_PREDICATES
386 @eval Base.$(pred)(d::Dual) = $(pred)(value(d))
387 end
388
389 for pred in BINARY_PREDICATES
390 @eval begin
391 @define_binary_dual_op(
392 Base.$(pred),
393 $(pred)(value(x), value(y)),
394 $(pred)(value(x), y),
395 $(pred)(x, value(y))
396 )
397 end
398 end
399
400 ########################
401 # Promotion/Conversion #
402 ########################
403
404 function Base.promote_rule(::Type{Dual{T1,V1,N1}},
405 ::Type{Dual{T2,V2,N2}}) where {T1,V1,N1,T2,V2,N2}
406 # V1 and V2 might themselves be Dual types
407 if T2 ≺ T1
408 Dual{T1,promote_type(V1,Dual{T2,V2,N2}),N1}
409 else
410 Dual{T2,promote_type(V2,Dual{T1,V1,N1}),N2}
411 end
412 end
413
414 function Base.promote_rule(::Type{Dual{T,A,N}},
415 ::Type{Dual{T,B,N}}) where {T,A,B,N}
416 return Dual{T,promote_type(A, B),N}
417 end
418
419 for R in (Irrational, Real, BigFloat, Bool)
420 if isconcretetype(R) # issue #322
421 @eval begin
422 Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N}
423 Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V,N} = Dual{T,promote_type(V, $R),N}
424 end
425 else
426 @eval begin
427 Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N}
428 Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V,N,R<:$R} = Dual{T,promote_type(V, R),N}
429 end
430 end
431 end
432
433 @inline Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V,N} = Dual{T}(V(value(d)), convert(Partials{N,V}, partials(d)))
434 @inline Base.convert(::Type{Dual{T,Dual{T,V,M},N}}, d::Dual{T,V,M}) where {T,V,N,M} = Dual{T}(d, Partials{N,Dual{T,V,M}}(zero_tuple(NTuple{N,Dual{T,V,M}})))
435 @inline Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(V(x), zero(Partials{N,V}))
436 @inline Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(V(x), zero(Partials{N,V}))
437 Base.convert(::Type{D}, d::D) where {D<:Dual} = d
438
439 Base.float(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,float(V),N}
440 Base.float(d::Dual) = convert(float(typeof(d)), d)
441
442 ###################################
443 # General Mathematical Operations #
444 ###################################
445
446 for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
447 if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-), (:Base, :sin), (:Base, :cos))
448 continue # Skip methods which we define elsewhere.
449 elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
450 continue # Skip rules for methods not defined in the current scope
451 end
452 if arity == 1
453 eval(unary_dual_definition(M, f))
454 elseif arity == 2
455 eval(binary_dual_definition(M, f))
456 else
457 # error("ForwardDiff currently only knows how to autogenerate Dual definitions for unary and binary functions.")
458 # However, the presence of N-ary rules need not cause any problems here, they can simply be ignored.
459 end
460 end
461
462 #################
463 # Special Cases #
464 #################
465
466 # +/- #
467 #-----#
468
469 @define_binary_dual_op(
470 Base.:+,
471 begin
472 vx, vy = value(x), value(y)
473 3 (1 %)
3 (1 %) samples spent in +
2 (67 %) (incl.) when called from + line 587
1 (33 %) (incl.) when called from brusselator_2d_loop line 54
2 (67 %) samples spent calling +
1 (33 %) samples spent calling +
Dual{Txy}(vx + vy, partials(x) + partials(y))
474 end,
475 Dual{Tx}(value(x) + y, partials(x)),
476 Dual{Ty}(x + value(y), partials(y))
477 )
478
479 @define_binary_dual_op(
480 Base.:-,
481 begin
482 vx, vy = value(x), value(y)
483 4 (1 %)
4 (1 %) samples spent in -
4 (100 %) (incl.) when called from brusselator_2d_loop line 54
4 (100 %) samples spent calling -
Dual{Txy}(vx - vy, partials(x) - partials(y))
484 end,
485 Dual{Tx}(value(x) - y, partials(x)),
486 Dual{Ty}(x - value(y), -partials(y))
487 )
488
489 @inline Base.:-(d::Dual{T}) where {T} = Dual{T}(-value(d), -partials(d))
490
491 # * #
492 #---#
493
494 @inline Base.:*(d::Dual, x::Bool) = x ? d : (signbit(value(d))==0 ? zero(d) : -zero(d))
495 @inline Base.:*(x::Bool, d::Dual) = d * x
496
497 # / #
498 #---#
499
500 # We can't use the normal diffrule autogeneration for this because (x/y) === (x * (1/y))
501 # doesn't generally hold true for floating point; see issue #264
502 @define_binary_dual_op(
503 Base.:/,
504 begin
505 vx, vy = value(x), value(y)
506 Dual{Txy}(vx / vy, _div_partials(partials(x), partials(y), vx, vy))
507 end,
508 Dual{Tx}(value(x) / y, partials(x) / y),
509 begin
510 v = value(y)
511 divv = x / v
512 Dual{Ty}(divv, -(divv / v) * partials(y))
513 end
514 )
515
516 # exponentiation #
517 #----------------#
518
519 for f in (:(Base.:^), :(NaNMath.pow))
520 @eval begin
521 @define_binary_dual_op(
522 $f,
523 begin
524 vx, vy = value(x), value(y)
525 expv = ($f)(vx, vy)
526 powval = vy * ($f)(vx, vy - 1)
527 if isconstant(y)
528 logval = one(expv)
529 elseif iszero(vx) && vy > 0
530 logval = zero(vx)
531 else
532 logval = expv * log(vx)
533 end
534 new_partials = _mul_partials(partials(x), partials(y), powval, logval)
535 return Dual{Txy}(expv, new_partials)
536 end,
537 begin
538 v = value(x)
539 expv = ($f)(v, y)
540 if y == zero(y) || iszero(partials(x))
541 new_partials = zero(partials(x))
542 else
543 new_partials = partials(x) * y * ($f)(v, y - 1)
544 end
545 return Dual{Tx}(expv, new_partials)
546 end,
547 begin
548 v = value(y)
549 expv = ($f)(x, v)
550 deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
551 return Dual{Ty}(expv, deriv * partials(y))
552 end
553 )
554 end
555 end
556
557 @inline Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{0}) where {T} =
558 Dual{T}(one(value(x)), zero(partials(x)))
559
560 for y in 1:3
561 @eval @inline function Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{$y}) where {T}
562 v = value(x)
563 expv = v^$y
564 deriv = $y * v^$(y - 1)
565 1 (0 %)
1 (0 %) samples spent in literal_pow
1 (100 %) (incl.) when called from brusselator_2d_loop line 49
1 (100 %) samples spent calling *
return Dual{T}(expv, deriv * partials(x))
566 end
567 end
568
569 # hypot #
570 #-------#
571
572 @inline function calc_hypot(x, y, z, ::Type{T}) where T
573 vx = value(x)
574 vy = value(y)
575 vz = value(z)
576 h = hypot(vx, vy, vz)
577 p = (vx / h) * partials(x) + (vy / h) * partials(y) + (vz / h) * partials(z)
578 return Dual{T}(h, p)
579 end
580
581 @define_ternary_dual_op(
582 Base.hypot,
583 calc_hypot(x, y, z, Txyz),
584 calc_hypot(x, y, z, Txy),
585 calc_hypot(x, y, z, Txz),
586 calc_hypot(x, y, z, Tyz),
587 calc_hypot(x, y, z, Tx),
588 calc_hypot(x, y, z, Ty),
589 calc_hypot(x, y, z, Tz),
590 )
591
592 # fma #
593 #-----#
594
595 @generated function calc_fma_xyz(x::Dual{T,<:Any,N},
596 y::Dual{T,<:Any,N},
597 z::Dual{T,<:Any,N}) where {T,N}
598 ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
599 return quote
600 $(Expr(:meta, :inline))
601 v = fma(value(x), value(y), value(z))
602 return Dual{T}(v, $ex)
603 end
604 end
605
606 @inline function calc_fma_xy(x::Dual{T}, y::Dual{T}, z::Real) where T
607 vx, vy = value(x), value(y)
608 result = fma(vx, vy, z)
609 return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx))
610 end
611
612 @generated function calc_fma_xz(x::Dual{T,<:Any,N},
613 y::Real,
614 z::Dual{T,<:Any,N}) where {T,N}
615 ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...)
616 return quote
617 $(Expr(:meta, :inline))
618 v = fma(value(x), y, value(z))
619 Dual{T}(v, $ex)
620 end
621 end
622
623 @define_ternary_dual_op(
624 Base.fma,
625 calc_fma_xyz(x, y, z), # xyz_body
626 calc_fma_xy(x, y, z), # xy_body
627 calc_fma_xz(x, y, z), # xz_body
628 Base.fma(y, x, z), # yz_body
629 Dual{Tx}(fma(value(x), y, z), partials(x) * y), # x_body
630 Base.fma(y, x, z), # y_body
631 Dual{Tz}(fma(x, y, value(z)), partials(z)) # z_body
632 )
633
634 # muladd #
635 #--------#
636
637 @generated function calc_muladd_xyz(x::Dual{T,<:Any,N},
638 y::Dual{T,<:Any,N},
639 z::Dual{T,<:Any,N}) where {T,N}
640 ex = Expr(:tuple, [:(muladd(value(x), partials(y)[$i], muladd(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
641 return quote
642 $(Expr(:meta, :inline))
643 v = muladd(value(x), value(y), value(z))
644 return Dual{T}(v, $ex)
645 end
646 end
647
648 @inline function calc_muladd_xy(x::Dual{T}, y::Dual{T}, z::Real) where T
649 vx, vy = value(x), value(y)
650 result = muladd(vx, vy, z)
651 return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx))
652 end
653
654 @generated function calc_muladd_xz(x::Dual{T,<:Any,N},
655 y::Real,
656 z::Dual{T,<:Any,N}) where {T,N}
657 ex = Expr(:tuple, [:(muladd(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...)
658 return quote
659 $(Expr(:meta, :inline))
660 v = muladd(value(x), y, value(z))
661 Dual{T}(v, $ex)
662 end
663 end
664
665 @define_ternary_dual_op(
666 Base.muladd,
667 calc_muladd_xyz(x, y, z), # xyz_body
668 calc_muladd_xy(x, y, z), # xy_body
669 calc_muladd_xz(x, y, z), # xz_body
670 Base.muladd(y, x, z), # yz_body
671 Dual{Tx}(muladd(value(x), y, z), partials(x) * y), # x_body
672 Base.muladd(y, x, z), # y_body
673 Dual{Tz}(muladd(x, y, value(z)), partials(z)) # z_body
674 )
675
676 # sin/cos #
677 #--------#
678 function Base.sin(d::Dual{T}) where T
679 s, c = sincos(value(d))
680 return Dual{T}(s, c * partials(d))
681 end
682
683 function Base.cos(d::Dual{T}) where T
684 s, c = sincos(value(d))
685 return Dual{T}(c, -s * partials(d))
686 end
687
688 @inline function Base.sincos(d::Dual{T}) where T
689 sd, cd = sincos(value(d))
690 return (Dual{T}(sd, cd * partials(d)), Dual{T}(cd, -sd * partials(d)))
691 end
692
693 # sincospi #
694 #----------#
695
696 if VERSION >= v"1.6.0-DEV.292"
697 @inline function Base.sincospi(d::Dual{T}) where T
698 sd, cd = sincospi(value(d))
699 return (Dual{T}(sd, cd * π * partials(d)), Dual{T}(cd, -sd * π * partials(d)))
700 end
701 end
702
703 # Symmetric eigvals #
704 #-------------------#
705
706 function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
707 λ,Q = eigen(Symmetric(value.(parent(A))))
708 parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
709 Dual{Tg}.(λ, tuple.(parts...))
710 end
711
712 function LinearAlgebra.eigvals(A::Hermitian{<:Complex{<:Dual{Tg,T,N}}}) where {Tg,T<:Real,N}
713 λ,Q = eigen(Hermitian(value.(real.(parent(A))) .+ im .* value.(imag.(parent(A)))))
714 parts = ntuple(j -> diag(real.(Q' * (getindex.(partials.(real.(A)) .+ im .* partials.(imag.(A)), j)) * Q)), N)
715 Dual{Tg}.(λ, tuple.(parts...))
716 end
717
718 function LinearAlgebra.eigvals(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
719 λ,Q = eigen(SymTridiagonal(value.(parent(A).dv),value.(parent(A).ev)))
720 parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
721 Dual{Tg}.(λ, tuple.(parts...))
722 end
723
724 # A ./ (λ - λ') but with diag special cased
725 function _lyap_div!(A, λ)
726 for (j,μ) in enumerate(λ), (k,λ) in enumerate(λ)
727 if k ≠ j
728 A[k,j] /= μ - λ
729 end
730 end
731 A
732 end
733
734 function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
735 λ = eigvals(A)
736 _,Q = eigen(Symmetric(value.(parent(A))))
737 parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
738 Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
739 end
740
741 function LinearAlgebra.eigen(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
742 λ = eigvals(A)
743 _,Q = eigen(SymTridiagonal(value.(parent(A))))
744 parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
745 Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
746 end
747
748 # Functions in SpecialFunctions which return tuples #
749 # Their derivatives are not defined in DiffRules #
750 #---------------------------------------------------#
751
752 function SpecialFunctions.logabsgamma(d::Dual{T,<:Real}) where {T}
753 x = value(d)
754 y, s = SpecialFunctions.logabsgamma(x)
755 return (Dual{T}(y, SpecialFunctions.digamma(x) * partials(d)), s)
756 end
757
758 # Derivatives wrt to first parameter and precision setting are not supported
759 function SpecialFunctions.gamma_inc(a::Real, d::Dual{T,<:Real}, ind::Integer) where {T}
760 x = value(d)
761 p, q = SpecialFunctions.gamma_inc(a, x, ind)
762 ∂p = exp(-x) * x^(a - 1) / SpecialFunctions.gamma(a) * partials(d)
763 return (Dual{T}(p, ∂p), Dual{T}(q, -∂p))
764 end
765
766 ###################
767 # Pretty Printing #
768 ###################
769
770 function Base.show(io::IO, d::Dual{T,V,N}) where {T,V,N}
771 print(io, "Dual{$(repr(T))}(", value(d))
772 for i in 1:N
773 print(io, ",", partials(d, i))
774 end
775 print(io, ")")
776 end
777
778 for op in (:(Base.typemin), :(Base.typemax), :(Base.floatmin), :(Base.floatmax))
779 @eval function $op(::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
780 ForwardDiff.Dual{T,V,N}($op(V))
781 end
782 end
783
784 if VERSION >= v"1.6.0-rc1"
785 Printf.tofloat(d::Dual) = Printf.tofloat(value(d))
786 end