Skip to content

Commit

Permalink
Merge pull request #299 from mcabbott/batch6
Browse files Browse the repository at this point in the history
Fix #282
  • Loading branch information
CarloLucibello authored Mar 23, 2021
2 parents fb53778 + 092bfce commit aa5a6bd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.17"
version = "0.7.18"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -19,10 +19,11 @@ julia = "1.3"
[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "StableRNGs", "Test", "Zygote"]
test = ["ChainRulesTestUtils", "FiniteDifferences", "Logging", "Random", "StableRNGs", "Test", "Zygote"]
20 changes: 11 additions & 9 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,13 @@ end

function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3})
function batched_mul_pullback(Δ)
Athunk = if size(A,3) == 1
@thunk(sum(batched_mul(Δ, batched_adjoint(B)), dims=3))
else
@thunk(batched_mul(Δ, batched_adjoint(B)))
Athunk = @thunk begin
tmp = batched_mul(Δ, batched_adjoint(B))
size(A,3) == 1 ? sum(tmp, dims=3) : tmp
end
Bthunk = if size(B,3) == 1
@thunk(sum(batched_mul(batched_adjoint(A), Δ), dims=3))
else
@thunk(batched_mul(batched_adjoint(A), Δ))
Bthunk = @thunk begin
tmp = batched_mul(batched_adjoint(A), Δ)
size(B,3) == 1 ? sum(tmp, dims=3) : tmp
end
return (NO_FIELDS, Athunk, Bthunk)
end
Expand Down Expand Up @@ -236,6 +234,8 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {
batched_transpose(A), 'T'
elseif Base.stride(A,1) == 1
A, 'N'
elseif Base.stride(A,2) == 1 # This is awful, but exhaustively tested. Issues 268, 282.
batched_transpose(A), 'T'
else
return batched_mul_generic!(C, A, B, α, β)
end
Expand All @@ -247,6 +247,8 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {
batched_transpose(B), 'T'
elseif Base.stride(B,1) == 1
B, 'N'
elseif Base.stride(B,2) == 1
batched_transpose(B), 'T'
else
return batched_mul_generic!(C, A, B, α, β)
end
Expand All @@ -270,7 +272,7 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST

size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C"))
size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C"))
@debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C)
@debug "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C)

Abase, Bbase = _unbatch(A), _unbatch(B)
sA, oA = size(A,3) == 1 ? (0,1) : (1,0)
Expand Down
20 changes: 12 additions & 8 deletions test/batchedmul.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NNlib, Test, LinearAlgebra
using NNlib, Test, LinearAlgebra, Logging
using NNlib: storage_type, storage_typejoin, is_strided,
batched_mul!, batched_mul_generic!, _unbatch, _copy_if_faster,
BatchedAdjoint, BatchedTranspose
Expand Down Expand Up @@ -119,17 +119,21 @@ end
end
end

perm_12(A) = PermutedDimsArray(A, (2,1,3))
perm_23(A) = PermutedDimsArray(A, (1,3,2))

@testset "batched_mul: trivial dimensions & unit strides, $T" for T in [Float64, ComplexF64]
@testset "$tA(rand$((sA...,2))) ⊠ $tB(rand$((sB...,2)))" for
tA in [identity, batched_adjoint, batched_transpose], sA in [(1,1), (1,3), (3,1), (3,3)],
tB in [identity, batched_adjoint, batched_transpose], sB in [(1,1), (1,3), (3,1), (3,3)]
@testset "$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))" for
tA in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], sA in [(1,1), (1,3), (3,1), (3,3)],
tB in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], sB in [(1,1), (1,3), (3,1), (3,3)]

A = tA(rand(T, sA..., 2))
B = tB(rand(T, sB..., 2))
size(A,2) == size(B,1) || continue
A = tA(rand(T, sA..., 3))
B = tB(rand(T, sB..., 3))
size(A,2) == size(B,1) && size(A,3) == size(B,3) == 3 || continue

C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2]; dims=3)
C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2], A[:,:,3] * B[:,:,3]; dims=3)
@test A B C
@test_logs min_level=Logging.Debug A B

# In-place batched_mul!
α, β = rand(T), rand(T)
Expand Down

0 comments on commit aa5a6bd

Please sign in to comment.