Skip to content

Commit

Permalink
Merge pull request #271 from mcabbott/batch3
Browse files Browse the repository at this point in the history
Fix stride & size inference of 'T'/'N' in `batched_mul`
  • Loading branch information
CarloLucibello authored Jan 21, 2021
2 parents e1a5945 + 5df84c4 commit d9aaaf7
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/batched/batchedadjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ Base.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.
Base.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian()
Base.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k)
Base.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k))
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjOrTrans, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
Base.@propagate_inbounds Base.setindex!(m::BatchedTranspose, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjoint, v, i::Int, j::Int, k::Int) = setindex!(m.parent, adjoint(v), j, i, k)

Base.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims)
Base.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims)
Expand Down
19 changes: 6 additions & 13 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,34 +224,27 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {
alpha, beta = promote(α, β, zero(T))
alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)

are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β)

if Base.stride(C,1) == 1
elseif Base.stride(C,2) == 1
@debug "transforming C = A * B into C' = B' * A'" size(C) strides(C)
return batched_mul!(batched_adjoint(C), batched_adjoint(B), batched_adjoint(A), α, β)
else
return batched_mul_generic!(C, A, B, α, β)
end
are_strided(_unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β)
C isa StridedArray || return batched_mul_generic!(C, A, B, α, β)

blasA, transA = if A isa BatchedAdjoint && T <: Complex
Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β)
parent(A), 'C'
elseif Base.stride(A,2) == 1 && size(A,1) > 1
batched_transpose(A), 'T'
elseif Base.stride(A,1) == 1
A, 'N'
elseif Base.stride(A,2) == 1
batched_transpose(A), 'T'
else
return batched_mul_generic!(C, A, B, α, β)
end

blasB, transB = if B isa BatchedAdjoint && T <: Complex
Base.stride(parent(B),1) == 1 || return batched_mul_generic!(C, A, B, α, β)
parent(B), 'C'
elseif Base.stride(B,2) == 1 && size(B,1) > 1
batched_transpose(B), 'T'
elseif Base.stride(B,1) == 1
B, 'N'
elseif Base.stride(B,2) == 1
batched_transpose(B), 'T'
else
return batched_mul_generic!(C, A, B, α, β)
end
Expand Down
33 changes: 32 additions & 1 deletion test/batchedmul.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using NNlib, Test, LinearAlgebra
using NNlib: storage_type, storage_typejoin, is_strided,
batched_mul!, _unbatch, _copy_if_faster,
batched_mul!, batched_mul_generic!, _unbatch, _copy_if_faster,
BatchedAdjoint, BatchedTranspose

function bmm_test(a,b; transA = false, transB = false)
Expand Down Expand Up @@ -119,6 +119,37 @@ end
end
end

@testset "batched_mul: trivial dimensions & unit strides, $T" for T in [Float64, ComplexF64]
@testset "$tA(rand$((sA...,2))) ⊠ $tB(rand$((sB...,2)))" for
tA in [identity, batched_adjoint, batched_transpose], sA in [(1,1), (1,3), (3,1), (3,3)],
tB in [identity, batched_adjoint, batched_transpose], sB in [(1,1), (1,3), (3,1), (3,3)]

A = tA(rand(T, sA..., 2))
B = tB(rand(T, sB..., 2))
size(A,2) == size(B,1) || continue

C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2]; dims=3)
@test A B C

# In-place batched_mul!
α, β = rand(T), rand(T)
D = rand(T, size(C))
@test batched_mul!(copy(D), A, B, α, β) α .* C .+ β .* D
@test batched_mul_generic!(copy(D), A, B, α, β) α .* C .+ β .* D

# ... and with weird LHS -- all to batched_mul_generic! right now
C2 = batched_transpose(permutedims(C, (2,1,3)))
C3 = batched_adjoint(permutedims(conj(C), (2,1,3)))
@test C2 == C3 == C
C2 .= D
C3 .= D
@test batched_mul!(C2, A, B, α, β) α .* C .+ β .* D
@test C2 α .* C .+ β .* D
@test batched_mul!(C3, A, B, α, β) α .* C .+ β .* D
@test C3 α .* C .+ β .* D
end
end

@testset "BatchedAdjOrTrans interface * $TB" for TB in [Float64, Float32]
A = randn(7,5,3)
B = randn(TB, 5,7,3)
Expand Down

0 comments on commit d9aaaf7

Please sign in to comment.