diff --git a/Project.toml b/Project.toml index 181190c..f524ed8 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,9 @@ uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" authors = ["David Widmann"] version = "0.1.0" +[deps] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + [compat] julia = "1" diff --git a/README.md b/README.md index c2f89a2..383c148 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,19 @@ [![Coverage](https://codecov.io/gh/devmotion/RealDot.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/devmotion/RealDot.jl) [![Coverage](https://coveralls.io/repos/github/devmotion/RealDot.jl/badge.svg?branch=main)](https://coveralls.io/github/devmotion/RealDot.jl?branch=main) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) + +This package only contains and exports a single function `realdot(x, y)`. +It computes `real(LinearAlgebra.dot(x, y))` while avoiding computing the imaginary part of `LinearAlgebra.dot(x, y)` if possible. + +The real dot product is useful when one treats complex numbers as embedded in a real vector space. +For example, take two complex arrays `x` and `y`. +Their real dot product is `real(dot(x, y)) == dot(real(x), real(y)) + dot(imag(x), imag(y))`. +This is the same result one would get by reinterpreting the arrays as real arrays: +```julia +xreal = reinterpret(real(eltype(x)), x) +yreal = reinterpret(real(eltype(y)), y) +real(dot(x, y)) == dot(xreal, yreal) +``` + +In particular, this function can be useful if you define pullbacks for non-holomorphic functions (see e.g. [this discussion in the ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/pull/474)). +It was implemented initially in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) in [this PR](https://github.com/JuliaDiff/ChainRules.jl/pull/216) as `_realconjtimes`. diff --git a/src/RealDot.jl b/src/RealDot.jl index d9d83db..6f286d9 100644 --- a/src/RealDot.jl +++ b/src/RealDot.jl @@ -1,5 +1,22 @@ module RealDot -# Write your package code here. +using LinearAlgebra: LinearAlgebra + +export realdot + +""" + realdot(x, y) + +Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible. + +This function can be useful if you work with derivatives of functions on complex +numbers. In particular, this computation shows up in pullbacks for non-holomorphic +functions. +""" +@inline realdot(x, y) = real(LinearAlgebra.dot(x, y)) +@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline realdot(x::Real, y::Number) = x * real(y) +@inline realdot(x::Number, y::Real) = real(x) * y +@inline realdot(x::Real, y::Real) = x * y end diff --git a/test/runtests.jl b/test/runtests.jl index cc75a0c..7cdb6d5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,63 @@ using RealDot +using LinearAlgebra using Test +# struct need to be defined outside of tests for julia 1.0 compat +# custom complex number (tests fallback definition) +struct CustomComplex{T} <: Number + re::T + im::T +end + +Base.real(x::CustomComplex) = x.re +Base.imag(x::CustomComplex) = x.im + +Base.conj(x::CustomComplex) = CustomComplex(x.re, -x.im) + +function Base.:*(x::CustomComplex, y::Union{Real,Complex}) + return CustomComplex(reim(Complex(reim(x)...) * y)...) +end +Base.:*(x::Union{Real,Complex}, y::CustomComplex) = y * x +function Base.:*(x::CustomComplex, y::CustomComplex) + return CustomComplex(reim(Complex(reim(x)...) * Complex(reim(y)...))...) +end + +# custom quaternion to test definition for hypercomplex numbers +# adapted from Quaternions.jl +struct Quaternion{T<:Real} <: Number + s::T + v1::T + v2::T + v3::T +end + +Base.real(q::Quaternion) = q.s +Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3) + +function Base.:*(q::Quaternion, w::Quaternion) + return Quaternion( + q.s * w.s - q.v1 * w.v1 - q.v2 * w.v2 - q.v3 * w.v3, + q.s * w.v1 + q.v1 * w.s + q.v2 * w.v3 - q.v3 * w.v2, + q.s * w.v2 - q.v1 * w.v3 + q.v2 * w.s + q.v3 * w.v1, + q.s * w.v3 + q.v1 * w.v2 - q.v2 * w.v1 + q.v3 * w.s, + ) +end + +function Base.:*(q::Quaternion, w::Union{Real,Complex,CustomComplex}) + a, b = reim(w) + return q * Quaternion(a, b, zero(a), zero(a)) +end +Base.:*(w::Union{Real,Complex,CustomComplex}, q::Quaternion) = conj(conj(q) * conj(w)) + @testset "RealDot.jl" begin - # Write your tests here. + scalars = ( + randn(), randn(ComplexF64), CustomComplex(randn(2)...), Quaternion(randn(4)...) + ) + arrays = (randn(10), randn(ComplexF64, 10)) + + for inputs in (scalars, arrays) + for x in inputs, y in inputs + @test realdot(x, y) == real(dot(x, y)) + end + end end