From 601e535df78d1cdade4f6eed76348a8bf36cf864 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 27 Dec 2019 15:09:18 +0200 Subject: [PATCH] Add dot(x,A,y) (#683) * add some dot(x,A,y) methods * Bump version to 3.2.0 Co-authored-by: Fredrik Ekre Co-authored-by: Martin Holters --- Project.toml | 2 +- README.md | 3 ++ src/Compat.jl | 84 ++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 95 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 183 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7edf53a00..45fc353d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Compat" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.1.0" +version = "3.2.0" [deps] Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" diff --git a/README.md b/README.md index ef2c33fb1..6f08ca9e1 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,8 @@ Please check the list below for the specific syntax you need. ## Supported features +* `dot` now has a 3-argument method `dot(x, A, y)` without storing the intermediate result `A*y` ([#32739]). (since Compat 3.2.0) + * `pkgdir(m)` returns the root directory of the package that imported module `m` ([#33128]). (since Compat 3.2.0) * `filter` can now act on a `Tuple` [#32968]. (since Compat 3.1.0) @@ -104,6 +106,7 @@ Note that you should specify the correct minimum version for `Compat` in the [#29674]: https://github.com/JuliaLang/julia/issues/29674 [#29749]: https://github.com/JuliaLang/julia/issues/29749 [#32628]: https://github.com/JuliaLang/julia/issues/32628 +[#32739]: https://github.com/JuliaLang/julia/pull/32739 [#33129]: https://github.com/JuliaLang/julia/issues/33129 [#33568]: https://github.com/JuliaLang/julia/pull/33568 [#33128]: https://github.com/JuliaLang/julia/pull/33128 diff --git a/src/Compat.jl b/src/Compat.jl index 3d2842c93..aa8f1dbcf 100644 --- a/src/Compat.jl +++ b/src/Compat.jl @@ -1,5 +1,8 @@ module Compat +import LinearAlgebra +using LinearAlgebra: Adjoint, Diagonal, Transpose, UniformScaling, RealHermSymComplexHerm + include("compatmacro.jl") # https://github.com/JuliaLang/julia/pull/29679 @@ -88,6 +91,87 @@ if VERSION < v"1.3.0-alpha.8" Base.mod(i::Integer, r::AbstractUnitRange{<:Integer}) = mod(i-first(r), length(r)) + first(r) end +# https://github.com/JuliaLang/julia/pull/32739 +# This omits special methods for more exotic matrix types, Triangular and worse. +if VERSION < v"1.4.0-DEV.92" # 2425ae760fb5151c5c7dd0554e87c5fc9e24de73 + + # stdlib/LinearAlgebra/src/generic.jl + LinearAlgebra.dot(x, A, y) = LinearAlgebra.dot(x, A*y) # generic fallback + + function LinearAlgebra.dot(x::AbstractVector, A::AbstractMatrix, y::AbstractVector) + (axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch()) + T = typeof(LinearAlgebra.dot(first(x), first(A), first(y))) + s = zero(T) + i₁ = first(eachindex(x)) + x₁ = first(x) + @inbounds for j in eachindex(y) + yj = y[j] + if !iszero(yj) + temp = zero(adjoint(A[i₁,j]) * x₁) + @simd for i in eachindex(x) + temp += adjoint(A[i,j]) * x[i] + end + s += LinearAlgebra.dot(temp, yj) + end + end + return s + end + LinearAlgebra.dot(x::AbstractVector, adjA::Adjoint, y::AbstractVector) = + adjoint(LinearAlgebra.dot(y, adjA.parent, x)) + LinearAlgebra.dot(x::AbstractVector, transA::Transpose{<:Real}, y::AbstractVector) = + adjoint(LinearAlgebra.dot(y, transA.parent, x)) + + # stdlib/LinearAlgebra/src/diagonal.jl + function LinearAlgebra.dot(x::AbstractVector, D::Diagonal, y::AbstractVector) + mapreduce(t -> LinearAlgebra.dot(t[1], t[2], t[3]), +, zip(x, D.diag, y)) + end + + # stdlib/LinearAlgebra/src/symmetric.jl + function LinearAlgebra.dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector) + require_one_based_indexing(x, y) + (length(x) == length(y) == size(A, 1)) || throw(DimensionMismatch()) + data = A.data + r = zero(eltype(x)) * zero(eltype(A)) * zero(eltype(y)) + if A.uplo == 'U' + @inbounds for j = 1:length(y) + r += LinearAlgebra.dot(x[j], real(data[j,j]), y[j]) + @simd for i = 1:j-1 + Aij = data[i,j] + r += LinearAlgebra.dot(x[i], Aij, y[j]) + + LinearAlgebra.dot(x[j], adjoint(Aij), y[i]) + end + end + else # A.uplo == 'L' + @inbounds for j = 1:length(y) + r += LinearAlgebra.dot(x[j], real(data[j,j]), y[j]) + @simd for i = j+1:length(y) + Aij = data[i,j] + r += LinearAlgebra.dot(x[i], Aij, y[j]) + + LinearAlgebra.dot(x[j], adjoint(Aij), y[i]) + end + end + end + return r + end + + # stdlib/LinearAlgebra/src/uniformscaling.jl + LinearAlgebra.dot(x::AbstractVector, J::UniformScaling, y::AbstractVector) = + LinearAlgebra.dot(x, J.λ, y) + LinearAlgebra.dot(x::AbstractVector, a::Number, y::AbstractVector) = + sum(t -> LinearAlgebra.dot(t[1], a, t[2]), zip(x, y)) + LinearAlgebra.dot(x::AbstractVector, a::Union{Real,Complex}, y::AbstractVector) = + a*LinearAlgebra.dot(x, y) +end + +# https://github.com/JuliaLang/julia/pull/30630 +if VERSION < v"1.2.0-DEV.125" # 1da48c2e4028c1514ed45688be727efbef1db884 + require_one_based_indexing(A...) = !Base.has_offset_axes(A...) || throw(ArgumentError( + "offset arrays are not supported but got an array with index other than 1")) +# At present this is only used in Compat inside the above dot(x,A,y) functions, #32739 +elseif VERSION < v"1.4.0-DEV.92" + using Base: require_one_based_indexing +end + # https://github.com/JuliaLang/julia/pull/33568 if VERSION < v"1.4.0-DEV.329" Base.:∘(f, g, h...) = ∘(f ∘ g, h...) diff --git a/test/runtests.jl b/test/runtests.jl index 4837c4183..256f51673 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,6 +90,101 @@ end @test_throws DivideError mod(3, 1:0) end +using LinearAlgebra + +@testset "generalized dot #32739" begin + # stdlib/LinearAlgebra/test/generic.jl + for elty in (Int, Float32, Float64, BigFloat, Complex{Float32}, Complex{Float64}, Complex{BigFloat}) + n = 10 + if elty <: Int + A = rand(-n:n, n, n) + x = rand(-n:n, n) + y = rand(-n:n, n) + elseif elty <: Real + A = convert(Matrix{elty}, randn(n,n)) + x = rand(elty, n) + y = rand(elty, n) + else + A = convert(Matrix{elty}, complex.(randn(n,n), randn(n,n))) + x = rand(elty, n) + y = rand(elty, n) + end + @test dot(x, A, y) ≈ dot(A'x, y) ≈ *(x', A, y) ≈ (x'A)*y + @test dot(x, A', y) ≈ dot(A*x, y) ≈ *(x', A', y) ≈ (x'A')*y + elty <: Real && @test dot(x, transpose(A), y) ≈ dot(x, transpose(A)*y) ≈ *(x', transpose(A), y) ≈ (x'*transpose(A))*y + B = reshape([A], 1, 1) + x = [x] + y = [y] + @test dot(x, B, y) ≈ dot(B'x, y) + @test dot(x, B', y) ≈ dot(B*x, y) + elty <: Real && @test dot(x, transpose(B), y) ≈ dot(x, transpose(B)*y) + end + + # stdlib/LinearAlgebra/test/symmetric.jl + n = 10 + areal = randn(n,n)/2 + aimg = randn(n,n)/2 + @testset for eltya in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int) + a = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(areal, aimg) : areal) + asym = transpose(a) + a # symmetric indefinite + aherm = a' + a # Hermitian indefinite + apos = a' * a # Hermitian positive definite + aposs = apos + transpose(apos) # Symmetric positive definite + ε = εa = eps(abs(float(one(eltya)))) + x = randn(n) + y = randn(n) + b = randn(n,n)/2 + x = eltya == Int ? rand(1:7, n) : convert(Vector{eltya}, eltya <: Complex ? complex.(x, zeros(n)) : x) + y = eltya == Int ? rand(1:7, n) : convert(Vector{eltya}, eltya <: Complex ? complex.(y, zeros(n)) : y) + b = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(b, zeros(n,n)) : b) + + @testset "generalized dot product" begin + for uplo in (:U, :L) + @test dot(x, Hermitian(aherm, uplo), y) ≈ dot(x, Hermitian(aherm, uplo)*y) ≈ dot(x, Matrix(Hermitian(aherm, uplo)), y) + @test dot(x, Hermitian(aherm, uplo), x) ≈ dot(x, Hermitian(aherm, uplo)*x) ≈ dot(x, Matrix(Hermitian(aherm, uplo)), x) + end + if eltya <: Real + for uplo in (:U, :L) + @test dot(x, Symmetric(aherm, uplo), y) ≈ dot(x, Symmetric(aherm, uplo)*y) ≈ dot(x, Matrix(Symmetric(aherm, uplo)), y) + @test dot(x, Symmetric(aherm, uplo), x) ≈ dot(x, Symmetric(aherm, uplo)*x) ≈ dot(x, Matrix(Symmetric(aherm, uplo)), x) + end + end + end + end + + # stdlib/LinearAlgebra/test/uniformscaling.jl + @testset "generalized dot" begin + x = rand(-10:10, 3) + y = rand(-10:10, 3) + λ = rand(-10:10) + J = UniformScaling(λ) + @test dot(x, J, y) == λ*dot(x, y) + end + + # stdlib/LinearAlgebra/test/bidiag.jl + # The special method for this is not in Compat #683, so this tests the generic fallback + @testset "generalized dot" begin + for elty in (Float64, ComplexF64) + dv = randn(elty, 5) + ev = randn(elty, 4) + x = randn(elty, 5) + y = randn(elty, 5) + for uplo in (:U, :L) + B = Bidiagonal(dv, ev, uplo) + @test dot(x, B, y) ≈ dot(B'x, y) ≈ dot(x, Matrix(B), y) + end + end + end + + # Diagonal -- no such test in Base. + @testset "diagonal" begin + x = rand(-10:10, 3) .+ im + y = rand(-10:10, 3) .+ im + d = Diagonal(rand(-10:10, 3) .+ im) + @test dot(x,d,y) == dot(x,collect(d),y) == dot(x, d*y) + end +end + # https://github.com/JuliaLang/julia/pull/33568 @testset "function composition" begin @test ∘(x -> x-2, x -> x-3, x -> x+5)(7) == 7