Skip to content

Commit

Permalink
Merge pull request #8249 from JuliaLang/anj/fixtri
Browse files Browse the repository at this point in the history
Fix #8243 and add scalar division for Triangular
  • Loading branch information
andreasnoack committed Sep 6, 2014
2 parents c786d72 + 18ca76c commit beca3e4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
36 changes: 30 additions & 6 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,45 @@ end

function (*){T,S,UpLo,IsUnit}(A::Triangular{T,S,UpLo,IsUnit}, x::Number)
n = size(A,1)
B = copy(A.data)
for j = 1:n
for i = UpLo == :L ? j:n : 1:j
A.data[i,j] = i == j & IsUnit ? x : A.data[i,j]*x
for i = UpLo == :L ? (j:n) : (1:j)
B[i,j] = (i == j && IsUnit ? x : B[i,j]*x)
end
end
A
Triangular{T,S,UpLo,IsUnit}(B)
end
function (*){T,S,UpLo,IsUnit}(x::Number, A::Triangular{T,S,UpLo,IsUnit})
n = size(A,1)
B = copy(A.data)
for j = 1:n
for i = UpLo == :L ? j:n : 1:j
A.data[i,j] = i == j & IsUnit ? x : x*A.data[i,j]
for i = UpLo == :L ? (j:n) : (1:j)
B[i,j] = i == j && IsUnit ? x : x*B[i,j]
end
end
A
Triangular{T,S,UpLo,IsUnit}(B)
end
function (/){T,S,UpLo,IsUnit}(A::Triangular{T,S,UpLo,IsUnit}, x::Number)
n = size(A,1)
B = copy(A.data)
invx = inv(x)
for j = 1:n
for i = UpLo == :L ? (j:n) : (1:j)
B[i,j] = (i == j && IsUnit ? invx : B[i,j]*invx)
end
end
Triangular{T,S,UpLo,IsUnit}(B)
end
function (\){T,S,UpLo,IsUnit}(x::Number, A::Triangular{T,S,UpLo,IsUnit})
n = size(A,1)
B = copy(A.data)
invx = inv(x)
for j = 1:n
for i = UpLo == :L ? (j:n) : (1:j)
B[i,j] = i == j && IsUnit ? invx : invx*B[i,j]
end
end
Triangular{T,S,UpLo,IsUnit}(B)
end

A_mul_B!{T,S,UpLo,IsUnit}(A::Triangular{T,S,UpLo,IsUnit}, B::Triangular{T,S,UpLo,IsUnit}) = Triangular{T,S,UpLo,IsUnit}(A*full!(B))
Expand Down
4 changes: 4 additions & 0 deletions test/linalg4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ for relty in (Float32, Float64, BigFloat), elty in (relty, Complex{relty})
@test_approx_eq full(op(TM, M2)) op(M, M2)
@test_approx_eq full(op(M, TM2)) op(M, M2)
end
@test M2*0.5 == full(TM2*0.5)
@test 0.5*M2 == full(0.5*TM2)
@test M2/0.5 == full(TM2/0.5)
@test 0.5\M2 == full(0.5\TM2)
end
end
end
Expand Down

0 comments on commit beca3e4

Please sign in to comment.