Skip to content

Commit

Permalink
Avoid aliasing in in UniformScaling*AbstractMatrix (#18286)
Browse files Browse the repository at this point in the history
* Avoid aliasing in in UniformScaling*AbstractMatrix

...and remove unnecessary UniformScaling*SparseMatrixCSC methods

* Broaden the tests for non-commutative multiplication

* Add Quaternion test case for q*[q] and clean up the Quaternion test type

(cherry picked from commit cd94c99)
  • Loading branch information
andreasnoack authored and tkelman committed Sep 13, 2016
1 parent 015eec0 commit a5bc279
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 54 deletions.
5 changes: 2 additions & 3 deletions base/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ inv(J::UniformScaling) = UniformScaling(inv(J.λ))
*(J1::UniformScaling, J2::UniformScaling) = UniformScaling(J1.λ*J2.λ)
*(B::BitArray{2}, J::UniformScaling) = *(bitunpack(B), J::UniformScaling)
*(J::UniformScaling, B::BitArray{2}) = *(J::UniformScaling, bitunpack(B))
*(A::AbstractMatrix, J::UniformScaling) = J.λ == 1 ? A : J.λ*A
*(J::UniformScaling, A::AbstractVecOrMat) = J.λ == 1 ? A : J.λ*A

*(A::AbstractMatrix, J::UniformScaling) = A*J.λ
*(J::UniformScaling, A::AbstractVecOrMat) = J.λ*A
*(x::Number, J::UniformScaling) = UniformScaling(x*J.λ)
*(J::UniformScaling, x::Number) = UniformScaling(J.λ*x)

Expand Down
5 changes: 0 additions & 5 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ function increment!{T<:Integer}(A::AbstractArray{T})
end
increment{T<:Integer}(A::AbstractArray{T}) = increment!(copy(A))

## Multiplication with UniformScaling (scaled identity matrices)

(*)(S::SparseMatrixCSC, J::UniformScaling) = J.λ == 1 ? S : J.λ*S
(*){Tv,Ti}(J::UniformScaling, S::SparseMatrixCSC{Tv,Ti}) = J.λ == 1 ? S : S*J.λ

## sparse matrix multiplication

function (*){TvA,TiA,TvB,TiB}(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB})
Expand Down
30 changes: 30 additions & 0 deletions test/linalg/generic.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license

import Base: -, *
using Base.Test

# A custom Quaternion type with minimal defined interface and methods.
# Used to test scale and scale! methods to show non-commutativity.
immutable Quaternion{T<:Real} <: Number
s::T
v1::T
v2::T
v3::T
end
Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...)
Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3
Base.abs(q::Quaternion) = sqrt(abs2(q))
Base.real{T}(::Type{Quaternion{T}}) = T
Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)

(-)(ql::Quaternion, qr::Quaternion) =
Quaternion(ql.s - qr.s, ql.v1 - qr.v1, ql.v2 - qr.v2, ql.v3 - qr.v3)
(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3,
q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2,
q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1,
q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s)

debug = false

srand(123)
Expand Down Expand Up @@ -116,6 +138,14 @@ b = randn(Base.LinAlg.SCAL_CUTOFF) # make sure we try BLAS path
@test isequal(scale(BigFloat[1.0], 2.0im), Complex{BigFloat}[2.0im])
@test isequal(scale(BigFloat[1.0], 2.0f0im), Complex{BigFloat}[2.0im])

# test scale and scale! for non-commutative multiplication
q = Quaternion(0.44567, 0.755871, 0.882548, 0.423612)
qmat = [Quaternion(0.015007, 0.355067, 0.418645, 0.318373)]
@test scale!(q, copy(qmat)) != scale!(copy(qmat), q)
## Test * because it doesn't dispatch to scale!
@test q*qmat != qmat*q
@test conj(q*qmat) conj(qmat)*conj(q)

# test ops on Numbers
for elty in [Float32,Float64,Complex64,Complex128]
a = rand(elty)
Expand Down
99 changes: 53 additions & 46 deletions test/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,56 +43,63 @@ B = bitrand(2,2)
@test I + B == B + eye(B)

A = randn(2,2)
@test A + I == A + eye(A)
@test I + A == A + eye(A)
@test I - I === UniformScaling(0)
@test B - I == B - eye(B)
@test I - B == eye(B) - B
@test A - I == A - eye(A)
@test I - A == eye(A) - A
@test I*J === UniformScaling(λ)
@test B*J == B*λ
@test J*B == B*λ

S = sprandn(3,3,0.5)
@test S*J == S*λ
@test J*S == S*λ
@test A*J == A*λ
@test J*A == A*λ
@test J*ones(3) == ones(3)*λ
@test λ*J === UniformScaling*J.λ)
@test J*λ === UniformScaling*J.λ)
@test J/I === J
@test I/A == inv(A)
@test A/I == A
@test I/λ === UniformScaling(1/λ)
@test I\J === J
@test @inferred(A + I) == A + eye(A)
@test @inferred(I + A) == A + eye(A)
@test @inferred(I - I) === UniformScaling(0)
@test @inferred(B - I) == B - eye(B)
@test @inferred(I - B) == eye(B) - B
@test @inferred(A - I) == A - eye(A)
@test @inferred(I - A) == eye(A) - A
@test @inferred(I*J) === UniformScaling(λ)
@test @inferred(B*J) == B*λ
@test @inferred(J*B) == B*λ
@test @inferred(I*A) !== A # Don't alias
@test @inferred(I*S) !== S # Don't alias
@test @inferred(A*I) !== A # Don't alias
@test @inferred(S*I) !== S # Don't alias

@test @inferred(S*J) == S*λ
@test @inferred(J*S) == S*λ
@test @inferred(A*J) == A*λ
@test @inferred(J*A) == A*λ
@test @inferred(J*ones(3)) == ones(3)*λ
@test @inferred*J) === UniformScaling*J.λ)
@test @inferred(J*λ) === UniformScaling*J.λ)
@test @inferred(J/I) === J
@test @inferred(I/A) == inv(A)
@test @inferred(A/I) == A
@test @inferred(I/λ) === UniformScaling(1/λ)
@test @inferred(I\J) === J

T = LowerTriangular(randn(3,3))
@test T + J == full(T) + J
@test J + T == J + full(T)
@test T - J == full(T) - J
@test J - T == J - full(T)
@test T\I == inv(T)
@test @inferred(T + J) == full(T) + J
@test @inferred(J + T) == J + full(T)
@test @inferred(T - J) == full(T) - J
@test @inferred(J - T) == J - full(T)
@test @inferred(T\I) == inv(T)

T = LinAlg.UnitLowerTriangular(randn(3,3))
@test T + J == full(T) + J
@test J + T == J + full(T)
@test T - J == full(T) - J
@test J - T == J - full(T)
@test T\I == inv(T)
@test @inferred(T + J) == full(T) + J
@test @inferred(J + T) == J + full(T)
@test @inferred(T - J) == full(T) - J
@test @inferred(J - T) == J - full(T)
@test @inferred(T\I) == inv(T)

T = UpperTriangular(randn(3,3))
@test T + J == full(T) + J
@test J + T == J + full(T)
@test T - J == full(T) - J
@test J - T == J - full(T)
@test T\I == inv(T)
@test @inferred(T + J) == full(T) + J
@test @inferred(J + T) == J + full(T)
@test @inferred(T - J) == full(T) - J
@test @inferred(J - T) == J - full(T)
@test @inferred(T\I) == inv(T)

T = LinAlg.UnitUpperTriangular(randn(3,3))
@test T + J == full(T) + J
@test J + T == J + full(T)
@test T - J == full(T) - J
@test J - T == J - full(T)
@test T\I == inv(T)
@test @inferred(T + J) == full(T) + J
@test @inferred(J + T) == J + full(T)
@test @inferred(T - J) == full(T) - J
@test @inferred(J - T) == J - full(T)
@test @inferred(T\I) == inv(T)

@test I\A == A
@test A\I == inv(A)
@test λ\I === UniformScaling(1/λ)
@test @inferred(I\A) == A
@test @inferred(A\I) == inv(A)
@test @inferred(λ\I) === UniformScaling(1/λ)

0 comments on commit a5bc279

Please sign in to comment.