From 3a346a68ea0eeb64d95365f7548b802e8456deec Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Wed, 25 May 2016 00:35:45 -0400 Subject: [PATCH] Fix a few minor problems for Triangular arithmetic. Fixes #16458 (#16562) --- base/linalg/triangular.jl | 95 +++++++++++++++++++++++++-------------- test/linalg/triangular.jl | 1 + 2 files changed, 62 insertions(+), 34 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index e23fe2ff0c309..5b187dedb3793 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1293,22 +1293,35 @@ end for f in (:A_mul_B!, :A_ldiv_B!) @eval begin + # Upper $f(A::UpperTriangular, B::UpperTriangular) = UpperTriangular($f(A, triu!(B.data))) $f(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular($f(A, triu!(B.data))) $f(A::UpperTriangular, B::UnitUpperTriangular) = - UpperTriangular($f(A, triu!(B.data))) - $f(A::UnitUpperTriangular, B::UnitUpperTriangular) = - UnitUpperTriangular($f(A, triu!(B.data))) + UpperTriangular($f(triu!(A.data), B)) + function $f(A::UnitUpperTriangular, B::UnitUpperTriangular) + BB = triu!(B.data) + for i = 1:size(BB, 1) + BB[i,i] = 1 + end + return UnitUpperTriangular($f(A, BB)) + end + + # Lower $f(A::LowerTriangular, B::LowerTriangular) = LowerTriangular($f(A, tril!(B.data))) $f(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular($f(A, tril!(B.data))) $f(A::LowerTriangular, B::UnitLowerTriangular) = - LowerTriangular($f(A, tril!(B.data))) - $f(A::UnitLowerTriangular, B::UnitLowerTriangular) = - LowerTriangular($f(A, tril!(B.data))) + LowerTriangular($f(tril!(A), B)) + function $f(A::UnitLowerTriangular, B::UnitLowerTriangular) + BB = tril!(B.data) + for i = 1:size(BB, 1) + BB[i,i] = 1 + end + UnitLowerTriangular($f(A, BB)) + end end end @@ -1345,25 +1358,29 @@ for t in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriang end end -for (f1, f2) in ((:*, :A_mul_B!), (:\, :A_ldiv_B)) +for (f1, f2) in ((:*, :A_mul_B!), (:\, :A_ldiv_B!)) @eval begin - function $f1(A::LowerTriangular, B::LowerTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + function ($f1)(A::LowerTriangular, B::LowerTriangular) + TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) + + ($f1)(zero(eltype(A)), zero(eltype(B)))) return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))) end - function $f1(A::UnitLowerTriangular, B::LowerTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + function $(f1)(A::UnitLowerTriangular, B::LowerTriangular) + TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) + + (*)(zero(eltype(A)), zero(eltype(B)))) return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))) end - function $f1(A::UpperTriangular, B::UpperTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + function ($f1)(A::UpperTriangular, B::UpperTriangular) + TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) + + ($f1)(zero(eltype(A)), zero(eltype(B)))) return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))) end - function $f1(A::UnitUpperTriangular, B::UpperTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + function ($f1)(A::UnitUpperTriangular, B::UpperTriangular) + TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) + + (*)(zero(eltype(A)), zero(eltype(B)))) return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))) end end @@ -1372,44 +1389,50 @@ end for (f1, f2) in ((:Ac_mul_B, :Ac_mul_B!), (:At_mul_B, :At_mul_B!), (:Ac_ldiv_B, Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!)) @eval begin - function $f1(A::UpperTriangular, B::LowerTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + function ($f1)(A::UpperTriangular, B::LowerTriangular) + TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) + + ($f1)(zero(eltype(A)), zero(eltype(B)))) return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))) end - function $f1(A::UnitUpperTriangular, B::LowerTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + function ($f1)(A::UnitUpperTriangular, B::LowerTriangular) + TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) + + (*)(zero(eltype(A)), zero(eltype(B)))) return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))) end - function $f1(A::LowerTriangular, B::UpperTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + function ($f1)(A::LowerTriangular, B::UpperTriangular) + TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) + + ($f1)(zero(eltype(A)), zero(eltype(B)))) return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))) end - function $f1(A::UnitLowerTriangular, B::UpperTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + function ($f1)(A::UnitLowerTriangular, B::UpperTriangular) + TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) + + (*)(zero(eltype(A)), zero(eltype(B)))) return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))) end end end function (/)(A::LowerTriangular, B::LowerTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))/ - one(eltype(A))) + TAB = typeof((/)(zero(eltype(A)), zero(eltype(B))) + + (/)(zero(eltype(A)), zero(eltype(B)))) return LowerTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B))) end function (/)(A::LowerTriangular, B::UnitLowerTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) + + (*)(zero(eltype(A)), zero(eltype(B)))) return LowerTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B))) end function (/)(A::UpperTriangular, B::UpperTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))/ - one(eltype(A))) + TAB = typeof((/)(zero(eltype(A)), zero(eltype(B))) + + (/)(zero(eltype(A)), zero(eltype(B)))) return UpperTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B))) end function (/)(A::UpperTriangular, B::UnitUpperTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) + + (*)(zero(eltype(A)), zero(eltype(B)))) return UpperTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B))) end @@ -1417,22 +1440,26 @@ for (f1, f2) in ((:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!)) @eval begin function $f1(A::LowerTriangular, B::UpperTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) + + ($f1)(zero(eltype(A)), zero(eltype(B)))) return LowerTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B))) end function $f1(A::LowerTriangular, B::UnitUpperTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) + + (*)(zero(eltype(A)), zero(eltype(B)))) return LowerTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B))) end function $f1(A::UpperTriangular, B::LowerTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) + + ($f1)(zero(eltype(A)), zero(eltype(B)))) return UpperTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B))) end function $f1(A::UpperTriangular, B::UnitLowerTriangular) - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) + TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) + + (*)(zero(eltype(A)), zero(eltype(B)))) return UpperTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B))) end end @@ -1510,7 +1537,7 @@ end ### Right division with triangle to the right hence lhs cannot be transposed. No quotients. for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!)) @eval begin - function ($f)(A::$mat, B::Tuple{UnitUpperTriangular, UnitLowerTriangular}) + function ($f)(A::$mat, B::Union{UnitUpperTriangular, UnitLowerTriangular}) TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) AA = similar(A, TAB, size(A)) copy!(AA, A) diff --git a/test/linalg/triangular.jl b/test/linalg/triangular.jl index cba960408bdb6..5218ff2074616 100644 --- a/test/linalg/triangular.jl +++ b/test/linalg/triangular.jl @@ -286,6 +286,7 @@ for elty1 in (Float32, Float64, BigFloat, Complex64, Complex128, Complex{BigFloa @test_approx_eq full(A1.'A2.') full(A1).'full(A2).' @test_approx_eq full(A1'A2') full(A1)'full(A2)' @test_approx_eq full(A1/A2) full(A1)/full(A2) + @test_approx_eq full(A1\A2) full(A1)\full(A2) @test_throws DimensionMismatch eye(n+1)/A2 @test_throws DimensionMismatch eye(n+1)/A2.' @test_throws DimensionMismatch eye(n+1)/A2'