diff --git a/Project.toml b/Project.toml index dca22b914..5e44c6e11 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.17" +version = "0.7.18" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -19,10 +19,11 @@ julia = "1.3" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "StableRNGs", "Test", "Zygote"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "Logging", "Random", "StableRNGs", "Test", "Zygote"] diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 7fbda166c..6e6f8e3fa 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -92,15 +92,13 @@ end function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}) function batched_mul_pullback(Δ) - Athunk = if size(A,3) == 1 - @thunk(sum(batched_mul(Δ, batched_adjoint(B)), dims=3)) - else - @thunk(batched_mul(Δ, batched_adjoint(B))) + Athunk = @thunk begin + tmp = batched_mul(Δ, batched_adjoint(B)) + size(A,3) == 1 ? sum(tmp, dims=3) : tmp end - Bthunk = if size(B,3) == 1 - @thunk(sum(batched_mul(batched_adjoint(A), Δ), dims=3)) - else - @thunk(batched_mul(batched_adjoint(A), Δ)) + Bthunk = @thunk begin + tmp = batched_mul(batched_adjoint(A), Δ) + size(B,3) == 1 ? sum(tmp, dims=3) : tmp end return (NO_FIELDS, Athunk, Bthunk) end @@ -236,6 +234,8 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where { batched_transpose(A), 'T' elseif Base.stride(A,1) == 1 A, 'N' + elseif Base.stride(A,2) == 1 # This is awful, but exhaustively tested. Issues 268, 282. + batched_transpose(A), 'T' else return batched_mul_generic!(C, A, B, α, β) end @@ -247,6 +247,8 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where { batched_transpose(B), 'T' elseif Base.stride(B,1) == 1 B, 'N' + elseif Base.stride(B,2) == 1 + batched_transpose(B), 'T' else return batched_mul_generic!(C, A, B, α, β) end @@ -270,7 +272,7 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C")) size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C")) - @debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C) + @debug "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C) Abase, Bbase = _unbatch(A), _unbatch(B) sA, oA = size(A,3) == 1 ? (0,1) : (1,0) diff --git a/test/batchedmul.jl b/test/batchedmul.jl index 58cbccf17..9c778914f 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -1,4 +1,4 @@ -using NNlib, Test, LinearAlgebra +using NNlib, Test, LinearAlgebra, Logging using NNlib: storage_type, storage_typejoin, is_strided, batched_mul!, batched_mul_generic!, _unbatch, _copy_if_faster, BatchedAdjoint, BatchedTranspose @@ -119,17 +119,21 @@ end end end +perm_12(A) = PermutedDimsArray(A, (2,1,3)) +perm_23(A) = PermutedDimsArray(A, (1,3,2)) + @testset "batched_mul: trivial dimensions & unit strides, $T" for T in [Float64, ComplexF64] - @testset "$tA(rand$((sA...,2))) ⊠ $tB(rand$((sB...,2)))" for - tA in [identity, batched_adjoint, batched_transpose], sA in [(1,1), (1,3), (3,1), (3,3)], - tB in [identity, batched_adjoint, batched_transpose], sB in [(1,1), (1,3), (3,1), (3,3)] + @testset "$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))" for + tA in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], sA in [(1,1), (1,3), (3,1), (3,3)], + tB in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], sB in [(1,1), (1,3), (3,1), (3,3)] - A = tA(rand(T, sA..., 2)) - B = tB(rand(T, sB..., 2)) - size(A,2) == size(B,1) || continue + A = tA(rand(T, sA..., 3)) + B = tB(rand(T, sB..., 3)) + size(A,2) == size(B,1) && size(A,3) == size(B,3) == 3 || continue - C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2]; dims=3) + C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2], A[:,:,3] * B[:,:,3]; dims=3) @test A ⊠ B ≈ C + @test_logs min_level=Logging.Debug A ⊠ B # In-place batched_mul! α, β = rand(T), rand(T)