-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add realdot
and imagdot
from ChainRules
#474
Changes from 13 commits
537fbad
4b08ee2
ace93fb
0d2269e
b1a89ec
3d05206
d6c6c2a
d9d046f
1747c42
630355e
ff5e71e
9c3cab2
5155c75
5fb7ff2
a042dd0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,34 @@ | ||||||||||||
""" | ||||||||||||
realdot(x, y) | ||||||||||||
|
||||||||||||
Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible. | ||||||||||||
|
||||||||||||
This function can be useful if you implement a `rrule` for a non-holomorphic function | ||||||||||||
on complex numbers. | ||||||||||||
|
||||||||||||
See also: [`imagdot`](@ref) | ||||||||||||
""" | ||||||||||||
@inline realdot(x, y) = real(dot(x, y)) | ||||||||||||
@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) | ||||||||||||
@inline realdot(x::Real, y::Complex) = x * real(y) | ||||||||||||
@inline realdot(x::Complex, y::Real) = real(x) * y | ||||||||||||
@inline realdot(x::Real, y::Real) = x * y | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be safe to generalize the cases, and then the last one is redundant since
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this introduce an ambiguity, when both are real? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's right. |
||||||||||||
|
||||||||||||
""" | ||||||||||||
imagdot(x, y) | ||||||||||||
|
||||||||||||
Compute `imag(dot(x, y))` while avoiding computing the real part if possible. | ||||||||||||
|
||||||||||||
This function can be useful if you implement a `rrule` for a non-holomorphic function | ||||||||||||
on complex numbers. | ||||||||||||
|
||||||||||||
See also: [`realdot`](@ref) | ||||||||||||
""" | ||||||||||||
@inline imagdot(x, y) = imag(dot(x, y)) | ||||||||||||
@inline function imagdot(x::Complex, y::Complex) | ||||||||||||
return muladd(-imag(x), real(y), real(x) * imag(y)) | ||||||||||||
end | ||||||||||||
@inline imagdot(x::Real, y::Complex) = x * imag(y) | ||||||||||||
@inline imagdot(x::Complex, y::Real) = -imag(x) * y | ||||||||||||
@inline imagdot(x::Real, y::Real) = ZeroTangent() | ||||||||||||
@inline imagdot(x::AbstractArray{<:Real}, y::AbstractArray{<:Real}) = ZeroTangent() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# struct need to be defined outside of tests for julia 1.0 compat | ||
# custom complex number to test fallback definition | ||
struct CustomComplex{T} | ||
re::T | ||
im::T | ||
end | ||
|
||
Base.real(x::CustomComplex) = x.re | ||
Base.imag(x::CustomComplex) = x.im | ||
|
||
function LinearAlgebra.dot(a::CustomComplex, b::Number) | ||
return CustomComplex(reim((a.re - a.im * im) * b)...) | ||
end | ||
function LinearAlgebra.dot(a::Number, b::CustomComplex) | ||
return CustomComplex(reim(conj(a) * (b.re + b.im * im))...) | ||
end | ||
function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex) | ||
return CustomComplex(reim((a.re - a.im * im) * (b.re + b.im * im))...) | ||
end | ||
|
||
@testset "utils.jl" begin | ||
@testset "dot" begin | ||
scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) | ||
arrays = (randn(10), randn(ComplexF64, 10)) | ||
for inputs in (scalars, arrays) | ||
for x in inputs, y in inputs | ||
@test realdot(x, y) == real(dot(x, y)) | ||
|
||
if eltype(x) <: Real && eltype(y) <: Real | ||
@test imagdot(x, y) === ZeroTangent() | ||
else | ||
@test imagdot(x, y) == imag(dot(x, y)) | ||
end | ||
end | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe they should be moved to
src/rule_definition_tools.jl
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That file is full of metaprogramming stuff.
It doesn't need this.
Files called
utils.jl
are always a mistake.I think something like
complex_math.jl
as it currently stands,or
additional_mathematical_functions.jl
if we think we will need more of them later.