Skip to content

Commit

Permalink
Fix dot(::Adjoint, ::Adjoint) for numbers that don't commute under mu…
Browse files Browse the repository at this point in the history
…ltiplication (#44219)

Co-authored-by: Fredrik Ekre <[email protected]>
  • Loading branch information
sethaxen and fredrikekre authored Feb 19, 2022
1 parent 4061e8f commit 928f63c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,9 @@ function dot(x::AbstractArray, y::AbstractArray)
s
end

dot(x::Adjoint, y::Adjoint) = conj(dot(parent(x), parent(y)))
function dot(x::Adjoint{<:Union{Real,Complex}}, y::Adjoint{<:Union{Real,Complex}})
return conj(dot(parent(x), parent(y)))
end
dot(x::Transpose, y::Transpose) = dot(parent(x), parent(y))

"""
Expand Down
13 changes: 7 additions & 6 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,20 +497,21 @@ end
end

@testset "adjtrans dot" begin
for t in (transpose, adjoint)
x, y = t(rand(ComplexF64, 10)), t(rand(ComplexF64, 10))
for t in (transpose, adjoint), T in (ComplexF64, Quaternion{Float64})
x, y = t(rand(T, 10)), t(rand(T, 10))
X, Y = copy(x), copy(y)
@test dot(x, y) dot(X, Y)
x, y = t([rand(ComplexF64, 2, 2) for _ in 1:5]), t([rand(ComplexF64, 2, 2) for _ in 1:5])
x, y = t([rand(T, 2, 2) for _ in 1:5]), t([rand(T, 2, 2) for _ in 1:5])
X, Y = copy(x), copy(y)
@test dot(x, y) dot(X, Y)
x, y = t(rand(ComplexF64, 10, 5)), t(rand(ComplexF64, 10, 5))
x, y = t(rand(T, 10, 5)), t(rand(T, 10, 5))
X, Y = copy(x), copy(y)
@test dot(x, y) dot(X, Y)
x = t([rand(ComplexF64, 2, 2) for _ in 1:5, _ in 1:5])
y = t([rand(ComplexF64, 2, 2) for _ in 1:5, _ in 1:5])
x = t([rand(T, 2, 2) for _ in 1:5, _ in 1:5])
y = t([rand(T, 2, 2) for _ in 1:5, _ in 1:5])
X, Y = copy(x), copy(y)
@test dot(x, y) dot(X, Y)
x, y = t([rand(T, 2, 2) for _ in 1:5]), t([rand(T, 2, 2) for _ in 1:5])
end
end

Expand Down
14 changes: 14 additions & 0 deletions test/testhelpers/Quaternions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module Quaternions
using Random

export Quaternion

Expand Down Expand Up @@ -36,4 +37,17 @@ Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity
Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w))
Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q))

# adapted from https://github.com/JuliaGeometry/Quaternions.jl/pull/42
function Base.rand(rng::AbstractRNG, ::Random.SamplerType{Quaternion{T}}) where {T<:Real}
return Quaternion{T}(rand(rng, T), rand(rng, T), rand(rng, T), rand(rng, T))
end
function Base.randn(rng::AbstractRNG, ::Type{Quaternion{T}}) where {T<:AbstractFloat}
return Quaternion{T}(
randn(rng, T) / 2,
randn(rng, T) / 2,
randn(rng, T) / 2,
randn(rng, T) / 2,
)
end

end

0 comments on commit 928f63c

Please sign in to comment.