Skip to content

Commit

Permalink
fix setindex, beef up tests, simplify C
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Jan 20, 2021
1 parent 33ddd99 commit 83b67ff
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
3 changes: 2 additions & 1 deletion src/batched/batchedadjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ Base.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.
Base.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian()
Base.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k)
Base.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k))
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjOrTrans, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
Base.@propagate_inbounds Base.setindex!(m::BatchedTranspose, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjoint, v, i::Int, j::Int, k::Int) = setindex!(m.parent, adjoint(v), j, i, k)

Base.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims)
Base.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims)
Expand Down
11 changes: 2 additions & 9 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,8 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {
alpha, beta = promote(α, β, zero(T))
alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)

are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β)

if Base.stride(C,1) == 1
elseif Base.stride(C,2) == 1
@debug "transforming C = A * B into C' = B' * A'" size(C) strides(C)
return batched_mul!(batched_adjoint(C), batched_adjoint(B), batched_adjoint(A), α, β)
else
return batched_mul_generic!(C, A, B, α, β)
end
are_strided(_unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β)
C isa StridedArray || return batched_mul_generic!(C, A, B, α, β)

blasA, transA = if A isa BatchedAdjoint && T <: Complex
Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β)
Expand Down
29 changes: 27 additions & 2 deletions test/batchedmul.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using NNlib, Test, LinearAlgebra
using NNlib: storage_type, storage_typejoin, is_strided,
batched_mul!, _unbatch, _copy_if_faster, BatchedAdjoint
batched_mul!, batched_mul_generic!, _unbatch, _copy_if_faster, BatchedAdjoint

function bmm_test(a,b; transA = false, transB = false)
bs = size(a,3)
Expand Down Expand Up @@ -118,15 +118,40 @@ end
end
end

@testset "batched_mul: trivial dimensions & unit strides, $T" for T in [ComplexF64, Float64]
@testset "batched_mul: trivial dimensions & unit strides, $T" for T in [Float64, ComplexF64]
@testset "$tA(rand$((sA...,2))) ⊠ $tB(rand$((sB...,2)))" for
tA in [identity, batched_adjoint, batched_transpose], sA in [(1,1), (1,3), (3,1), (3,3)],
tB in [identity, batched_adjoint, batched_transpose], sB in [(1,1), (1,3), (3,1), (3,3)]

A = tA(rand(T, sA..., 2))
B = tB(rand(T, sB..., 2))
size(A,2) == size(B,1) || continue

C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2]; dims=3)
@test A B C

# In-place batched_mul!
α, β = rand(T), rand(T)
D = rand(T, size(C))
@test batched_mul!(copy(D), A, B, α, β) α .* C .+ β .* D
@test batched_mul_generic!(copy(D), A, B, α, β) α .* C .+ β .* D

# ... and with weird LHS -- all to batched_mul_generic! right now
C2 = batched_transpose(permutedims(C, (2,1,3)))
C3 = batched_adjoint(permutedims(conj(C), (2,1,3)))
@test C2 == C3 == C
C2 .= 22
C3 .= 33
@test batched_mul!(C2, A, B, α) α .* C
@test C2 α .* C
@test batched_mul!(C3, A, B, α) α .* C
@test C3 α .* C
C2 .= D
C3 .= D
@test batched_mul!(C2, A, B, α, β) α .* C .+ β .* D
@test C2 α .* C .+ β .* D
@test batched_mul!(C3, A, B, α, β) α .* C .+ β .* D
@test C3 α .* C .+ β .* D
end
end

Expand Down

0 comments on commit 83b67ff

Please sign in to comment.