From 0dd8f48ebba2ba4c7b103626a65280c7356c7d56 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 27 Nov 2018 10:46:23 +0100 Subject: [PATCH 1/2] order of mul in (l)mul!(::Diagonal,::Sparse) add tests for non-commutative mul --- stdlib/SparseArrays/src/linalg.jl | 6 +++--- stdlib/SparseArrays/test/sparse.jl | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index ee3ca211e5f67..001fa8aa168bf 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -1186,7 +1186,7 @@ function mul!(C::SparseMatrixCSC, D::Diagonal{T, <:Vector}, A::SparseMatrixCSC) Arowval = A.rowval resize!(Cnzval, length(Anzval)) for col = 1:n, p = A.colptr[col]:(A.colptr[col+1]-1) - @inbounds Cnzval[p] = Anzval[p] * b[Arowval[p]] + @inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] end C end @@ -1222,7 +1222,7 @@ function rmul!(A::SparseMatrixCSC, D::Diagonal) (n == size(D, 1)) || throw(DimensionMismatch()) Anzval = A.nzval @inbounds for col = 1:n, p = A.colptr[col]:(A.colptr[col + 1] - 1) - Anzval[p] *= D.diag[col] + Anzval[p] = Anzval[p] * D.diag[col] end return A end @@ -1233,7 +1233,7 @@ function lmul!(D::Diagonal, A::SparseMatrixCSC) Anzval = A.nzval Arowval = A.rowval @inbounds for col = 1:n, p = A.colptr[col]:(A.colptr[col + 1] - 1) - Anzval[p] *= D.diag[Arowval[p]] + Anzval[p] = D.diag[Arowval[p]] * Anzval[p] end return A end diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index 3fa2c87fa8099..d99dc23b3b6cf 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -364,6 +364,10 @@ end @test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2)) end +const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") +isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl")) +using .Main.Quaternions + sA = sprandn(3, 7, 0.5) sC = similar(sA) dA = Array(sA) @@ -396,6 +400,26 @@ dA = Array(sA) @test_throws DimensionMismatch rdiv!(copy(sAt), Diagonal(fill(1., length(b)+1))) @test_throws LinearAlgebra.SingularException rdiv!(copy(sAt), Diagonal(zeros(length(b)))) end + + @testset "non-commutative multiplication" begin + # non-commutative multiplication + Avals = Quaternion.(randn(10), randn(10), randn(10), randn(10)) + sA = sparse(rand(1:3, 10), rand(1:7, 10), Avals, 3, 7) + sC = copy(sA) + dA = Array(sA) + + b = Quaternion.(randn(7), randn(7), randn(7), randn(7)) + D = Diagonal(b) + @test Array(sA * D) ≈ dA * D + @test rmul!(copy(sA), D) ≈ dA * D + @test mul!(sC, copy(sA), D) ≈ dA * D + + b = Quaternion.(randn(3), randn(3), randn(3), randn(3)) + D = Diagonal(b) + @test Array(D * sA) ≈ D * dA + @test lmul!(D, copy(sA)) ≈ D * dA + @test mul!(sC, D, copy(sA)) ≈ D * dA + end end @testset "copyto!" begin From f90db88e2536ee13d30e51fd53f099727c0003dc Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 27 Nov 2018 21:36:18 +0100 Subject: [PATCH 2/2] Create Quaternions.jl remove quaternions from generic tests --- stdlib/LinearAlgebra/test/generic.jl | 31 ++++--------------------- test/testhelpers/Quaternions.jl | 34 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 27 deletions(-) create mode 100644 test/testhelpers/Quaternions.jl diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index 33ae3ddfd0c4c..c1a34ca5b28e4 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -3,33 +3,10 @@ module TestGeneric using Test, LinearAlgebra, Random -import Base: -, *, /, \ - -# A custom Quaternion type with minimal defined interface and methods. -# Used to test mul and mul! methods to show non-commutativity. -struct Quaternion{T<:Real} <: Number - s::T - v1::T - v2::T - v3::T -end -Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...) -Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3 -Base.abs(q::Quaternion) = sqrt(abs2(q)) -Base.real(::Type{Quaternion{T}}) where {T} = T -Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3) -Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3) - -(-)(ql::Quaternion, qr::Quaternion) = - Quaternion(ql.s - qr.s, ql.v1 - qr.v1, ql.v2 - qr.v2, ql.v3 - qr.v3) -(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3, - q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2, - q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1, - q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s) -(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) -(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity -(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w)) -(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q)) + +const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") +isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl")) +using .Main.Quaternions Random.seed!(123) diff --git a/test/testhelpers/Quaternions.jl b/test/testhelpers/Quaternions.jl new file mode 100644 index 0000000000000..0920b9dea00c6 --- /dev/null +++ b/test/testhelpers/Quaternions.jl @@ -0,0 +1,34 @@ +module Quaternions + +export Quaternion + +# A custom Quaternion type with minimal defined interface and methods. +# Used to test mul and mul! methods to show non-commutativity. +struct Quaternion{T<:Real} <: Number + s::T + v1::T + v2::T + v3::T +end +Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...) +Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3 +Base.abs(q::Quaternion) = sqrt(abs2(q)) +Base.real(::Type{Quaternion{T}}) where {T} = T +Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3) +Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3) +Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T)) + +Base.:(+)(ql::Quaternion, qr::Quaternion) = + Quaternion(ql.s + qr.s, ql.v1 + qr.v1, ql.v2 + qr.v2, ql.v3 + qr.v3) +Base.:(-)(ql::Quaternion, qr::Quaternion) = + Quaternion(ql.s - qr.s, ql.v1 - qr.v1, ql.v2 - qr.v2, ql.v3 - qr.v3) +Base.:(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3, + q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2, + q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1, + q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s) +Base.:(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) +Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity +Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w)) +Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q)) + +end