Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
perf: restore old batched_mul
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 25, 2024
1 parent f43aee0 commit a8c0f3b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ ChainRulesCore = "1.24"
Compat = "4.15.0"
CpuId = "0.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.1"
Enzyme = "0.13.12"
EnzymeCore = "0.8.1"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
Expand Down
11 changes: 4 additions & 7 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
batched_matmul_loopvec_impl!(z, x, y)
return
end
# Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983
fallback_batched_matmul!(z, LoopedArrayOp(), x, y)
# NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed
NNlib.batched_mul!(z, x, y)
return
end

Expand All @@ -80,10 +78,9 @@ end
function fallback_batched_matmul!(
z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
# XXX: bring back once the enzyme segfault is fixed
# @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
# $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
# slow." maxlog=1
@warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
$(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
slow." maxlog=1

if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
(size(x, 2) != size(y, 1))
Expand Down

0 comments on commit a8c0f3b

Please sign in to comment.