From 537fbad0e37a6e5efe18dd0a2c1cb0c0a2425c25 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 30 Jun 2020 08:00:40 -0700 Subject: [PATCH 01/18] Add missing complex tests and rules (#216) * Fix indentation * Test \ on complex inputs * Test ^ on complex inputs * Test identity on complex inputs * Test muladd on complex inputs * Test binary functions on complex inputs * Test functions on complex inputs * Release type constraint on exp * Add _realconjtimes * Use _realconjtimes in abs/abs2 rules * Add complex rule for hypot * Add generic rule for adjoint * Add generic rule for real * Add generic rule for imag * Add complex rule for hypot * Add rules/tests for Complex * Test frule for identity * Add missing angle test * Make inline just in case * Unify abs rules * Introduce _imagconjtimes utility function * Unify angle rules * Unify sign rules * Multiply by correct variable * Fix argument order * Bump ChainRulesTestUtils version number * Restrict to Complex * Use muladd * Update src/rulesets/Base/fastmath_able.jl Co-authored-by: willtebbutt Co-authored-by: willtebbutt --- src/utils.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 src/utils.jl diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..d957d55 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,15 @@ +# real(conj(x) * y) avoiding computing the imaginary part if possible +@inline _realconjtimes(x, y) = real(conj(x) * y) +@inline _realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline _realconjtimes(x::Real, y::Complex) = x * real(y) +@inline _realconjtimes(x::Complex, y::Real) = real(x) * y +@inline _realconjtimes(x::Real, y::Real) = x * y + +# imag(conj(x) * y) avoiding computing the real part if possible +@inline _imagconjtimes(x, y) = imag(conj(x) * y) +@inline function _imagconjtimes(x::Complex, y::Complex) + return muladd(-imag(x), real(y), real(x) * imag(y)) +end +@inline _imagconjtimes(x::Real, y::Complex) = x * imag(y) +@inline _imagconjtimes(x::Complex, y::Real) = -imag(x) * y +@inline _imagconjtimes(x::Real, y::Real) = Zero() From 4b08ee2f1de30c28aa84ab9da08bf0d625ba5438 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 19:32:33 +0100 Subject: [PATCH 02/18] rename differentials (#413) * rename DoesNotExist * rename Composite * bump version and compat * rename Zero * remove typos * reexport deprecated types manually --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index d957d55..30bbdd4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,4 +12,4 @@ end @inline _imagconjtimes(x::Real, y::Complex) = x * imag(y) @inline _imagconjtimes(x::Complex, y::Real) = -imag(x) * y -@inline _imagconjtimes(x::Real, y::Real) = Zero() +@inline _imagconjtimes(x::Real, y::Real) = ZeroTangent() From 15ae11ec588f12e7cc2477f7ee87647f7d840881 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 1 Oct 2021 16:54:01 +0200 Subject: [PATCH 03/18] Rename to `realconjtimes` and `imagconjtimes` and export them --- src/utils.jl | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 30bbdd4..06d6a4a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,15 +1,33 @@ -# real(conj(x) * y) avoiding computing the imaginary part if possible -@inline _realconjtimes(x, y) = real(conj(x) * y) -@inline _realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) -@inline _realconjtimes(x::Real, y::Complex) = x * real(y) -@inline _realconjtimes(x::Complex, y::Real) = real(x) * y -@inline _realconjtimes(x::Real, y::Real) = x * y - -# imag(conj(x) * y) avoiding computing the real part if possible -@inline _imagconjtimes(x, y) = imag(conj(x) * y) -@inline function _imagconjtimes(x::Complex, y::Complex) +""" + realconjtimes(x, y) + +Compute `real(conj(x) * y)` while avoiding computing the imaginary part if possible. + +This function can be useful if you implement a `rrule` for a non-holomorphic function +on complex numbers. + +See also: [`imagconjtimes`](@ref) +""" +@inline realconjtimes(x, y) = real(conj(x) * y) +@inline realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline realconjtimes(x::Real, y::Complex) = x * real(y) +@inline realconjtimes(x::Complex, y::Real) = real(x) * y +@inline realconjtimes(x::Real, y::Real) = x * y + +""" + imagconjtimes(x, y) + +Compute `imag(conj(x) * y)` while avoiding computing the real part if possible. + +This function can be useful if you implement a `rrule` for a non-holomorphic function +on complex numbers. + +See also: [`realconjtimes`](@ref) +""" +@inline imagconjtimes(x, y) = imag(conj(x) * y) +@inline function imagconjtimes(x::Complex, y::Complex) return muladd(-imag(x), real(y), real(x) * imag(y)) end -@inline _imagconjtimes(x::Real, y::Complex) = x * imag(y) -@inline _imagconjtimes(x::Complex, y::Real) = -imag(x) * y -@inline _imagconjtimes(x::Real, y::Real) = ZeroTangent() +@inline imagconjtimes(x::Real, y::Complex) = x * imag(y) +@inline imagconjtimes(x::Complex, y::Real) = -imag(x) * y +@inline imagconjtimes(x::Real, y::Real) = ZeroTangent() From 262f9a92edd264d570199560d0e8354527845efc Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 1 Oct 2021 16:54:36 +0200 Subject: [PATCH 04/18] Add tests --- test/utils.jl | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 test/utils.jl diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..a3e4f9e --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,31 @@ +@testset "utils.jl" begin + @testset "conjtimes" begin + # custom complex number to test fallback definition + struct CustomComplex{T} + re::T + im::T + end + + Base.real(x::CustomComplex) = x.re + Base.imag(x::CustomComplex) = x.im + + Base.conj(x::CustomComplex) = CustomComplex(x.re, -x.im) + + Base.:*(a::CustomComplex, b::Number) = CustomComplex(reim((a.re + a.im * im) * b)...) + Base.:*(a::Number, b::CustomComplex) = b * a + function Base.:*(a::CustomComplex, b::CustomComplex) + return CustomComplex(reim((a.re + a.im * im) * (b.re + b.im * im))...) + end + + inputs = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) + for x in inputs, y in inputs + @test realconjtimes(x, y) == real(conj(x) * y) + + if x isa Real && y isa Real + @test imagconjtimes(x, y) === ZeroTangent() + else + @test imagconjtimes(x, y) == imag(conj(x) * y) + end + end + end +end From 8a750f2cdf3151a3a865aa62a30d5c3f4760a5c1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 1 Oct 2021 17:12:26 +0200 Subject: [PATCH 05/18] Fix tests with Julia 1.0 --- test/utils.jl | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index a3e4f9e..55a5dbb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,22 +1,23 @@ -@testset "utils.jl" begin - @testset "conjtimes" begin - # custom complex number to test fallback definition - struct CustomComplex{T} - re::T - im::T - end +# struct need to be defined outside of tests for julia 1.0 compat +# custom complex number to test fallback definition +struct CustomComplex{T} + re::T + im::T +end - Base.real(x::CustomComplex) = x.re - Base.imag(x::CustomComplex) = x.im +Base.real(x::CustomComplex) = x.re +Base.imag(x::CustomComplex) = x.im - Base.conj(x::CustomComplex) = CustomComplex(x.re, -x.im) +Base.conj(x::CustomComplex) = CustomComplex(x.re, -x.im) - Base.:*(a::CustomComplex, b::Number) = CustomComplex(reim((a.re + a.im * im) * b)...) - Base.:*(a::Number, b::CustomComplex) = b * a - function Base.:*(a::CustomComplex, b::CustomComplex) - return CustomComplex(reim((a.re + a.im * im) * (b.re + b.im * im))...) - end +Base.:*(a::CustomComplex, b::Number) = CustomComplex(reim((a.re + a.im * im) * b)...) +Base.:*(a::Number, b::CustomComplex) = b * a +function Base.:*(a::CustomComplex, b::CustomComplex) + return CustomComplex(reim((a.re + a.im * im) * (b.re + b.im * im))...) +end +@testset "utils.jl" begin + @testset "conjtimes" begin inputs = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) for x in inputs, y in inputs @test realconjtimes(x, y) == real(conj(x) * y) From b91d03d69a3ea1eb49a821d0b9f88490223662ac Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 1 Oct 2021 23:23:41 +0200 Subject: [PATCH 06/18] Rename to `realdot` and `imagdot` --- src/utils.jl | 33 +++++++++++++++++---------------- test/utils.jl | 33 +++++++++++++++++++-------------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 06d6a4a..d97eaaf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,33 +1,34 @@ """ - realconjtimes(x, y) + realdot(x, y) -Compute `real(conj(x) * y)` while avoiding computing the imaginary part if possible. +Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible. This function can be useful if you implement a `rrule` for a non-holomorphic function on complex numbers. -See also: [`imagconjtimes`](@ref) +See also: [`imagdot`](@ref) """ -@inline realconjtimes(x, y) = real(conj(x) * y) -@inline realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) -@inline realconjtimes(x::Real, y::Complex) = x * real(y) -@inline realconjtimes(x::Complex, y::Real) = real(x) * y -@inline realconjtimes(x::Real, y::Real) = x * y +@inline realdot(x, y) = real(dot(x, y)) +@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline realdot(x::Real, y::Complex) = x * real(y) +@inline realdot(x::Complex, y::Real) = real(x) * y +@inline realdot(x::Real, y::Real) = x * y """ - imagconjtimes(x, y) + imagdot(x, y) -Compute `imag(conj(x) * y)` while avoiding computing the real part if possible. +Compute `imag(dot(x, y))` while avoiding computing the real part if possible. This function can be useful if you implement a `rrule` for a non-holomorphic function on complex numbers. -See also: [`realconjtimes`](@ref) +See also: [`realdot`](@ref) """ -@inline imagconjtimes(x, y) = imag(conj(x) * y) -@inline function imagconjtimes(x::Complex, y::Complex) +@inline imagdot(x, y) = imag(dot(x, y)) +@inline function imagdot(x::Complex, y::Complex) return muladd(-imag(x), real(y), real(x) * imag(y)) end -@inline imagconjtimes(x::Real, y::Complex) = x * imag(y) -@inline imagconjtimes(x::Complex, y::Real) = -imag(x) * y -@inline imagconjtimes(x::Real, y::Real) = ZeroTangent() +@inline imagdot(x::Real, y::Complex) = x * imag(y) +@inline imagdot(x::Complex, y::Real) = -imag(x) * y +@inline imagdot(x::Real, y::Real) = ZeroTangent() +@inline imagdot(x::AbstractArray{<:Real}, y::AbstractArray{<:Real}) = ZeroTangent() diff --git a/test/utils.jl b/test/utils.jl index 55a5dbb..0cd1535 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -8,24 +8,29 @@ end Base.real(x::CustomComplex) = x.re Base.imag(x::CustomComplex) = x.im -Base.conj(x::CustomComplex) = CustomComplex(x.re, -x.im) - -Base.:*(a::CustomComplex, b::Number) = CustomComplex(reim((a.re + a.im * im) * b)...) -Base.:*(a::Number, b::CustomComplex) = b * a -function Base.:*(a::CustomComplex, b::CustomComplex) - return CustomComplex(reim((a.re + a.im * im) * (b.re + b.im * im))...) +function LinearAlgebra.dot(a::CustomComplex, b::Number) + return CustomComplex(reim((a.re - a.im * im) * b)...) +end +function LinearAlgebra.dot(a::Number, b::CustomComplex) + return CustomComplex(reim(conj(a) * (b.re + b.im * im))...) +end +function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex) + return CustomComplex(reim((a.re - a.im * im) * (b.re + b.im * im))...) end @testset "utils.jl" begin - @testset "conjtimes" begin - inputs = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) - for x in inputs, y in inputs - @test realconjtimes(x, y) == real(conj(x) * y) + @testset "dot" begin + scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) + arrays = (randn(10), randn(ComplexF64, 10)) + for inputs in (scalars, arrays) + for x in inputs, y in inputs + @test realdot(x, y) == real(dot(x, y)) - if x isa Real && y isa Real - @test imagconjtimes(x, y) === ZeroTangent() - else - @test imagconjtimes(x, y) == imag(conj(x) * y) + if eltype(x) <: Real && eltype(y) <: Real + @test imagdot(x, y) === ZeroTangent() + else + @test imagdot(x, y) == imag(dot(x, y)) + end end end end From 360dcb99db66a0858fc128909bbc21fa0aaaf409 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 2 Oct 2021 01:08:16 +0200 Subject: [PATCH 07/18] Add dispatch for real arrays --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index d97eaaf..601e9d2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,6 +13,7 @@ See also: [`imagdot`](@ref) @inline realdot(x::Real, y::Complex) = x * real(y) @inline realdot(x::Complex, y::Real) = real(x) * y @inline realdot(x::Real, y::Real) = x * y +@inline realdot(x::AbstractArray{<:Real}, y::AbstractArray{<:Real}) = dot(x, y) """ imagdot(x, y) From d6ba3bea9946b90ebf13cd84603a736b903a100b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 11 Oct 2021 20:52:52 +0200 Subject: [PATCH 08/18] Update src/utils.jl Co-authored-by: Seth Axen --- src/utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 601e9d2..d97eaaf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,7 +13,6 @@ See also: [`imagdot`](@ref) @inline realdot(x::Real, y::Complex) = x * real(y) @inline realdot(x::Complex, y::Real) = real(x) * y @inline realdot(x::Real, y::Real) = x * y -@inline realdot(x::AbstractArray{<:Real}, y::AbstractArray{<:Real}) = dot(x, y) """ imagdot(x, y) From 25bbc2acb6e2e600dcd9bdb2833979f1afdbb322 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 11 Oct 2021 22:40:06 +0200 Subject: [PATCH 09/18] Generalize `::Complex` to `::Number` --- src/utils.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index d97eaaf..0399120 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,9 +9,9 @@ on complex numbers. See also: [`imagdot`](@ref) """ @inline realdot(x, y) = real(dot(x, y)) -@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) -@inline realdot(x::Real, y::Complex) = x * real(y) -@inline realdot(x::Complex, y::Real) = real(x) * y +@inline realdot(x::Number, y::Number) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline realdot(x::Real, y::Number) = x * real(y) +@inline realdot(x::Number, y::Real) = real(x) * y @inline realdot(x::Real, y::Real) = x * y """ @@ -25,10 +25,10 @@ on complex numbers. See also: [`realdot`](@ref) """ @inline imagdot(x, y) = imag(dot(x, y)) -@inline function imagdot(x::Complex, y::Complex) +@inline function imagdot(x::Number, y::Number) return muladd(-imag(x), real(y), real(x) * imag(y)) end -@inline imagdot(x::Real, y::Complex) = x * imag(y) -@inline imagdot(x::Complex, y::Real) = -imag(x) * y +@inline imagdot(x::Real, y::Number) = x * imag(y) +@inline imagdot(x::Number, y::Real) = -imag(x) * y @inline imagdot(x::Real, y::Real) = ZeroTangent() @inline imagdot(x::AbstractArray{<:Real}, y::AbstractArray{<:Real}) = ZeroTangent() From 638a6f5d285dea02fbea7654421bcf6d29073f65 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 11 Oct 2021 22:41:56 +0200 Subject: [PATCH 10/18] Rename `utils.jl` to `complex_math.jl` --- src/{utils.jl => complex_math.jl} | 0 test/{utils.jl => complex_math.jl} | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/{utils.jl => complex_math.jl} (100%) rename test/{utils.jl => complex_math.jl} (97%) diff --git a/src/utils.jl b/src/complex_math.jl similarity index 100% rename from src/utils.jl rename to src/complex_math.jl diff --git a/test/utils.jl b/test/complex_math.jl similarity index 97% rename from test/utils.jl rename to test/complex_math.jl index 0cd1535..6963c0e 100644 --- a/test/utils.jl +++ b/test/complex_math.jl @@ -18,7 +18,7 @@ function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex) return CustomComplex(reim((a.re - a.im * im) * (b.re + b.im * im))...) end -@testset "utils.jl" begin +@testset "complex_math.jl" begin @testset "dot" begin scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) arrays = (randn(10), randn(ComplexF64, 10)) From 96b44e94d9deff068fb744643de4155c392a9783 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 16 Oct 2021 16:00:25 +0200 Subject: [PATCH 11/18] Remove `imagdot` --- src/complex_math.jl | 26 +++----------------------- test/complex_math.jl | 8 +------- 2 files changed, 4 insertions(+), 30 deletions(-) diff --git a/src/complex_math.jl b/src/complex_math.jl index 0399120..7ea003e 100644 --- a/src/complex_math.jl +++ b/src/complex_math.jl @@ -3,32 +3,12 @@ Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible. -This function can be useful if you implement a `rrule` for a non-holomorphic function -on complex numbers. - -See also: [`imagdot`](@ref) +This function can be useful if you work with derivatives of functions on complex +numbers. In particular, this computation shows up in pullbacks for non-holomorphic +functions. """ @inline realdot(x, y) = real(dot(x, y)) @inline realdot(x::Number, y::Number) = muladd(real(x), real(y), imag(x) * imag(y)) @inline realdot(x::Real, y::Number) = x * real(y) @inline realdot(x::Number, y::Real) = real(x) * y @inline realdot(x::Real, y::Real) = x * y - -""" - imagdot(x, y) - -Compute `imag(dot(x, y))` while avoiding computing the real part if possible. - -This function can be useful if you implement a `rrule` for a non-holomorphic function -on complex numbers. - -See also: [`realdot`](@ref) -""" -@inline imagdot(x, y) = imag(dot(x, y)) -@inline function imagdot(x::Number, y::Number) - return muladd(-imag(x), real(y), real(x) * imag(y)) -end -@inline imagdot(x::Real, y::Number) = x * imag(y) -@inline imagdot(x::Number, y::Real) = -imag(x) * y -@inline imagdot(x::Real, y::Real) = ZeroTangent() -@inline imagdot(x::AbstractArray{<:Real}, y::AbstractArray{<:Real}) = ZeroTangent() diff --git a/test/complex_math.jl b/test/complex_math.jl index 6963c0e..1dc165c 100644 --- a/test/complex_math.jl +++ b/test/complex_math.jl @@ -19,18 +19,12 @@ function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex) end @testset "complex_math.jl" begin - @testset "dot" begin + @testset "realdot" begin scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) arrays = (randn(10), randn(ComplexF64, 10)) for inputs in (scalars, arrays) for x in inputs, y in inputs @test realdot(x, y) == real(dot(x, y)) - - if eltype(x) <: Real && eltype(y) <: Real - @test imagdot(x, y) === ZeroTangent() - else - @test imagdot(x, y) == imag(dot(x, y)) - end end end end From ce9c71eaa29914b9b7e5f7088123ed6d7dd30fce Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 16 Oct 2021 22:20:56 +0200 Subject: [PATCH 12/18] Add `realdot` --- Project.toml | 3 +++ src/RealDot.jl | 19 ++++++++++++++++++- src/complex_math.jl | 14 -------------- test/complex_math.jl | 31 ------------------------------- test/runtests.jl | 30 +++++++++++++++++++++++++++++- 5 files changed, 50 insertions(+), 47 deletions(-) delete mode 100644 src/complex_math.jl delete mode 100644 test/complex_math.jl diff --git a/Project.toml b/Project.toml index 181190c..f524ed8 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,9 @@ uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" authors = ["David Widmann"] version = "0.1.0" +[deps] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + [compat] julia = "1" diff --git a/src/RealDot.jl b/src/RealDot.jl index d9d83db..b7338c2 100644 --- a/src/RealDot.jl +++ b/src/RealDot.jl @@ -1,5 +1,22 @@ module RealDot -# Write your package code here. +using LinearAlgebra: LinearAlgebra + +export realdot + +""" + realdot(x, y) + +Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible. + +This function can be useful if you work with derivatives of functions on complex +numbers. In particular, this computation shows up in pullbacks for non-holomorphic +functions. +""" +@inline realdot(x, y) = real(LinearAlgebra.dot(x, y)) +@inline realdot(x::Number, y::Number) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline realdot(x::Real, y::Number) = x * real(y) +@inline realdot(x::Number, y::Real) = real(x) * y +@inline realdot(x::Real, y::Real) = x * y end diff --git a/src/complex_math.jl b/src/complex_math.jl deleted file mode 100644 index 7ea003e..0000000 --- a/src/complex_math.jl +++ /dev/null @@ -1,14 +0,0 @@ -""" - realdot(x, y) - -Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible. - -This function can be useful if you work with derivatives of functions on complex -numbers. In particular, this computation shows up in pullbacks for non-holomorphic -functions. -""" -@inline realdot(x, y) = real(dot(x, y)) -@inline realdot(x::Number, y::Number) = muladd(real(x), real(y), imag(x) * imag(y)) -@inline realdot(x::Real, y::Number) = x * real(y) -@inline realdot(x::Number, y::Real) = real(x) * y -@inline realdot(x::Real, y::Real) = x * y diff --git a/test/complex_math.jl b/test/complex_math.jl deleted file mode 100644 index 1dc165c..0000000 --- a/test/complex_math.jl +++ /dev/null @@ -1,31 +0,0 @@ -# struct need to be defined outside of tests for julia 1.0 compat -# custom complex number to test fallback definition -struct CustomComplex{T} - re::T - im::T -end - -Base.real(x::CustomComplex) = x.re -Base.imag(x::CustomComplex) = x.im - -function LinearAlgebra.dot(a::CustomComplex, b::Number) - return CustomComplex(reim((a.re - a.im * im) * b)...) -end -function LinearAlgebra.dot(a::Number, b::CustomComplex) - return CustomComplex(reim(conj(a) * (b.re + b.im * im))...) -end -function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex) - return CustomComplex(reim((a.re - a.im * im) * (b.re + b.im * im))...) -end - -@testset "complex_math.jl" begin - @testset "realdot" begin - scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) - arrays = (randn(10), randn(ComplexF64, 10)) - for inputs in (scalars, arrays) - for x in inputs, y in inputs - @test realdot(x, y) == real(dot(x, y)) - end - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index cc75a0c..c10cdb8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,34 @@ using RealDot +using LinearAlgebra using Test +# struct need to be defined outside of tests for julia 1.0 compat +# custom complex number to test fallback definition +struct CustomComplex{T} + re::T + im::T +end + +Base.real(x::CustomComplex) = x.re +Base.imag(x::CustomComplex) = x.im + +function LinearAlgebra.dot(a::CustomComplex, b::Number) + return CustomComplex(reim((a.re - a.im * im) * b)...) +end +function LinearAlgebra.dot(a::Number, b::CustomComplex) + return CustomComplex(reim(conj(a) * (b.re + b.im * im))...) +end +function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex) + return CustomComplex(reim((a.re - a.im * im) * (b.re + b.im * im))...) +end + @testset "RealDot.jl" begin - # Write your tests here. + scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) + arrays = (randn(10), randn(ComplexF64, 10)) + + for inputs in (scalars, arrays) + for x in inputs, y in inputs + @test realdot(x, y) == real(dot(x, y)) + end + end end From 5de6e9b7c55ff91fb7e0e45bfb8947bc0158fd9e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 16 Oct 2021 22:21:12 +0200 Subject: [PATCH 13/18] Update README --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index c2f89a2..7f94296 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,10 @@ [![Coverage](https://codecov.io/gh/devmotion/RealDot.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/devmotion/RealDot.jl) [![Coverage](https://coveralls.io/repos/github/devmotion/RealDot.jl/badge.svg?branch=main)](https://coveralls.io/github/devmotion/RealDot.jl?branch=main) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) + +This package only contains and exports a single function `realdot(x, y)`. It computes +`real(LinearAlgebra.dot(x, y))` while avoiding computing the imaginary part of +`LinearAlgebra.dot((x, y)` if possible. + +This function can be useful e.g. if you define pullbacks for non-holomorphic functions +(see e.g. [this discussion in the ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/pull/474)). It was implemented initially in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) in [this PR](https://github.com/JuliaDiff/ChainRules.jl/pull/216) as `_realconjtimes`. From e48caee6d10e3b2f1a9fddbb5b15a6ce178ae4ad Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 18 Oct 2021 13:16:25 +0200 Subject: [PATCH 14/18] Apply suggestions from code review Co-authored-by: Seth Axen --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7f94296..e706b66 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,13 @@ This package only contains and exports a single function `realdot(x, y)`. It computes `real(LinearAlgebra.dot(x, y))` while avoiding computing the imaginary part of -`LinearAlgebra.dot((x, y)` if possible. - +`LinearAlgebra.dot(x, y)` if possible. +The real dot product is useful when one treats complex numbers as embedded in a real vector space. +For example, take the complex arrays ``x = a + i b`` and ``y = c + i d``. Their real dot product is +`real(dot(x, y)) == dot(real(x), real(y)) + dot(imag(x), imag(y))`. This is the same result one would get by reinterpreting the arrays as real arrays: +```julia +xreal = reinterpret(real(eltype(x)), x) +yreal = reinterpret(real(eltype(y)), y) +real(dot(x, y)) == dot(xreal, yreal) This function can be useful e.g. if you define pullbacks for non-holomorphic functions (see e.g. [this discussion in the ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/pull/474)). It was implemented initially in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) in [this PR](https://github.com/JuliaDiff/ChainRules.jl/pull/216) as `_realconjtimes`. From 56186319329a7abee364cba1cd4358243cfe7c78 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 18 Oct 2021 13:20:05 +0200 Subject: [PATCH 15/18] Update README.md --- README.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index e706b66..383c148 100644 --- a/README.md +++ b/README.md @@ -5,15 +5,18 @@ [![Coverage](https://coveralls.io/repos/github/devmotion/RealDot.jl/badge.svg?branch=main)](https://coveralls.io/github/devmotion/RealDot.jl?branch=main) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) -This package only contains and exports a single function `realdot(x, y)`. It computes -`real(LinearAlgebra.dot(x, y))` while avoiding computing the imaginary part of -`LinearAlgebra.dot(x, y)` if possible. +This package only contains and exports a single function `realdot(x, y)`. +It computes `real(LinearAlgebra.dot(x, y))` while avoiding computing the imaginary part of `LinearAlgebra.dot(x, y)` if possible. + The real dot product is useful when one treats complex numbers as embedded in a real vector space. -For example, take the complex arrays ``x = a + i b`` and ``y = c + i d``. Their real dot product is -`real(dot(x, y)) == dot(real(x), real(y)) + dot(imag(x), imag(y))`. This is the same result one would get by reinterpreting the arrays as real arrays: +For example, take two complex arrays `x` and `y`. +Their real dot product is `real(dot(x, y)) == dot(real(x), real(y)) + dot(imag(x), imag(y))`. +This is the same result one would get by reinterpreting the arrays as real arrays: ```julia xreal = reinterpret(real(eltype(x)), x) yreal = reinterpret(real(eltype(y)), y) real(dot(x, y)) == dot(xreal, yreal) -This function can be useful e.g. if you define pullbacks for non-holomorphic functions -(see e.g. [this discussion in the ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/pull/474)). It was implemented initially in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) in [this PR](https://github.com/JuliaDiff/ChainRules.jl/pull/216) as `_realconjtimes`. +``` + +In particular, this function can be useful if you define pullbacks for non-holomorphic functions (see e.g. [this discussion in the ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/pull/474)). +It was implemented initially in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) in [this PR](https://github.com/JuliaDiff/ChainRules.jl/pull/216) as `_realconjtimes`. From 7194ee913c2f46fff3c9022c5c18dc6cbd130ca5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 18 Oct 2021 14:13:53 +0200 Subject: [PATCH 16/18] Update src/RealDot.jl Co-authored-by: Seth Axen --- src/RealDot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/RealDot.jl b/src/RealDot.jl index b7338c2..6f286d9 100644 --- a/src/RealDot.jl +++ b/src/RealDot.jl @@ -14,7 +14,7 @@ numbers. In particular, this computation shows up in pullbacks for non-holomorph functions. """ @inline realdot(x, y) = real(LinearAlgebra.dot(x, y)) -@inline realdot(x::Number, y::Number) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) @inline realdot(x::Real, y::Number) = x * real(y) @inline realdot(x::Number, y::Real) = real(x) * y @inline realdot(x::Real, y::Real) = x * y From 1ecb63d61b301c1551c5da4000d177bd1027f330 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 18 Oct 2021 16:30:08 +0200 Subject: [PATCH 17/18] Add test with quaternions --- test/runtests.jl | 47 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index c10cdb8..80b5045 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,8 +3,8 @@ using LinearAlgebra using Test # struct need to be defined outside of tests for julia 1.0 compat -# custom complex number to test fallback definition -struct CustomComplex{T} +# custom complex number (tests fallback definition) +struct CustomComplex{T} <: Number re::T im::T end @@ -12,18 +12,47 @@ end Base.real(x::CustomComplex) = x.re Base.imag(x::CustomComplex) = x.im -function LinearAlgebra.dot(a::CustomComplex, b::Number) - return CustomComplex(reim((a.re - a.im * im) * b)...) +Base.conj(x::CustomComplex) = CustomComplex(x.re, -x.im) + +function Base.:*(x::CustomComplex, y::Union{Real,Complex}) + return CustomComplex(reim(Complex(reim(x)...) * y)...) +end +Base.:*(x::Union{Real,Complex}, y::CustomComplex) = y * x +function Base.:*(x::CustomComplex, y::CustomComplex) + return CustomComplex(reim(Complex(reim(x)...) * Complex(reim(y)...))...) end -function LinearAlgebra.dot(a::Number, b::CustomComplex) - return CustomComplex(reim(conj(a) * (b.re + b.im * im))...) + +# custom quaternion to test definition for hypercomplex numbers +# adapted from Quaternions.jl +struct Quaternion{T<:Real} <: Number + s::T + v1::T + v2::T + v3::T end -function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex) - return CustomComplex(reim((a.re - a.im * im) * (b.re + b.im * im))...) + +Base.real(q::Quaternion) = q.s +Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3) + +function Base.:*(q::Quaternion, w::Quaternion) + return Quaternion( + q.s * w.s - q.v1 * w.v1 - q.v2 * w.v2 - q.v3 * w.v3, + q.s * w.v1 + q.v1 * w.s + q.v2 * w.v3 - q.v3 * w.v2, + q.s * w.v2 - q.v1 * w.v3 + q.v2 * w.s + q.v3 * w.v1, + q.s * w.v3 + q.v1 * w.v2 - q.v2 * w.v1 + q.v3 * w.s, + ) +end + +function Base.:*(q::Quaternion, w::Union{Real,Complex,CustomComplex}) + a, b = reim(w) + return q * Quaternion(a, b, zero(a), zero(a)) end +Base.:*(q::Union{Real,Complex,CustomComplex}, w::Quaternion) = w * q @testset "RealDot.jl" begin - scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) + scalars = ( + randn(), randn(ComplexF64), CustomComplex(randn(2)...), Quaternion(randn(4)...) + ) arrays = (randn(10), randn(ComplexF64, 10)) for inputs in (scalars, arrays) From c6d1099e269d4e6f520753c9daaaa6f9ad484747 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 18 Oct 2021 21:44:34 +0200 Subject: [PATCH 18/18] Fix quaternion multiplication --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 80b5045..7cdb6d5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,7 +47,7 @@ function Base.:*(q::Quaternion, w::Union{Real,Complex,CustomComplex}) a, b = reim(w) return q * Quaternion(a, b, zero(a), zero(a)) end -Base.:*(q::Union{Real,Complex,CustomComplex}, w::Quaternion) = w * q +Base.:*(w::Union{Real,Complex,CustomComplex}, q::Quaternion) = conj(conj(q) * conj(w)) @testset "RealDot.jl" begin scalars = (