Skip to content

Commit

Permalink
Add realdot (#2)
Browse files Browse the repository at this point in the history
* Add missing complex tests and rules (#216)

* Fix indentation

* Test \ on complex inputs

* Test ^ on complex inputs

* Test identity on complex inputs

* Test muladd on complex inputs

* Test binary functions on complex inputs

* Test functions on complex inputs

* Release type constraint on exp

* Add _realconjtimes

* Use _realconjtimes in abs/abs2 rules

* Add complex rule for hypot

* Add generic rule for adjoint

* Add generic rule for real

* Add generic rule for imag

* Add complex rule for hypot

* Add rules/tests for Complex

* Test frule for identity

* Add missing angle test

* Make inline just in case

* Unify abs rules

* Introduce _imagconjtimes utility function

* Unify angle rules

* Unify sign rules

* Multiply by correct variable

* Fix argument order

* Bump ChainRulesTestUtils version number

* Restrict to Complex

* Use muladd

* Update src/rulesets/Base/fastmath_able.jl

Co-authored-by: willtebbutt <[email protected]>

Co-authored-by: willtebbutt <[email protected]>

* rename differentials (#413)

* rename DoesNotExist

* rename Composite

* bump version and compat

* rename Zero

* remove typos

* reexport deprecated types manually

* Rename to `realconjtimes` and `imagconjtimes` and export them

* Add tests

* Fix tests with Julia 1.0

* Rename to `realdot` and `imagdot`

* Add dispatch for real arrays

* Update src/utils.jl

Co-authored-by: Seth Axen <[email protected]>

* Generalize `::Complex` to `::Number`

* Rename `utils.jl` to `complex_math.jl`

* Remove `imagdot`

* Add `realdot`

* Update README

* Apply suggestions from code review

Co-authored-by: Seth Axen <[email protected]>

* Update README.md

* Update src/RealDot.jl

Co-authored-by: Seth Axen <[email protected]>

* Add test with quaternions

* Fix quaternion multiplication

Co-authored-by: Seth Axen <[email protected]>
Co-authored-by: willtebbutt <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
  • Loading branch information
4 people authored Oct 18, 2021
1 parent ebc598d commit 6a8e18b
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 2 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
authors = ["David Widmann"]
version = "0.1.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
julia = "1"

Expand Down
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,19 @@
[![Coverage](https://codecov.io/gh/devmotion/RealDot.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/devmotion/RealDot.jl)
[![Coverage](https://coveralls.io/repos/github/devmotion/RealDot.jl/badge.svg?branch=main)](https://coveralls.io/github/devmotion/RealDot.jl?branch=main)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)

This package only contains and exports a single function `realdot(x, y)`.
It computes `real(LinearAlgebra.dot(x, y))` while avoiding computing the imaginary part of `LinearAlgebra.dot(x, y)` if possible.

The real dot product is useful when one treats complex numbers as embedded in a real vector space.
For example, take two complex arrays `x` and `y`.
Their real dot product is `real(dot(x, y)) == dot(real(x), real(y)) + dot(imag(x), imag(y))`.
This is the same result one would get by reinterpreting the arrays as real arrays:
```julia
xreal = reinterpret(real(eltype(x)), x)
yreal = reinterpret(real(eltype(y)), y)
real(dot(x, y)) == dot(xreal, yreal)
```

In particular, this function can be useful if you define pullbacks for non-holomorphic functions (see e.g. [this discussion in the ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/pull/474)).
It was implemented initially in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) in [this PR](https://github.com/JuliaDiff/ChainRules.jl/pull/216) as `_realconjtimes`.
19 changes: 18 additions & 1 deletion src/RealDot.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
module RealDot

# Write your package code here.
using LinearAlgebra: LinearAlgebra

export realdot

"""
realdot(x, y)
Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible.
This function can be useful if you work with derivatives of functions on complex
numbers. In particular, this computation shows up in pullbacks for non-holomorphic
functions.
"""
@inline realdot(x, y) = real(LinearAlgebra.dot(x, y))
@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y))
@inline realdot(x::Real, y::Number) = x * real(y)
@inline realdot(x::Number, y::Real) = real(x) * y
@inline realdot(x::Real, y::Real) = x * y

end
59 changes: 58 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,63 @@
using RealDot
using LinearAlgebra
using Test

# struct need to be defined outside of tests for julia 1.0 compat
# custom complex number (tests fallback definition)
struct CustomComplex{T} <: Number
re::T
im::T
end

Base.real(x::CustomComplex) = x.re
Base.imag(x::CustomComplex) = x.im

Base.conj(x::CustomComplex) = CustomComplex(x.re, -x.im)

function Base.:*(x::CustomComplex, y::Union{Real,Complex})
return CustomComplex(reim(Complex(reim(x)...) * y)...)
end
Base.:*(x::Union{Real,Complex}, y::CustomComplex) = y * x
function Base.:*(x::CustomComplex, y::CustomComplex)
return CustomComplex(reim(Complex(reim(x)...) * Complex(reim(y)...))...)
end

# custom quaternion to test definition for hypercomplex numbers
# adapted from Quaternions.jl
struct Quaternion{T<:Real} <: Number
s::T
v1::T
v2::T
v3::T
end

Base.real(q::Quaternion) = q.s
Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)

function Base.:*(q::Quaternion, w::Quaternion)
return Quaternion(
q.s * w.s - q.v1 * w.v1 - q.v2 * w.v2 - q.v3 * w.v3,
q.s * w.v1 + q.v1 * w.s + q.v2 * w.v3 - q.v3 * w.v2,
q.s * w.v2 - q.v1 * w.v3 + q.v2 * w.s + q.v3 * w.v1,
q.s * w.v3 + q.v1 * w.v2 - q.v2 * w.v1 + q.v3 * w.s,
)
end

function Base.:*(q::Quaternion, w::Union{Real,Complex,CustomComplex})
a, b = reim(w)
return q * Quaternion(a, b, zero(a), zero(a))
end
Base.:*(w::Union{Real,Complex,CustomComplex}, q::Quaternion) = conj(conj(q) * conj(w))

@testset "RealDot.jl" begin
# Write your tests here.
scalars = (
randn(), randn(ComplexF64), CustomComplex(randn(2)...), Quaternion(randn(4)...)
)
arrays = (randn(10), randn(ComplexF64, 10))

for inputs in (scalars, arrays)
for x in inputs, y in inputs
@test realdot(x, y) == real(dot(x, y))
end
end
end

0 comments on commit 6a8e18b

Please sign in to comment.