Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add realdot and imagdot from ChainRules #474

Closed
wants to merge 15 commits into from
Closed
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.7.1"
version = "1.8.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
5 changes: 4 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ Private = false
## Rule Definition Tools
```@autodocs
Modules = [ChainRulesCore]
Pages = ["rule_definition_tools.jl"]
Pages = [
"rule_definition_tools.jl",
"utils.jl",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe they should be moved to src/rule_definition_tools.jl?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That file is full of metaprogramming stuff.
It doesn't need this.

Files called utils.jl are always a mistake.
I think something like complex_math.jl as it currently stands,
or additional_mathematical_functions.jl if we think we will need more of them later.

]
Private = false
```

Expand Down
3 changes: 3 additions & 0 deletions docs/src/complex.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,6 @@ end
There are various notions of complex derivatives (holomorphic and Wirtinger derivatives, Jacobians, gradients, etc.) which differ in subtle but important ways.
The goal of ChainRules is to provide the basic differentiation rules upon which these derivatives can be implemented, but it does not implement these derivatives itself.
It is recommended that you carefully check how the above definitions of `frule` and `rrule` translate into your specific notion of complex derivative, since getting this wrong will quietly give you wrong results.

!!! note
If you implement `rrule` for a non-holomorphic function, [`realdot`](@ref) and [`imagdot`](@ref) can be useful.
3 changes: 3 additions & 0 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ export add!! # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives
# differentials
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
# helpers for rules with complex numbers
export realdot, imagdot

include("compat.jl")
include("debug_mode.jl")
Expand All @@ -34,6 +36,7 @@ include("config.jl")
include("rules.jl")
include("rule_definition_tools.jl")
include("ignore_derivatives.jl")
include("utils.jl")

include("deprecated.jl")

Expand Down
34 changes: 34 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
realdot(x, y)

Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible.

This function can be useful if you implement a `rrule` for a non-holomorphic function
on complex numbers.

See also: [`imagdot`](@ref)
"""
@inline realdot(x, y) = real(dot(x, y))
@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y))
@inline realdot(x::Real, y::Complex) = x * real(y)
@inline realdot(x::Complex, y::Real) = real(x) * y
@inline realdot(x::Real, y::Real) = x * y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be safe to generalize the cases, and then the last one is redundant since real(::Real) is a no-op

Suggested change
@inline realdot(x::Real, y::Complex) = x * real(y)
@inline realdot(x::Complex, y::Real) = real(x) * y
@inline realdot(x::Real, y::Real) = x * y
@inline realdot(x::Real, y::Number) = x * real(y)
@inline realdot(x::Number, y::Real) = real(x) * y

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this introduce an ambiguity, when both are real?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right.


"""
imagdot(x, y)

Compute `imag(dot(x, y))` while avoiding computing the real part if possible.

This function can be useful if you implement a `rrule` for a non-holomorphic function
on complex numbers.

See also: [`realdot`](@ref)
"""
@inline imagdot(x, y) = imag(dot(x, y))
@inline function imagdot(x::Complex, y::Complex)
return muladd(-imag(x), real(y), real(x) * imag(y))
end
@inline imagdot(x::Real, y::Complex) = x * imag(y)
@inline imagdot(x::Complex, y::Real) = -imag(x) * y
@inline imagdot(x::Real, y::Real) = ZeroTangent()
@inline imagdot(x::AbstractArray{<:Real}, y::AbstractArray{<:Real}) = ZeroTangent()
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Test
include("rule_definition_tools.jl")
include("config.jl")
include("ignore_derivatives.jl")
include("utils.jl")

include("deprecated.jl")
end
37 changes: 37 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# struct need to be defined outside of tests for julia 1.0 compat
# custom complex number to test fallback definition
struct CustomComplex{T}
re::T
im::T
end

Base.real(x::CustomComplex) = x.re
Base.imag(x::CustomComplex) = x.im

function LinearAlgebra.dot(a::CustomComplex, b::Number)
return CustomComplex(reim((a.re - a.im * im) * b)...)
end
function LinearAlgebra.dot(a::Number, b::CustomComplex)
return CustomComplex(reim(conj(a) * (b.re + b.im * im))...)
end
function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex)
return CustomComplex(reim((a.re - a.im * im) * (b.re + b.im * im))...)
end

@testset "utils.jl" begin
@testset "dot" begin
scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...))
arrays = (randn(10), randn(ComplexF64, 10))
for inputs in (scalars, arrays)
for x in inputs, y in inputs
@test realdot(x, y) == real(dot(x, y))

if eltype(x) <: Real && eltype(y) <: Real
@test imagdot(x, y) === ZeroTangent()
else
@test imagdot(x, y) == imag(dot(x, y))
end
end
end
end
end