Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: disable threading for certain devices
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 25, 2024
1 parent a8c0f3b commit 73e4211
Showing 1 changed file with 37 additions and 4 deletions.
41 changes: 37 additions & 4 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,17 @@ end
function batched_matmul_loopvec_impl! end

function fallback_batched_matmul(
dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT}
opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT}
z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1),
size(y, 2), max(size(x, 3), size(y, 3)))
fallback_batched_matmul!(z, dev, x, y)
return z
end

function fallback_batched_matmul!(
z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3},
z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
@warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
@warn "Using fallback Batched Matrix Multiply routine for $(opmode) with A: size = \
$(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
slow." maxlog=1

Expand All @@ -87,6 +87,36 @@ function fallback_batched_matmul!(
throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul."))
end

if use_threaded_batched_matmul(get_device_type(x))
unsafe_fallback_threaded_batched_matmul!(z, x, y)
else
unsafe_fallback_serial_batched_matmul!(z, x, y)
end

return
end

function unsafe_fallback_serial_batched_matmul!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
if size(x, 3) == size(y, 3)
for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, L), batchview(y, L))
end
elseif size(x, 3) == 1
for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, 1), batchview(y, L))
end
else # has to be size(y, 3) == 1
for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, L), batchview(y, 1))
end
end
end

function unsafe_fallback_threaded_batched_matmul!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
old_threads = maybe_reduce_BLAS_threads(z)

if size(x, 3) == size(y, 3)
Expand All @@ -104,10 +134,13 @@ function fallback_batched_matmul!(
end

reset_BLAS_threads(old_threads)

return
end

use_threaded_batched_matmul(::Type) = false
use_threaded_batched_matmul(::Type{CUDADevice}) = true
use_threaded_batched_matmul(::Type{CPUDevice}) = true

function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {xT, yT}
∇batched_matmul = @closure Δ_ -> begin
Expand Down

0 comments on commit 73e4211

Please sign in to comment.