forked from ITensor/ITensors.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmul.jl
43 lines (40 loc) · 1.04 KB
/
mul.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
## Code adapted from NDTensors/ext/NDTensorsCUDAExt/mul.jl
# This was calling generic matrix multiplication.
function LinearAlgebra.mul!(
CM::Exposed{<:ROCArray,<:LinearAlgebra.Transpose},
AM::Exposed{<:ROCArray},
BM::Exposed{<:ROCArray},
α,
β,
)
mul!(transpose(CM), transpose(BM), transpose(AM), α, β)
return unexpose(CM)
end
# This was calling generic matrix multiplication.
function LinearAlgebra.mul!(
CM::Exposed{<:ROCArray,<:LinearAlgebra.Adjoint},
AM::Exposed{<:ROCArray},
BM::Exposed{<:ROCArray},
α,
β,
)
mul!(CM', BM', AM', α, β)
return unexpose(CM)
end
# Fix issue in AMDGPU.jl where it cannot distinguish
# Transpose{Reshape{Adjoint{ROCArray}}} as a ROCArray and calls generic matmul
function LinearAlgebra.mul!(
CM::Exposed{<:ROCArray},
AM::Exposed{<:ROCArray},
BM::Exposed{
<:ROCArray,
<:LinearAlgebra.Transpose{
<:Any,<:Base.ReshapedArray{<:Any,<:Any,<:LinearAlgebra.Adjoint}
},
},
α,
β,
)
mul!(CM, AM, expose(transpose(copy(expose(parent(BM))))), α, β)
return unexpose(CM)
end